hyperoptax.bayesian#
Classes
- class hyperoptax.bayesian.BayesianSearchState(space, X, y, mask, log_length_scale)[source]#
State for
BayesianSearch.All arrays are fixed-size (shape determined by
n_maxat init time) to satisfy JAX’s static-shape requirement. Themaskfield tracks which entries have been written.- y#
Observed results, shape
(n_max,), zero-padded, stored as raw (un-negated) values regardless ofmaximize.- Type:
- mask#
Boolean validity mask, shape
(n_max,);Truefor slots that contain real observations.- Type:
- log_length_scale#
Per-dimension ARD length scales in log space, shape
(n_params,). Tuned by Adam each iteration.- Type:
- class hyperoptax.bayesian.BayesianSearch(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)[source]#
Bayesian optimisation with a Gaussian Process surrogate.
Uses a GP (Matérn 0.5 kernel by default) to model the objective and selects the next batch of candidates by maximising an acquisition function (PI by default). ARD length scales are tuned with Adam each iteration. Parallel batches are generated via the Kriging Believer hallucination strategy.
- Parameters:
jitter (float)
kernel (BaseKernel)
acquisition (BaseAcquisition)
n_candidates (int)
n_restarts (int)
n_lbfgs_steps (int)
n_hparam_steps (int)
n_warmup (int)
maximize (bool)
n_parallel (int)
hallucination (BaseHallucination)
- jitter#
Small diagonal added to the kernel matrix for numerical stability (default
1e-6).- Type:
- n_candidates#
Number of random candidates sampled per iteration for the discrete pre-selection step (default
1000).- Type:
- n_hparam_steps#
Adam steps used to tune
log_length_scaleeach iteration (default20). Set to0to disable.- Type:
- hallucination#
Hallucination strategy for Kriging Believer parallel selection (default
SampleHallucination).
- kernel: BaseKernel#
- acquisition: BaseAcquisition#
- hallucination: BaseHallucination#
- best_result(state)[source]#
Return the best observed raw result (max if maximize, min if minimize).
- Parameters:
state (BayesianSearchState)
- Return type:
- best_params(state)[source]#
Return the parameter pytree that achieved the best observed result.
- Parameters:
state (BayesianSearchState)
- get_next_params(state, key, params=None, results=None)[source]#
Select the next batch of
n_parallelcandidates.During the first
n_warmupiterations, candidates are chosen uniformly at random. Afterwards, the GP posterior is used to maximise the acquisition function via L-BFGS with Kriging Believer hallucination for the parallel slots.
- update_state(state, key, results, params)[source]#
Record new observations and update ARD length scales.
Writes the batch of results into the fixed-size state buffers and, if
n_hparam_steps > 0, runs a short Adam loop to tunelog_length_scalevia marginal-likelihood maximisation.- Parameters:
state – Current
BayesianSearchState.key – PRNG key (unused but kept for API consistency).
results – Array of shape
(n_parallel,)with observed objective values.params – Either the batched params pytree returned by
get_next_params()(each leaf shape(n_parallel,)), or a raw(n_parallel, n_params)flat array.
- optimize(state, key, func, n_iterations=None)[source]#
High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.
- optimize_scan(state, key, func, n_iterations=None)[source]#
Like optimize, but uses jax.lax.scan for the inner loop.
Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).
funcmust return a scalar (()shape). If your function returns shape(1,), call.squeeze()insidefuncbefore returning.
- __init__(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)#
- Parameters:
jitter (float)
kernel (BaseKernel)
acquisition (BaseAcquisition)
n_candidates (int)
n_restarts (int)
n_lbfgs_steps (int)
n_hparam_steps (int)
n_warmup (int)
maximize (bool)
n_parallel (int)
hallucination (BaseHallucination)
- Return type:
None