Source code for gpytorch.utils.nearest_neighbors

import warnings

import torch
from torch.nn import Module

[docs]class NNUtil(Module): r""" Utility for nearest neighbor search. It would first try to use `faiss`_ (requiring separate pacakge installment) as the backend for better computational performance. Otherwise, `scikit-learn` would be used as it is pre-installed with gpytorch. :param int k: number of nearest neighbors :param int dim: dimensionality of data :param torch.Size batch_shape: batch shape for train data :param str preferred_nnlib: currently supports `faiss` and `scikit-learn` (default: faiss). :param torch.device device: device that the NN search will be performed on. Example: >>> train_x = torch.randn(10, 5) >>> nn_util = NNUtil(k=3, dim=train_x.size(-1), device=train_x.device) >>> nn_util.set_nn_idx(train_x) >>> test_x = torch.randn(2, 5) >>> test_nn_indices = nn_util.find_nn_idx(test_x) # finding 3 nearest neighbors for test_x >>> test_nn_indices = nn_util.find_nn_idx(test_x, k=2) # finding 2 nearest neighbors for test_x >>> sequential_nn_idx = nn_util.build_sequential_nn_idx(train_x) # build up sequential nearest neighbor >>> # structure for train_x .. _faiss: """ def __init__(self, k, dim, batch_shape=torch.Size([]), preferred_nnlib="faiss", device="cpu"): super().__init__() assert k > 0, f"k must be greater than 0, but got k = {k}." self.k = k self.dim = dim if not isinstance(batch_shape, torch.Size): raise RuntimeError(f"batch_shape must be an instance of torch.Size, but got {type(batch_shape)}") self.batch_shape = batch_shape self.train_n = None if preferred_nnlib == "faiss": try: import faiss import faiss.contrib.torch_utils # noqa F401 self.nnlib = "faiss" self.cpu() # Initializes the index except ImportError: warnings.warn( "Tried to import faiss, but failed. Falling back to scikit-learn nearest neighbor search.", ImportWarning, ) self.nnlib = "sklearn" self.train_neighbors = None else: self.nnlib = "sklearn" self.train_neighbors = None def cuda(self, device=None): super().cuda(device=device) if self.nnlib == "faiss": from faiss import GpuIndexFlatL2, StandardGpuResources self.res = StandardGpuResources() self.index = [GpuIndexFlatL2(self.res, self.dim) for _ in range(self.batch_shape.numel())] return self def cpu(self): super().cpu() if self.nnlib == "faiss": from faiss import IndexFlatL2 self.res = None self.index = [IndexFlatL2(self.dim) for _ in range(self.batch_shape.numel())] return self
[docs] def find_nn_idx(self, test_x, k=None): """ Find :math:`k` nearest neighbors for test data `test_x` among the training data stored in this utility :param test_x: test data, shape (... x N x D) :param int k: number of nearest neighbors. Default is the value used in utility initialization. :rtype: torch.LongTensor :return: the indices of nearest neighbors in the training data """ assert self.train_n is not None, "Please initialize with training data first." if k is None: k = self.k else: assert k > 0, f"k must be greater than 0, but got k = {k}." assert k <= self.train_n, ( f"k should be smaller than number of train data, " f"but got k = {k}, number of train data = {self.train_n}." ) test_x = self._expand_and_check_shape(test_x) test_n = test_x.shape[-2] test_x = test_x.view(-1, test_n, self.dim) nn_idx = torch.empty(self.batch_shape.numel(), test_n, k, dtype=torch.int64, device=test_x.device) with torch.no_grad(): if self.nnlib == "sklearn": if self.train_neighbors is None: raise RuntimeError("The nearest neighbor set has not been defined. First call `set_nn_idx`") for i in range(self.batch_shape.numel()): nn_idx_i = torch.from_numpy(self.train_neighbors[i].kneighbors(test_x[i].cpu().numpy())[1][..., :k]) nn_idx[i] = nn_idx_i.long().to(test_x.device) else: for i in range(self.batch_shape.numel()): nn_idx[i] = self.index[i].search(test_x[i], k)[1] nn_idx = nn_idx.view(*self.batch_shape, test_n, k) return nn_idx
[docs] def set_nn_idx(self, train_x): """ Set the indices of training data to facilitate nearest neighbor search. This function needs to be called every time that the data changes. :param torch.Tensor train_x: training data points (... x N x D) """ train_x = self._expand_and_check_shape(train_x) self.train_n = train_x.shape[-2] with torch.no_grad(): if self.nnlib == "sklearn": self.train_neighbors = [] from sklearn.neighbors import NearestNeighbors train_x = train_x.view(-1, self.train_n, self.dim) for i in range(self.batch_shape.numel()): x = train_x[i].cpu().numpy() self.train_neighbors.append(NearestNeighbors(n_neighbors=self.k, algorithm="auto").fit(x)) elif self.nnlib == "faiss": train_x = train_x.view(-1, self.train_n, self.dim) for i in range(self.batch_shape.numel()): self.index[i].reset() self.index[i].add(train_x[i])
[docs] def build_sequential_nn_idx(self, x): r""" Build the sequential :math:`k` nearest neighbor structure within training data in the following way: for the :math:`i`-th data point :math:`x_i`, find its :math:`k` nearest neighbors among preceding training data :math:`x_1, \cdots, x_{i-1}`, for `i=k+1:N` where `N` is the size of training data. :param x: training data. Shape `(N, D)` :rtype: torch.LongTensor :return: indices of nearest neighbors. Shape: `(N-k, k)` """ x = self._expand_and_check_shape(x) N = x.shape[-2] assert self.k < N, f"k should be smaller than number of data, but got k = {self.k}, number of data = {N}." nn_idx = torch.empty(self.batch_shape.numel(), N - self.k, self.k, dtype=torch.int64) x_np = x.view(-1, N, self.dim).data.float().cpu().numpy() if self.nnlib == "faiss": from faiss import IndexFlatL2 # building nearest neighbor structure within inducing points index = IndexFlatL2(self.dim) with torch.no_grad(): if self.res is not None: from faiss import index_cpu_to_gpu index = index_cpu_to_gpu(self.res, 0, index) for bi in range(self.batch_shape.numel()): index.reset() index.add(x_np[bi][: self.k]) for i in range(self.k, N): row = x_np[bi][i][None, :] nn_idx[bi][i - self.k].copy_( torch.from_numpy(, self.k)[1][..., 0, :]).long().to(x.device) ) index.add(row) else: assert self.nnlib == "sklearn" from sklearn.neighbors import NearestNeighbors for bi in range(self.batch_shape.numel()): # finding k nearest neighbors in the first k for i in range(self.k, N): train_neighbors = NearestNeighbors(n_neighbors=self.k, algorithm="auto").fit(x_np[bi][:i]) nn_idx_i = torch.from_numpy( train_neighbors.kneighbors( x_np[bi][i][ None, ] )[1] ).squeeze() nn_idx[bi][i - self.k].copy_(nn_idx_i) nn_idx = nn_idx.view(*self.batch_shape, N - self.k, self.k) return nn_idx
[docs] def to(self, device): """ Put the utility to a cpu or gpu device. :param torch.device device: Target device. """ if str(device) == "cpu": return self.cpu() elif "cuda" in str(device): return self.cuda() else: raise ValueError(f"Unknown device {device}")
def _expand_and_check_shape(self, x): if len(x.shape) == 1: x = x.unsqueeze(-1) assert x.shape[:-2] == self.batch_shape, ( f"x's batch shape must be equal to self.batch_shape, " f"but got x's batch shape={x.shape[:-2]}, self.batch_shape={self.batch_shape}." ) assert x.shape[-1] == self.dim, ( f"x's dim must be equal to self.dim, " f"but got x's dim = {x.shape[-1]}, self.dim = {self.dim}" ) return x