Optimizers#
Hyperoptax provides several optimization algorithms for hyperparameter tuning.
Base Optimizer#
- class hyperoptax.base.OptimizerState(space)[source]#
Bases:
objectBase optimizer state — a JAX pytree holding the search space definition.
- Parameters:
space (PyTree)
- space: PyTree#
- __init__(space)#
- Parameters:
space (PyTree)
- Return type:
None
- class hyperoptax.base.Optimizer[source]#
Bases:
object- optimize(state, key, func, n_iterations)[source]#
High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.- Parameters:
state (OptimizerState)
key (Array)
func (Callable)
n_iterations (int)
- Return type:
tuple[OptimizerState, tuple[PyTree, Array]]
- optimize_scan(state, key, func, n_iterations)[source]#
Like optimize, but uses jax.lax.scan for the inner loop.
Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.- Parameters:
state (OptimizerState)
key (Array)
func (Callable)
n_iterations (int)
- Return type:
tuple[OptimizerState, tuple[PyTree, Array]]
- update_state(state, key, results, params=None)[source]#
Updates the optimizer state based on the results of the function.
- Parameters:
state (OptimizerState)
key (PRNGKey)
results (Array)
params (PyTree | None)
- Return type:
- get_next_params(state, key, params=None, results=None)[source]#
Gets the next parameters to sample from the space. Returns a batched pytree where every leaf has shape (n_parallel, …). params and results are the previous iteration’s values (None on first call).
- Parameters:
state (OptimizerState)
key (PRNGKey)
params (PyTree | None)
results (Array | None)
- Return type:
PyTree
Bayesian Optimizer#
- class hyperoptax.bayesian.BayesianSearchState(space, X, y, mask, log_length_scale)[source]#
Bases:
OptimizerStateState for
BayesianSearch.All arrays are fixed-size (shape determined by
n_maxat init time) to satisfy JAX’s static-shape requirement. Themaskfield tracks which entries have been written.- y#
Observed results, shape
(n_max,), zero-padded, stored as raw (un-negated) values regardless ofmaximize.- Type:
- mask#
Boolean validity mask, shape
(n_max,);Truefor slots that contain real observations.- Type:
- log_length_scale#
Per-dimension ARD length scales in log space, shape
(n_params,). Tuned by Adam each iteration.- Type:
- __init__(space, X, y, mask, log_length_scale)#
- space: PyTree#
- class hyperoptax.bayesian.BayesianSearch(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)[source]#
Bases:
OptimizerBayesian optimisation with a Gaussian Process surrogate.
Uses a GP (Matérn 0.5 kernel by default) to model the objective and selects the next batch of candidates by maximising an acquisition function (PI by default). ARD length scales are tuned with Adam each iteration. Parallel batches are generated via the Kriging Believer hallucination strategy.
- Parameters:
jitter (float)
kernel (BaseKernel)
acquisition (BaseAcquisition)
n_candidates (int)
n_restarts (int)
n_lbfgs_steps (int)
n_hparam_steps (int)
n_warmup (int)
maximize (bool)
n_parallel (int)
hallucination (BaseHallucination)
- jitter#
Small diagonal added to the kernel matrix for numerical stability (default
1e-6).- Type:
- n_candidates#
Number of random candidates sampled per iteration for the discrete pre-selection step (default
1000).- Type:
- n_hparam_steps#
Adam steps used to tune
log_length_scaleeach iteration (default20). Set to0to disable.- Type:
- hallucination#
Hallucination strategy for Kriging Believer parallel selection (default
SampleHallucination).
- kernel: BaseKernel#
- acquisition: BaseAcquisition#
- hallucination: BaseHallucination#
- best_result(state)[source]#
Return the best observed raw result (max if maximize, min if minimize).
- Parameters:
state (BayesianSearchState)
- Return type:
- best_params(state)[source]#
Return the parameter pytree that achieved the best observed result.
- Parameters:
state (BayesianSearchState)
- get_next_params(state, key, params=None, results=None)[source]#
Select the next batch of
n_parallelcandidates.During the first
n_warmupiterations, candidates are chosen uniformly at random. Afterwards, the GP posterior is used to maximise the acquisition function via L-BFGS with Kriging Believer hallucination for the parallel slots.
- update_state(state, key, results, params)[source]#
Record new observations and update ARD length scales.
Writes the batch of results into the fixed-size state buffers and, if
n_hparam_steps > 0, runs a short Adam loop to tunelog_length_scalevia marginal-likelihood maximisation.- Parameters:
state – Current
BayesianSearchState.key – PRNG key (unused but kept for API consistency).
results – Array of shape
(n_parallel,)with observed objective values.params – Either the batched params pytree returned by
get_next_params()(each leaf shape(n_parallel,)), or a raw(n_parallel, n_params)flat array.
- optimize(state, key, func, n_iterations=None)[source]#
High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.
- optimize_scan(state, key, func, n_iterations=None)[source]#
Like optimize, but uses jax.lax.scan for the inner loop.
Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.
- __init__(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)#
- Parameters:
jitter (float)
kernel (BaseKernel)
acquisition (BaseAcquisition)
n_candidates (int)
n_restarts (int)
n_lbfgs_steps (int)
n_hparam_steps (int)
n_warmup (int)
maximize (bool)
n_parallel (int)
hallucination (BaseHallucination)
- Return type:
None
Grid Search#
- class hyperoptax.grid.GridSearchState(space, grid, grid_idx)[source]#
Bases:
OptimizerStateState for
GridSearch.- grid#
Array of shape
(n_total, n_params)containing all parameter combinations, pre-truncated to a multiple ofn_parallel.- Type:
- grid_idx#
Current position in
grid; incremented byn_parallelafter each call toupdate_state.- Type:
- __init__(space, grid, grid_idx)#
- space: PyTree#
- class hyperoptax.grid.GridSearch(shuffle=False, n_parallel=1)[source]#
Bases:
OptimizerExhaustive grid search over a discrete search space.
Iterates through every combination of the provided
DiscreteSpacevalues in order (or randomly ifshuffle=True). All spaces in the search space must beDiscreteSpace.- shuffle#
If
True, randomise the traversal order duringinit. Pass an explicitkeytoinitfor reproducibility.- Type:
- classmethod init(space, key=None, **kwargs)[source]#
Initialise the grid search.
- Parameters:
space – A pytree of
DiscreteSpaceobjects. All leaves must beDiscreteSpace; mixed spaces are not supported.key – Optional PRNG key used when
shuffle=True. Falls back toPRNGKey(0)whenNone.**kwargs – Forwarded to
GridSearchconstructor (e.g.n_parallel,shuffle).
- Returns:
(state, optimizer)tuple.
- get_next_params(state, key, params=None, results=None)[source]#
Return the next
n_parallelparameter combinations from the grid.- Parameters:
state (GridSearchState)
- Return type:
PyTree
- update_state(state, key, results, params=None)[source]#
Advance the grid index by
n_parallel.- Parameters:
state (GridSearchState)
- Return type:
Random Search#
- class hyperoptax.random.RandomSearch(n_parallel=1)[source]#
Bases:
OptimizerStateless random search — samples each space independently each iteration.
No model is fitted and no history is maintained, so this is the cheapest optimizer and useful as a strong baseline.
- Parameters:
n_parallel (int)
- get_next_params(state, key, params=None, results=None)[source]#
Sample
n_parallelindependent configurations from the search space.- Parameters:
state (OptimizerState)
key (PRNGKey)
- Return type:
PyTree
- update_state(state, key, results, params=None)[source]#
RandomSearch is memoryless, no state to update.
- Parameters:
state (OptimizerState)
key (PRNGKey)
results (Array)
- Return type: