hyperoptax package

Contents

hyperoptax package#

class hyperoptax.BayesianSearch(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)[source]#

Bases: Optimizer

Bayesian optimisation with a Gaussian Process surrogate.

Uses a GP (Matérn 0.5 kernel by default) to model the objective and selects the next batch of candidates by maximising an acquisition function (PI by default). ARD length scales are tuned with Adam each iteration. Parallel batches are generated via the Kriging Believer hallucination strategy.

Parameters:
jitter#

Small diagonal added to the kernel matrix for numerical stability (default 1e-6).

Type:

float

kernel#

Kernel function (default Matern with nu=0.5).

Type:

hyperoptax.kernels.BaseKernel

acquisition#

Acquisition function (default PI with xi=0.01).

Type:

hyperoptax.acquisition.BaseAcquisition

n_candidates#

Number of random candidates sampled per iteration for the discrete pre-selection step (default 1000).

Type:

int

n_restarts#

Number of L-BFGS restarts seeded from the top candidates (default 2).

Type:

int

n_lbfgs_steps#

Gradient steps per L-BFGS restart (default 10).

Type:

int

n_hparam_steps#

Adam steps used to tune log_length_scale each iteration (default 20). Set to 0 to disable.

Type:

int

n_warmup#

Number of pure-random iterations before the GP is used (default 1).

Type:

int

maximize#

Set False to minimise the objective (default True).

Type:

bool

n_parallel#

Number of parallel candidates per iteration (default 4).

Type:

int

hallucination#

Hallucination strategy for Kriging Believer parallel selection (default SampleHallucination).

Type:

hyperoptax.acquisition.BaseHallucination

jitter: float = 1e-06#
kernel: BaseKernel#
acquisition: BaseAcquisition#
n_candidates: int = 1000#
n_restarts: int = 2#
n_lbfgs_steps: int = 10#
n_hparam_steps: int = 20#
n_warmup: int = 1#
maximize: bool = True#
n_parallel: int = 4#
hallucination: BaseHallucination#
classmethod init(space, n_max=200, **kwargs)[source]#
best_result(state)[source]#

Return the best observed raw result (max if maximize, min if minimize).

Parameters:

state (BayesianSearchState)

Return type:

Array

best_params(state)[source]#

Return the parameter pytree that achieved the best observed result.

Parameters:

state (BayesianSearchState)

get_next_params(state, key, params=None, results=None)[source]#

Select the next batch of n_parallel candidates.

During the first n_warmup iterations, candidates are chosen uniformly at random. Afterwards, the GP posterior is used to maximise the acquisition function via L-BFGS with Kriging Believer hallucination for the parallel slots.

update_state(state, key, results, params)[source]#

Record new observations and update ARD length scales.

Writes the batch of results into the fixed-size state buffers and, if n_hparam_steps > 0, runs a short Adam loop to tune log_length_scale via marginal-likelihood maximisation.

Parameters:
  • state – Current BayesianSearchState.

  • key – PRNG key (unused but kept for API consistency).

  • results – Array of shape (n_parallel,) with observed objective values.

  • params – Either the batched params pytree returned by get_next_params() (each leaf shape (n_parallel,)), or a raw (n_parallel, n_params) flat array.

optimize(state, key, func, n_iterations=None)[source]#

High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

optimize_scan(state, key, func, n_iterations=None)[source]#

Like optimize, but uses jax.lax.scan for the inner loop.

Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

__init__(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)#
Parameters:
Return type:

None

class hyperoptax.GridSearch(shuffle=False, n_parallel=1)[source]#

Bases: Optimizer

Exhaustive grid search over a discrete search space.

Iterates through every combination of the provided DiscreteSpace values in order (or randomly if shuffle=True). All spaces in the search space must be DiscreteSpace.

Parameters:
shuffle#

If True, randomise the traversal order during init. Pass an explicit key to init for reproducibility.

Type:

bool

n_parallel#

Number of grid points evaluated per iteration.

Type:

int

shuffle: bool = False#
n_parallel: int = 1#
classmethod init(space, key=None, **kwargs)[source]#

Initialise the grid search.

Parameters:
  • space – A pytree of DiscreteSpace objects. All leaves must be DiscreteSpace; mixed spaces are not supported.

  • key – Optional PRNG key used when shuffle=True. Falls back to PRNGKey(0) when None.

  • **kwargs – Forwarded to GridSearch constructor (e.g. n_parallel, shuffle).

Returns:

(state, optimizer) tuple.

get_next_params(state, key, params=None, results=None)[source]#

Return the next n_parallel parameter combinations from the grid.

Parameters:

state (GridSearchState)

Return type:

PyTree

update_state(state, key, results, params=None)[source]#

Advance the grid index by n_parallel.

Parameters:

state (GridSearchState)

Return type:

GridSearchState

__init__(shuffle=False, n_parallel=1)#
Parameters:
Return type:

None

class hyperoptax.RandomSearch(n_parallel=1)[source]#

Bases: Optimizer

Stateless random search — samples each space independently each iteration.

No model is fitted and no history is maintained, so this is the cheapest optimizer and useful as a strong baseline.

Parameters:

n_parallel (int)

n_parallel#

Number of random configurations evaluated per iteration.

Type:

int

n_parallel: int = 1#
classmethod init(space, **kwargs)[source]#
get_next_params(state, key, params=None, results=None)[source]#

Sample n_parallel independent configurations from the search space.

Parameters:
Return type:

PyTree

update_state(state, key, results, params=None)[source]#

RandomSearch is memoryless, no state to update.

Parameters:
Return type:

OptimizerState

__init__(n_parallel=1)#
Parameters:

n_parallel (int)

Return type:

None

class hyperoptax.DiscreteSpace(values)[source]#

Bases: Space

Discrete space over a fixed set of values.

Samples uniformly from values. transform snaps any continuous value to the nearest element, which is useful when discrete candidates are generated via continuous optimization (e.g. in BayesianSearch).

Parameters:

values (tuple)

values#

Tuple of candidate values to sample from.

Type:

tuple

values: tuple#
property lower_bound: float#
property upper_bound: float#
sample(key)[source]#
Parameters:

key (PRNGKey)

Return type:

float

transform(value)[source]#
Return type:

Array

__init__(values)#
Parameters:

values (tuple)

Return type:

None

class hyperoptax.LinearSpace(lower_bound, upper_bound)[source]#

Bases: Space

Uniform continuous space over [lower_bound, upper_bound].

Parameters:
lower_bound#

Inclusive lower bound of the interval.

Type:

float

upper_bound#

Exclusive upper bound of the interval.

Type:

float

lower_bound: float#
upper_bound: float#
sample(key)[source]#
Parameters:

key (PRNGKey)

Return type:

float

__init__(lower_bound, upper_bound)#
Parameters:
Return type:

None

class hyperoptax.LogSpace(lower_bound, upper_bound, base=10)[source]#

Bases: LinearSpace

Log-uniform continuous space over [lower_bound, upper_bound].

Samples uniformly in log space so that each order of magnitude receives equal probability mass. Useful for learning rates and other scale parameters that span several orders of magnitude.

Parameters:
lower_bound#

Inclusive lower bound (in original scale, e.g. 1e-5).

Type:

float

upper_bound#

Exclusive upper bound (in original scale, e.g. 1e-1).

Type:

float

base#

Logarithm base (default 10). Must be greater than 1.

Type:

float

base: float = 10#
sample(key)[source]#
Parameters:

key (PRNGKey)

Return type:

Array

__init__(lower_bound, upper_bound, base=10)#
Parameters:
Return type:

None

lower_bound: float#
upper_bound: float#
class hyperoptax.QLinearSpace(lower_bound, upper_bound, datatype=<class 'jax.numpy.int32'>)[source]#

Bases: LinearSpace

Quantized (integer) variant of LinearSpace.

Samples uniformly from [lower_bound, upper_bound] and rounds to the nearest integer. Use this for discrete integer hyperparameters with a uniform prior (e.g. number of layers, batch size).

Parameters:
lower_bound#

Inclusive lower bound.

Type:

float

upper_bound#

Exclusive upper bound.

Type:

float

datatype#

Integer dtype used after rounding (default jnp.int32).

Type:

type

datatype#

alias of int32

transform(value)[source]#
Return type:

Array

__init__(lower_bound, upper_bound, datatype=<class 'jax.numpy.int32'>)#
Parameters:
Return type:

None

class hyperoptax.QLogSpace(lower_bound, upper_bound, base=10, datatype=<class 'jax.numpy.int32'>)[source]#

Bases: LogSpace

Quantized (integer) variant of LogSpace.

Samples in log space and rounds to the nearest integer. Use this for integer hyperparameters whose scale spans orders of magnitude (e.g. number of hidden units, number of warmup steps).

Parameters:
datatype#

alias of int32

transform(value)[source]#
Return type:

Array

__init__(lower_bound, upper_bound, base=10, datatype=<class 'jax.numpy.int32'>)#
Parameters:
Return type:

None

class hyperoptax.EI(xi=0.01)[source]#

Bases: BaseAcquisition

Expected Improvement acquisition function.

Parameters:

xi (float)

__init__(xi=0.01)[source]#
Parameters:

xi (float)

class hyperoptax.PI(xi=0.01)[source]#

Bases: BaseAcquisition

Probability of Improvement acquisition function.

Parameters:

xi (float)

__init__(xi=0.01)[source]#
Parameters:

xi (float)

class hyperoptax.UCB(kappa=2.0)[source]#

Bases: BaseAcquisition

Upper Confidence Bound acquisition function.

Parameters:

kappa (float)

__init__(kappa=2.0)[source]#
Parameters:

kappa (float)

class hyperoptax.BaseHallucination[source]#

Bases: object

Base class for Kriging Believer hallucination strategies.

Any callable with signature (mean, std, key, y_max) -> scalar can be used as a hallucination strategy — subclassing is optional.

class hyperoptax.MeanHallucination[source]#

Bases: BaseHallucination

Classical Kriging Believer: hallucinate with GP posterior mean.

class hyperoptax.SampleHallucination[source]#

Bases: BaseHallucination

Randomized Kriging Believer (RKB): hallucinate with a posterior sample.

arXiv 2603.01470.

class hyperoptax.UCBHallucination(kappa=2.0)[source]#

Bases: BaseHallucination

Optimistic hallucination: mean + kappa * std.

Parameters:

kappa (float)

__init__(kappa=2.0)[source]#
Parameters:

kappa (float)

class hyperoptax.ConstantHallucination(value=None)[source]#

Bases: BaseHallucination

Ginsbourger et al. 2010: hallucinate with y_max or a fixed constant.

If value is None, uses the current best observed value (y_max). Otherwise uses the fixed value regardless of observations.

Parameters:

value (float | None)

__init__(value=None)[source]#
Parameters:

value (float | None)

class hyperoptax.Matern(length_scale=1.0, nu=2.5)[source]#

Bases: BaseKernel

Matern kernel family.

Parameters:
  • length_scale (float, default = 1.0) – Characteristic length scale.

  • nu (float, default = 2.5) – Controls smoothness (nu ∈ {0.5, 1.5, 2.5, ∞}).

__init__(length_scale=1.0, nu=2.5)[source]#
Parameters:
class hyperoptax.RBF(length_scale=1.0)[source]#

Bases: BaseKernel

Radial basis function (RBF) / squared-exponential kernel.

Parameters:

length_scale (float)

__init__(length_scale=1.0)[source]#
Parameters:

length_scale (float)

Submodules#

hyperoptax.acquisition module#

class hyperoptax.acquisition.BaseAcquisition[source]#

Bases: object

