🚧 WORK IN PROGRESS - This documentation is currently under development 🚧
Hyperoptax Documentation#
Welcome to Hyperoptax - a lightweight toolbox for parallel hyperparameter optimization of pure JAX functions.
Hyperoptax provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across parameter spaces in parallel - all while staying in pure JAX.
Quick Start#
import jax
from hyperoptax import BayesianSearch, LogSpace, LinearSpace
def train_nn(key, params):
learning_rate = params["learning_rate"]
final_lr_pct = params["final_lr_pct"]
...
return val_loss # scalar, lower is better
search_space = {
"learning_rate": LogSpace(1e-5, 1e-1),
"final_lr_pct": LinearSpace(0.01, 0.5),
}
state, optimizer = BayesianSearch.init(
search_space,
n_max=100,
maximize=False,
)
state, (params_hist, results_hist) = optimizer.optimize(
state, jax.random.PRNGKey(0), train_nn
)
print(optimizer.best_params(state))