Source code for hyperoptax.acquisition

import jax
import jax.numpy as jnp
from jax.scipy.stats import norm


[docs] class BaseAcquisition: """Base class for acquisition functions.""" def __call__(self, mean: jax.Array, std: jax.Array, y_max=None): """ Compute the acquisition value for a given mean and standard deviation. Args: mean (N,): The mean of the Gaussian process. std (N,): The standard deviation of the Gaussian process. y_max: Optional pre-computed reference value (e.g. best observed mean). Used by EI/PI to ensure consistency when evaluating a single point. Returns: (N,): The acquisition value for the given mean and standard deviation. """ raise NotImplementedError @staticmethod def _resolve_y_max(mean: jax.Array, y_max) -> jax.Array: """Return y_max if provided, else fall back to max(mean). Note: the fallback is only appropriate for standalone/exploratory use. Always pass y_max explicitly when observed data is available. """ return jnp.max(mean) if y_max is None else y_max def _sort_acq_vals(self, mean: jax.Array, std: jax.Array, seen_mask: jax.Array): acq_vals = self(mean, std) # shape (N,) masked_acq = jnp.where(seen_mask, -jnp.inf, acq_vals) return jnp.argsort(masked_acq)
[docs] def get_argmax( self, mean: jax.Array, std: jax.Array, seen_mask: jax.Array, n_points: int = 1 ): return self._sort_acq_vals(mean, std, seen_mask)[-n_points:]
[docs] class UCB(BaseAcquisition): """Upper Confidence Bound acquisition function."""
[docs] def __init__(self, kappa: float = 2.0): self.kappa = kappa
def __call__(self, mean: jax.Array, std: jax.Array, y_max=None): return mean + self.kappa * std
[docs] class EI(BaseAcquisition): """Expected Improvement acquisition function."""
[docs] def __init__(self, xi: float = 0.01): self.xi = xi
def __call__(self, mean: jax.Array, std: jax.Array, y_max=None): _y_max = self._resolve_y_max(mean, y_max) a = mean - self.xi - _y_max z = a / std return a * norm.cdf(z) + std * norm.pdf(z)
[docs] class PI(BaseAcquisition): """Probability of Improvement acquisition function."""
[docs] def __init__(self, xi: float = 0.01): self.xi = xi
def __call__(self, mean: jax.Array, std: jax.Array, y_max=None): _y_max = self._resolve_y_max(mean, y_max) z = (mean - self.xi - _y_max) / std return norm.cdf(z)
[docs] class BaseHallucination: """Base class for Kriging Believer hallucination strategies. Any callable with signature ``(mean, std, key, y_max) -> scalar`` can be used as a hallucination strategy — subclassing is optional. """
[docs] class MeanHallucination(BaseHallucination): """Classical Kriging Believer: hallucinate with GP posterior mean.""" def __call__(self, mean, std, key, y_max): return mean[0]
[docs] class SampleHallucination(BaseHallucination): """Randomized Kriging Believer (RKB): hallucinate with a posterior sample. arXiv 2603.01470. """ def __call__(self, mean, std, key, y_max): return mean[0] + std[0] * jax.random.normal(key)
[docs] class UCBHallucination(BaseHallucination): """Optimistic hallucination: mean + kappa * std."""
[docs] def __init__(self, kappa: float = 2.0): self.kappa = kappa
def __call__(self, mean, std, key, y_max): return mean[0] + self.kappa * std[0]
[docs] class ConstantHallucination(BaseHallucination): """Ginsbourger et al. 2010: hallucinate with y_max or a fixed constant. If value is None, uses the current best observed value (y_max). Otherwise uses the fixed value regardless of observations. """
[docs] def __init__(self, value: float | None = None): self.value = value
def __call__(self, mean, std, key, y_max): if self.value is None: return y_max return jnp.asarray(self.value, dtype=mean.dtype)