Skip to content

Sampling from a conditioned GP sometimes returns nans if the kernel argument of .condition() is not None #175

@zairving

Description

@zairving

Hi, I've noticed that trying to draw samples from a conditioned GP sometimes returns nans if the kernel argument in condition() is not None using version 0.2.3. I've attached a simple example script that shows this using a GP with a Cosine + Matern-3/2 covariance function:

import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import numpy as np
import tinygp
from tinygp.helpers import JAXArray, dataclass
from matplotlib import pyplot as plt

@dataclass
class Kernel(tinygp.kernels.quasisep.Quasisep):
    """
    Define Matern32 + Cosine kernel.
    """
    
    ell: float
    period: float
    
    kernel1 = tinygp.kernels.quasisep.Matern32
    kernel2 = tinygp.kernels.quasisep.Cosine
    
    def kernel(self):
        return self.kernel1(scale=self.ell, sigma=1.) + self.kernel2(scale=self.period, sigma=1.)
    
    def coord_to_sortable(self, X: JAXArray) -> JAXArray:
        return X
    
    def design_matrix(self) -> JAXArray:
        return self.kernel().design_matrix()
    
    def observation_model(self, X: JAXArray) -> JAXArray:
        return self.kernel().observation_model(X)
    
    def stationary_covariance(self) -> JAXArray:
        return self.kernel().stationary_covariance()
    
    def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
        return self.kernel().transition_matrix(X1, X2)

def build_gp(kernel: tinygp.kernels.Kernel) -> tinygp.GaussianProcess:
    """
    Create GP.
    
    Parameters
    ----------
    kernel : tinygp.kernels.Kernel
        Covariance function.

    Returns
    -------
    tinygp.GaussianProcess
        GP.
    """
    
    return tinygp.GaussianProcess(kernel, x_train, diag=yerr**2)


def f(x):
    return np.sin(2*np.pi*x/3) + .1*(x-5)**2


# define training set
x_train = np.linspace(0, 10, 20)
y = f(x_train)
yerr = 1e-4

# define test points
x_test = np.linspace(0, 10, 100)

# get GP posterior
gp = build_gp(Kernel(1, 3))
cond_gp = gp.condition(y, X_test=x_test).gp
mu, std = cond_gp.mean, np.sqrt(cond_gp.variance)

# plot training set
fig, ax = plt.subplots(tight_layout=True, dpi=300)
ax.plot(x_test, f(x_test), "k-", label="truth")
ax.plot(x_train, y, "kx", label="training set")

# plot GP posterior
ax.plot(x_test, mu, "r--", label="GP predictive mean")
ax.fill_between(x_test, mu + std, mu - std, color="grey", alpha=.5, label="$1 \\sigma$")

ax.legend()
ax.set_xlabel("x")
ax.set_ylabel("y")

# get GP posterior for cosine covariance function only
cond_gp_2 = gp.condition(y, X_test=x_test, kernel=tinygp.kernels.quasisep.Cosine(scale=3., sigma=1.)).gp
mu_2, std_2 = cond_gp_2.mean, np.sqrt(cond_gp_2.variance)

# plot GP posterior with only cosine covariance function
fig, ax = plt.subplots(tight_layout=True, dpi=300)
ax.plot(x_test, cond_gp_2.mean, "r-")
ax.fill_between(x_test, cond_gp_2.mean + np.sqrt(cond_gp_2.variance), cond_gp_2.mean - np.sqrt(cond_gp_2.variance), color="grey", alpha=.5)

ax.set_xlabel("x")
ax.set_ylabel("y")

# print samples from both posteriors
print(cond_gp.sample(jax.random.PRNGKey(1), shape=(1,)))  # looks fine
print(cond_gp_2.sample(jax.random.PRNGKey(1), shape=(1,)))  # all nans

plt.show()

The above script outputs:

[[ 2.50012735  2.59014386  2.85667705  3.00609789  2.9937565   2.93048325
   2.73319407  2.65074389  2.54421483  2.58738723  2.46443243  2.15823879
   1.92819261  1.75723368  1.47599347  1.15749331  0.89880361  0.65716534
   0.51971902  0.33985576  0.07506202 -0.14893402 -0.19965154 -0.26591499
  -0.15513785 -0.21372362 -0.14032008 -0.03931875  0.01916854  0.19023129
   0.32316472  0.56286932  0.89426615  1.14260549  1.35545179  1.45583628
   1.29725162  1.09479702  0.94947624  0.86849479  0.86490132  0.73593167
   0.58713621  0.47952291  0.32324959  0.07787159 -0.21878404 -0.49717701
  -0.73023776 -1.05082951 -1.20287238 -1.15972825 -1.01116336 -0.93524627
  -0.8279752  -0.78530024 -0.69443229 -0.47788543 -0.15773303  0.17291732
   0.35685938  0.55664173  0.7749113   0.81277362  0.84833338  0.94217954
   1.08111466  1.25301934  1.32023165  1.33545543  1.28766961  1.19355587
   1.03967649  0.81478082  0.55663554  0.34715331  0.09951966  0.02361216
   0.08205901  0.19257153  0.16334865  0.09058106  0.14834335  0.22836011
   0.33910977  0.44856155  0.5423875   0.85391957  1.29340175  1.5078933
   1.65621456  1.95303683  2.16557529  2.51817753  2.92529502  3.24003017
   3.40978806  3.40567284  3.4124543   3.3659661 ]]

for the sample from the GP when the kernel argument of .condition() is unspecified, and:

[[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan nan nan nan nan nan nan nan nan]]

for the sample from the GP when the kernel argument is specified as the Cosine component of Kernel(). As a sanity check, I plot this distribution and I see no reason why sampling from it should be returning nans: cosine_gp Bizarrely, if I set the kernel argument to the Matern-3/2 component, sampling from the conditioned GP does not return nans. Any insight as to what might be going on here? Am I overlooking an obvious mistake in my code?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions