{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# GP Regression with a Spectral Mixture Kernel\n", "\n", "## Introduction\n", "\n", "This example shows how to use a `SpectralMixtureKernel` module on an `ExactGP` model. This module is designed for\n", "\n", "- When you want to use exact inference (e.g. for regression)\n", "- When you want to use a more sophisticated kernel than RBF\n", "\n", "The Spectral Mixture (SM) kernel was invented and discussed in [Wilson et al., 2013](https://arxiv.org/pdf/1302.4245.pdf)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import math\n", "import torch\n", "import gpytorch\n", "from matplotlib import pyplot as plt\n", "\n", "%matplotlib inline\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the next cell, we set up the training data for this example. We'll be using 15 regularly spaced points on [0,1] which we evaluate the function on and add Gaussian noise to get the training labels." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "train_x = torch.linspace(0, 1, 15)\n", "train_y = torch.sin(train_x * (2 * math.pi))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set up the model\n", "\n", "The model should be very similar to the `ExactGP` model in the [simple regression example](./Simple_GP_Regression.ipynb).\n", "\n", "The only difference is here, we're using a more complex kernel (the `SpectralMixtureKernel`). This kernel requires careful initialization to work properly. To that end, in the model `__init__` function, we call\n", "\n", "```\n", "self.covar_module = gpytorch.kernels.SpectralMixtureKernel(n_mixtures=4)\n", "self.covar_module.initialize_from_data(train_x, train_y)\n", "```\n", "\n", "This ensures that, when we perform optimization to learn kernel hyperparameters, we will be starting from a reasonable initialization." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class SpectralMixtureGPModel(gpytorch.models.ExactGP):\n", " def __init__(self, train_x, train_y, likelihood):\n", " super(SpectralMixtureGPModel, self).__init__(train_x, train_y, likelihood)\n", " self.mean_module = gpytorch.means.ConstantMean()\n", " self.covar_module = gpytorch.kernels.SpectralMixtureKernel(num_mixtures=4)\n", " self.covar_module.initialize_from_data(train_x, train_y)\n", "\n", " def forward(self,x):\n", " mean_x = self.mean_module(x)\n", " covar_x = self.covar_module(x)\n", " return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n", "\n", " \n", "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n", "model = SpectralMixtureGPModel(train_x, train_y, likelihood)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the next cell, we handle using Type-II MLE to train the hyperparameters of the Gaussian process.\n", "The spectral mixture kernel's hyperparameters start from what was specified in `initialize_from_data`.\n", "\n", "See the [simple regression example](./Simple_GP_Regression.ipynb) for more info on this step." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter 1/100 - Loss: 1.281\n", "Iter 2/100 - Loss: 1.258\n", "Iter 3/100 - Loss: 1.232\n", "Iter 4/100 - Loss: 1.212\n", "Iter 5/100 - Loss: 1.192\n", "Iter 6/100 - Loss: 1.172\n", "Iter 7/100 - Loss: 1.156\n", "Iter 8/100 - Loss: 1.143\n", "Iter 9/100 - Loss: 1.131\n", "Iter 10/100 - Loss: 1.121\n", "Iter 11/100 - Loss: 1.114\n", "Iter 12/100 - Loss: 1.110\n", "Iter 13/100 - Loss: 1.106\n", "Iter 14/100 - Loss: 1.103\n", "Iter 15/100 - Loss: 1.101\n", "Iter 16/100 - Loss: 1.098\n", "Iter 17/100 - Loss: 1.094\n", "Iter 18/100 - Loss: 1.089\n", "Iter 19/100 - Loss: 1.085\n", "Iter 20/100 - Loss: 1.080\n", "Iter 21/100 - Loss: 1.076\n", "Iter 22/100 - Loss: 1.072\n", "Iter 23/100 - Loss: 1.069\n", "Iter 24/100 - Loss: 1.067\n", "Iter 25/100 - Loss: 1.065\n", "Iter 26/100 - Loss: 1.063\n", "Iter 27/100 - Loss: 1.061\n", "Iter 28/100 - Loss: 1.060\n", "Iter 29/100 - Loss: 1.057\n", "Iter 30/100 - Loss: 1.054\n", "Iter 31/100 - Loss: 1.051\n", "Iter 32/100 - Loss: 1.048\n", "Iter 33/100 - Loss: 1.044\n", "Iter 34/100 - Loss: 1.039\n", "Iter 35/100 - Loss: 1.035\n", "Iter 36/100 - Loss: 1.029\n", "Iter 37/100 - Loss: 1.023\n", "Iter 38/100 - Loss: 1.015\n", "Iter 39/100 - Loss: 1.006\n", "Iter 40/100 - Loss: 0.995\n", "Iter 41/100 - Loss: 0.981\n", "Iter 42/100 - Loss: 0.965\n", "Iter 43/100 - Loss: 0.946\n", "Iter 44/100 - Loss: 0.924\n", "Iter 45/100 - Loss: 0.898\n", "Iter 46/100 - Loss: 0.870\n", "Iter 47/100 - Loss: 0.839\n", "Iter 48/100 - Loss: 0.806\n", "Iter 49/100 - Loss: 0.770\n", "Iter 50/100 - Loss: 0.731\n", "Iter 51/100 - Loss: 0.686\n", "Iter 52/100 - Loss: 0.637\n", "Iter 53/100 - Loss: 0.583\n", "Iter 54/100 - Loss: 0.523\n", "Iter 55/100 - Loss: 0.460\n", "Iter 56/100 - Loss: 0.394\n", "Iter 57/100 - Loss: 0.327\n", "Iter 58/100 - Loss: 0.260\n", "Iter 59/100 - Loss: 0.194\n", "Iter 60/100 - Loss: 0.133\n", "Iter 61/100 - Loss: 0.078\n", "Iter 62/100 - Loss: 0.032\n", "Iter 63/100 - Loss: -0.005\n", "Iter 64/100 - Loss: -0.040\n", "Iter 65/100 - Loss: -0.086\n", "Iter 66/100 - Loss: -0.144\n", "Iter 67/100 - Loss: -0.206\n", "Iter 68/100 - Loss: -0.264\n", "Iter 69/100 - Loss: -0.313\n", "Iter 70/100 - Loss: -0.354\n", "Iter 71/100 - Loss: -0.392\n", "Iter 72/100 - Loss: -0.430\n", "Iter 73/100 - Loss: -0.472\n", "Iter 74/100 - Loss: -0.518\n", "Iter 75/100 - Loss: -0.567\n", "Iter 76/100 - Loss: -0.616\n", "Iter 77/100 - Loss: -0.662\n", "Iter 78/100 - Loss: -0.703\n", "Iter 79/100 - Loss: -0.740\n", "Iter 80/100 - Loss: -0.775\n", "Iter 81/100 - Loss: -0.816\n", "Iter 82/100 - Loss: -0.860\n", "Iter 83/100 - Loss: -0.902\n", "Iter 84/100 - Loss: -0.940\n", "Iter 85/100 - Loss: -0.975\n", "Iter 86/100 - Loss: -1.008\n", "Iter 87/100 - Loss: -1.042\n", "Iter 88/100 - Loss: -1.077\n", "Iter 89/100 - Loss: -1.112\n", "Iter 90/100 - Loss: -1.145\n", "Iter 91/100 - Loss: -1.173\n", "Iter 92/100 - Loss: -1.199\n", "Iter 93/100 - Loss: -1.226\n", "Iter 94/100 - Loss: -1.254\n", "Iter 95/100 - Loss: -1.279\n", "Iter 96/100 - Loss: -1.302\n", "Iter 97/100 - Loss: -1.321\n", "Iter 98/100 - Loss: -1.341\n", "Iter 99/100 - Loss: -1.361\n", "Iter 100/100 - Loss: -1.379\n" ] } ], "source": [ "# this is for running the notebook in our testing framework\n", "import os\n", "smoke_test = ('CI' in os.environ)\n", "training_iter = 2 if smoke_test else 100\n", "\n", "# Find optimal model hyperparameters\n", "model.train()\n", "likelihood.train()\n", "\n", "# Use the adam optimizer\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n", "\n", "# \"Loss\" for GPs - the marginal log likelihood\n", "mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n", "\n", "for i in range(training_iter):\n", " optimizer.zero_grad()\n", " output = model(train_x)\n", " loss = -mll(output, train_y)\n", " loss.backward()\n", " print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iter, loss.item()))\n", " optimizer.step()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we've learned good hyperparameters, it's time to use our model to make predictions. The spectral mixture kernel is especially good at extrapolation. To that end, we'll see how well the model extrapolates past the interval `[0, 1]`.\n", "\n", "In the next cell, we plot the mean and confidence region of the Gaussian process model. The `confidence_region` method is a helper method that returns 2 standard deviations above and below the mean." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQMAAADGCAYAAADWg+V4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3hUVfrA8e+ZSZt00kggdJHeew29qoCNRVAUFUHdn6zrrrCWVVERl2XVXQRBbCurIgIqqDSRHnovUkILpEBIAulTzu+PSUKAhCQzd1pyPs8zTzKTe899CZN3zj1VSClRFEXRuToARVHcg0oGiqIAKhkoilJIJQNFUQCVDBRFKaSSgaIogAbJQAjhJ4TYIYTYL4Q4LIR4XYvAFEVxLmHvOAMhhAACpJRZQghvYDPwnJQyXosAFUVxDi97C5DWbJJV+NS78KFGMimKh9GkzUAIoRdC7ANSgTVSyu1alKsoivPYXTMAkFKagbZCiFBgmRCipZTyUMljhBATgYkAAQEBHZo2barFpasFCVgsErNFotMJvHTC1SEpHmr37t2XpZSRpf3M7jaDWwoU4u9AtpRyVlnHdOzYUe7atUvT61ZF+89nsC0hjdwC8w2vh/p70ygykDuiAokJ8cPabKMo5RNC7JZSdiztZ3bXDIQQkYBRSpkhhDAAA4CZ9pZbnUkp2XTiMrvPppf684wcI7vPprP7bDo1/L1pX68GzWOC8dKrnmLFdlrcJsQAnwsh9FjbIBZLKVdoUG61ZDJb+OVwMidSsso/GEjPMbLuaCrbTqXROjaUNnVC8PfR5O5PqWa06E04ALTTIJZqz2i2sHRPIhcz8ip9bk6BmfiENHaeuULjqEBa1g6hTpi/A6JUqir1EeJG9p7LsCkRlGS2SI4lX+NY8jXCA31oGh1MvXB/ooJ8bWpbMBqNJCYmkpdnX1yKc/n5+REbG4u3t3eFz1HJwE3kGc3sOntF0zLTsgrYcvIyW06Cv4+eOmH+xNYwEBbgQ1iAT4VuJxITEwkKCqJ+/fqqodJDSClJS0sjMTGRBg0aVPg8lQzcxK4z6eQbLQ4rP6fAzO/J1/g9+Vrxa37eekL9vfH30ePnrcfgbf3qpbd2X+p1AnN2DtG165BvsnBDKiglL6hU4To+Xvri74UQhIeHc+nSpUqVoZKBG8jON7HvfOk9B46UZzSTnGm+7THdw81czTM5KSLFVlFBuhtqbrbU4lRflBvYcfoKRrMawV2WixcSGT/mAbq1a0mXNs15+cU/U1BQAMDXi/7LtBemuDjCWzWsFVHq67VqBNC/Zxd6d2lPvx6dmfef97FYbl8jPHf2LEu//doRYd5AJQMXy8w1cvBCpqvD0FRKchIjhw4kNSXZ7rKklEwY9weGDL+bbXsPsXXPQbKzspnxxt81iLR0JpPjakJ+BgPrNm9n4/Y9fLN8BetWr2LWO2/d9pzz586y9NvFDoupiEoGLrbtVBpmS9WqFcx+dwbb47fyz5lv213W5g2/4efnx5hxjwCg1+t5Y8a7fP3lF+Tk5ABw8cIFxtx7Dz06tC7+w8rOzmbsA6Po16MzcV07sPy7bwHYv3cPI4cNZFDv7vxh1N2kJCcBMGr4IN5+/VVGDhvIe7Nm0rFVk+JP7JycHNo3vwOj0ciZhATG3HsPg3p3Z8SQ/pw4/jsAZ8+cYfiAOAb36cHMNys2iz8yMopZ78/h0/nzkFJy7uxZRgzpz8Be3RjYqxs7t28D4K3XXmb7ti3079mFj+Z8UOZx9lJtBi6UmWvkWPJVV4ehmXpRNcjPv94F+fnCBXy+cAG+vn6cTbWtTeT3Y0do3fbGYSxBwcHUjq3DmYRTAOzdvYvf4ndhMPgzpG9PBgwaQuL5c0RHx7Do22UAXM3MxGg08tJfn+ezr74lIiKS5d99y4zpr/HenI8AyMzMZPlPawA4uH8fWzdvomfvOFb/vJI+/Qbi7e3NC889w7vv/ZuGje5gz64dTH3+Ob5b8QuvTH2B8Y9P5MExY/lkwbyK/84aNMBisXD5UioRkZF8s3wlfn5+JJw6yaQJ41m9YQsvvfYmc//9Hl8uXgpYk9PNxw3fY//wfpUMXOhEyjWq0rYVOw4c4fWXp/Hzih/Izc3FYDAw9K4RvPbWDJvLlFKW2hgmuf567779CAsLB2DY3SPYEb+V/oOG8PrL05j+6ksMHDKUrt17cvTIYY4dPcLokXcBYDabqVkzurjMEffeX+L7+/hh6RJ69o7j+6Xf8ujjT5GdlcWuHfE8OX5s8XEF+fkA7IzfxsL/fgXAA6Mf4s2/v1zxf2PhjH+T0cjf/vInDh08gF6vJ+HkiVKPr+hxlaWSgQudSK3YkGNPUTM6hsCgIPLy8/H18yMvP5+g4CCiSvzBVVaTps1Z+cPyG167dvUqFxMTqdegIfv37b0lWQghaHRHY1Zv2Mq6Nat4+/VXies3gGF33UOTps1YuXZDqdfy978+YnPw0Lt46/VXSb9yhQP79tIzrg852dkEh4SybnPpM/RtacE/e/o0ep2eiMgoZr3zFhGRUfy6ZQcWi4V6UaGlnvPRh/+u0HGVpdoMXORqnpHkzKo3qu9S6iXGP/YEP63dwPjHniA1JdWu8nr16UtOTi6Lv1oEWD/NX3tpKg+OHVf8x7tx/TrSr1whNzeXX1b+SKcu3UhOuojB35/7R49h8h+ncHD/Xho1vpO0y5fZtcO6CJfRaOTY0SOlXjcgMJB27TvyytQXGDB4GHq9nqDgYOrWq8cPy74DrLWWwwcPANCpa7fidonvFles5f/y5Uv89U9/5LGJkxBCcO1qJjWjo9HpdHz79f8wm63dvoGBgWRlXR8fUtZx9lI1Axep6EQkT/Ppout/CO/Mft/u8oQQfLroa6b+eQr/encGFouF/oMG87dX3yg+pnPX7jz71OOcSTjFqAdG07Z9B9avXcMbr/4NnU6Hl5cXM2d/gI+PDx9/8T9efvHPXL16FZPJxMTJz9K0WfNSrz3i3vt5cvxYlq5cXfzanAWfMfX5/+O9WTMxGo2MvO8BWrRqzfR3ZvH0E+NZMPc/3HXPyDL/PXm5ufTv2QWj0YiXlxf3jx7DpGefA+DRJ57i8YfH8OPypfToFYd/QAAAzVu2wkvvRb8enRn90Lgyj7OX5usZVIRazwC+2XnO7nkIztA9PI/6d9zp6jCUcpQ29+To0aM0a9bshtdut56Buk1wgWt5RpKq4C2C4tlUMnCBk6lZVaoXQakaVDJwgaraXqB4NpUMnCwr38TFzFxXh6Eot1DJwMnULYLirrTYXq2OEGK9EOJo4fZqz2kRWFV1IuVa+QcpigtoUTMwAX+WUjYDugLPCCFK77it5nIKTB7RnehuokMMPDtxQvFzk8lE84Z1GPfgvS6MquqxOxlIKZOklHsKv78GHAVq21tuVZSYnotF3SNUmn9AAMeOHCE319rWsmH9OmJq1XJxVFWPpm0GQoj6WFdKVturlSIxPcfVIXisfgMHsXbVzwAsW7KYkfc9UPyz7OxspjzzFIP79GBAz678svJHgDKn+m7ZtJFRwwfx+MNj6NmxDU8/8SiuGHznbjQbjiyECAS+A6ZIKW+Zl1tye7W6detqdVmPciHds3sRXpnqxaGD2rY5t2xlYfo75S8mMvK+B/jnzLcZOGQYRw8fYsy48WzfthWA92fNpGfvPrw35yMyMzIY2q8Xvfr0K3NKMMChA/vZEL+b6Jha3D2oLzvit9KlWw9N/22eRpNkULgV+3fAIinl0tKOkVLOB+aDdTiyFtf1JLkFZtKyC1wdhsdq3rIV58+dY9mSxfQfOPiGn/326zpW/bySuf9+D4D8/DwuJJ4nOjqmzKm+7dp3pFbtWABatGrD+XPnVDKwtwBhHRC9EDgqpZxtf0hV04WMHI/vUqzIJ7gjDR42nDdensbSlau4cqXEsvJSsvC/X3FH4xvnUPxjxptlTvX18fUt/l6v1zl0qTNPoUWdrwfwMNBPCLGv8DFMg3KrlEQPv0VwB2PGPcLzL06jWYuWN7zep/8AFn70YfF9/8H9+wDHTfWtqrToTdgspRRSytZSyraFj5+0CK4qUcnAfrVqx/Lk5Gdvef1Pf52GyWiib/dOxHXtwMy3rNObH33iKRZ/tYhh/XuTcPKEZlN9qyo1hdkJ8oxm5m045ZG3CWoKs2dQU5g9xIWMXI9MBEr1opKBE3h6l6JSPahk4AQXMlQyUNyfSgYOlm8yk3o139VhKEq5VDJwsKSMPDUfQfEIKhk4mOpSVDyFSgYOdiFDTU7SQmpKMk899jBd2jSnV+d2PHT/SE7ZsJNQ/NbN9O7Snv49u5B08QKPPzym1ONGDR/Evj277Q3bo6h9ExzIaLaQUsXaCxZsTNC0vCd7Nyz3GCklj40dzYNjxvHRp/8FrBONLqWm0OiOxpW63tLF3zD5j1OKN3It2hJNUTUDh0rOzKtyOyy7wpaNG/D29mb8408Wv9aydRu6dOvB6y9PI65rB/p061i8o1FZU5QXff4pPyz7jtkz3+bpJx7l3NmzxHXtAEBubi5PPfYwfbt3YuKj48jLvX5799u6tQwfEMfAXt144pGHyM6yLmjbsVUT3n17OgN7daNPt47FOzJnZ2Xx3NMT6dOtI327d2LF98tuW467UMnAgZKvqlWNtHDs6OFbdmIGWPnDcg4fPMCvW3bw7fcrmf7q34q3WD90YD/T3/kHG3fs5eyZ0+yI38rY8Y8xaNhwXp3+Nh9+/NkNZX2+cD4Gf3/Wb93JlBde5MC+vQCkpV3mvVnvsPj7n1izaRtt2rVn3pwPis8LCw9nzaZtjJ/wJHM/sM6anP3uDIKDg/lt2y7Wb91Jz959yi3HHajbBAeqinspupMd8VsZef+D6PV6IqNq0q1HL/bt2U1gUHClpyjHb93ME089A1inSzdv0QqA3Tt3cPzYMe4Z3A+AgoICOnbqUnze8LtHANCmXTt++vF7ADb9tp55n35RfExojRqs/uWn25bjDlQycKCUm2oGR3f4k3HJm5AIk/URbiIgxIwNm/dWK02aNi+uapd0u3k1Nk1RLu0/Qkp69+3HvE++uPVnJa6j0+kxmU3FcQluKqucctyBuk1wkOx8E9fyrr8Bzx/3ZeGrtfn2/Zp8/Ept/jm5Hq8+2IgvZ0RjsbgwUA/QM64P+fn5fPnZJ8Wv7d29i5DQGvywdAlms5nLly+xbetm2nUodQ5Oubp278nSwt2Tjx45zJHDBwFo36kzO7dv4/SpUwDk5OSU24sR168/nyyYW/w8Iz3dpnKcTSUDBynZXmA2wTezaxIYaubFj8/wf++fY/wrF+k1Mp29vwWzfnENF0bq/qw7MX/DhvXr6NKmOb27tGfWO29x7wOjadaiJf16dOb+u4fyyutvEVUz2qZrjH98ItnZWfTt3ok5788uTioREZG8/+ECJj3+CH27d2L4gDhOFjYUluVPf5lKRkYGcV070K9HZ7Zs2mBTOc6mpjA7yNaTl9l+2roaz5r/hfHzZxFMeO0CLbtnFx8jJXw5I5p9G4N4akYid7ZzvwFKagqzZ1BTmN1YyjVrzSDlnA+rF4XRNu7aDYkArLeoD/4phag6Bfz37RjSU1UTjuI6Khk4SHJmPhaz9fbA108y6unUUo/zNUgee/UiJqPg8zdjMBWo1kTFNVQycICMnALyjGY2/xDKmSMGRk5OJahG2evvRdUxMuaFFM4dM7BiYYQTI1WU6zRJBkKIT4QQqUKIQ1qU5+mSr+ZRkC/4+fNwmnbMpkP/8vdXbN0zi+53ZbD5+1BSznk7IcqKUxuMeB5b/s+0qhl8BgzRqKwKS0pKIi4ujuTk5Bu+d7XkzDxO7PEnP0dP3L3pFR5HMOSRNLz9LKxcGOnYACshyyTIykxXCcGDSClJS0vDz8+vUudp0mIlpdxYuLWaU02fPp3NmzfzxhvW1XCLvv/www+dHcoNUq/mc3BrMH4BZhq1qfisxcBQM/1Hp/PTpxEkHDTQsJXrexeOXvUG0gi8fNnVoSi3ccVwY23Sz8+P2NjYSpWhWddiYTJYIaVsWcbPS26v1uHs2bM2X8tgMJCX5w30xprPdIAeSATiAesvIzfX+X9MFovk3+tO8tIDDWjSPodx0ypXUynIE8yYUJ+QCBPPvX9ejU5UKmTKgMa3dC2Wxi26FqWU86WUHaWUHSMj7asGJyQkULf+98AKYDmwFPgW2IyfXzPGjh3L6dOn7Q/aBpez8zl+wI/sTC9a9aj8rDQfP8nQ8WmcO2Zg/6ZAB0SoKKXzyN4EX98YEs/1ABbh5d0FaA30AiAvbzzBwcFER9s2Es1eKZn5HNoSiJe3haadsss/oRQdB1wlpkE+Kz+JwGTUOEBFKYPHJYOkpCQ6dPwPFosPDz54hl0759OgQRYNGlyga49MfHyf5tSpRJc1JiZn5nFwayB3ts/B12DbLZhOD3c9cYm0iz5sXRFa/gmKogGtuha/ArYBTYQQiUKIx7UotzRvvPEmZ04PJiLqON988xJt2rQhISGBhIQEXnulBgX5QZhMI29oWHSm7bvMpKd423SLUFLTjjk0bpfN6i/Dyc3yuJyteCBN3mVSyjFSyhgppbeUMlZKuVCLcksyGAwIIZg37wTQmMupf0cIgcFgKD5mxIgA4AS//nonFouFuXPn3nKMIxnNFjav9UXoJC262naLUEQIuGfiZXKzdKz9OkyjCBWlbB7zkZOQkMBDDz2ETvdHIAWD4ZdbGgpPnz5Fm7bxQE+gNf7+/k5tTEy9ls+BLYE0aJFLYKj9O/7WbpRPh/5X2bQslCspat6C4lgekwxiYmIQoh4WyzD0+s/Iz796S0NhTEwM7druA3LR6Z4lLy/PqY2Juw4WkHTat/gWwUtnf7/g0EfTQMBPn6phyopjeUwyAIjf0RYQ/Pjj3UyaNKnUBsLMzNM0arwHL+/HmDBhilMbEVd8b/3jb9XdmgwGtYima8Nwu8qsEWUi7t509vwazPnjvuWfoCg28pi6Z0EBpF25j76DzAwd2pyhQ+eUetzSpUvZsQO6dIFmLd/l+ef0Totxw2ofajfKIyzaRMPIAJpEBxX/LD4hzeZy+41OJ/7nEH5cEMnkdxPVQCTFITymZrBkiSQjTc+fnys/f3XqBM1amZgzB6dthZ6UaubkIV9adMvGx0tH36ZRxT/r1ijcrhqCIcDCoHFXOLnfn6M7ArQIV1Fu4THJ4MuvLETFmBk6pCJDLmHyUzoSTug5etQJwQE/rzEhpaBx2xy6Nwon2O/GseLdGoXTpaHtvQLdh2cQWbuA7z+KpCBfVQ0U7XlEMrhwIYnVP2fTs2cuugpGPHyo9cARI/7llHaDXzdY0Htb6NxZ0rZO6QOFujYIx8/bttsWvRfc98dULiX68PNn9rVDKEppPCIZTJnyKWZzMJlXllT4nAYNwD8gjZMn6zpl8NHOeD31muQzrF3NMieM6HSCBhH+Nl/jzvY5dL8rg41La5Bw0DljJ5Tqw62TQdFAoyVLrNNn1617qUKDiAwGAzqdICf7RyCOuXPnOXTwUVYWnDriTaduZiICb9/i3zDSvslHdz95ibBoI1/Nqkl+rrpdULTj1sng+kCjAcBx/P0zKjSIqOg8H59tQAR+fp0cOvjot41mzGbBgD7l/zrrhfujt2P8ga9BMuaFFK4ke7PiY/dZBEXxfG6dDGJiYggMDMVi6YFev7HCg4hiYmIIDg7GaFwNQH5+V4cOPlr1qxmhk9wzuPxxAL5eeuqE2VdDadgql973prPlx1CO77H9tkNRSnLrZABw/HgAEMJbbw0sc6BRaVJSUpg8eRg1o/OoU/cxhzYibtkMDe4sICq8Yo2DDSPsX6dg6KNpRNXJ53/vRnP5onutmah4JrffRGXmTMnUqYKkJLDlg/2RRy18vxzSr+gq3BNRGfn5EBwiufehPL76pGKf+NfyjHy8yf5blqTTPsx5oQ6+BgvPzj5PjagK7CeoVEketdKRrVattdCosdmmRAAwaICOq5k69u/XNq4iu3dDQb5gcP+KdxkG+XlTM7hyi1WWJqZBAZPeSSQ3W8eHf40lM815oy2Vqsetk4HRCPFbdfTvZ3uDW9++1q/r12sU1E1Wr7XOThw2sHJV9YaR2owkjG2cz8S3L5CV7sW8F2O5lq4SgmIbt04Gu3ZBbo5g4ADbw6xdG+o3tLBqjf1Tikvz20ZJ3YYmoqIql7C0SgYA9Zvl8fj0C1xJ8ebDv8SqCU2KTdw6Gfz6q7U9Iy7OvnL694OtW3SYNL6lNpth13Yd3XtWfk/1qCA/gg3aNfzd0TqXJ6ZfICdLz/v/V5cfP45Qw5aVSnHrZLBqrYU7m5qxczFlBvTXkXVNsHevNnEVOXAAsrN0DOpn2+RPLWsHAI3b5jL14zN0GnyV9YvDmPVUPY7v8XfaZC3Fs2m1BuIQIcTvQoiTQoipWpR55kwSmzca6dbV/r0P+vSxfh09ep6mXYwrVmQC0LqlbRuMNNKgi/FmhkALo/+UwuSZ55ES5k2N5Z0J9Vn1ZZjqglRuy+5kIITQA3OAoUBzYIwQorm95U6Z8j+kxY/kpEX2FkV0NASHXOT06fqazlNY+MlJ4DQLF75m0/m1axg0WQ2pNI3b5fKXj84y+vlkQiJMrP5vOG8/2oB/PVuXJR9EsXVFCKcP+5GbrVM1BwXQYJyBEKIb8JqUcnDh82kAUsoZZZ1zu3EG1t2S8oCXgdeBCCDd5h2Srpf3H2A8UAMw2bXj0vUyU4BfCsu1bRenb3ae42JGnk1xVEZ6qhd7fwviSHwAF0/7kpd9vdfBy8dCUKiZgBDrw8fXgrePxNtX4uUj0eslOr1Epwedzvp+ETqKF1kRovA9JG74UmElu8cr+na05Rx7VPZ6zo7vD/3CGT/evnEGWqx0VBs4X+J5ItCllCBKbq9WZmEJCQm88MILfPNNG8zm/fj75zNq1FhmzZplU3BF5S1ZspWCgmfw9e3J/ffXtrm8ojKffHI2K1dGARvx9/dn1KhRNpUZE2JwSjKoEWWi34Pp9HswHSmtySEpwZeU8z5kZeitj0wvsq/qybzkhbFAYCzQYSwQSDNYzAKLxfpVAkiQUjVQuoucM5Lx4+0rQ4tkUNo74pZcKKWcD8wHa82grMKK5hVYLKPx9Y20e1HT6/MUlgJQUNCd4OB0u+YpxMTEkJnZHgAfn612xVgr1I/dtm87aRMhIKymibCaJlp0s29J96JPveJPv0p+CpZ2eHkpxpZzbj6/ssdX5nrOjg/guQGNK3nGrbRIBolAnRLPY4GL9hRonVcwiYkTJzJ//nySkpLsCtBa3v2s/DkHY8FYkpNftqs8gJOn6uIfkMaWzV+xYIHtMUaHePa6BNdvFVwbR3Wn12CsmRZtBl7AcaA/cAHYCTwkpTxc1jmVmZugpWeelXzyCWRmCHx8bC/HYoGISAtDhsL/vrS/Q2bh5tNczVWbKiq2c4u5CVJKE/AssAo4Ciy+XSJwpQH9BXm5gh077Cvn8GHrxKdBdoyMLKlWiP3zFBTFXlptr/aTlPJOKWUjKeVbWpTpCHFx1pbvopGNtio6v2jeg72iVTJQ3IBbj0DUWlgYNGthYfXayg8fLmnVWguxdS3Uq6dNXLVCPbvdQKkaqlUyAOvQ5J3bdeTZ2JtnscCWTUKzWgFAZKAv3nrVAqe4VjVMBoKCAkF8vG3n798PVzN1DOyv3a9OpxNEabC+gaLYo9olg969raPoxo//1KZ5CsuXW+cjtGiRqmlctTy8i1HxfNUuGYSEQGiNs5w718imeQqffHoGOMHHH7+maVyqEVFxtWqVDIr2YbiSthjoyty5n1V4PwXruV4knq8PrGfu3Lma7sVQK1QlA8W1qlUyKNpPwdc3HvDB17dfhfdTSEhIYPDgaUAIsB5/f39N92Lw9/Ei1F9NMVZcp1olg6J5CgUF6wAT+fndKzynICYmhoyMdgD4+sbbPWei1GuoWwXFhapVMoCieQrjaNIsl4jI0ZVqRDyVUI/gkCS2b19eqT0cKipGNSIqLqTFRCWPsnSpdfbiX1+UzJ4dyBdfLK3QeUYjZGe3Y+w4aNMmhjlz5mgeW4xqN1BcqNrVDIoMHSIwmwRLllZsNOJPP1nIzdEVb/XuCBEBvvh4Vdv/EsXFqu07Ly4Oatc18/6/K5YM/vmBmfBIM8OHOy4mnU5osrmKotii2iYDnQ4mT4Z9u7w4cOD2xyYkwOb1Xkx4QuLt4AZ/1YiouEq1TQYAk57U4+VtYeCg5bdtDHzznUykNDPmgSsOj0kNPlJcpVong/BwqFsvntSUAbz88rulHpOXB19+rgO+Z8GC1xwek6oZKK5SbZNB0WjEhJNTgEAWLsy7ZUShwWDAYHgYY0EQMEfzUYelUYOPFFeptsmgaDSiwXAY2I0Qz/DQQzeOKExISCC0xisI8TuOGHVYFlU7UFyh2iaDotGI+fl5eHktQMoWZGe3u2FEYVJSDBnpdwIf4ufn55BRh6XGpgYfKS5gVzIQQjwghDgshLAIIUpdZNGdpaSkMGnSJDZtehov7xy2xbcnLi6O5ORkkpKSGDr0R3S6XCZM8CY+Pt4how5Lo2oGiivYtTqyEKIZYAE+Al6QUlZoyWNXrY58O089Y2LBXJByCRExDZAI0pJacseduznxey+nxmKxSOZuOEWByb7l2ZTqQ4vVke0ajiylPFp4AXuKcTnrdmnRwDKgLZeTLIAZOMTJ408ixO92bcdWWTqdICrIl8R051xPUcCJbQZCiIlCiF1CiF2XLl1y1mUrxNqY2B2DoTvQDL2+NXp9O6AL/v7nndJoeDPVbqA4W7nJQAixVghxqJTHiMpcSEo5X0rZUUrZMTIy0vaIHeB6Y2I+fn5+mM1mzGazUxsNb4lJTVpSnKzc2wQp5QBnBOJqRY2JEydOZNSoUQAsW7ZMk+3dbKEaERVnq3ZTmMtSNLUZrLcNRRwxVbki/H28CDF4k6m2XVOcxN6uxVFCiESgG7BSCLFKm7AUUOsiKs5lVzKQUi6TUsZKKX2llDWllIO1Ckzx/B2aFc9SbUcgegK1IaviTCoZuLGIQLXykeI86p3mxnQ6QW21KaviJCoZuLk6Yf6uDkGpJlQycHP1wlUyUJxDJSt6r/4AAAaJSURBVAM3FxHoS6CvGg6iOJ5KBh6gTphqN1AcTyUDD1A3LMDVISjVgEoGHqCuajdQnEAlAw8Q6OtFeKCPq8NQqjiVDDxEXdXFqDiYSgYeQiUDxdFUMvAQsTX80es8e3k5xb2pZOAhfLx0RKtNWRUHUsnAg6heBcWRVDLwIKrdQHEklQw8SHSwH77e6r9McQz1zvIgOp2gSc0gV4ehVFH2roH4DyHEMSHEASHEMiFEqFaBKaVrXivY1SEoVZS9NYM1QEspZWvgODDN/pCU24kJMajRiIpD2Lsg6moppanwaTwQa39ISnmax6jagaI9LdsMJgA/a1ieUoZmMcHoPHx/S8X9lLtqhhBiLVDa3mIvSSm/LzzmJcAELLpNOROBiQB169a1KVjFKsDXi/oR/iRcynZ1KEoVYvf2akKI8cBdQH95m/3dpZTzgflg3ZK9knEqN2lRK1glA0VT9vYmDAFeBO6RUuZoE5JSEQ0iAvH30bs6DKUKsbfN4D9AELBGCLFPCDFPg5iUCtDrBE2i1ZgDRTt2rbQppbxDq0CUymtRK4S95zJcHYZSRagRiB4sMsiX2BpqsVRFGyoZeLi+TaPUOgeKJlQy8HARgb60r1vD1WEoVYBKBlVAl4ZhBBu8XR2G4uFUMqgCvPU6+jaJdHUYiodTyaCKaBgZyB1Rga4OQ/FgKhlUIX2aROLjpf5LFduod04VEuTnzah2tTGokYmKDVQyqGJqhRp4sGMd1aCoVJpKBlVQWIAPf+hUh6hgX1eHongQlQyqqABfLx7oUIcWtYLVoCSlQlQyqMJ8vHQMahHNYz3q06l+mFpZWbktuyYqKZ4hyM+bno0j6NwgjBOp10jOzCPlaj6Xs/IxW9TSEoqVSgbViI+Xjha1QmhRKwQAk9nClZwCcvLNZBeYyC0wk1NgxmSxYDRLTGaJyWJBSrBIef1rUYESSjwrVvYSN4qjCA2WwVPJoBrz0uuICvKzrkihVHvqJlJRFEAlA0VRCqlkoCgKYP+CqNMLt1bbJ4RYLYSopVVgiqI4l701g39IKVtLKdsCK4BXNYhJURQXsHd7taslngZAKf1MiqJ4BLu7FoUQbwGPAJlAX7sjUhTFJcRtNkGyHlCB7dUKj5sG+Ekp/15GOcXbqwFNgN8rEF8EcLkCx7mSu8fo7vGB+8fo7vFBxWOsJ6UsdVmscpNBRQkh6gErpZQtNSnQWuYuKWVHrcpzBHeP0d3jA/eP0d3jA21itLc3oXGJp/cAx+wpT1EU17G3zeAdIUQTwAKcBSbZH5KiKK5g7/Zq92kVSBnmO7h8Lbh7jO4eH7h/jO4eH2gQo2ZtBoqieDY1HFlRFMCNk4EQYogQ4nchxEkhxFRXx3MzIcQnQohUIcQhV8dSGiFEHSHEeiHEUSHEYSHEc66OqSQhhJ8QYocQYn9hfK+7OqayCCH0Qoi9QogVro6lNEKIM0KIg4XTAnbZXI473iYIIfTAcWAgkAjsBMZIKY+4NLAShBC9gSzgCy27U7UihIgBYqSUe4QQQcBuYKS7/A6FdTWOAClllhDCG9gMPCeljHdxaLcQQjwPdASCpZR3uTqemwkhzgAdpZR2jYVw15pBZ+CklDJBSlkAfA2McHFMN5BSbgSuuDqOskgpk6SUewq/vwYcBWq7NqrrpFVW4VPvwofbfTIJIWKB4cDHro7F0dw1GdQGzpd4nogbvZE9jRCiPtAO2O7aSG5UWP3eB6QCa6SUbhVfofeAv2LtPndXElgthNhdONLXJu6aDEpb0M3tPjU8gRAiEPgOmHLTxDKXk1KaC2e8xgKdhRBudbslhLgLSJVS7nZ1LOXoIaVsDwwFnim8ha00d00GiUCdEs9jgYsuisVjFd6LfwcsklIudXU8ZZFSZgC/AUNcHMrNegD3FN6Tfw30E0J86dqQbiWlvFj4NRVYhvU2u9LcNRnsBBoLIRoIIXyAPwA/uDgmj1LYQLcQOCqlnO3qeG4mhIgUQoQWfm8ABuBmw9mllNOklLFSyvpY34O/SinHuTisGwghAgobiBFCBACDAJt6uNwyGUgpTcCzwCqsDV+LpZSHXRvVjYQQXwHbgCZCiEQhxOOujukmPYCHsX6a7St8DHN1UCXEAOuFEAewJv81Ukq37LpzczWBzUKI/cAOrJMFf7GlILfsWlQUxfncsmagKIrzqWSgKAqgkoGiKIVUMlAUBVDJQFGUQioZKIoCqGSgKEohlQwURQHg/wGjQiQl9JhptQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Test points every 0.1 between 0 and 5\n", "test_x = torch.linspace(0, 5, 51)\n", "\n", "# Get into evaluation (predictive posterior) mode\n", "model.eval()\n", "likelihood.eval()\n", "\n", "# The gpytorch.settings.fast_pred_var flag activates LOVE (for fast variances)\n", "# See https://arxiv.org/abs/1803.06058\n", "with torch.no_grad(), gpytorch.settings.fast_pred_var():\n", " # Make predictions\n", " observed_pred = likelihood(model(test_x))\n", "\n", " # Initialize plot\n", " f, ax = plt.subplots(1, 1, figsize=(4, 3))\n", "\n", " # Get upper and lower confidence bounds\n", " lower, upper = observed_pred.confidence_region()\n", " # Plot training data as black stars\n", " ax.plot(train_x.numpy(), train_y.numpy(), 'k*')\n", " # Plot predictive means as blue line\n", " ax.plot(test_x.numpy(), observed_pred.mean.numpy(), 'b')\n", " # Shade between the lower and upper confidence bounds\n", " ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)\n", " ax.set_ylim([-3, 3])\n", " ax.legend(['Observed Data', 'Mean', 'Confidence'])" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 1 }