hyperoptax.bayesian#

Classes

class hyperoptax.bayesian.BayesianSearchState(space, X, y, mask, log_length_scale)[source]#

State for BayesianSearch.

All arrays are fixed-size (shape determined by n_max at init time) to satisfy JAX’s static-shape requirement. The mask field tracks which entries have been written.

Parameters:
X#

Observation inputs, shape (n_max, n_params), zero-padded.

Type:

jax.Array

y#

Observed results, shape (n_max,), zero-padded, stored as raw (un-negated) values regardless of maximize.

Type:

jax.Array

mask#

Boolean validity mask, shape (n_max,); True for slots that contain real observations.

Type:

jax.Array

log_length_scale#

Per-dimension ARD length scales in log space, shape (n_params,). Tuned by Adam each iteration.

Type:

jax.Array

X: Array#
y: Array#
mask: Array#
log_length_scale: Array#
__init__(space, X, y, mask, log_length_scale)#
Parameters:
Return type:

None

class hyperoptax.bayesian.BayesianSearch(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)[source]#

Bayesian optimisation with a Gaussian Process surrogate.

Uses a GP (Matérn 0.5 kernel by default) to model the objective and selects the next batch of candidates by maximising an acquisition function (PI by default). ARD length scales are tuned with Adam each iteration. Parallel batches are generated via the Kriging Believer hallucination strategy.

Parameters:
jitter#

Small diagonal added to the kernel matrix for numerical stability (default 1e-6).

Type:

float

kernel#

Kernel function (default Matern with nu=0.5).

Type:

hyperoptax.kernels.BaseKernel

acquisition#

Acquisition function (default PI with xi=0.01).

Type:

hyperoptax.acquisition.BaseAcquisition

n_candidates#

Number of random candidates sampled per iteration for the discrete pre-selection step (default 1000).

Type:

int

n_restarts#

Number of L-BFGS restarts seeded from the top candidates (default 2).

Type:

int

n_lbfgs_steps#

Gradient steps per L-BFGS restart (default 10).

Type:

int

n_hparam_steps#

Adam steps used to tune log_length_scale each iteration (default 20). Set to 0 to disable.

Type:

int

n_warmup#

Number of pure-random iterations before the GP is used (default 1).

Type:

int

maximize#

Set False to minimise the objective (default True).

Type:

bool

n_parallel#

Number of parallel candidates per iteration (default 4).

Type:

int

hallucination#

Hallucination strategy for Kriging Believer parallel selection (default SampleHallucination).

Type:

hyperoptax.acquisition.BaseHallucination

jitter: float = 1e-06#
kernel: BaseKernel#
acquisition: BaseAcquisition#
n_candidates: int = 1000#
n_restarts: int = 2#
n_lbfgs_steps: int = 10#
n_hparam_steps: int = 20#
n_warmup: int = 1#
maximize: bool = True#
n_parallel: int = 4#
hallucination: BaseHallucination#
classmethod init(space, n_max=200, **kwargs)[source]#
best_result(state)[source]#

Return the best observed raw result (max if maximize, min if minimize).

Parameters:

state (BayesianSearchState)

Return type:

Array

best_params(state)[source]#

Return the parameter pytree that achieved the best observed result.

Parameters:

state (BayesianSearchState)

get_next_params(state, key, params=None, results=None)[source]#

Select the next batch of n_parallel candidates.

During the first n_warmup iterations, candidates are chosen uniformly at random. Afterwards, the GP posterior is used to maximise the acquisition function via L-BFGS with Kriging Believer hallucination for the parallel slots.

update_state(state, key, results, params)[source]#

Record new observations and update ARD length scales.

Writes the batch of results into the fixed-size state buffers and, if n_hparam_steps > 0, runs a short Adam loop to tune log_length_scale via marginal-likelihood maximisation.

Parameters:
  • state – Current BayesianSearchState.

  • key – PRNG key (unused but kept for API consistency).

  • results – Array of shape (n_parallel,) with observed objective values.

  • params – Either the batched params pytree returned by get_next_params() (each leaf shape (n_parallel,)), or a raw (n_parallel, n_params) flat array.

optimize(state, key, func, n_iterations=None)[source]#

High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

optimize_scan(state, key, func, n_iterations=None)[source]#

Like optimize, but uses jax.lax.scan for the inner loop.

Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

__init__(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)#
Parameters:
Return type:

None