mirror of https://github.com/microsoft/autogen.git
Update OptunaSearch (#1106)
* update optuna * update setup * fix dependencies * fix bugs in test * fix bugs, web format --------- Co-authored-by: “skzhang1” <“shaokunzhang529@gmail.com”> Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
This commit is contained in:
parent
4f1dfe6676
commit
dd9202bb01
|
@ -19,6 +19,7 @@ import time
|
||||||
import functools
|
import functools
|
||||||
import warnings
|
import warnings
|
||||||
import copy
|
import copy
|
||||||
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional, Union, List, Tuple, Callable
|
from typing import Any, Dict, Optional, Union, List, Tuple, Callable
|
||||||
import pickle
|
import pickle
|
||||||
|
@ -33,6 +34,7 @@ from ..sample import (
|
||||||
Uniform,
|
Uniform,
|
||||||
)
|
)
|
||||||
from ..trial import flatten_dict, unflatten_dict
|
from ..trial import flatten_dict, unflatten_dict
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -259,14 +261,11 @@ except ImportError:
|
||||||
OptunaTrialState = None
|
OptunaTrialState = None
|
||||||
OptunaTrial = None
|
OptunaTrial = None
|
||||||
|
|
||||||
# (Optional) Default (anonymous) metric when using tune.report(x)
|
|
||||||
DEFAULT_METRIC = "_metric"
|
DEFAULT_METRIC = "_metric"
|
||||||
|
|
||||||
# (Auto-filled) The index of this training iteration.
|
|
||||||
TRAINING_ITERATION = "training_iteration"
|
TRAINING_ITERATION = "training_iteration"
|
||||||
|
|
||||||
# print a warning if define by run function takes longer than this to execute
|
DEFINE_BY_RUN_WARN_THRESHOLD_S = 1
|
||||||
DEFINE_BY_RUN_WARN_THRESHOLD_S = 1 # 1 is arbitrary
|
|
||||||
|
|
||||||
|
|
||||||
def validate_warmstart(
|
def validate_warmstart(
|
||||||
|
@ -309,6 +308,7 @@ def validate_warmstart(
|
||||||
|
|
||||||
class _OptunaTrialSuggestCaptor:
|
class _OptunaTrialSuggestCaptor:
|
||||||
"""Utility to capture returned values from Optuna's suggest_ methods.
|
"""Utility to capture returned values from Optuna's suggest_ methods.
|
||||||
|
|
||||||
This will wrap around the ``optuna.Trial` object and decorate all
|
This will wrap around the ``optuna.Trial` object and decorate all
|
||||||
`suggest_` callables with a function capturing the returned value,
|
`suggest_` callables with a function capturing the returned value,
|
||||||
which will be saved in the ``captured_values`` dict.
|
which will be saved in the ``captured_values`` dict.
|
||||||
|
@ -338,80 +338,192 @@ class _OptunaTrialSuggestCaptor:
|
||||||
|
|
||||||
class OptunaSearch(Searcher):
|
class OptunaSearch(Searcher):
|
||||||
"""A wrapper around Optuna to provide trial suggestions.
|
"""A wrapper around Optuna to provide trial suggestions.
|
||||||
[Optuna](https://optuna.org/)
|
|
||||||
is a hyperparameter optimization library.
|
|
||||||
In contrast to other libraries, it employs define-by-run style
|
|
||||||
hyperparameter definitions.
|
|
||||||
This Searcher is a thin wrapper around Optuna's search algorithms.
|
|
||||||
You can pass any Optuna sampler, which will be used to generate
|
|
||||||
hyperparameter suggestions.
|
|
||||||
Args:
|
|
||||||
space (dict|Callable): Hyperparameter search space definition for
|
|
||||||
Optuna's sampler. This can be either a class `dict` with
|
|
||||||
parameter names as keys and ``optuna.distributions`` as values,
|
|
||||||
or a Callable - in which case, it should be a define-by-run
|
|
||||||
function using ``optuna.trial`` to obtain the hyperparameter
|
|
||||||
values. The function should return either a class `dict` of
|
|
||||||
constant values with names as keys, or None.
|
|
||||||
For more information, see
|
|
||||||
[tutorial](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/002_configurations.html).
|
|
||||||
Warning - No actual computation should take place in the define-by-run
|
|
||||||
function. Instead, put the training logic inside the function
|
|
||||||
or class trainable passed to tune.run.
|
|
||||||
metric (str): The training result objective value attribute. If None
|
|
||||||
but a mode was passed, the anonymous metric `_metric` will be used
|
|
||||||
per default.
|
|
||||||
mode (str): One of {min, max}. Determines whether objective is
|
|
||||||
minimizing or maximizing the metric attribute.
|
|
||||||
points_to_evaluate (list): Initial parameter suggestions to be run
|
|
||||||
first. This is for when you already have some good parameters
|
|
||||||
you want to run first to help the algorithm make better suggestions
|
|
||||||
for future parameters. Needs to be a list of dicts containing the
|
|
||||||
configurations.
|
|
||||||
sampler (optuna.samplers.BaseSampler): Optuna sampler used to
|
|
||||||
draw hyperparameter configurations. Defaults to ``TPESampler``.
|
|
||||||
seed (int): Seed to initialize sampler with. This parameter is only
|
|
||||||
used when ``sampler=None``. In all other cases, the sampler
|
|
||||||
you pass should be initialized with the seed already.
|
|
||||||
evaluated_rewards (list): If you have previously evaluated the
|
|
||||||
parameters passed in as points_to_evaluate you can avoid
|
|
||||||
re-running those trials by passing in the reward attributes
|
|
||||||
as a list so the optimiser can be told the results without
|
|
||||||
needing to re-compute the trial. Must be the same length as
|
|
||||||
points_to_evaluate.
|
|
||||||
|
|
||||||
Tune automatically converts search spaces to Optuna's format:
|
`Optuna <https://optuna.org/>`_ is a hyperparameter optimization library.
|
||||||
|
In contrast to other libraries, it employs define-by-run style
|
||||||
|
hyperparameter definitions.
|
||||||
|
|
||||||
````python
|
This Searcher is a thin wrapper around Optuna's search algorithms.
|
||||||
from ray.tune.suggest.optuna import OptunaSearch # ray version < 2
|
You can pass any Optuna sampler, which will be used to generate
|
||||||
config = { "a": tune.uniform(6, 8),
|
hyperparameter suggestions.
|
||||||
"b": tune.loguniform(1e-4, 1e-2)}
|
|
||||||
optuna_search = OptunaSearch(metric="loss", mode="min")
|
|
||||||
tune.run(trainable, config=config, search_alg=optuna_search)
|
|
||||||
````
|
|
||||||
|
|
||||||
If you would like to pass the search space manually, the code would
|
Multi-objective optimization is supported.
|
||||||
look like this:
|
|
||||||
|
Args:
|
||||||
|
space: Hyperparameter search space definition for
|
||||||
|
Optuna's sampler. This can be either a dict with
|
||||||
|
parameter names as keys and ``optuna.distributions`` as values,
|
||||||
|
or a Callable - in which case, it should be a define-by-run
|
||||||
|
function using ``optuna.trial`` to obtain the hyperparameter
|
||||||
|
values. The function should return either a dict of
|
||||||
|
constant values with names as keys, or None.
|
||||||
|
For more information, see https://optuna.readthedocs.io\
|
||||||
|
/en/stable/tutorial/10_key_features/002_configurations.html.
|
||||||
|
|
||||||
|
Warning - No actual computation should take place in the define-by-run
|
||||||
|
function. Instead, put the training logic inside the function
|
||||||
|
or class trainable passed to ``tune.run``.
|
||||||
|
|
||||||
|
metric: The training result objective value attribute. If
|
||||||
|
None but a mode was passed, the anonymous metric ``_metric``
|
||||||
|
will be used per default. Can be a list of metrics for
|
||||||
|
multi-objective optimization.
|
||||||
|
mode: One of {min, max}. Determines whether objective is
|
||||||
|
minimizing or maximizing the metric attribute. Can be a list of
|
||||||
|
modes for multi-objective optimization (corresponding to
|
||||||
|
``metric``).
|
||||||
|
points_to_evaluate: Initial parameter suggestions to be run
|
||||||
|
first. This is for when you already have some good parameters
|
||||||
|
you want to run first to help the algorithm make better suggestions
|
||||||
|
for future parameters. Needs to be a list of dicts containing the
|
||||||
|
configurations.
|
||||||
|
sampler: Optuna sampler used to
|
||||||
|
draw hyperparameter configurations. Defaults to ``MOTPESampler``
|
||||||
|
for multi-objective optimization with Optuna<2.9.0, and
|
||||||
|
``TPESampler`` in every other case.
|
||||||
|
|
||||||
|
Warning: Please note that with Optuna 2.10.0 and earlier
|
||||||
|
default ``MOTPESampler``/``TPESampler`` suffer
|
||||||
|
from performance issues when dealing with a large number of
|
||||||
|
completed trials (approx. >100). This will manifest as
|
||||||
|
a delay when suggesting new configurations.
|
||||||
|
This is an Optuna issue and may be fixed in a future
|
||||||
|
Optuna release.
|
||||||
|
|
||||||
|
seed: Seed to initialize sampler with. This parameter is only
|
||||||
|
used when ``sampler=None``. In all other cases, the sampler
|
||||||
|
you pass should be initialized with the seed already.
|
||||||
|
evaluated_rewards: If you have previously evaluated the
|
||||||
|
parameters passed in as points_to_evaluate you can avoid
|
||||||
|
re-running those trials by passing in the reward attributes
|
||||||
|
as a list so the optimiser can be told the results without
|
||||||
|
needing to re-compute the trial. Must be the same length as
|
||||||
|
points_to_evaluate.
|
||||||
|
|
||||||
|
Warning - When using ``evaluated_rewards``, the search space ``space``
|
||||||
|
must be provided as a dict with parameter names as
|
||||||
|
keys and ``optuna.distributions`` instances as values. The
|
||||||
|
define-by-run search space definition is not yet supported with
|
||||||
|
this functionality.
|
||||||
|
|
||||||
|
Tune automatically converts search spaces to Optuna's format:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from ray.tune.suggest.optuna import OptunaSearch # ray version < 2
|
from ray.tune.suggest.optuna import OptunaSearch
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"a": tune.uniform(6, 8)
|
||||||
|
"b": tune.loguniform(1e-4, 1e-2)
|
||||||
|
}
|
||||||
|
|
||||||
|
optuna_search = OptunaSearch(
|
||||||
|
metric="loss",
|
||||||
|
mode="min")
|
||||||
|
|
||||||
|
tune.run(trainable, config=config, search_alg=optuna_search)
|
||||||
|
```
|
||||||
|
|
||||||
|
If you would like to pass the search space manually, the code would
|
||||||
|
look like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ray.tune.suggest.optuna import OptunaSearch
|
||||||
import optuna
|
import optuna
|
||||||
config = { "a": optuna.distributions.UniformDistribution(6, 8),
|
|
||||||
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2)}
|
space = {
|
||||||
optuna_search = OptunaSearch(space,metric="loss",mode="min")
|
"a": optuna.distributions.UniformDistribution(6, 8),
|
||||||
|
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
optuna_search = OptunaSearch(
|
||||||
|
space,
|
||||||
|
metric="loss",
|
||||||
|
mode="min")
|
||||||
|
|
||||||
tune.run(trainable, search_alg=optuna_search)
|
tune.run(trainable, search_alg=optuna_search)
|
||||||
|
|
||||||
# Equivalent Optuna define-by-run function approach:
|
# Equivalent Optuna define-by-run function approach:
|
||||||
|
|
||||||
def define_search_space(trial: optuna.Trial):
|
def define_search_space(trial: optuna.Trial):
|
||||||
trial.suggest_float("a", 6, 8)
|
trial.suggest_float("a", 6, 8)
|
||||||
trial.suggest_float("b", 1e-4, 1e-2, log=True)
|
trial.suggest_float("b", 1e-4, 1e-2, log=True)
|
||||||
# training logic goes into trainable, this is just
|
# training logic goes into trainable, this is just
|
||||||
# for search space definition
|
# for search space definition
|
||||||
|
|
||||||
optuna_search = OptunaSearch(
|
optuna_search = OptunaSearch(
|
||||||
define_search_space,
|
define_search_space,
|
||||||
metric="loss",
|
metric="loss",
|
||||||
mode="min")
|
mode="min")
|
||||||
|
|
||||||
|
tune.run(trainable, search_alg=optuna_search)
|
||||||
|
```
|
||||||
|
|
||||||
|
Multi-objective optimization is supported:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ray.tune.suggest.optuna import OptunaSearch
|
||||||
|
import optuna
|
||||||
|
|
||||||
|
space = {
|
||||||
|
"a": optuna.distributions.UniformDistribution(6, 8),
|
||||||
|
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Note you have to specify metric and mode here instead of
|
||||||
|
# in tune.run
|
||||||
|
optuna_search = OptunaSearch(
|
||||||
|
space,
|
||||||
|
metric=["loss1", "loss2"],
|
||||||
|
mode=["min", "max"])
|
||||||
|
|
||||||
|
# Do not specify metric and mode here!
|
||||||
|
tune.run(
|
||||||
|
trainable,
|
||||||
|
search_alg=optuna_search
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can pass configs that will be evaluated first using
|
||||||
|
``points_to_evaluate``:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ray.tune.suggest.optuna import OptunaSearch
|
||||||
|
import optuna
|
||||||
|
|
||||||
|
space = {
|
||||||
|
"a": optuna.distributions.UniformDistribution(6, 8),
|
||||||
|
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
optuna_search = OptunaSearch(
|
||||||
|
space,
|
||||||
|
points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
|
||||||
|
metric="loss",
|
||||||
|
mode="min")
|
||||||
|
|
||||||
|
tune.run(trainable, search_alg=optuna_search)
|
||||||
|
```
|
||||||
|
|
||||||
|
Avoid re-running evaluated trials by passing the rewards together with
|
||||||
|
`points_to_evaluate`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ray.tune.suggest.optuna import OptunaSearch
|
||||||
|
import optuna
|
||||||
|
|
||||||
|
space = {
|
||||||
|
"a": optuna.distributions.UniformDistribution(6, 8),
|
||||||
|
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
optuna_search = OptunaSearch(
|
||||||
|
space,
|
||||||
|
points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
|
||||||
|
evaluated_rewards=[0.89, 0.42]
|
||||||
|
metric="loss",
|
||||||
|
mode="min")
|
||||||
|
|
||||||
tune.run(trainable, search_alg=optuna_search)
|
tune.run(trainable, search_alg=optuna_search)
|
||||||
.. versionadded:: 0.8.8
|
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -425,15 +537,15 @@ class OptunaSearch(Searcher):
|
||||||
Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
|
Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
|
||||||
]
|
]
|
||||||
] = None,
|
] = None,
|
||||||
metric: Optional[str] = None,
|
metric: Optional[Union[str, List[str]]] = None,
|
||||||
mode: Optional[str] = None,
|
mode: Optional[Union[str, List[str]]] = None,
|
||||||
points_to_evaluate: Optional[List[Dict]] = None,
|
points_to_evaluate: Optional[List[Dict]] = None,
|
||||||
sampler: Optional["BaseSampler"] = None,
|
sampler: Optional["BaseSampler"] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
evaluated_rewards: Optional[List] = None,
|
evaluated_rewards: Optional[List] = None,
|
||||||
):
|
):
|
||||||
assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
|
assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
|
||||||
super(OptunaSearch, self).__init__(metric=metric, mode=mode, max_concurrent=None, use_early_stopped_trials=None)
|
super(OptunaSearch, self).__init__(metric=metric, mode=mode)
|
||||||
|
|
||||||
if isinstance(space, dict) and space:
|
if isinstance(space, dict) and space:
|
||||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||||
|
@ -457,33 +569,60 @@ class OptunaSearch(Searcher):
|
||||||
"`seed` parameter has to be passed to the sampler directly "
|
"`seed` parameter has to be passed to the sampler directly "
|
||||||
"and will be ignored."
|
"and will be ignored."
|
||||||
)
|
)
|
||||||
|
elif sampler:
|
||||||
|
assert isinstance(sampler, BaseSampler), (
|
||||||
|
"You can only pass an instance of " "`optuna.samplers.BaseSampler` " "as a sampler to `OptunaSearcher`."
|
||||||
|
)
|
||||||
|
|
||||||
self._sampler = sampler or ot.samplers.TPESampler(seed=seed)
|
self._sampler = sampler
|
||||||
|
self._seed = seed
|
||||||
|
|
||||||
assert isinstance(self._sampler, BaseSampler), (
|
self._completed_trials = set()
|
||||||
"You can only pass an instance of `optuna.samplers.BaseSampler` " "as a sampler to `OptunaSearcher`."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._ot_trials = {}
|
self._ot_trials = {}
|
||||||
self._ot_study = None
|
self._ot_study = None
|
||||||
if self._space:
|
if self._space:
|
||||||
self._setup_study(mode)
|
self._setup_study(mode)
|
||||||
|
|
||||||
def _setup_study(self, mode: str):
|
def _setup_study(self, mode: Union[str, list]):
|
||||||
if self._metric is None and self._mode:
|
if self._metric is None and self._mode:
|
||||||
|
if isinstance(self._mode, list):
|
||||||
|
raise ValueError(
|
||||||
|
"If ``mode`` is a list (multi-objective optimization " "case), ``metric`` must be defined."
|
||||||
|
)
|
||||||
# If only a mode was passed, use anonymous metric
|
# If only a mode was passed, use anonymous metric
|
||||||
self._metric = DEFAULT_METRIC
|
self._metric = DEFAULT_METRIC
|
||||||
|
|
||||||
pruner = ot.pruners.NopPruner()
|
pruner = ot.pruners.NopPruner()
|
||||||
storage = ot.storages.InMemoryStorage()
|
storage = ot.storages.InMemoryStorage()
|
||||||
|
try:
|
||||||
|
from packaging import version
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("To use BlendSearch, run: pip install flaml[blendsearch]")
|
||||||
|
if self._sampler:
|
||||||
|
sampler = self._sampler
|
||||||
|
elif isinstance(mode, list) and version.parse(ot.__version__) < version.parse("2.9.0"):
|
||||||
|
# MOTPESampler deprecated in Optuna>=2.9.0
|
||||||
|
sampler = ot.samplers.MOTPESampler(seed=self._seed)
|
||||||
|
else:
|
||||||
|
sampler = ot.samplers.TPESampler(seed=self._seed)
|
||||||
|
|
||||||
|
if isinstance(mode, list):
|
||||||
|
study_direction_args = dict(
|
||||||
|
directions=["minimize" if m == "min" else "maximize" for m in mode],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
study_direction_args = dict(
|
||||||
|
direction="minimize" if mode == "min" else "maximize",
|
||||||
|
)
|
||||||
|
|
||||||
self._ot_study = ot.study.create_study(
|
self._ot_study = ot.study.create_study(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
sampler=self._sampler,
|
sampler=sampler,
|
||||||
pruner=pruner,
|
pruner=pruner,
|
||||||
study_name=self._study_name,
|
study_name=self._study_name,
|
||||||
direction="minimize" if mode == "min" else "maximize",
|
|
||||||
load_if_exists=True,
|
load_if_exists=True,
|
||||||
|
**study_direction_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._points_to_evaluate:
|
if self._points_to_evaluate:
|
||||||
|
@ -500,7 +639,7 @@ class OptunaSearch(Searcher):
|
||||||
for point in self._points_to_evaluate:
|
for point in self._points_to_evaluate:
|
||||||
self._ot_study.enqueue_trial(point)
|
self._ot_study.enqueue_trial(point)
|
||||||
|
|
||||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str], config: Dict) -> bool:
|
def set_search_properties(self, metric: Optional[str], mode: Optional[str], config: Dict, **spec) -> bool:
|
||||||
if self._space:
|
if self._space:
|
||||||
return False
|
return False
|
||||||
space = self.convert_search_space(config)
|
space = self.convert_search_space(config)
|
||||||
|
@ -510,7 +649,7 @@ class OptunaSearch(Searcher):
|
||||||
if mode:
|
if mode:
|
||||||
self._mode = mode
|
self._mode = mode
|
||||||
|
|
||||||
self._setup_study(mode)
|
self._setup_study(self._mode)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _suggest_from_define_by_run_func(
|
def _suggest_from_define_by_run_func(
|
||||||
|
@ -553,21 +692,8 @@ class OptunaSearch(Searcher):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
UNDEFINED_METRIC_MODE.format(cls=self.__class__.__name__, metric=self._metric, mode=self._mode)
|
UNDEFINED_METRIC_MODE.format(cls=self.__class__.__name__, metric=self._metric, mode=self._mode)
|
||||||
)
|
)
|
||||||
|
if callable(self._space):
|
||||||
if isinstance(self._space, list):
|
# Define-by-run case
|
||||||
# Keep for backwards compatibility
|
|
||||||
# Deprecate: 1.5
|
|
||||||
if trial_id not in self._ot_trials:
|
|
||||||
self._ot_trials[trial_id] = self._ot_study.ask()
|
|
||||||
|
|
||||||
ot_trial = self._ot_trials[trial_id]
|
|
||||||
|
|
||||||
# getattr will fetch the trial.suggest_ function on Optuna trials
|
|
||||||
params = {
|
|
||||||
args[0] if len(args) > 0 else kwargs["name"]: getattr(ot_trial, fn)(*args, **kwargs)
|
|
||||||
for (fn, args, kwargs) in self._space
|
|
||||||
}
|
|
||||||
elif callable(self._space):
|
|
||||||
if trial_id not in self._ot_trials:
|
if trial_id not in self._ot_trials:
|
||||||
self._ot_trials[trial_id] = self._ot_study.ask()
|
self._ot_trials[trial_id] = self._ot_study.ask()
|
||||||
|
|
||||||
|
@ -584,15 +710,36 @@ class OptunaSearch(Searcher):
|
||||||
return unflatten_dict(params)
|
return unflatten_dict(params)
|
||||||
|
|
||||||
def on_trial_result(self, trial_id: str, result: Dict):
|
def on_trial_result(self, trial_id: str, result: Dict):
|
||||||
|
if isinstance(self.metric, list):
|
||||||
|
# Optuna doesn't support incremental results
|
||||||
|
# for multi-objective optimization
|
||||||
|
return
|
||||||
|
if trial_id in self._completed_trials:
|
||||||
|
logger.warning(
|
||||||
|
f"Received additional result for trial {trial_id}, but " f"it already finished. Result: {result}"
|
||||||
|
)
|
||||||
|
return
|
||||||
metric = result[self.metric]
|
metric = result[self.metric]
|
||||||
step = result[TRAINING_ITERATION]
|
step = result[TRAINING_ITERATION]
|
||||||
ot_trial = self._ot_trials[trial_id]
|
ot_trial = self._ot_trials[trial_id]
|
||||||
ot_trial.report(metric, step)
|
ot_trial.report(metric, step)
|
||||||
|
|
||||||
def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: bool = False):
|
def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: bool = False):
|
||||||
|
if trial_id in self._completed_trials:
|
||||||
|
logger.warning(
|
||||||
|
f"Received additional completion for trial {trial_id}, but " f"it already finished. Result: {result}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
ot_trial = self._ot_trials[trial_id]
|
ot_trial = self._ot_trials[trial_id]
|
||||||
|
|
||||||
val = result.get(self.metric, None) if result else None
|
if result:
|
||||||
|
if isinstance(self.metric, list):
|
||||||
|
val = [result.get(metric, None) for metric in self.metric]
|
||||||
|
else:
|
||||||
|
val = result.get(self.metric, None)
|
||||||
|
else:
|
||||||
|
val = None
|
||||||
ot_trial_state = OptunaTrialState.COMPLETE
|
ot_trial_state = OptunaTrialState.COMPLETE
|
||||||
if val is None:
|
if val is None:
|
||||||
if error:
|
if error:
|
||||||
|
@ -601,9 +748,11 @@ class OptunaSearch(Searcher):
|
||||||
ot_trial_state = OptunaTrialState.PRUNED
|
ot_trial_state = OptunaTrialState.PRUNED
|
||||||
try:
|
try:
|
||||||
self._ot_study.tell(ot_trial, val, state=ot_trial_state)
|
self._ot_study.tell(ot_trial, val, state=ot_trial_state)
|
||||||
except ValueError as exc:
|
except Exception as exc:
|
||||||
logger.warning(exc) # E.g. if NaN was reported
|
logger.warning(exc) # E.g. if NaN was reported
|
||||||
|
|
||||||
|
self._completed_trials.add(trial_id)
|
||||||
|
|
||||||
def add_evaluated_point(
|
def add_evaluated_point(
|
||||||
self,
|
self,
|
||||||
parameters: Dict,
|
parameters: Dict,
|
||||||
|
@ -618,6 +767,13 @@ class OptunaSearch(Searcher):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
UNDEFINED_METRIC_MODE.format(cls=self.__class__.__name__, metric=self._metric, mode=self._mode)
|
UNDEFINED_METRIC_MODE.format(cls=self.__class__.__name__, metric=self._metric, mode=self._mode)
|
||||||
)
|
)
|
||||||
|
if callable(self._space):
|
||||||
|
raise TypeError(
|
||||||
|
"Define-by-run function passed in `space` argument is not "
|
||||||
|
"yet supported when using `evaluated_rewards`. Please provide "
|
||||||
|
"an `OptunaDistribution` dict or pass a Ray Tune "
|
||||||
|
"search space to `tune.run()`."
|
||||||
|
)
|
||||||
|
|
||||||
ot_trial_state = OptunaTrialState.COMPLETE
|
ot_trial_state = OptunaTrialState.COMPLETE
|
||||||
if error:
|
if error:
|
||||||
|
|
5
setup.py
5
setup.py
|
@ -91,7 +91,10 @@ setuptools.setup(
|
||||||
"joblib<1.3.0", # temp solution for joblib 1.3.0 issue, no need once https://github.com/joblib/joblib-spark/pull/48 is merged
|
"joblib<1.3.0", # temp solution for joblib 1.3.0 issue, no need once https://github.com/joblib/joblib-spark/pull/48 is merged
|
||||||
],
|
],
|
||||||
"catboost": ["catboost>=0.26"],
|
"catboost": ["catboost>=0.26"],
|
||||||
"blendsearch": ["optuna==2.8.0"],
|
"blendsearch": [
|
||||||
|
"optuna==2.8.0",
|
||||||
|
"packaging",
|
||||||
|
],
|
||||||
"ray": [
|
"ray": [
|
||||||
"ray[tune]~=1.13",
|
"ray[tune]~=1.13",
|
||||||
],
|
],
|
||||||
|
|
|
@ -70,8 +70,8 @@ def test_searchers():
|
||||||
searcher = OptunaSearch(["a", config["a"]], metric="m", mode="max")
|
searcher = OptunaSearch(["a", config["a"]], metric="m", mode="max")
|
||||||
try:
|
try:
|
||||||
searcher.suggest("t0")
|
searcher.suggest("t0")
|
||||||
except ValueError:
|
except AttributeError:
|
||||||
# not enough values to unpack (expected 3, got 1)
|
# 'list' object has no attribute 'items'
|
||||||
pass
|
pass
|
||||||
searcher = OptunaSearch(
|
searcher = OptunaSearch(
|
||||||
config,
|
config,
|
||||||
|
@ -221,6 +221,21 @@ def test_searchers():
|
||||||
upper={"root": [{"a": 0.9}, {"a": 0.8}]},
|
upper={"root": [{"a": 0.9}, {"a": 0.8}]},
|
||||||
space={"root": config1},
|
space={"root": config1},
|
||||||
)
|
)
|
||||||
|
searcher = OptunaSearch(
|
||||||
|
define_search_space,
|
||||||
|
points_to_evaluate=[{"a": 6, "b": 1e-3}],
|
||||||
|
metric=["a", "b"],
|
||||||
|
mode=["max", "max"],
|
||||||
|
)
|
||||||
|
searcher.set_search_properties("m", "min", config)
|
||||||
|
searcher.suggest("t1")
|
||||||
|
searcher.on_trial_complete("t1", None, False)
|
||||||
|
searcher.suggest("t2")
|
||||||
|
searcher.on_trial_complete("t2", None, True)
|
||||||
|
searcher.suggest("t3")
|
||||||
|
searcher.on_trial_complete("t3", {"m": np.nan})
|
||||||
|
searcher.save("test/tune/optuna.pkl")
|
||||||
|
searcher.restore("test/tune/optuna.pkl")
|
||||||
searcher = CFO(
|
searcher = CFO(
|
||||||
metric="m",
|
metric="m",
|
||||||
mode="min",
|
mode="min",
|
||||||
|
|
Loading…
Reference in New Issue