Base class for acquisition functions.

get_argmax(mean, std, seen_mask, n_points=1)[source]#
Parameters:
class hyperoptax.acquisition.UCB(kappa=2.0)[source]#

Bases: BaseAcquisition

Upper Confidence Bound acquisition function.

Parameters:

kappa (float)

__init__(kappa=2.0)[source]#
Parameters:

kappa (float)

class hyperoptax.acquisition.EI(xi=0.01)[source]#

Bases: BaseAcquisition

Expected Improvement acquisition function.

Parameters:

xi (float)

__init__(xi=0.01)[source]#
Parameters:

xi (float)

class hyperoptax.acquisition.PI(xi=0.01)[source]#

Bases: BaseAcquisition

Probability of Improvement acquisition function.

Parameters:

xi (float)

__init__(xi=0.01)[source]#
Parameters:

xi (float)

class hyperoptax.acquisition.BaseHallucination[source]#

Bases: object

Base class for Kriging Believer hallucination strategies.

Any callable with signature (mean, std, key, y_max) -> scalar can be used as a hallucination strategy — subclassing is optional.

class hyperoptax.acquisition.MeanHallucination[source]#

Bases: BaseHallucination

Classical Kriging Believer: hallucinate with GP posterior mean.

class hyperoptax.acquisition.SampleHallucination[source]#

Bases: BaseHallucination

Randomized Kriging Believer (RKB): hallucinate with a posterior sample.

arXiv 2603.01470.

class hyperoptax.acquisition.UCBHallucination(kappa=2.0)[source]#

Bases: BaseHallucination

Optimistic hallucination: mean + kappa * std.

Parameters:

kappa (float)

__init__(kappa=2.0)[source]#
Parameters:

kappa (float)

class hyperoptax.acquisition.ConstantHallucination(value=None)[source]#

Bases: BaseHallucination

Ginsbourger et al. 2010: hallucinate with y_max or a fixed constant.

If value is None, uses the current best observed value (y_max). Otherwise uses the fixed value regardless of observations.

Parameters:

value (float | None)

__init__(value=None)[source]#
Parameters:

value (float | None)

hyperoptax.base module#

class hyperoptax.base.OptimizerState(space)[source]#

Bases: object

Base optimizer state — a JAX pytree holding the search space definition.

Parameters:

space (PyTree)

space: PyTree#
replace(**kwargs)[source]#
Return type:

OptimizerState

__init__(space)#
Parameters:

space (PyTree)

Return type:

None

class hyperoptax.base.Optimizer[source]#

Bases: object

n_parallel: int = 1#
classmethod init(space, **kwargs)[source]#
Return type:

OptimizerState

optimize(state, key, func, n_iterations)[source]#

High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

Parameters:
Return type:

tuple[OptimizerState, tuple[PyTree, Array]]

optimize_scan(state, key, func, n_iterations)[source]#

Like optimize, but uses jax.lax.scan for the inner loop.

Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

Parameters:
Return type:

tuple[OptimizerState, tuple[PyTree, Array]]

update_state(state, key, results, params=None)[source]#

Updates the optimizer state based on the results of the function.

Parameters:
Return type:

OptimizerState

get_next_params(state, key, params=None, results=None)[source]#

Gets the next parameters to sample from the space. Returns a batched pytree where every leaf has shape (n_parallel, …). params and results are the previous iteration’s values (None on first call).

Parameters:
Return type:

PyTree

hyperoptax.bayesian module#

class hyperoptax.bayesian.BayesianSearchState(space, X, y, mask, log_length_scale)[source]#

Bases: OptimizerState

State for BayesianSearch.

All arrays are fixed-size (shape determined by n_max at init time) to satisfy JAX’s static-shape requirement. The mask field tracks which entries have been written.

Parameters:
X#

Observation inputs, shape (n_max, n_params), zero-padded.

Type:

jax.Array

y#

Observed results, shape (n_max,), zero-padded, stored as raw (un-negated) values regardless of maximize.

