import jax
import jax.numpy as jnp
from jax.scipy.stats import norm
[docs]
class BaseAcquisition:
"""Base class for acquisition functions."""
def __call__(self, mean: jax.Array, std: jax.Array, y_max=None):
"""
Compute the acquisition value for a given mean and standard deviation.
Args:
mean (N,): The mean of the Gaussian process.
std (N,): The standard deviation of the Gaussian process.
y_max: Optional pre-computed reference value (e.g. best observed mean).
Used by EI/PI to ensure consistency when evaluating a single point.
Returns:
(N,): The acquisition value for the given mean and standard deviation.
"""
raise NotImplementedError
@staticmethod
def _resolve_y_max(mean: jax.Array, y_max) -> jax.Array:
"""Return y_max if provided, else fall back to max(mean).
Note: the fallback is only appropriate for standalone/exploratory use.
Always pass y_max explicitly when observed data is available.
"""
return jnp.max(mean) if y_max is None else y_max
def _sort_acq_vals(self, mean: jax.Array, std: jax.Array, seen_mask: jax.Array):
acq_vals = self(mean, std) # shape (N,)
masked_acq = jnp.where(seen_mask, -jnp.inf, acq_vals)
return jnp.argsort(masked_acq)
[docs]
def get_argmax(
self, mean: jax.Array, std: jax.Array, seen_mask: jax.Array, n_points: int = 1
):
return self._sort_acq_vals(mean, std, seen_mask)[-n_points:]
[docs]
class UCB(BaseAcquisition):
"""Upper Confidence Bound acquisition function."""
[docs]
def __init__(self, kappa: float = 2.0):
self.kappa = kappa
def __call__(self, mean: jax.Array, std: jax.Array, y_max=None):
return mean + self.kappa * std
[docs]
class EI(BaseAcquisition):
"""Expected Improvement acquisition function."""
[docs]
def __init__(self, xi: float = 0.01):
self.xi = xi
def __call__(self, mean: jax.Array, std: jax.Array, y_max=None):
_y_max = self._resolve_y_max(mean, y_max)
a = mean - self.xi - _y_max
z = a / std
return a * norm.cdf(z) + std * norm.pdf(z)
[docs]
class PI(BaseAcquisition):
"""Probability of Improvement acquisition function."""
[docs]
def __init__(self, xi: float = 0.01):
self.xi = xi
def __call__(self, mean: jax.Array, std: jax.Array, y_max=None):
_y_max = self._resolve_y_max(mean, y_max)
z = (mean - self.xi - _y_max) / std
return norm.cdf(z)
[docs]
class BaseHallucination:
"""Base class for Kriging Believer hallucination strategies.
Any callable with signature ``(mean, std, key, y_max) -> scalar`` can be
used as a hallucination strategy — subclassing is optional.
"""
[docs]
class MeanHallucination(BaseHallucination):
"""Classical Kriging Believer: hallucinate with GP posterior mean."""
def __call__(self, mean, std, key, y_max):
return mean[0]
[docs]
class SampleHallucination(BaseHallucination):
"""Randomized Kriging Believer (RKB): hallucinate with a posterior sample.
arXiv 2603.01470.
"""
def __call__(self, mean, std, key, y_max):
return mean[0] + std[0] * jax.random.normal(key)
[docs]
class UCBHallucination(BaseHallucination):
"""Optimistic hallucination: mean + kappa * std."""
[docs]
def __init__(self, kappa: float = 2.0):
self.kappa = kappa
def __call__(self, mean, std, key, y_max):
return mean[0] + self.kappa * std[0]
[docs]
class ConstantHallucination(BaseHallucination):
"""Ginsbourger et al. 2010: hallucinate with y_max or a fixed constant.
If value is None, uses the current best observed value (y_max).
Otherwise uses the fixed value regardless of observations.
"""
[docs]
def __init__(self, value: float | None = None):
self.value = value
def __call__(self, mean, std, key, y_max):
if self.value is None:
return y_max
return jnp.asarray(self.value, dtype=mean.dtype)