hyperoptax package#
- class hyperoptax.BayesianSearch(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)[source]#
Bases:
OptimizerBayesian optimisation with a Gaussian Process surrogate.
Uses a GP (Matérn 0.5 kernel by default) to model the objective and selects the next batch of candidates by maximising an acquisition function (PI by default). ARD length scales are tuned with Adam each iteration. Parallel batches are generated via the Kriging Believer hallucination strategy.
- Parameters:
jitter (float)
kernel (BaseKernel)
acquisition (BaseAcquisition)
n_candidates (int)
n_restarts (int)
n_lbfgs_steps (int)
n_hparam_steps (int)
n_warmup (int)
maximize (bool)
n_parallel (int)
hallucination (BaseHallucination)
- jitter#
Small diagonal added to the kernel matrix for numerical stability (default
1e-6).- Type:
- n_candidates#
Number of random candidates sampled per iteration for the discrete pre-selection step (default
1000).- Type:
- n_hparam_steps#
Adam steps used to tune
log_length_scaleeach iteration (default20). Set to0to disable.- Type:
- hallucination#
Hallucination strategy for Kriging Believer parallel selection (default
SampleHallucination).
- kernel: BaseKernel#
- acquisition: BaseAcquisition#
- hallucination: BaseHallucination#
- best_result(state)[source]#
Return the best observed raw result (max if maximize, min if minimize).
- Parameters:
state (BayesianSearchState)
- Return type:
- best_params(state)[source]#
Return the parameter pytree that achieved the best observed result.
- Parameters:
state (BayesianSearchState)
- get_next_params(state, key, params=None, results=None)[source]#
Select the next batch of
n_parallelcandidates.During the first
n_warmupiterations, candidates are chosen uniformly at random. Afterwards, the GP posterior is used to maximise the acquisition function via L-BFGS with Kriging Believer hallucination for the parallel slots.
- update_state(state, key, results, params)[source]#
Record new observations and update ARD length scales.
Writes the batch of results into the fixed-size state buffers and, if
n_hparam_steps > 0, runs a short Adam loop to tunelog_length_scalevia marginal-likelihood maximisation.- Parameters:
state – Current
BayesianSearchState.key – PRNG key (unused but kept for API consistency).
results – Array of shape
(n_parallel,)with observed objective values.params – Either the batched params pytree returned by
get_next_params()(each leaf shape(n_parallel,)), or a raw(n_parallel, n_params)flat array.
- optimize(state, key, func, n_iterations=None)[source]#
High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.
- optimize_scan(state, key, func, n_iterations=None)[source]#
Like optimize, but uses jax.lax.scan for the inner loop.
Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.
- __init__(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)#
- Parameters:
jitter (float)
kernel (BaseKernel)
acquisition (BaseAcquisition)
n_candidates (int)
n_restarts (int)
n_lbfgs_steps (int)
n_hparam_steps (int)
n_warmup (int)
maximize (bool)
n_parallel (int)
hallucination (BaseHallucination)
- Return type:
None
- class hyperoptax.GridSearch(shuffle=False, n_parallel=1)[source]#
Bases:
OptimizerExhaustive grid search over a discrete search space.
Iterates through every combination of the provided
DiscreteSpacevalues in order (or randomly ifshuffle=True). All spaces in the search space must beDiscreteSpace.- shuffle#
If
True, randomise the traversal order duringinit. Pass an explicitkeytoinitfor reproducibility.- Type:
- classmethod init(space, key=None, **kwargs)[source]#
Initialise the grid search.
- Parameters:
space – A pytree of
DiscreteSpaceobjects. All leaves must beDiscreteSpace; mixed spaces are not supported.key – Optional PRNG key used when
shuffle=True. Falls back toPRNGKey(0)whenNone.**kwargs – Forwarded to
GridSearchconstructor (e.g.n_parallel,shuffle).
- Returns:
(state, optimizer)tuple.
- get_next_params(state, key, params=None, results=None)[source]#
Return the next
n_parallelparameter combinations from the grid.- Parameters:
state (GridSearchState)
- Return type:
PyTree
- update_state(state, key, results, params=None)[source]#
Advance the grid index by
n_parallel.- Parameters:
state (GridSearchState)
- Return type:
- class hyperoptax.RandomSearch(n_parallel=1)[source]#
Bases:
OptimizerStateless random search — samples each space independently each iteration.
No model is fitted and no history is maintained, so this is the cheapest optimizer and useful as a strong baseline.
- Parameters:
n_parallel (int)
- get_next_params(state, key, params=None, results=None)[source]#
Sample
n_parallelindependent configurations from the search space.- Parameters:
state (OptimizerState)
key (PRNGKey)
- Return type:
PyTree
- update_state(state, key, results, params=None)[source]#
RandomSearch is memoryless, no state to update.
- Parameters:
state (OptimizerState)
key (PRNGKey)
results (Array)
- Return type:
- class hyperoptax.DiscreteSpace(values)[source]#
Bases:
SpaceDiscrete space over a fixed set of values.
Samples uniformly from
values.transformsnaps any continuous value to the nearest element, which is useful when discrete candidates are generated via continuous optimization (e.g. inBayesianSearch).- Parameters:
values (tuple)
- class hyperoptax.LinearSpace(lower_bound, upper_bound)[source]#
Bases:
SpaceUniform continuous space over
[lower_bound, upper_bound].
- class hyperoptax.LogSpace(lower_bound, upper_bound, base=10)[source]#
Bases:
LinearSpaceLog-uniform continuous space over
[lower_bound, upper_bound].Samples uniformly in log space so that each order of magnitude receives equal probability mass. Useful for learning rates and other scale parameters that span several orders of magnitude.
- __init__(lower_bound, upper_bound, base=10)#
- class hyperoptax.QLinearSpace(lower_bound, upper_bound, datatype=<class 'jax.numpy.int32'>)[source]#
Bases:
LinearSpaceQuantized (integer) variant of
LinearSpace.Samples uniformly from
[lower_bound, upper_bound]and rounds to the nearest integer. Use this for discrete integer hyperparameters with a uniform prior (e.g. number of layers, batch size).
- class hyperoptax.QLogSpace(lower_bound, upper_bound, base=10, datatype=<class 'jax.numpy.int32'>)[source]#
Bases:
LogSpaceQuantized (integer) variant of
LogSpace.Samples in log space and rounds to the nearest integer. Use this for integer hyperparameters whose scale spans orders of magnitude (e.g. number of hidden units, number of warmup steps).
- class hyperoptax.EI(xi=0.01)[source]#
Bases:
BaseAcquisitionExpected Improvement acquisition function.
- Parameters:
xi (float)
- class hyperoptax.PI(xi=0.01)[source]#
Bases:
BaseAcquisitionProbability of Improvement acquisition function.
- Parameters:
xi (float)
- class hyperoptax.UCB(kappa=2.0)[source]#
Bases:
BaseAcquisitionUpper Confidence Bound acquisition function.
- Parameters:
kappa (float)
- class hyperoptax.BaseHallucination[source]#
Bases:
objectBase class for Kriging Believer hallucination strategies.
Any callable with signature
(mean, std, key, y_max) -> scalarcan be used as a hallucination strategy — subclassing is optional.
- class hyperoptax.MeanHallucination[source]#
Bases:
BaseHallucinationClassical Kriging Believer: hallucinate with GP posterior mean.
- class hyperoptax.SampleHallucination[source]#
Bases:
BaseHallucinationRandomized Kriging Believer (RKB): hallucinate with a posterior sample.
arXiv 2603.01470.
- class hyperoptax.UCBHallucination(kappa=2.0)[source]#
Bases:
BaseHallucinationOptimistic hallucination: mean + kappa * std.
- Parameters:
kappa (float)
- class hyperoptax.ConstantHallucination(value=None)[source]#
Bases:
BaseHallucinationGinsbourger et al. 2010: hallucinate with y_max or a fixed constant.
If value is None, uses the current best observed value (y_max). Otherwise uses the fixed value regardless of observations.
- Parameters:
value (float | None)
- class hyperoptax.Matern(length_scale=1.0, nu=2.5)[source]#
Bases:
BaseKernelMatern kernel family.
- Parameters:
- class hyperoptax.RBF(length_scale=1.0)[source]#
Bases:
BaseKernelRadial basis function (RBF) / squared-exponential kernel.
- Parameters:
length_scale (float)
Submodules#
hyperoptax.acquisition module#
- class hyperoptax.acquisition.BaseAcquisition[source]#
Bases:
objectBase class for acquisition functions.
- class hyperoptax.acquisition.UCB(kappa=2.0)[source]#
Bases:
BaseAcquisitionUpper Confidence Bound acquisition function.
- Parameters:
kappa (float)
- class hyperoptax.acquisition.EI(xi=0.01)[source]#
Bases:
BaseAcquisitionExpected Improvement acquisition function.
- Parameters:
xi (float)
- class hyperoptax.acquisition.PI(xi=0.01)[source]#
Bases:
BaseAcquisitionProbability of Improvement acquisition function.
- Parameters:
xi (float)
- class hyperoptax.acquisition.BaseHallucination[source]#
Bases:
objectBase class for Kriging Believer hallucination strategies.
Any callable with signature
(mean, std, key, y_max) -> scalarcan be used as a hallucination strategy — subclassing is optional.
- class hyperoptax.acquisition.MeanHallucination[source]#
Bases:
BaseHallucinationClassical Kriging Believer: hallucinate with GP posterior mean.
- class hyperoptax.acquisition.SampleHallucination[source]#
Bases:
BaseHallucinationRandomized Kriging Believer (RKB): hallucinate with a posterior sample.
arXiv 2603.01470.
- class hyperoptax.acquisition.UCBHallucination(kappa=2.0)[source]#
Bases:
BaseHallucinationOptimistic hallucination: mean + kappa * std.
- Parameters:
kappa (float)
- class hyperoptax.acquisition.ConstantHallucination(value=None)[source]#
Bases:
BaseHallucinationGinsbourger et al. 2010: hallucinate with y_max or a fixed constant.
If value is None, uses the current best observed value (y_max). Otherwise uses the fixed value regardless of observations.
- Parameters:
value (float | None)
hyperoptax.base module#
- class hyperoptax.base.OptimizerState(space)[source]#
Bases:
objectBase optimizer state — a JAX pytree holding the search space definition.
- Parameters:
space (PyTree)
- space: PyTree#
- __init__(space)#
- Parameters:
space (PyTree)
- Return type:
None
- class hyperoptax.base.Optimizer[source]#
Bases:
object- optimize(state, key, func, n_iterations)[source]#
High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.- Parameters:
state (OptimizerState)
key (Array)
func (Callable)
n_iterations (int)
- Return type:
tuple[OptimizerState, tuple[PyTree, Array]]
- optimize_scan(state, key, func, n_iterations)[source]#
Like optimize, but uses jax.lax.scan for the inner loop.
Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.- Parameters:
state (OptimizerState)
key (Array)
func (Callable)
n_iterations (int)
- Return type:
tuple[OptimizerState, tuple[PyTree, Array]]
- update_state(state, key, results, params=None)[source]#
Updates the optimizer state based on the results of the function.
- Parameters:
state (OptimizerState)
key (PRNGKey)
results (Array)
params (PyTree | None)
- Return type:
- get_next_params(state, key, params=None, results=None)[source]#
Gets the next parameters to sample from the space. Returns a batched pytree where every leaf has shape (n_parallel, …). params and results are the previous iteration’s values (None on first call).
- Parameters:
state (OptimizerState)
key (PRNGKey)
params (PyTree | None)
results (Array | None)
- Return type:
PyTree
hyperoptax.bayesian module#
- class hyperoptax.bayesian.BayesianSearchState(space, X, y, mask, log_length_scale)[source]#
Bases:
OptimizerStateState for
BayesianSearch.All arrays are fixed-size (shape determined by
n_maxat init time) to satisfy JAX’s static-shape requirement. Themaskfield tracks which entries have been written.- y#
Observed results, shape
(n_max,), zero-padded, stored as raw (un-negated) values regardless ofmaximize.- Type:
- mask#
Boolean validity mask, shape
(n_max,);Truefor slots that contain real observations.- Type:
- log_length_scale#
Per-dimension ARD length scales in log space, shape
(n_params,). Tuned by Adam each iteration.- Type:
- __init__(space, X, y, mask, log_length_scale)#
- space: PyTree#
- class hyperoptax.bayesian.BayesianSearch(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)[source]#
Bases:
OptimizerBayesian optimisation with a Gaussian Process surrogate.
Uses a GP (Matérn 0.5 kernel by default) to model the objective and selects the next batch of candidates by maximising an acquisition function (PI by default). ARD length scales are tuned with Adam each iteration. Parallel batches are generated via the Kriging Believer hallucination strategy.
- Parameters:
jitter (float)
kernel (BaseKernel)
acquisition (BaseAcquisition)
n_candidates (int)
n_restarts (int)
n_lbfgs_steps (int)
n_hparam_steps (int)
n_warmup (int)
maximize (bool)
n_parallel (int)
hallucination (BaseHallucination)
- jitter#
Small diagonal added to the kernel matrix for numerical stability (default
1e-6).- Type:
- n_candidates#
Number of random candidates sampled per iteration for the discrete pre-selection step (default
1000).- Type:
- n_hparam_steps#
Adam steps used to tune
log_length_scaleeach iteration (default20). Set to0to disable.- Type:
- hallucination#
Hallucination strategy for Kriging Believer parallel selection (default
SampleHallucination).
- kernel: BaseKernel#
- acquisition: BaseAcquisition#
- hallucination: BaseHallucination#
- best_result(state)[source]#
Return the best observed raw result (max if maximize, min if minimize).
- Parameters:
state (BayesianSearchState)
- Return type:
- best_params(state)[source]#
Return the parameter pytree that achieved the best observed result.
- Parameters:
state (BayesianSearchState)
- get_next_params(state, key, params=None, results=None)[source]#
Select the next batch of
n_parallelcandidates.During the first
n_warmupiterations, candidates are chosen uniformly at random. Afterwards, the GP posterior is used to maximise the acquisition function via L-BFGS with Kriging Believer hallucination for the parallel slots.
- update_state(state, key, results, params)[source]#
Record new observations and update ARD length scales.
Writes the batch of results into the fixed-size state buffers and, if
n_hparam_steps > 0, runs a short Adam loop to tunelog_length_scalevia marginal-likelihood maximisation.- Parameters:
state – Current
BayesianSearchState.key – PRNG key (unused but kept for API consistency).
results – Array of shape
(n_parallel,)with observed objective values.params – Either the batched params pytree returned by
get_next_params()(each leaf shape(n_parallel,)), or a raw(n_parallel, n_params)flat array.
- optimize(state, key, func, n_iterations=None)[source]#
High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.
- optimize_scan(state, key, func, n_iterations=None)[source]#
Like optimize, but uses jax.lax.scan for the inner loop.
Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.
- __init__(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)#
- Parameters:
jitter (float)
kernel (BaseKernel)
acquisition (BaseAcquisition)
n_candidates (int)
n_restarts (int)
n_lbfgs_steps (int)
n_hparam_steps (int)
n_warmup (int)
maximize (bool)
n_parallel (int)
hallucination (BaseHallucination)
- Return type:
None
hyperoptax.grid module#
- class hyperoptax.grid.GridSearchState(space, grid, grid_idx)[source]#
Bases:
OptimizerStateState for
GridSearch.- grid#
Array of shape
(n_total, n_params)containing all parameter combinations, pre-truncated to a multiple ofn_parallel.- Type:
- grid_idx#
Current position in
grid; incremented byn_parallelafter each call toupdate_state.- Type:
- __init__(space, grid, grid_idx)#
- space: PyTree#
- class hyperoptax.grid.GridSearch(shuffle=False, n_parallel=1)[source]#
Bases:
OptimizerExhaustive grid search over a discrete search space.
Iterates through every combination of the provided
DiscreteSpacevalues in order (or randomly ifshuffle=True). All spaces in the search space must beDiscreteSpace.- shuffle#
If
True, randomise the traversal order duringinit. Pass an explicitkeytoinitfor reproducibility.- Type:
- classmethod init(space, key=None, **kwargs)[source]#
Initialise the grid search.
- Parameters:
space – A pytree of
DiscreteSpaceobjects. All leaves must beDiscreteSpace; mixed spaces are not supported.key – Optional PRNG key used when
shuffle=True. Falls back toPRNGKey(0)whenNone.**kwargs – Forwarded to
GridSearchconstructor (e.g.n_parallel,shuffle).
- Returns:
(state, optimizer)tuple.
- get_next_params(state, key, params=None, results=None)[source]#
Return the next
n_parallelparameter combinations from the grid.- Parameters:
state (GridSearchState)
- Return type:
PyTree
- update_state(state, key, results, params=None)[source]#
Advance the grid index by
n_parallel.- Parameters:
state (GridSearchState)
- Return type:
hyperoptax.kernels module#
- class hyperoptax.kernels.BaseKernel[source]#
Bases:
ABCAbstract base class for positive-definite kernels.
- class hyperoptax.kernels.RBF(length_scale=1.0)[source]#
Bases:
BaseKernelRadial basis function (RBF) / squared-exponential kernel.
- Parameters:
length_scale (float)
- class hyperoptax.kernels.Matern(length_scale=1.0, nu=2.5)[source]#
Bases:
BaseKernelMatern kernel family.
- Parameters:
hyperoptax.spaces module#
- class hyperoptax.spaces.Space[source]#
Bases:
ABCAbstract base class for hyperparameter search spaces.
- __init__()#
- Return type:
None
- class hyperoptax.spaces.LinearSpace(lower_bound, upper_bound)[source]#
Bases:
SpaceUniform continuous space over
[lower_bound, upper_bound].
- class hyperoptax.spaces.DiscreteSpace(values)[source]#
Bases:
SpaceDiscrete space over a fixed set of values.
Samples uniformly from
values.transformsnaps any continuous value to the nearest element, which is useful when discrete candidates are generated via continuous optimization (e.g. inBayesianSearch).- Parameters:
values (tuple)
- class hyperoptax.spaces.LogSpace(lower_bound, upper_bound, base=10)[source]#
Bases:
LinearSpaceLog-uniform continuous space over
[lower_bound, upper_bound].Samples uniformly in log space so that each order of magnitude receives equal probability mass. Useful for learning rates and other scale parameters that span several orders of magnitude.
- __init__(lower_bound, upper_bound, base=10)#
- class hyperoptax.spaces.QLinearSpace(lower_bound, upper_bound, datatype=<class 'jax.numpy.int32'>)[source]#
Bases:
LinearSpaceQuantized (integer) variant of
LinearSpace.Samples uniformly from
[lower_bound, upper_bound]and rounds to the nearest integer. Use this for discrete integer hyperparameters with a uniform prior (e.g. number of layers, batch size).
- class hyperoptax.spaces.QLogSpace(lower_bound, upper_bound, base=10, datatype=<class 'jax.numpy.int32'>)[source]#
Bases:
LogSpaceQuantized (integer) variant of
LogSpace.Samples in log space and rounds to the nearest integer. Use this for integer hyperparameters whose scale spans orders of magnitude (e.g. number of hidden units, number of warmup steps).