hyperoptax.base#
Classes
- class hyperoptax.base.OptimizerState(space)[source]#
Base optimizer state — a JAX pytree holding the search space definition.
- Parameters:
space (PyTree)
- space: PyTree#
- __init__(space)#
- Parameters:
space (PyTree)
- Return type:
None
- class hyperoptax.base.Optimizer[source]#
-
- optimize(state, key, func, n_iterations)[source]#
High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.- Parameters:
state (OptimizerState)
key (Array)
func (Callable)
n_iterations (int)
- Return type:
tuple[OptimizerState, tuple[PyTree, Array]]
- optimize_scan(state, key, func, n_iterations)[source]#
Like optimize, but uses jax.lax.scan for the inner loop.
Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.- Parameters:
state (OptimizerState)
key (Array)
func (Callable)
n_iterations (int)
- Return type:
tuple[OptimizerState, tuple[PyTree, Array]]
- update_state(state, key, results, params=None)[source]#
Updates the optimizer state based on the results of the function.
- Parameters:
state (OptimizerState)
key (PRNGKey)
results (Array)
params (PyTree | None)
- Return type:
- get_next_params(state, key, params=None, results=None)[source]#
Gets the next parameters to sample from the space. Returns a batched pytree where every leaf has shape (n_parallel, …). params and results are the previous iteration’s values (None on first call).
- Parameters:
state (OptimizerState)
key (PRNGKey)
params (PyTree | None)
results (Array | None)
- Return type:
PyTree