import dataclasses
import functools
import jax
import jax.numpy as jnp
import optax
from hyperoptax import acquisition as acq
from hyperoptax import base, kernels
from hyperoptax import spaces as sp
MASK_VARIANCE = 1e12 # large diagonal added to masked rows to isolate them from GP fit
[docs]
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class BayesianSearchState(base.OptimizerState):
"""State for :class:`BayesianSearch`.
All arrays are fixed-size (shape determined by ``n_max`` at init time) to
satisfy JAX's static-shape requirement. The ``mask`` field tracks which
entries have been written.
Attributes:
X: Observation inputs, shape ``(n_max, n_params)``, zero-padded.
y: Observed results, shape ``(n_max,)``, zero-padded, stored as raw
(un-negated) values regardless of ``maximize``.
mask: Boolean validity mask, shape ``(n_max,)``; ``True`` for slots
that contain real observations.
log_length_scale: Per-dimension ARD length scales in log space,
shape ``(n_params,)``. Tuned by Adam each iteration.
"""
X: jax.Array # (n_max, n_params) padded with zeros
y: jax.Array # (n_max,) padded with zeros — raw (un-negated) results
mask: jax.Array # (n_max,) bool, True for valid entries
log_length_scale: jax.Array # (n_params,) per-dimension ARD length scales
[docs]
@dataclasses.dataclass
class BayesianSearch(base.Optimizer):
"""Bayesian optimisation with a Gaussian Process surrogate.
Uses a GP (Matérn 0.5 kernel by default) to model the objective and
selects the next batch of candidates by maximising an acquisition function
(PI by default). ARD length scales are tuned with Adam each iteration.
Parallel batches are generated via the Kriging Believer hallucination
strategy.
Attributes:
jitter: Small diagonal added to the kernel matrix for numerical
stability (default ``1e-6``).
kernel: Kernel function (default :class:`~hyperoptax.kernels.Matern`
with ``nu=0.5``).
acquisition: Acquisition function (default
:class:`~hyperoptax.acquisition.PI` with ``xi=0.01``).
n_candidates: Number of random candidates sampled per iteration for
the discrete pre-selection step (default ``1000``).
n_restarts: Number of L-BFGS restarts seeded from the top candidates
(default ``2``).
n_lbfgs_steps: Gradient steps per L-BFGS restart (default ``10``).
n_hparam_steps: Adam steps used to tune ``log_length_scale`` each
iteration (default ``20``). Set to ``0`` to disable.
n_warmup: Number of pure-random iterations before the GP is used
(default ``1``).
maximize: Set ``False`` to minimise the objective (default ``True``).
n_parallel: Number of parallel candidates per iteration (default ``4``).
hallucination: Hallucination strategy for Kriging Believer parallel
selection (default :class:`~hyperoptax.acquisition.SampleHallucination`).
"""
jitter: float = 1e-6
kernel: kernels.BaseKernel = dataclasses.field(
default_factory=lambda: kernels.Matern(length_scale=1.0, nu=0.5)
)
acquisition: acq.BaseAcquisition = dataclasses.field(
default_factory=lambda: acq.PI(xi=0.01)
)
n_candidates: int = 1000 # random candidates sampled for continuous spaces
n_restarts: int = 2 # number of L-BFGS restarts (seeded from top candidates)
n_lbfgs_steps: int = 10 # gradient steps per restart
n_hparam_steps: int = 20 # Adam steps to tune log_length_scale each iteration
n_warmup: int = 1 # pure-random evaluations before GP kicks in
maximize: bool = True # set False to minimize the objective
n_parallel: int = 4
hallucination: acq.BaseHallucination = dataclasses.field(
default_factory=acq.SampleHallucination
)
[docs]
@classmethod
def init(cls, space, n_max=200, **kwargs):
# Create the optimizer first so we can read kernel.length_scale for init.
optimizer = cls(**kwargs)
leaves = jax.tree.leaves(space, is_leaf=lambda x: isinstance(x, sp.Space))
state = BayesianSearchState(
space=space,
X=jnp.zeros((n_max, len(leaves))),
y=jnp.zeros(n_max),
mask=jnp.zeros(n_max, dtype=bool),
log_length_scale=jnp.log(
jnp.ones(len(leaves)) * float(optimizer.kernel.length_scale)
),
)
return state, optimizer
# ------------------------------------------------------------------
# Convenience accessors
# ------------------------------------------------------------------
[docs]
def best_result(self, state: BayesianSearchState) -> jax.Array:
"""Return the best observed raw result (max if maximize, min if minimize)."""
if self.maximize:
return jnp.max(state.y, where=state.mask, initial=-jnp.inf)
else:
return jnp.min(state.y, where=state.mask, initial=jnp.inf)
[docs]
def best_params(self, state: BayesianSearchState):
"""Return the parameter pytree that achieved the best observed result."""
if self.maximize:
best_n = int(jnp.argmax(jnp.where(state.mask, state.y, -jnp.inf)))
else:
best_n = int(jnp.argmin(jnp.where(state.mask, state.y, jnp.inf)))
x_best = state.X[best_n]
_, treedef = jax.tree.flatten(
state.space, is_leaf=lambda x: isinstance(x, sp.Space)
)
# Return scalar leaves (shape ()) — one value per parameter.
return treedef.unflatten([x_best[i] for i in range(treedef.num_leaves)])
# ------------------------------------------------------------------
# Space helpers
# ------------------------------------------------------------------
def _sample_candidates(self, space, key, n):
"""Sample n random candidates from a continuous space."""
leaves = jax.tree.leaves(space, is_leaf=lambda x: isinstance(x, sp.Space))
keys_per_leaf = jax.random.split(key, len(leaves))
cols = [
jax.vmap(lambda k: leaf.sample(k).squeeze())(
jax.random.split(keys_per_leaf[j], n)
)
for j, leaf in enumerate(leaves)
]
return jnp.stack(cols, axis=-1) # (n, n_params)
def _space_bounds(self, space):
"""Returns (lowers, uppers) arrays of shape (n_params,)."""
leaves = jax.tree.leaves(space, is_leaf=lambda x: isinstance(x, sp.Space))
lowers = jnp.array([leaf.lower_bound for leaf in leaves])
uppers = jnp.array([leaf.upper_bound for leaf in leaves])
return lowers, uppers
# ------------------------------------------------------------------
# GP helpers
# ------------------------------------------------------------------
def _effective_y(self, state: BayesianSearchState) -> jax.Array:
"""y in 'higher is better' orientation for GP fitting."""
return state.y if self.maximize else -state.y
def _gp_fit(self, X, y, mask, length_scale):
"""Fit the GP: return (L, alpha, ymean) for use in predictions."""
ymean = jnp.mean(y, where=mask)
y_centered = (y - ymean) * mask
K = self.kernel(X, X, length_scale=length_scale)
M = jnp.outer(mask.astype(float), mask.astype(float))
K = K * M + self.jitter * jnp.eye(X.shape[0])
K += jnp.diag((1.0 - mask.astype(float)) * MASK_VARIANCE)
L = jnp.linalg.cholesky(K)
alpha = jax.scipy.linalg.cho_solve((L, True), y_centered)
return L, alpha, ymean
def _gp_predict(self, X_test, L, alpha, ymean, X_train, length_scale):
"""GP posterior mean and std at X_test given a fitted GP."""
K_star = self.kernel(X_test, X_train, length_scale=length_scale) # (m, n)
mean = K_star @ alpha + ymean
v = jax.scipy.linalg.cho_solve((L, True), K_star.T) # (n, m)
var = jnp.clip(1.0 - jnp.sum(K_star * v.T, axis=1), 0.0)
return mean, jnp.sqrt(var)
def _gp_posterior(self, X, y, mask, X_test, length_scale):
"""Convenience: fit + predict in one call."""
L, alpha, ymean = self._gp_fit(X, y, mask, length_scale)
return self._gp_predict(X_test, L, alpha, ymean, X, length_scale)
@functools.cached_property
def _tune_hparams_fn(self):
"""JIT-compiled hparam tuner, built lazily on first use.
Accepts all varying data as explicit JAX arguments so the compiled
XLA program is reused across iterations regardless of how many
observations have accumulated (no recompilation per new n_seen).
"""
n_steps = self.n_hparam_steps
@jax.jit
def tune(X, y, mask, log_length_scale):
def neg_log_ml(log_ls):
ls = jnp.exp(log_ls)
L, alpha, ymean = self._gp_fit(X, y, mask, ls)
y_c = (y - ymean) * mask
return 0.5 * y_c @ alpha + jnp.sum(jnp.log(jnp.diag(L)))
adam = optax.adam(0.1)
opt_state = adam.init(log_length_scale)
def step(carry, _):
log_ls, opt_state = carry
grad = jax.grad(neg_log_ml)(log_ls)
updates, new_opt_state = adam.update(grad, opt_state)
return (optax.apply_updates(log_ls, updates), new_opt_state), None
(log_ls, _), _ = jax.lax.scan(
step, (log_length_scale, opt_state), None, length=n_steps
)
return log_ls
return tune
def _tune_hparams(self, state: BayesianSearchState) -> jax.Array:
return self._tune_hparams_fn(
state.X, self._effective_y(state), state.mask, state.log_length_scale
)
# ------------------------------------------------------------------
# Parameter selection
# ------------------------------------------------------------------
def _random_select(self, state, key, X_cands):
"""Randomly pick n_parallel candidates (used during warmup)."""
idxs = jax.random.choice(
key, self.n_candidates, (self.n_parallel,), replace=False
)
return X_cands[idxs] # (n_parallel, n_params)
def _gp_select(self, state, key, X_cands, lowers, uppers, length_scale):
"""Kriging Believer: sequential L-BFGS with GP hallucination."""
eff_y = self._effective_y(state)
n_params = state.X.shape[1]
n_max = state.X.shape[0]
X_ext = jnp.concatenate(
[state.X, jnp.zeros((self.n_parallel, n_params))], axis=0
)
y_ext = jnp.concatenate([eff_y, jnp.zeros(self.n_parallel)], axis=0)
mask_ext = jnp.concatenate(
[state.mask, jnp.zeros(self.n_parallel, dtype=bool)], axis=0
)
xs_raw_list = []
for i in range(self.n_parallel):
key, key_liar = jax.random.split(key)
L, alpha, ymean = self._gp_fit(X_ext, y_ext, mask_ext, length_scale)
mean_cands, std_cands = self._gp_predict(
X_cands, L, alpha, ymean, X_ext, length_scale
)
acq_vals = self.acquisition(mean_cands, std_cands)
y_max = jnp.max(y_ext, where=mask_ext, initial=-jnp.inf)
n_seeds = min(self.n_restarts, self.n_candidates)
seed_idxs = jnp.argsort(acq_vals)[-n_seeds:]
seeds = X_cands[seed_idxs] # (n_seeds, n_params)
# L-BFGS restarts: pick best via jnp.where so this is JAX-traceable
solver = optax.lbfgs()
def neg_acq(x):
K_star = self.kernel(x[None], X_ext, length_scale=length_scale)
mean = K_star @ alpha + ymean
v = jax.scipy.linalg.cho_solve((L, True), K_star.T)
std = jnp.sqrt(jnp.clip(1.0 - jnp.sum(K_star * v.T, axis=1), 0.0))
return -self.acquisition(mean, std, y_max=y_max)[0]
def lbfgs_step(carry, _):
x, s = carry
val, grad = jax.value_and_grad(neg_acq)(x)
updates, new_s = solver.update(
grad, s, x, value=val, grad=grad, value_fn=neg_acq
)
return (
jnp.clip(optax.apply_updates(x, updates), lowers, uppers),
new_s,
), None
def _lbfgs_restart(carry, x0):
best_x, best_val = carry
(x_refined, _), _ = jax.lax.scan(
lbfgs_step,
(x0, solver.init(x0)),
None,
length=self.n_lbfgs_steps,
)
mean_r, std_r = self._gp_predict(
x_refined[None], L, alpha, ymean, X_ext, length_scale
)
val = self.acquisition(mean_r, std_r, y_max=y_max)[0]
best_x = jnp.where(val > best_val, x_refined, best_x)
best_val = jnp.where(val > best_val, val, best_val)
return (best_x, best_val), None
(best_x, _), _ = jax.lax.scan(
_lbfgs_restart,
(seeds[-1], acq_vals[seed_idxs[-1]]),
seeds,
)
# Hallucinate: use liar strategy to generate pseudo-observation
mean_i, std_i = self._gp_predict(
best_x[None], L, alpha, ymean, X_ext, length_scale
)
X_ext = X_ext.at[n_max + i].set(best_x)
y_ext = y_ext.at[n_max + i].set(
self.hallucination(mean_i, std_i, key_liar, y_max)
)
mask_ext = mask_ext.at[n_max + i].set(True)
xs_raw_list.append(best_x)
return jnp.stack(xs_raw_list) # (n_parallel, n_params)
def _select_next_x(self, state, key):
key_sample, key_rest = jax.random.split(key)
leaves = jax.tree.leaves(state.space, is_leaf=lambda x: isinstance(x, sp.Space))
_, treedef = jax.tree.flatten(
state.space, is_leaf=lambda x: isinstance(x, sp.Space)
)
lowers, uppers = self._space_bounds(state.space)
length_scale = jnp.exp(state.log_length_scale)
X_cands = self._sample_candidates(
state.space, key_sample, self.n_candidates
).astype(jnp.float32)
# Use lax.cond so this is JAX-traceable (required for lax.scan / vmap)
xs_raw = jax.lax.cond(
state.mask.sum() < self.n_warmup,
lambda k: self._random_select(state, k, X_cands),
lambda k: self._gp_select(state, k, X_cands, lowers, uppers, length_scale),
key_rest,
)
# Apply per-leaf transforms (rounds QLinearSpace/QLogSpace to integers, etc.)
xs_out = jnp.stack(
[
jnp.stack(
[
leaf.transform(xs_raw[j, i : i + 1]).squeeze()
for i, leaf in enumerate(leaves)
]
)
for j in range(self.n_parallel)
]
) # (n_parallel, n_params)
batch_params = treedef.unflatten(
[xs_out[:, i] for i in range(treedef.num_leaves)]
)
return batch_params
[docs]
def get_next_params(self, state, key, params=None, results=None):
"""Select the next batch of ``n_parallel`` candidates.
During the first ``n_warmup`` iterations, candidates are chosen
uniformly at random. Afterwards, the GP posterior is used to maximise
the acquisition function via L-BFGS with Kriging Believer hallucination
for the parallel slots.
"""
return self._select_next_x(state, key)
def _write_observation_batch(self, state, x_new, results, n):
"""Write n_parallel observations to the padded state buffers starting at slot n.
Uses fori_loop to avoid unrolling into n_parallel separate cond nodes in the
XLA graph (prevents linear compile-time growth with n_parallel).
"""
n_max = state.X.shape[0]
n_params = state.X.shape[1]
n_parallel = results.shape[0]
def body(i, s):
slot = n + i
x_row = jax.lax.dynamic_slice(x_new, (i, 0), (1, n_params))
y_scalar = jax.lax.dynamic_slice(results, (i,), (1,))
return jax.lax.cond(
slot < n_max,
lambda s: s.replace(
X=jax.lax.dynamic_update_slice(s.X, x_row, (slot, 0)),
y=jax.lax.dynamic_update_slice(s.y, y_scalar, (slot,)),
mask=jax.lax.dynamic_update_slice(
s.mask, jnp.ones(1, dtype=bool), (slot,)
),
),
lambda s: s,
s,
)
return jax.lax.fori_loop(0, n_parallel, body, state)
[docs]
def update_state(self, state, key, results, params):
"""Record new observations and update ARD length scales.
Writes the batch of results into the fixed-size state buffers and,
if ``n_hparam_steps > 0``, runs a short Adam loop to tune
``log_length_scale`` via marginal-likelihood maximisation.
Args:
state: Current :class:`BayesianSearchState`.
key: PRNG key (unused but kept for API consistency).
results: Array of shape ``(n_parallel,)`` with observed objective
values.
params: Either the batched params pytree returned by
:meth:`get_next_params` (each leaf shape ``(n_parallel,)``), or
a raw ``(n_parallel, n_params)`` flat array.
"""
results = jnp.atleast_1d(jnp.squeeze(results))
n_parallel = results.shape[0] # static Python int
if isinstance(params, jax.Array):
x_new = jnp.atleast_2d(params)
else:
x_new = jnp.stack(jax.tree.leaves(params), axis=-1)
n = state.mask.sum() # dynamic JAX scalar
state = self._write_observation_batch(state, x_new, results, n)
if self.n_hparam_steps > 0:
log_ls = jax.lax.cond(
n + n_parallel >= 2,
self._tune_hparams,
lambda s: s.log_length_scale,
state,
)
state = state.replace(log_length_scale=log_ls)
return state
def _n_iterations(self, state):
"""Number of optimize iterations derived from buffer capacity and n_parallel.
Note: int(state.mask.sum()) forces a device sync — acceptable here since
optimize() is a Python loop.
"""
remaining = state.X.shape[0] - int(state.mask.sum())
n_full = remaining // self.n_parallel
has_overflow = (remaining % self.n_parallel) > 0
return n_full + (1 if has_overflow else 0)
[docs]
def optimize(self, state, key, func, n_iterations=None):
if n_iterations is None:
n_iterations = self._n_iterations(state)
return super().optimize(state, key, func, n_iterations)
[docs]
def optimize_scan(self, state, key, func, n_iterations=None):
if n_iterations is None:
n_iterations = self._n_iterations(state)
return super().optimize_scan(state, key, func, n_iterations)