Optimizers

Contents

Optimizers#

Hyperoptax provides several optimization algorithms for hyperparameter tuning.

Base Optimizer#

class hyperoptax.base.OptimizerState(space)[source]#

Bases: object

Base optimizer state — a JAX pytree holding the search space definition.

Parameters:

space (PyTree)

space: PyTree#
replace(**kwargs)[source]#
Return type:

OptimizerState

__init__(space)#
Parameters:

space (PyTree)

Return type:

None

class hyperoptax.base.Optimizer[source]#

Bases: object

n_parallel: int = 1#
classmethod init(space, **kwargs)[source]#
Return type:

OptimizerState

optimize(state, key, func, n_iterations)[source]#

High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

Parameters:
Return type:

tuple[OptimizerState, tuple[PyTree, Array]]

optimize_scan(state, key, func, n_iterations)[source]#

Like optimize, but uses jax.lax.scan for the inner loop.

Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

Parameters:
Return type:

tuple[OptimizerState, tuple[PyTree, Array]]

update_state(state, key, results, params=None)[source]#

Updates the optimizer state based on the results of the function.

Parameters:
Return type:

OptimizerState

get_next_params(state, key, params=None, results=None)[source]#

Gets the next parameters to sample from the space. Returns a batched pytree where every leaf has shape (n_parallel, …). params and results are the previous iteration’s values (None on first call).

Parameters:
Return type:

PyTree

Bayesian Optimizer#

class hyperoptax.bayesian.BayesianSearchState(space, X, y, mask, log_length_scale)[source]#

Bases: OptimizerState

State for BayesianSearch.

All arrays are fixed-size (shape determined by n_max at init time) to satisfy JAX’s static-shape requirement. The mask field tracks which entries have been written.

Parameters:
X#

Observation inputs, shape (n_max, n_params), zero-padded.

Type:

jax.Array

y#

Observed results, shape (n_max,), zero-padded, stored as raw (un-negated) values regardless of maximize.

Type:

jax.Array

mask#

Boolean validity mask, shape (n_max,); True for slots that contain real observations.

Type:

jax.Array

log_length_scale#

Per-dimension ARD length scales in log space, shape (n_params,). Tuned by Adam each iteration.

Type:

jax.Array

X: Array#
y: Array#
mask: Array#
log_length_scale: Array#
__init__(space, X, y, mask, log_length_scale)#
Parameters:
Return type:

None

space: PyTree#
class hyperoptax.bayesian.BayesianSearch(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)[source]#

Bases: Optimizer

Bayesian optimisation with a Gaussian Process surrogate.

Uses a GP (Matérn 0.5 kernel by default) to model the objective and selects the next batch of candidates by maximising an acquisition function (PI by default). ARD length scales are tuned with Adam each iteration. Parallel batches are generated via the Kriging Believer hallucination strategy.

Parameters:
jitter#

Small diagonal added to the kernel matrix for numerical stability (default 1e-6).

Type:

float

kernel#

Kernel function (default Matern with nu=0.5).

Type:

hyperoptax.kernels.BaseKernel

acquisition#

Acquisition function (default PI with xi=0.01).

Type:

hyperoptax.acquisition.BaseAcquisition

n_candidates#

Number of random candidates sampled per iteration for the discrete pre-selection step (default 1000).

Type:

int

n_restarts#

Number of L-BFGS restarts seeded from the top candidates (default 2).

Type:

int

n_lbfgs_steps#

Gradient steps per L-BFGS restart (default 10).

Type:

int

n_hparam_steps#

Adam steps used to tune log_length_scale each iteration (default 20). Set to 0 to disable.

Type:

int

n_warmup#

Number of pure-random iterations before the GP is used (default 1).

Type:

int

maximize#

Set False to minimise the objective (default True).

Type:

bool

n_parallel#

Number of parallel candidates per iteration (default 4).

Type:

int

hallucination#

Hallucination strategy for Kriging Believer parallel selection (default SampleHallucination).

Type:

hyperoptax.acquisition.BaseHallucination

jitter: float = 1e-06#
kernel: BaseKernel#
acquisition: BaseAcquisition#
n_candidates: int = 1000#
n_restarts: int = 2#
n_lbfgs_steps: int = 10#
n_hparam_steps: int = 20#
n_warmup: int = 1#
maximize: bool = True#
n_parallel: int = 4#
hallucination: BaseHallucination#
classmethod init(space, n_max=200, **kwargs)[source]#
best_result(state)[source]#

