hyperoptax.base#

Classes

class hyperoptax.base.OptimizerState(space)[source]#

Base optimizer state — a JAX pytree holding the search space definition.

Parameters:

space (PyTree)

space: PyTree#
replace(**kwargs)[source]#
Return type:

OptimizerState

__init__(space)#
Parameters:

space (PyTree)

Return type:

None

class hyperoptax.base.Optimizer[source]#
n_parallel: int = 1#
classmethod init(space, **kwargs)[source]#
Return type:

OptimizerState

optimize(state, key, func, n_iterations)[source]#

High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

Parameters:
Return type:

tuple[OptimizerState, tuple[PyTree, Array]]

optimize_scan(state, key, func, n_iterations)[source]#

Like optimize, but uses jax.lax.scan for the inner loop.

Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

Parameters:
Return type:

tuple[OptimizerState, tuple[PyTree, Array]]

update_state(state, key, results, params=None)[source]#

Updates the optimizer state based on the results of the function.

Parameters:
Return type:

OptimizerState

get_next_params(state, key, params=None, results=None)[source]#

Gets the next parameters to sample from the space. Returns a batched pytree where every leaf has shape (n_parallel, …). params and results are the previous iteration’s values (None on first call).

Parameters:
Return type:

PyTree