#!/usr/bin/env python3
from typing import List, Optional
import torch
from linear_operator import to_linear_operator
from linear_operator.operators import CatLinearOperator
from torch.nn.parallel import DataParallel
from .. import settings
from .kernel import Kernel
[docs]class MultiDeviceKernel(DataParallel, Kernel):
r"""
Allocates the covariance matrix on distributed devices, e.g. multiple GPUs.
Args:
base_kernel: Base kernel to distribute
device_ids: list of `torch.device` objects to place kernel chunks on
output_device: Device where outputs will be placed
"""
def __init__(
self,
base_kernel: Kernel,
device_ids: List[torch.device],
output_device: Optional[torch.device] = None,
create_cuda_context: Optional[bool] = True,
**kwargs,
):
# Need to warm up each GPU otherwise scattering in forward will be
# EXTREMELY slow. This memory will be available as soon as we leave __init__
if create_cuda_context:
for d in device_ids:
_ = torch.tensor([], device=d)
DataParallel.__init__(self, module=base_kernel, device_ids=device_ids, output_device=output_device, dim=-2)
self.output_device = output_device if output_device else device_ids[0]
self.__cached_x1 = torch.empty(1)
self.__cached_x2 = torch.empty(1)
@property
def base_kernel(self):
return self.module
def forward(self, x1, x2, diag=False, **kwargs):
if diag:
return self.module.forward(x1, x2, diag=True, **kwargs).to(self.output_device)
if x1.size(-2) < len(self.device_ids) + 1:
return self.module.forward(x1, x2, diag=diag, **kwargs).to(self.output_device)
if not x1.device == self.__cached_x1.device or not torch.equal(x1, self.__cached_x1):
self._x1_scattered, self._kwargs = self.scatter((x1,), kwargs, self.device_ids)
self.__cached_x1 = x1
if not x2.device == self.__cached_x2.device or not torch.equal(x2, self.__cached_x2):
self._x2_subs = [x2.to(x1_[0].device) for x1_ in self._x1_scattered]
self.__cached_x2 = x2
inputs = tuple((x1_[0], x2_) for x1_, x2_ in zip(self._x1_scattered, self._x2_subs))
if not self.device_ids:
return self.module.forward(*inputs, **self._kwargs)
if len(self.device_ids) == 1:
return self.module.forward(*inputs[0], **self._kwargs[0])
# JIT modules can't be pickled and replicated yet
# But reinitializing the distance_module every forward pass
# is slow and should be removed once JIT modules can be pickled
def set_distance_module_to_none(module):
if hasattr(module, "distance_module"):
module.distance_module = None
self.module.apply(set_distance_module_to_none)
# Can't cache the replication because the base kernel module can change every time (e.g. param updates)
replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
# TODO: parallel_apply might be too heavyweight in some cases?
with settings.lazily_evaluate_kernels(False):
outputs = self.parallel_apply(replicas, inputs, self._kwargs)
return self.gather(outputs, self.output_device)
def gather(self, outputs, output_device):
return CatLinearOperator(
*(to_linear_operator(o) for o in outputs), dim=self.dim, output_device=self.output_device
)
def num_outputs_per_input(self, x1, x2):
return self.base_kernel.num_outputs_per_input(x1, x2)