Source code for hyperoptax.grid
import dataclasses
import jax
import jax.numpy as jnp
from jaxtyping import PyTree
from hyperoptax import base
from hyperoptax import spaces as sp
[docs]
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class GridSearchState(base.OptimizerState):
"""State for :class:`GridSearch`.
Attributes:
grid: Array of shape ``(n_total, n_params)`` containing all parameter
combinations, pre-truncated to a multiple of ``n_parallel``.
grid_idx: Current position in ``grid``; incremented by ``n_parallel``
after each call to ``update_state``.
"""
grid: jax.Array
grid_idx: int
[docs]
@dataclasses.dataclass
class GridSearch(base.Optimizer):
"""Exhaustive grid search over a discrete search space.
Iterates through every combination of the provided ``DiscreteSpace`` values
in order (or randomly if ``shuffle=True``). All spaces in the search space
must be :class:`~hyperoptax.spaces.DiscreteSpace`.
Attributes:
shuffle: If ``True``, randomise the traversal order during ``init``.
Pass an explicit ``key`` to ``init`` for reproducibility.
n_parallel: Number of grid points evaluated per iteration.
"""
shuffle: bool = False
n_parallel: int = 1
[docs]
@classmethod
def init(cls, space, key=None, **kwargs):
"""Initialise the grid search.
Args:
space: A pytree of :class:`~hyperoptax.spaces.DiscreteSpace` objects.
All leaves must be ``DiscreteSpace``; mixed spaces are not supported.
key: Optional PRNG key used when ``shuffle=True``. Falls back to
``PRNGKey(0)`` when ``None``.
**kwargs: Forwarded to :class:`GridSearch` constructor (e.g. ``n_parallel``,
``shuffle``).
Returns:
``(state, optimizer)`` tuple.
"""
is_discrete = jax.tree.map(
lambda x: isinstance(x, sp.DiscreteSpace),
space,
is_leaf=lambda x: isinstance(x, sp.Space),
)
if not all(jax.tree.leaves(is_discrete)):
raise ValueError("GridSearch requires all spaces to be DiscreteSpace.")
optimizer = cls(**kwargs)
leaves = jax.tree.leaves(space, is_leaf=lambda x: isinstance(x, sp.Space))
values_list = [jnp.array(leaf.values) for leaf in leaves]
# TODO: use indexes so that we don't generate the full grid.
grids = jnp.meshgrid(*values_list, indexing="ij")
# Flatten into (n_total, n_leaves) so grid[i] is the i-th param combination
grid = jnp.stack([g.ravel() for g in grids], axis=-1)
if optimizer.shuffle:
# key=None falls back to PRNGKey(0); pass key explicitly for reproducibility
key = key if key is not None else jax.random.PRNGKey(0)
grid = jax.random.permutation(key, grid)
n_usable = (len(grid) // optimizer.n_parallel) * optimizer.n_parallel
grid = grid[:n_usable]
state = GridSearchState(
space=space,
grid=grid,
grid_idx=0,
)
return state, optimizer
[docs]
def get_next_params(
self, state: GridSearchState, key, params=None, results=None
) -> PyTree:
"""Return the next ``n_parallel`` parameter combinations from the grid."""
# Only check eagerly; inside lax.scan grid_idx is an abstract tracer.
if not isinstance(state.grid_idx, jax.core.Tracer):
if int(state.grid_idx) + self.n_parallel > state.grid.shape[0]:
raise ValueError(
f"Not enough grid points remaining "
f"(grid_idx={int(state.grid_idx)}, n_parallel={self.n_parallel}, "
f"grid_size={state.grid.shape[0]})."
)
# Extract n_parallel rows; use dynamic slice for scan compatibility
rows = jax.lax.dynamic_slice_in_dim(
state.grid, state.grid_idx, self.n_parallel, axis=0
) # (n_parallel, n_leaves)
_, treedef = jax.tree.flatten(
state.space, is_leaf=lambda x: isinstance(x, sp.Space)
)
# Each leaf gets shape (n_parallel,)
return treedef.unflatten([rows[:, i] for i in range(treedef.num_leaves)])
[docs]
def update_state(
self, state: GridSearchState, key, results, params=None
) -> GridSearchState:
"""Advance the grid index by ``n_parallel``."""
return state.replace(grid_idx=state.grid_idx + self.n_parallel)