Source code for hyperoptax.kernels

from abc import ABC, abstractmethod

import jax
import jax.numpy as jnp


[docs] def cdist(x: jax.Array, y: jax.Array) -> jax.Array: """Pairwise Euclidean distance (``cdist``) between two 2-D arrays. Parameters ---------- x, y : jax.Array Arrays with shape ``(N, D)`` and ``(M, D)``, respectively. Returns ------- jax.Array A distance matrix of shape ``(N, M)``. """ # jax compatible cdist https://github.com/jax-ml/jax/discussions/15862 # Use double-where trick to avoid inf gradients at zero distance (sqrt'(0) = inf). d2 = jnp.sum((x[:, None] - y[None, :]) ** 2, -1) safe_d2 = jnp.where(d2 == 0, jnp.ones_like(d2), d2) return jnp.where(d2 == 0, jnp.zeros_like(d2), jnp.sqrt(safe_d2))
[docs] class BaseKernel(ABC): """Abstract base class for positive-definite kernels.""" @abstractmethod def __call__(self, x: jax.Array, y: jax.Array, length_scale=None) -> jax.Array: raise NotImplementedError
# TODO: add basic operations between kernels
[docs] class RBF(BaseKernel): """Radial basis function (RBF) / squared-exponential kernel."""
[docs] def __init__(self, length_scale: float = 1.0): self.length_scale = length_scale
def __call__(self, x: jax.Array, y: jax.Array, length_scale=None) -> jax.Array: ls = self.length_scale if length_scale is None else length_scale return jnp.exp(-(cdist(x, y) ** 2) / (2 * ls**2))
[docs] class Matern(BaseKernel): """Matern kernel family. Parameters ---------- length_scale : float, default = 1.0 Characteristic length scale. nu : float, default = 2.5 Controls smoothness (``nu`` ∈ {0.5, 1.5, 2.5, ∞}). """ _VALID_NU = {0.5, 1.5, 2.5, float("inf")}
[docs] def __init__(self, length_scale: float = 1.0, nu: float = 2.5): if nu not in self._VALID_NU: valid = sorted(v for v in self._VALID_NU if v != float("inf")) raise ValueError( f"Matern kernel with nu={nu} is not supported. " f"Choose from {valid} or inf." ) self.length_scale = length_scale self.nu = nu # controls smoothness of the kernel, lower is less smooth
def __call__(self, x: jax.Array, y: jax.Array, length_scale=None) -> jax.Array: ls = self.length_scale if length_scale is None else length_scale dists = cdist(x / ls, y / ls) if self.nu == 0.5: return jnp.exp(-dists) elif self.nu == 1.5: K = jnp.sqrt(3) * dists return (1 + K) * jnp.exp(-K) elif self.nu == 2.5: K = jnp.sqrt(5) * dists return (1 + K + K**2 / 3) * jnp.exp(-K) else: # nu == inf: RBF kernel return jnp.exp(-(dists**2) / 2)