-
Notifications
You must be signed in to change notification settings - Fork 33
Description
Hi, I've noticed that trying to draw samples from a conditioned GP sometimes returns nans if the kernel argument in condition() is not None using version 0.2.3. I've attached a simple example script that shows this using a GP with a Cosine + Matern-3/2 covariance function:
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
import numpy as np
import tinygp
from tinygp.helpers import JAXArray, dataclass
from matplotlib import pyplot as plt
@dataclass
class Kernel(tinygp.kernels.quasisep.Quasisep):
"""
Define Matern32 + Cosine kernel.
"""
ell: float
period: float
kernel1 = tinygp.kernels.quasisep.Matern32
kernel2 = tinygp.kernels.quasisep.Cosine
def kernel(self):
return self.kernel1(scale=self.ell, sigma=1.) + self.kernel2(scale=self.period, sigma=1.)
def coord_to_sortable(self, X: JAXArray) -> JAXArray:
return X
def design_matrix(self) -> JAXArray:
return self.kernel().design_matrix()
def observation_model(self, X: JAXArray) -> JAXArray:
return self.kernel().observation_model(X)
def stationary_covariance(self) -> JAXArray:
return self.kernel().stationary_covariance()
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
return self.kernel().transition_matrix(X1, X2)
def build_gp(kernel: tinygp.kernels.Kernel) -> tinygp.GaussianProcess:
"""
Create GP.
Parameters
----------
kernel : tinygp.kernels.Kernel
Covariance function.
Returns
-------
tinygp.GaussianProcess
GP.
"""
return tinygp.GaussianProcess(kernel, x_train, diag=yerr**2)
def f(x):
return np.sin(2*np.pi*x/3) + .1*(x-5)**2
# define training set
x_train = np.linspace(0, 10, 20)
y = f(x_train)
yerr = 1e-4
# define test points
x_test = np.linspace(0, 10, 100)
# get GP posterior
gp = build_gp(Kernel(1, 3))
cond_gp = gp.condition(y, X_test=x_test).gp
mu, std = cond_gp.mean, np.sqrt(cond_gp.variance)
# plot training set
fig, ax = plt.subplots(tight_layout=True, dpi=300)
ax.plot(x_test, f(x_test), "k-", label="truth")
ax.plot(x_train, y, "kx", label="training set")
# plot GP posterior
ax.plot(x_test, mu, "r--", label="GP predictive mean")
ax.fill_between(x_test, mu + std, mu - std, color="grey", alpha=.5, label="$1 \\sigma$")
ax.legend()
ax.set_xlabel("x")
ax.set_ylabel("y")
# get GP posterior for cosine covariance function only
cond_gp_2 = gp.condition(y, X_test=x_test, kernel=tinygp.kernels.quasisep.Cosine(scale=3., sigma=1.)).gp
mu_2, std_2 = cond_gp_2.mean, np.sqrt(cond_gp_2.variance)
# plot GP posterior with only cosine covariance function
fig, ax = plt.subplots(tight_layout=True, dpi=300)
ax.plot(x_test, cond_gp_2.mean, "r-")
ax.fill_between(x_test, cond_gp_2.mean + np.sqrt(cond_gp_2.variance), cond_gp_2.mean - np.sqrt(cond_gp_2.variance), color="grey", alpha=.5)
ax.set_xlabel("x")
ax.set_ylabel("y")
# print samples from both posteriors
print(cond_gp.sample(jax.random.PRNGKey(1), shape=(1,))) # looks fine
print(cond_gp_2.sample(jax.random.PRNGKey(1), shape=(1,))) # all nans
plt.show()
The above script outputs:
[[ 2.50012735 2.59014386 2.85667705 3.00609789 2.9937565 2.93048325
2.73319407 2.65074389 2.54421483 2.58738723 2.46443243 2.15823879
1.92819261 1.75723368 1.47599347 1.15749331 0.89880361 0.65716534
0.51971902 0.33985576 0.07506202 -0.14893402 -0.19965154 -0.26591499
-0.15513785 -0.21372362 -0.14032008 -0.03931875 0.01916854 0.19023129
0.32316472 0.56286932 0.89426615 1.14260549 1.35545179 1.45583628
1.29725162 1.09479702 0.94947624 0.86849479 0.86490132 0.73593167
0.58713621 0.47952291 0.32324959 0.07787159 -0.21878404 -0.49717701
-0.73023776 -1.05082951 -1.20287238 -1.15972825 -1.01116336 -0.93524627
-0.8279752 -0.78530024 -0.69443229 -0.47788543 -0.15773303 0.17291732
0.35685938 0.55664173 0.7749113 0.81277362 0.84833338 0.94217954
1.08111466 1.25301934 1.32023165 1.33545543 1.28766961 1.19355587
1.03967649 0.81478082 0.55663554 0.34715331 0.09951966 0.02361216
0.08205901 0.19257153 0.16334865 0.09058106 0.14834335 0.22836011
0.33910977 0.44856155 0.5423875 0.85391957 1.29340175 1.5078933
1.65621456 1.95303683 2.16557529 2.51817753 2.92529502 3.24003017
3.40978806 3.40567284 3.4124543 3.3659661 ]]
for the sample from the GP when the kernel argument of .condition() is unspecified, and:
[[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan]]
for the sample from the GP when the kernel argument is specified as the Cosine component of Kernel(). As a sanity check, I plot this distribution and I see no reason why sampling from it should be returning nans:
Bizarrely, if I set the kernel argument to the Matern-3/2 component, sampling from the conditioned GP does not return nans. Any insight as to what might be going on here? Am I overlooking an obvious mistake in my code?