Source code for hyperoptax.random

import dataclasses

import jax
import jax.numpy as jnp
from jaxtyping import PyTree

from hyperoptax import base, utils
from hyperoptax import spaces as sp


[docs] @dataclasses.dataclass class RandomSearch(base.Optimizer): """Stateless random search — samples each space independently each iteration. No model is fitted and no history is maintained, so this is the cheapest optimizer and useful as a strong baseline. Attributes: n_parallel: Number of random configurations evaluated per iteration. """ n_parallel: int = 1
[docs] @classmethod def init(cls, space, **kwargs): return base.OptimizerState(space=space), cls(**kwargs)
[docs] def get_next_params( self, state: base.OptimizerState, key: jax.random.PRNGKey, params=None, results=None, ) -> PyTree: """Sample ``n_parallel`` independent configurations from the search space.""" def sample_once(k): subkeys = utils.make_key_tree(state.space, k) sample = jax.tree.map( lambda x, sk: x.sample(sk), state.space, subkeys, is_leaf=lambda x: isinstance(x, sp.Space), ) # Squeeze (1,) per-leaf values to scalars for stacking return jax.tree.map(lambda leaf: leaf.squeeze(), sample) keys = jax.random.split(key, self.n_parallel) samples = [sample_once(k) for k in keys] return jax.tree.map(lambda *leaves: jnp.stack(leaves), *samples)
[docs] def update_state( self, state: base.OptimizerState, key: jax.random.PRNGKey, results: jax.Array, params=None, ) -> base.OptimizerState: """ RandomSearch is memoryless, no state to update. """ return state