Type:

jax.Array

mask#

Boolean validity mask, shape (n_max,); True for slots that contain real observations.

Type:

jax.Array

log_length_scale#

Per-dimension ARD length scales in log space, shape (n_params,). Tuned by Adam each iteration.

Type:

jax.Array

X: Array#
y: Array#
mask: Array#
log_length_scale: Array#
__init__(space, X, y, mask, log_length_scale)#
Parameters:
Return type:

None

space: PyTree#
class hyperoptax.bayesian.BayesianSearch(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)[source]#

Bases: Optimizer

Bayesian optimisation with a Gaussian Process surrogate.

Uses a GP (Matérn 0.5 kernel by default) to model the objective and selects the next batch of candidates by maximising an acquisition function (PI by default). ARD length scales are tuned with Adam each iteration. Parallel batches are generated via the Kriging Believer hallucination strategy.

Parameters:
jitter#

Small diagonal added to the kernel matrix for numerical stability (default 1e-6).

Type:

float

kernel#

Kernel function (default Matern with nu=0.5).

Type:

hyperoptax.kernels.BaseKernel

acquisition#

Acquisition function (default PI with xi=0.01).

Type:

hyperoptax.acquisition.BaseAcquisition

n_candidates#

Number of random candidates sampled per iteration for the discrete pre-selection step (default 1000).

Type:

int

n_restarts#

Number of L-BFGS restarts seeded from the top candidates (default 2).

Type:

int

n_lbfgs_steps#

Gradient steps per L-BFGS restart (default 10).

Type:

int

n_hparam_steps#

Adam steps used to tune log_length_scale each iteration (default 20). Set to 0 to disable.

Type:

int

n_warmup#

Number of pure-random iterations before the GP is used (default 1).

Type:

int

maximize#

Set False to minimise the objective (default True).

Type:

bool

n_parallel#

Number of parallel candidates per iteration (default 4).

Type:

int

hallucination#

Hallucination strategy for Kriging Believer parallel selection (default SampleHallucination).

Type:

hyperoptax.acquisition.BaseHallucination

jitter: float = 1e-06#
kernel: BaseKernel#
acquisition: BaseAcquisition#
n_candidates: int = 1000#
n_restarts: int = 2#
n_lbfgs_steps: int = 10#
n_hparam_steps: int = 20#
n_warmup: int = 1#
maximize: bool = True#
n_parallel: int = 4#
hallucination: BaseHallucination#
classmethod init(space, n_max=200, **kwargs)[source]#
best_result(state)[source]#

Return the best observed raw result (max if maximize, min if minimize).

Parameters:

state (BayesianSearchState)

Return type:

Array

best_params(state)[source]#

Return the parameter pytree that achieved the best observed result.

Parameters:

state (BayesianSearchState)

get_next_params(state, key, params=None, results=None)[source]#

Select the next batch of n_parallel candidates.

During the first n_warmup iterations, candidates are chosen uniformly at random. Afterwards, the GP posterior is used to maximise the acquisition function via L-BFGS with Kriging Believer hallucination for the parallel slots.

update_state(state, key, results, params)[source]#

Record new observations and update ARD length scales.

Writes the batch of results into the fixed-size state buffers and, if n_hparam_steps > 0, runs a short Adam loop to tune log_length_scale via marginal-likelihood maximisation.

Parameters:
  • state – Current BayesianSearchState.

  • key – PRNG key (unused but kept for API consistency).

  • results – Array of shape (n_parallel,) with observed objective values.

  • params – Either the batched params pytree returned by get_next_params() (each leaf shape (n_parallel,)), or a raw (n_parallel, n_params) flat array.

optimize(state, key, func, n_iterations=None)[source]#

High Level API for optimizing a function over a space. Not recommended if you want to do fancy things with parallel computation.

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

optimize_scan(state, key, func, n_iterations=None)[source]#

Like optimize, but uses jax.lax.scan for the inner loop.

Requires func to be JAX-traceable (jit-compilable). Returns stacked arrays instead of lists: params_hist is a pytree where each leaf has shape (n_iterations, n_parallel, …), and results_hist has shape (n_iterations, n_parallel).

func must return a scalar (() shape). If your function returns shape (1,), call .squeeze() inside func before returning.

__init__(jitter=1e-06, kernel=<factory>, acquisition=<factory>, n_candidates=1000, n_restarts=2, n_lbfgs_steps=10, n_hparam_steps=20, n_warmup=1, maximize=True, n_parallel=4, hallucination=<factory>)#
Parameters:
Return type:

None

hyperoptax.grid module#

class hyperoptax.grid.GridSearchState(space, grid, grid_idx)[source]#

Bases: OptimizerState

State for GridSearch.

Parameters:
  • space (PyTree)

  • grid (Array)

  • grid_idx (int)

grid#

Array of shape (n_total, n_params) containing all parameter combinations, pre-truncated to a multiple of n_parallel.

Type:

jax.Array

grid_idx#

Current position in grid; incremented by n_parallel after each call to update_state.

Type:

int

grid: Array#
grid_idx: int#
__init__(space, grid, grid_idx)#
Parameters:
  • space (PyTree)

  • grid (Array)

  • grid_idx (int)

Return type:

None

space: PyTree#
class hyperoptax.grid.GridSearch(shuffle=False, n_parallel=1)[source]#

Bases: Optimizer

Exhaustive grid search over a discrete search space.

Iterates through every combination of the provided DiscreteSpace values in order (or randomly if shuffle=True). All spaces in the search space must be DiscreteSpace.

Parameters:
shuffle#

If True, randomise the traversal order during init. Pass an explicit key to init for reproducibility.

Type:

bool

n_parallel#

Number of grid points evaluated per iteration.

Type:

int

shuffle: bool = False#
n_parallel: int = 1#
classmethod init(space, key=None, **kwargs)[source]#

Initialise the grid search.

Parameters:
  • space – A pytree of DiscreteSpace objects. All leaves must be DiscreteSpace; mixed spaces are not supported.

  • key – Optional PRNG key used when shuffle=True. Falls back to PRNGKey(0) when None.

  • **kwargs – Forwarded to GridSearch constructor (e.g. n_parallel, shuffle).

Returns:

(state, optimizer) tuple.

get_next_params(state, key, params=None, results=None)[source]#

Return the next n_parallel parameter combinations from the grid.

Parameters:

state (GridSearchState)

Return type:

PyTree

update_state(state, key, results, params=None)[source]#

Advance the grid index by n_parallel.

Parameters:

state (GridSearchState)

Return type:

GridSearchState

__init__(shuffle=False, n_parallel=1)#
Parameters:
Return type:

None

hyperoptax.kernels module#

hyperoptax.kernels.cdist(x, y)[source]#

Pairwise Euclidean distance (cdist) between two 2-D arrays.

Parameters:
  • x (jax.Array) – Arrays with shape (N, D) and (M, D), respectively.

  • y (jax.Array) – Arrays with shape (N, D) and (M, D), respectively.

Returns:

A distance matrix of shape (N, M).

Return type:

jax.Array

class hyperoptax.kernels.BaseKernel[source]#

Bases: ABC

Abstract base class for positive-definite kernels.

class hyperoptax.kernels.RBF(length_scale=1.0)[source]#

Bases: BaseKernel

Radial basis function (RBF) / squared-exponential kernel.

Parameters:

length_scale (float)

__init__(length_scale=1.0)[source]#
Parameters:

length_scale (float)

class hyperoptax.kernels.Matern(length_scale=1.0, nu=2.5)[source]#

Bases: BaseKernel

Matern kernel family.

Parameters:
  • length_scale (float, default = 1.0) – Characteristic length scale.

  • nu (float, default = 2.5) – Controls smoothness (nu ∈ {0.5, 1.5, 2.5, ∞}).

