hyperoptax.grid#

Classes

GridSearch([shuffle, n_parallel])

Exhaustive grid search over a discrete search space.

class hyperoptax.grid.GridSearchState(space, grid, grid_idx)[source]#

State for GridSearch.

Parameters:
  • space (PyTree)

  • grid (Array)

  • grid_idx (int)

grid#

Array of shape (n_total, n_params) containing all parameter combinations, pre-truncated to a multiple of n_parallel.

Type:

jax.Array

grid_idx#

Current position in grid; incremented by n_parallel after each call to update_state.

Type:

int

grid: Array#
grid_idx: int#
__init__(space, grid, grid_idx)#
Parameters:
  • space (PyTree)

  • grid (Array)

  • grid_idx (int)

Return type:

None

class hyperoptax.grid.GridSearch(shuffle=False, n_parallel=1)[source]#

Exhaustive grid search over a discrete search space.

Iterates through every combination of the provided DiscreteSpace values in order (or randomly if shuffle=True). All spaces in the search space must be DiscreteSpace.

Parameters:
shuffle#

If True, randomise the traversal order during init. Pass an explicit key to init for reproducibility.

Type:

bool

n_parallel#

Number of grid points evaluated per iteration.

Type:

int

shuffle: bool = False#
n_parallel: int = 1#
classmethod init(space, key=None, **kwargs)[source]#

Initialise the grid search.

Parameters:
  • space – A pytree of DiscreteSpace objects. All leaves must be DiscreteSpace; mixed spaces are not supported.

  • key – Optional PRNG key used when shuffle=True. Falls back to PRNGKey(0) when None.

  • **kwargs – Forwarded to GridSearch constructor (e.g. n_parallel, shuffle).

Returns:

(state, optimizer) tuple.

get_next_params(state, key, params=None, results=None)[source]#

Return the next n_parallel parameter combinations from the grid.

Parameters:

state (GridSearchState)

Return type:

PyTree

update_state(state, key, results, params=None)[source]#

Advance the grid index by n_parallel.

Parameters:

state (GridSearchState)

Return type:

GridSearchState

__init__(shuffle=False, n_parallel=1)#
Parameters:
Return type:

None