Return the best observed raw result (max if maximize, min if minimize).

Parameters:

state (BayesianSearchState)

Return type:

Array

best_params(state)[source]#

Return the parameter pytree that achieved the best observed result.

Parameters:

state (BayesianSearchState)

get_next_params(state, key, params=None, results=None)[source]#

Select the next batch of n_parallel candidates.

During the first n_warmup iterations, candidates are chosen uniformly at random. Afterwards, the GP posterior is used to maximise the acquisition function via L-BFGS with Kriging Believer hallucination for the parallel slots.

update_state(state, key, results, params)[source]#

Record new observations and update ARD length scales.

Writes the batch of results into the fixed-size state buffers and, if n_hparam_steps > 0, runs a short Adam loop to tune log_length_scale via marginal-likelihood maximisation.

Parameters:
  • state – Current BayesianSearchState.

  • key – PRNG key (unused but kept for API consistency).

  • results – Array of shape (n_parallel,) with observed objective values.

  • params – Either the batched params pytree returned by get_next_params() (each leaf shape (n_parallel,)), or a raw (n_parallel, n_params) flat array.

optimize(state, key, func, n_iterations=None)[source]#

High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

optimize_scan(state, key, func, n_iterations=None)[source]#

Like optimize, but uses jax.lax.scan for the inner loop.

Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

__init__(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)#
Parameters:
Return type:

None

Grid Search#

class hyperoptax.grid.GridSearchState(space, grid, grid_idx)[source]#

Bases: OptimizerState

State for GridSearch.

Parameters:
  • space (PyTree)

  • grid (Array)

  • grid_idx (int)

grid#

Array of shape (n_total, n_params) containing all parameter combinations, pre-truncated to a multiple of n_parallel.

Type:

jax.Array

grid_idx#

Current position in grid; incremented by n_parallel after each call to update_state.

Type:

int

grid: Array#
grid_idx: int#
__init__(space, grid, grid_idx)#
Parameters:
  • space (PyTree)

  • grid (Array)

  • grid_idx (int)

Return type:

None

space: PyTree#
class hyperoptax.grid.GridSearch(shuffle=False, n_parallel=1)[source]#

Bases: Optimizer

Exhaustive grid search over a discrete search space.

Iterates through every combination of the provided DiscreteSpace values in order (or randomly if shuffle=True). All spaces in the search space must be DiscreteSpace.

Parameters:
shuffle#

If True, randomise the traversal order during init. Pass an explicit key to init for reproducibility.

Type:

bool

n_parallel#

Number of grid points evaluated per iteration.

Type:

int

shuffle: bool = False#
n_parallel: int = 1#
classmethod init(space, key=None, **kwargs)[source]#

Initialise the grid search.

Parameters:
  • space – A pytree of DiscreteSpace objects. All leaves must be DiscreteSpace; mixed spaces are not supported.

  • key – Optional PRNG key used when shuffle=True. Falls back to PRNGKey(0) when None.

  • **kwargs – Forwarded to GridSearch constructor (e.g. n_parallel, shuffle).

Returns:

(state, optimizer) tuple.

get_next_params(state, key, params=None, results=None)[source]#

Return the next n_parallel parameter combinations from the grid.

Parameters:

state (GridSearchState)

Return type:

PyTree

update_state(state, key, results, params=None)[source]#

Advance the grid index by n_parallel.

Parameters:

state (GridSearchState)

Return type:

GridSearchState

__init__(shuffle=False, n_parallel=1)#
Parameters:
Return type:

None

Random Search#

class hyperoptax.random.RandomSearch(n_parallel=1)[source]#

Bases: Optimizer

Stateless random search — samples each space independently each iteration.

No model is fitted and no history is maintained, so this is the cheapest optimizer and useful as a strong baseline.

Parameters:

n_parallel (int)

n_parallel#

Number of random configurations evaluated per iteration.

Type:

int

n_parallel: int = 1#
classmethod init(space, **kwargs)[source]#
get_next_params(state, key, params=None, results=None)[source]#

Sample n_parallel independent configurations from the search space.

Parameters:
Return type:

PyTree

update_state(state, key, results, params=None)[source]#

RandomSearch is memoryless, no state to update.

Parameters:
Return type:

OptimizerState

__init__(n_parallel=1)#
Parameters:

n_parallel (int)

Return type:

None