__init__(length_scale=1.0, nu=2.5)[source]#
Parameters:

hyperoptax.spaces module#

hyperoptax.spaces.log_transform(x, base)[source]#
Parameters:
Return type:

float

class hyperoptax.spaces.Space[source]#

Bases: ABC

Abstract base class for hyperparameter search spaces.

abstract sample(key)[source]#
Parameters:

key (PRNGKey)

Return type:

Array

transform(value)[source]#
__init__()#
Return type:

None

class hyperoptax.spaces.LinearSpace(lower_bound, upper_bound)[source]#

Bases: Space

Uniform continuous space over [lower_bound, upper_bound].

Parameters:
lower_bound#

Inclusive lower bound of the interval.

Type:

float

upper_bound#

Exclusive upper bound of the interval.

Type:

float

lower_bound: float#
upper_bound: float#
sample(key)[source]#
Parameters:

key (PRNGKey)

Return type:

float

__init__(lower_bound, upper_bound)#
Parameters:
Return type:

None

class hyperoptax.spaces.DiscreteSpace(values)[source]#

Bases: Space

Discrete space over a fixed set of values.

Samples uniformly from values. transform snaps any continuous value to the nearest element, which is useful when discrete candidates are generated via continuous optimization (e.g. in BayesianSearch).

Parameters:

values (tuple)

values#

Tuple of candidate values to sample from.

Type:

tuple

values: tuple#
property lower_bound: float#
property upper_bound: float#
sample(key)[source]#
Parameters:

key (PRNGKey)

Return type:

float

transform(value)[source]#
Return type:

Array

__init__(values)#
Parameters:

values (tuple)

Return type:

None

class hyperoptax.spaces.LogSpace(lower_bound, upper_bound, base=10)[source]#

Bases: LinearSpace

Log-uniform continuous space over [lower_bound, upper_bound].

Samples uniformly in log space so that each order of magnitude receives equal probability mass. Useful for learning rates and other scale parameters that span several orders of magnitude.

Parameters:
lower_bound#

Inclusive lower bound (in original scale, e.g. 1e-5).

Type:

float

upper_bound#

Exclusive upper bound (in original scale, e.g. 1e-1).

Type:

float

base#

Logarithm base (default 10). Must be greater than 1.

Type:

float

base: float = 10#
sample(key)[source]#
Parameters:

key (PRNGKey)

Return type:

Array

__init__(lower_bound, upper_bound, base=10)#
Parameters:
Return type:

None

lower_bound: float#
upper_bound: float#
class hyperoptax.spaces.QLinearSpace(lower_bound, upper_bound, datatype=<class 'jax.numpy.int32'>)[source]#

Bases: LinearSpace

Quantized (integer) variant of LinearSpace.

Samples uniformly from [lower_bound, upper_bound] and rounds to the nearest integer. Use this for discrete integer hyperparameters with a uniform prior (e.g. number of layers, batch size).

Parameters:
lower_bound#

Inclusive lower bound.

Type:

float

upper_bound#

Exclusive upper bound.

Type:

float

datatype#

Integer dtype used after rounding (default jnp.int32).

Type:

type

datatype#

alias of int32

transform(value)[source]#
Return type:

Array

__init__(lower_bound, upper_bound, datatype=<class 'jax.numpy.int32'>)#
Parameters:
Return type:

None

class hyperoptax.spaces.QLogSpace(lower_bound, upper_bound, base=10, datatype=<class 'jax.numpy.int32'>)[source]#

Bases: LogSpace

Quantized (integer) variant of LogSpace.

Samples in log space and rounds to the nearest integer. Use this for integer hyperparameters whose scale spans orders of magnitude (e.g. number of hidden units, number of warmup steps).

Parameters:
datatype#

alias of int32

transform(value)[source]#
Return type:

Array

__init__(lower_bound, upper_bound, base=10, datatype=<class 'jax.numpy.int32'>)#
Parameters:
Return type:

None