# Copyright 2021 Sean Robertson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for optimizing param parameters via Optuna
See Also
--------
:ref:`Hyperparameter Optimization with Optuna`
A tutorial on how to use this module
"""
from __future__ import annotations
import abc
import warnings
import collections.abc
from typing import Collection, Optional, Set, Type, TypeVar
from collections import OrderedDict
from copy import deepcopy
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import param
from pydrobert.param.abc import AbstractParameterized
__all__ = [
"get_param_dict_tunable",
"parameterized_class_from_tunable",
"suggest_param_dict",
"TunableParameterized",
]
[docs]
class TunableParameterized(AbstractParameterized):
"""An interface for Optuna to tune Parameterized instances
The :class:`TunableParameterized` interface requires two class methods:
- :func:`get_tunable`
- :func:`suggest_params`
Any object with both is a :class:`TunableParameterized`. Just like in
:mod:`collections.abc`, the class need not directly subclass
:class:`TunableParameterized` for :func:`isinstance` and :func:`issubclass` to
return :obj:`True`. Subclassing :class:`TunableParameterized` directly will ensure
the function also inherits from :class:`param.parameterized.Parameterized`
"""
__abstract = True # this is how param handles abstract classes for now
__slots__ = tuple()
[docs]
@classmethod
@abc.abstractmethod
def get_tunable(cls) -> Set[str]:
"""Get a set of names of tunable parameters
The values are intended to be names of parameters. Values should not contain
:obj:`'.'`.
"""
return set()
[docs]
@classmethod
@abc.abstractmethod
def suggest_params(
cls,
trial,
base: Optional[TunableParameterized] = None,
only: Optional[Collection[str]] = None,
prefix: str = "",
) -> TunableParameterized:
"""Populate an instance of this class with parameters based on trial
Parameters
----------
trial : optuna.trial.Trial
The current optuna trial. Parameter values will be sampled from
this
base
If set, parameter values will be loaded into this instance. If
:obj:`None`, a new instance will be created matching this class
type
only
Only sample parameters with names in this set. If :obj:`None`,
all the parameters from :func:`get_tunable()` will be sampled
prefix
A value to be prepended to the names from `only` when sampling
those parameters from `trial`
Returns
-------
TunableParameterized
Either `base` if not :obj:`None`, or a new instance of this class
with parameters matching sampled values
"""
params = cls() if base is None else base
return params
@classmethod
def __subclasshook__(cls, C):
if cls is TunableParameterized:
return _check_methods(C, "get_tunable", "suggest_params")
return NotImplemented
[docs]
def get_param_dict_tunable(
param_dict: dict, on_decimal: Literal["ignore", "warn", "raise"] = "warn"
) -> OrderedDict:
"""Return a set of all the tunable parameters in a parameter dictionary
This function crawls through a (possibly nested) dictionary of objects, looks for
any that implement the :class:`TunableParameterized` interface, collects the results
of calls to :func:`get_tunable`, and returns the set `tunable`.
Elements of `tunable` are strings with the format
``"<key_0>.<key_1>.<...>.<parameter_name>"``, where ``parameter_name`` is a
parameter from ``param_dict[<key_0>][<key_1>][...].get_tunable()``
Parameters
----------
param_dict
on_decimal
:obj:`'.'` can produce ambiguous parameters in `tunable`. When one is found as a
key in `param_dict` or as a tunable parameter: "raise" means a
:class:`ValueError` will be raised; "warn" means a warning will be issued via
:mod:`warnings`; and "ignore" just ignores it
Returns
-------
tunable : collections.OrderedDict
"""
if on_decimal not in {"ignore", "warn", "raise"}:
raise ValueError("on_decimal must be 'ignore', 'warn', or 'raise'")
tunable_params = _tunable_params_from_param_dict(param_dict, on_decimal)
tunable = set()
for prefix, params in list(tunable_params.items()):
new_tunable = params.get_tunable()
if on_decimal != "ignore":
decimal_tunable = tuple(x for x in new_tunable if "." in x)
if decimal_tunable:
msg = (
"Found parameters in param_dict{} with '.' in their name: "
"{}. These can lead to ambiguities in suggest_param_dict "
"and should be avoided".format(
_to_multikey(prefix), decimal_tunable
)
)
if on_decimal == "raise":
raise ValueError(msg)
else:
warnings.warn(msg)
tunable |= {".".join([prefix, x]) for x in new_tunable}
return tunable
P = TypeVar("P", bound=param.Parameterized)
[docs]
def parameterized_class_from_tunable(
tunable: Collection, base: Type[P] = param.Parameterized, default: list = [],
) -> Type[P]:
"""Construct a Parameterized class to store parameters to optimize
This function creates a subclass of :class:`param.parameterized.Parameterized` that
has only one parameter: `only`. `only` is a :class:`param.ListSelector` that allows
values from `tunable`
Parameters
----------
tunable
base
The parent class of the returned class
default
The default value for the `only` parameter
Returns
-------
Derived : base
The derived class
Examples
--------
Group what hyperparameters you wish to optimize in the same param dict as
the rest of your parameters
>>> class ModelParams(param.Parameterized):
>>> lr = param.Number(1e-4, bounds=(1e-8, None))
>>> num_layers = param.Integer(3, bounds=(1, None))
>>> @classmethod
>>> def get_tunable(cls):
>>> return {'num_layers', 'lr'}
>>> @classmethod
>>> def suggest_params(cls, trial, base=None, only=None, prefix=None):
>>> pass # do this
>>>
>>> param_dict = {'model': ModelParams()}
>>> tunable = get_param_dict_tunable(param_dict)
>>> OptimParams = parameterized_class_from_tunable(tunable)
>>> param_dict['hyperparameter_optimization'] = OptimParams()
"""
class Derived(base):
only = param.ListSelector(
default,
objects=list(tunable),
doc="When performing hyperparameter optimization, only optimize "
"these parameters",
)
return Derived
[docs]
def suggest_param_dict(
trial,
global_dict: dict,
only: Optional[Set[str]] = None,
on_decimal: Literal["ignore", "warn", "raise"] = "warn",
warn: bool = True,
) -> dict:
"""Use Optuna trial to sample values for TunableParameterized in dict
This function creates a deep copy of the dictionary `global_dict`. Then, for every
:class:`TunableParameterized` it finds in the copy, it calls that instance's
:func:`suggest_params` to optimize an appropriate subset of parameters.
Parameters
----------
trial : optunal.trial.Trial
The trial from an Optuna experiment. This is passed along to each
:class:`TunableParameterized` in `global_dict`
global_dict
A (possibly nested) dictionary containing some :class:`TunableParameterized` as
values
only
A set containing parameter names to optimize. Names are formatted
``"<key_0>.<key_1>.<...>.<parameter_name>"``, where ``parameter_name`` is a
parameter from ``global_dict[<key_0>][<key_1>][...].get_tunable()``. If
:obj:`None`, the entire set returned by :func:`get_param_dict_tunable`.
on_decimal
'.' can produce ambiguous parameters in `only`. When one is found as a key in
`global_dict` or as a tunable parameter: "raise" means a :class:`ValueError`
will be raised; "warn" means a warning will be issued via :mod:`warnings`; and
"ignore" just ignores it
warn
If `warn` is :obj:`True` and any elements of `only` do not match this
description, a warning will be raised via :mod:`warnings`
Returns
-------
param_dict : dict
"""
if only is None:
only = get_param_dict_tunable(global_dict, on_decimal)
second_pass = True
else:
only = set(only) # in case a list, and also allows us to modify
second_pass = False
param_dict = deepcopy(global_dict)
tunable_params = _tunable_params_from_param_dict(
param_dict, "ignore" if second_pass else on_decimal
)
for prefix, param_ in list(tunable_params.items()):
prefix = prefix + "."
prefix_only = {x[len(prefix) :] for x in only if x.startswith(prefix)}
prefix_only = prefix_only & param_.get_tunable()
only -= {prefix + x for x in prefix_only}
param_.suggest_params(trial, base=param_, only=prefix_only, prefix=prefix)
if warn and only:
warnings.warn(
'"only" contained extra parameters: {}. To suppress this warning, '
"set warn=False".format(only)
)
return param_dict
def _to_multikey(s):
# turn '.'-delimited string into one ["that"]["looks"]["like"]["this"]
return '["' + s.replace(".", '"]["') + '"]'
def _tunable_params_from_param_dict(param_dict, on_decimal, prefix=""):
# crawl a possibly nested dictionary for TunableParameterized instances
# and return a dictionary where values are TunableParameterized and keys
# are a '.'-delimited list of the multi-keys that got us there
tunable_params = OrderedDict()
for key, value in list(param_dict.items()):
if "." in key and on_decimal != "ignore":
msg = (
"Found key{} with '.' in its name: '{}'. This can lead to "
"ambiguities in suggest_param_dict and should be avoided"
"".format(
" at param_dict" + _to_multikey(prefix) if prefix else "", key
)
)
if on_decimal == "raise":
raise ValueError(msg)
else:
warnings.warn(msg)
key = ".".join([prefix, key] if prefix else [key])
if isinstance(value, TunableParameterized):
tunable_params[key] = value
elif isinstance(value, collections.abc.Mapping):
tunable_params.update(
_tunable_params_from_param_dict(value, on_decimal, key)
)
return tunable_params
# from
# https://github.com/python/cpython/blob/2085bd0877e17ad4d98a4586d5eabb6faecbb190/Lib/_collections_abc.py
# combined with
# https://github.com/python/cpython/blob/1a7c3571c789d704503135fe7c20d6e6f78aec86/Lib/_abcoll.py
def _check_methods(C, *methods):
try:
mro = C.__mro__
for method in methods:
for B in mro:
if method in B.__dict__:
if B.__dict__[method] is None:
return NotImplemented
break
else:
return NotImplemented
except AttributeError:
for method in methods:
if getattr(C, method, None) is None:
return NotImplemented
return True