Exact GP Regression on Classification Labels¶
In this notebok, we demonstrate how one can convert classification problems into regression problems by performing fixed noise regression on the classification labels.
We follow the method of Dirichlet-based Gaussian Processes for Large-Scale Calibrated Classification who transform classification targets into regression ones by using an approximate likelihood.
[1]:
import math
import torch
import numpy as np
import gpytorch
from matplotlib import pyplot as plt
%matplotlib inline
%load_ext autoreload
%autoreload 2
[pyKeOps]: Warning, no cuda detected. Switching to cpu only.
Generate Data¶
Firs, we will generate 500 data points from a smooth underlying latent function and have the inputs be iid Gaussian. The decision boundaries are given by rounding a latent function:
\(f(x,y) = \sin(0.15 \pi u + (x + y)) + 1,\) where \(u \sim \text{Unif}(0,1).\) Then, \(y = \text{round}(f(x,y)).\)
[2]:
def gen_data(num_data, seed = 2019):
torch.random.manual_seed(seed)
x = torch.randn(num_data,1)
y = torch.randn(num_data,1)
u = torch.rand(1)
data_fn = lambda x, y: 1 * torch.sin(0.15 * u * 3.1415 * (x + y)) + 1
latent_fn = data_fn(x, y)
z = torch.round(latent_fn).long().squeeze()
return torch.cat((x,y),dim=1), z, data_fn
[3]:
train_x, train_y, genfn = gen_data(500)
[4]:
plt.scatter(train_x[:,0].numpy(), train_x[:,1].numpy(), c = train_y)
[4]:
<matplotlib.collections.PathCollection at 0x7fa6fde35090>

The below plots illustrate the decision boundary. We will predict the class logits and ultimately the predictions across \([-3, 3]^2\) for illustration; this region contains both interpolation and extrapolation.
[5]:
test_d1 = np.linspace(-3, 3, 20)
test_d2 = np.linspace(-3, 3, 20)
test_x_mat, test_y_mat = np.meshgrid(test_d1, test_d2)
test_x_mat, test_y_mat = torch.Tensor(test_x_mat), torch.Tensor(test_y_mat)
test_x = torch.cat((test_x_mat.view(-1,1), test_y_mat.view(-1,1)),dim=1)
test_labels = torch.round(genfn(test_x_mat, test_y_mat))
test_y = test_labels.view(-1)
[6]:
plt.contourf(test_x_mat.numpy(), test_y_mat.numpy(), test_labels.numpy())
[6]:
<matplotlib.contour.QuadContourSet at 0x7fa6fe09ed10>

Setting Up the Model¶
The Dirichlet GP model is an exact GP model with a couple of caveats. First, it uses a special likelihood: a DirichletClassificationLikelihood, and second, it is natively a multi-output model (for each data point, we need to predict num_classes
, \(C\), outputs) so we need to specify the batch shape for our mean and covariance functions.
The DirichletClassificationLikelhood is just a special type of FixedGaussianNoiseLikelihood
that does the required data transformations into a regression problem for us. Succinctly, we soft one hot encode the labels into \(C\) outputs so that \(\alpha_i = \alpha_\epsilon\) if \(y_c=0\) and \(\alpha_i = 1 + \alpha_\epsilon\) if \(y_c=1.\) Then, our variances are \(\sigma^2 = \log(1./\alpha + 1.)\) and our targets are \(\log(\alpha) - 0.5 \sigma^2.\)
That is, rather than a classification problem, we have a regression problem with \(C\) outputs. For more details, please see the original paper.
[7]:
from gpytorch.models import ExactGP
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel
[8]:
# We will use the simplest form of GP model, exact inference
class DirichletGPModel(ExactGP):
def __init__(self, train_x, train_y, likelihood, num_classes):
super(DirichletGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean(batch_shape=torch.Size((num_classes,)))
self.covar_module = ScaleKernel(
RBFKernel(batch_shape=torch.Size((num_classes,))),
batch_shape=torch.Size((num_classes,)),
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
# initialize likelihood and model
# we let the DirichletClassificationLikelihood compute the targets for us
likelihood = DirichletClassificationLikelihood(train_y, learn_additional_noise=True)
model = DirichletGPModel(train_x, likelihood.transformed_targets, likelihood, num_classes=likelihood.num_classes)
Now we train and fit the model as we would any other GPyTorch model.
[9]:
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iter = 2 if smoke_test else 50
# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iter):
# Zero gradients from previous iteration
optimizer.zero_grad()
# Output from model
output = model(train_x)
# Calc loss and backprop gradients
loss = -mll(output, train_y).sum()
loss.backward()
print('Iter %d/%d - Loss: %.3f lengthscale: %.3f noise: %.3f' % (
i + 1, training_iter, loss.item(),
model.covar_module.base_kernel.lengthscale.mean().item(),
model.likelihood.second_noise_covar.noise.mean().item()
))
optimizer.step()
Iter 1/50 - Loss: 4.777 lengthscale: 0.693 noise: 0.693
Iter 2/50 - Loss: 4.735 lengthscale: 0.744 noise: 0.644
Iter 3/50 - Loss: 4.695 lengthscale: 0.798 noise: 0.598
Iter 4/50 - Loss: 4.659 lengthscale: 0.853 noise: 0.555
Iter 5/50 - Loss: 4.625 lengthscale: 0.909 noise: 0.513
Iter 6/50 - Loss: 4.594 lengthscale: 0.966 noise: 0.475
Iter 7/50 - Loss: 4.565 lengthscale: 1.023 noise: 0.438
Iter 8/50 - Loss: 4.538 lengthscale: 1.080 noise: 0.404
Iter 9/50 - Loss: 4.513 lengthscale: 1.137 noise: 0.373
Iter 10/50 - Loss: 4.490 lengthscale: 1.194 noise: 0.344
Iter 11/50 - Loss: 4.469 lengthscale: 1.250 noise: 0.316
Iter 12/50 - Loss: 4.450 lengthscale: 1.305 noise: 0.291
Iter 13/50 - Loss: 4.432 lengthscale: 1.359 noise: 0.268
Iter 14/50 - Loss: 4.416 lengthscale: 1.413 noise: 0.247
Iter 15/50 - Loss: 4.401 lengthscale: 1.465 noise: 0.228
Iter 16/50 - Loss: 4.387 lengthscale: 1.516 noise: 0.210
Iter 17/50 - Loss: 4.375 lengthscale: 1.566 noise: 0.193
Iter 18/50 - Loss: 4.363 lengthscale: 1.614 noise: 0.179
Iter 19/50 - Loss: 4.352 lengthscale: 1.661 noise: 0.165
Iter 20/50 - Loss: 4.342 lengthscale: 1.707 noise: 0.153
Iter 21/50 - Loss: 4.332 lengthscale: 1.751 noise: 0.141
Iter 22/50 - Loss: 4.323 lengthscale: 1.794 noise: 0.131
Iter 23/50 - Loss: 4.315 lengthscale: 1.835 noise: 0.122
Iter 24/50 - Loss: 4.307 lengthscale: 1.875 noise: 0.113
Iter 25/50 - Loss: 4.300 lengthscale: 1.912 noise: 0.105
Iter 26/50 - Loss: 4.294 lengthscale: 1.948 noise: 0.098
Iter 27/50 - Loss: 4.288 lengthscale: 1.983 noise: 0.092
Iter 28/50 - Loss: 4.283 lengthscale: 2.016 noise: 0.086
Iter 29/50 - Loss: 4.278 lengthscale: 2.047 noise: 0.081
Iter 30/50 - Loss: 4.273 lengthscale: 2.077 noise: 0.076
Iter 31/50 - Loss: 4.269 lengthscale: 2.106 noise: 0.072
Iter 32/50 - Loss: 4.266 lengthscale: 2.133 noise: 0.068
Iter 33/50 - Loss: 4.262 lengthscale: 2.159 noise: 0.064
Iter 34/50 - Loss: 4.259 lengthscale: 2.184 noise: 0.060
Iter 35/50 - Loss: 4.257 lengthscale: 2.208 noise: 0.057
Iter 36/50 - Loss: 4.254 lengthscale: 2.231 noise: 0.054
Iter 37/50 - Loss: 4.252 lengthscale: 2.253 noise: 0.052
Iter 38/50 - Loss: 4.249 lengthscale: 2.275 noise: 0.049
Iter 39/50 - Loss: 4.247 lengthscale: 2.295 noise: 0.047
Iter 40/50 - Loss: 4.245 lengthscale: 2.315 noise: 0.045
Iter 41/50 - Loss: 4.244 lengthscale: 2.335 noise: 0.043
Iter 42/50 - Loss: 4.242 lengthscale: 2.353 noise: 0.041
Iter 43/50 - Loss: 4.240 lengthscale: 2.371 noise: 0.040
Iter 44/50 - Loss: 4.239 lengthscale: 2.389 noise: 0.038
Iter 45/50 - Loss: 4.238 lengthscale: 2.406 noise: 0.037
Iter 46/50 - Loss: 4.237 lengthscale: 2.422 noise: 0.035
Iter 47/50 - Loss: 4.235 lengthscale: 2.438 noise: 0.034
Iter 48/50 - Loss: 4.234 lengthscale: 2.454 noise: 0.033
Iter 49/50 - Loss: 4.233 lengthscale: 2.469 noise: 0.032
Iter 50/50 - Loss: 4.232 lengthscale: 2.483 noise: 0.031
Model Predictions¶
[10]:
model.eval()
likelihood.eval()
with gpytorch.settings.fast_pred_var(), torch.no_grad():
test_dist = model(test_x)
pred_means = test_dist.loc
We’ve predicted the logits for each class in the classification problem, and can clearly see that the logits for class 0 are highest in the bottom left, the logits for class 2 are highest in the top right, nad the logits for class 1 are highest in the middle.
[11]:
fig, ax = plt.subplots(1, 3, figsize = (15, 5))
for i in range(3):
im = ax[i].contourf(
test_x_mat.numpy(), test_y_mat.numpy(), pred_means[i].numpy().reshape((20,20))
)
fig.colorbar(im, ax=ax[i])
ax[i].set_title("Logits: Class " + str(i), fontsize = 20)

Unfortunately, we can’t get closed form estimates of the probabilities; however, we can approximate them with a lightweight sampling step using \(J\) samples from the posterior as:
Here, we draw \(256\) samples from the posterior.
[12]:
pred_samples = test_dist.sample(torch.Size((256,))).exp()
probabilities = (pred_samples / pred_samples.sum(-2, keepdim=True)).mean(0)
/Users/wesleymaddox/Documents/GitHub/wjm_gpytorch/gpytorch/utils/cholesky.py:46: NumericalWarning: A not p.d., added jitter of 1.0e-05 to the diagonal
warnings.warn(f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", NumericalWarning)
[13]:
fig, ax = plt.subplots(1, 3, figsize = (15, 5))
levels = np.linspace(0, 1.05, 20)
for i in range(3):
im = ax[i].contourf(
test_x_mat.numpy(), test_y_mat.numpy(), probabilities[i].numpy().reshape((20,20)), levels=levels
)
fig.colorbar(im, ax=ax[i])
ax[i].set_title("Probabilities: Class " + str(i), fontsize = 20)

Finally, we plot the decision boundary (on the right) and the true decision boundary on the left. They align pretty closely.
To get the decision boundary from our model, all we need to do is to compute the elementwise maximium at each test point.
[14]:
fig, ax = plt.subplots(1,2, figsize=(10, 5))
ax[0].contourf(test_x_mat.numpy(), test_y_mat.numpy(), test_labels.numpy())
ax[0].set_title('True Response', fontsize=20)
ax[1].contourf(test_x_mat.numpy(), test_y_mat.numpy(), pred_means.max(0)[1].reshape((20,20)))
ax[1].set_title('Estimated Response', fontsize=20)
[14]:
Text(0.5, 1.0, 'Estimated Response')

[ ]: