import dataclasses
import inspect
import warnings
from dataclasses import dataclass
from typing import Callable
import jax
import jax.numpy as jnp
from jaxtyping import PyTree
def _validate_func(func):
try:
sig = inspect.signature(func)
positional = [
p
for p in sig.parameters.values()
if p.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
if len(positional) < 2:
raise TypeError(
f"func must have signature fn(key, config) — "
f"received a function with {len(positional)} positional parameter(s). "
"Did you forget the key argument?"
)
except (ValueError, TypeError) as e:
if "func must have" in str(e):
raise
warnings.warn(
"Can't introspect function signature - ensure that the "
"function has a signature fn(key, config)."
)
return
[docs]
@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class OptimizerState:
"""Base optimizer state — a JAX pytree holding the search space definition."""
space: PyTree
[docs]
def replace(self, **kwargs) -> "OptimizerState":
return dataclasses.replace(self, **kwargs)
[docs]
class Optimizer:
n_parallel: int = 1
[docs]
@classmethod
def init(cls, space, **kwargs) -> OptimizerState:
return OptimizerState(space=space), cls()
[docs]
def optimize(
self,
state: OptimizerState,
key: jax.Array, # () PRNG key
func: Callable, # (key, config) -> () scalar result
n_iterations: int,
) -> tuple[OptimizerState, tuple[PyTree, jax.Array]]:
"""
High Level API for optimizing a function over a space.
Not recommended if you want to do fancy things
with parallel computation.
``func`` must return a scalar (``()`` shape). If your function returns
shape ``(1,)``, call ``.squeeze()`` inside ``func`` before returning.
"""
_validate_func(func)
params_hist, results_hist = [], []
params, results = None, None
for _ in range(n_iterations):
key, key_get, key_funcs, key_update = jax.random.split(key, 4)
params = self.get_next_params(state, key_get, params, results)
# params: pytree, each leaf shape (n_parallel, ...)
# func_keys: (n_parallel, 2)
# batch_results: (n_parallel,)
func_keys = jax.random.split(key_funcs, self.n_parallel)
batch_results = jax.vmap(func)(func_keys, params) # (n_parallel,)
state = self.update_state(state, key_update, batch_results, params)
params_hist.append(params)
results_hist.append(batch_results)
return state, (params_hist, results_hist)
[docs]
def optimize_scan(
self,
state: OptimizerState,
key: jax.Array, # () PRNG key
func: Callable, # (key, config) -> () scalar result
n_iterations: int,
) -> tuple[OptimizerState, tuple[PyTree, jax.Array]]:
"""
Like optimize, but uses jax.lax.scan for the inner loop.
Requires func to be JAX-traceable (jit-compilable). Returns stacked
arrays instead of lists: params_hist is a pytree where each leaf has
shape (n_iterations, n_parallel, ...), and results_hist has shape
(n_iterations, n_parallel).
``func`` must return a scalar (``()`` shape). If your function returns
shape ``(1,)``, call ``.squeeze()`` inside ``func`` before returning.
"""
_validate_func(func)
# Run one step outside scan to determine pytree structure and result shape.
key, key_get, key_funcs, key_update = jax.random.split(key, 4)
params0 = self.get_next_params(state, key_get, None, None)
# params0: pytree, each leaf shape (n_parallel, ...)
# results0: (n_parallel,)
func_keys0 = jax.random.split(key_funcs, self.n_parallel)
results0 = jax.vmap(func)(func_keys0, params0) # (n_parallel,)
state = self.update_state(state, key_update, results0, params0)
# Save step-0 outputs before scan overwrites these names via carry.
first_params, first_results = params0, results0
def step(carry, _):
state, key, params, results = carry
key, key_get, key_funcs, key_update = jax.random.split(key, 4)
params = self.get_next_params(state, key_get, params, results)
func_keys = jax.random.split(key_funcs, self.n_parallel)
batch_results = jax.vmap(func)(func_keys, params) # (n_parallel,)
state = self.update_state(state, key_update, batch_results, params)
return (state, key, params, batch_results), (params, batch_results)
(final_state, _, _, _), (params_hist, results_hist) = jax.lax.scan(
step,
(state, key, params0, results0),
None,
length=n_iterations - 1,
)
# Prepend step 0 so the output has n_iterations total entries.
params_hist = jax.tree.map(
lambda first, rest: jnp.concatenate([first[None], rest]),
first_params,
params_hist,
)
results_hist = jnp.concatenate([first_results[None], results_hist])
return final_state, (params_hist, results_hist)
[docs]
def update_state(
self,
state: OptimizerState,
key: jax.random.PRNGKey,
results: jax.Array,
params: PyTree | None = None,
) -> OptimizerState:
"""
Updates the optimizer state based on the results of the function.
"""
raise NotImplementedError
[docs]
def get_next_params(
self,
state: OptimizerState,
key: jax.random.PRNGKey,
params: PyTree | None = None,
results: jax.Array = None,
) -> PyTree:
"""
Gets the next parameters to sample from the space.
Returns a batched pytree where every leaf has shape (n_parallel, ...).
params and results are the previous iteration's values (None on first call).
"""
raise NotImplementedError