Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matern52 grad #2512

Merged
merged 6 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .kernel import AdditiveKernel, Kernel, ProductKernel
from .lcm_kernel import LCMKernel
from .linear_kernel import LinearKernel
from .matern52_kernel_grad import Matern52KernelGrad
from .matern_kernel import MaternKernel
from .multi_device_kernel import MultiDeviceKernel
from .multitask_kernel import MultitaskKernel
Expand Down Expand Up @@ -69,4 +70,5 @@
"ScaleKernel",
"SpectralDeltaKernel",
"SpectralMixtureKernel",
"Matern52KernelGrad",
]
152 changes: 152 additions & 0 deletions gpytorch/kernels/matern52_kernel_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/usr/bin/env python3

import math

import torch
from linear_operator.operators import KroneckerProductLinearOperator

from gpytorch.kernels.matern_kernel import MaternKernel

sqrt5 = math.sqrt(5)
five_thirds = 5.0 / 3.0


class Matern52KernelGrad(MaternKernel):
r"""
Computes a covariance matrix of the Matern52 kernel that models the covariance
between the values and partial derivatives for inputs :math:`\mathbf{x_1}`
and :math:`\mathbf{x_2}`.

See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options.

.. note::

This kernel does not have an `outputscale` parameter. To add a scaling parameter,
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.

:param ard_num_dims: Set this if you want a separate lengthscale for each input
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
:param batch_shape: Set this if you want a separate lengthscale for each batch of input
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
:param active_dims: Set this if you want to compute the covariance of only
a few input dimensions. The ints corresponds to the indices of the
dimensions. (Default: `None`.)
:param lengthscale_prior: Set this if you want to apply a prior to the
lengthscale parameter. (Default: `None`)
:param lengthscale_constraint: Set this if you want to apply a constraint
to the lengthscale parameter. (Default: `Positive`.)
:param eps: The minimum value that the lengthscale can take (prevents
divide by zero errors). (Default: `1e-6`.)

:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
ard_num_dims and batch_shape arguments.

Example:
>>> x = torch.randn(10, 5)
>>> # Non-batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad())
>>> covar = covar_module(x) # Output: LinearOperator of size (60 x 60), where 60 = n * (d + 1)
>>>
>>> batch_x = torch.randn(2, 10, 5)
>>> # Batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad())
>>> # Batch: different lengthscale for each batch
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad(batch_shape=torch.Size([2]))) # noqa: E501
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 60 x 60)
"""

def __init__(self, **kwargs):

# remove nu in case it was set
kwargs.pop("nu", None)
super(Matern52KernelGrad, self).__init__(nu=2.5, **kwargs)

def forward(self, x1, x2, diag=False, **params):

lengthscale = self.lengthscale

batch_shape = x1.shape[:-2]
n_batch_dims = len(batch_shape)
n1, d = x1.shape[-2:]
n2 = x2.shape[-2]

if not diag:

K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)

distance_matrix = self.covar_dist(x1.div(lengthscale), x2.div(lengthscale), diag=diag, **params)
exp_neg_sqrt5r = torch.exp(-sqrt5 * distance_matrix)

# differences matrix in each dimension to be used for derivatives
# shape of n1 x n2 x d
outer = x1.view(*batch_shape, n1, 1, d) - x2.view(*batch_shape, 1, n2, d)
outer = outer / lengthscale.unsqueeze(-2) ** 2
# shape of n1 x d x n2
outer = torch.transpose(outer, -1, -2).contiguous()

# 1) Kernel block, cov(f^m, f^n)
# shape is n1 x n2
exp_component = torch.exp(-sqrt5 * distance_matrix)
constant_component = (sqrt5 * distance_matrix).add(1).add(five_thirds * distance_matrix**2)

K[..., :n1, :n2] = constant_component * exp_component

# 2) First gradient block, cov(f^m, omega^n_d)
outer1 = outer.view(*batch_shape, n1, n2 * d)
K[..., :n1, n2:] = outer1 * (-five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
[*([1] * (n_batch_dims + 1)), d]
)

# 3) Second gradient block, cov(omega^m_d, f^n)
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
outer2 = outer2.transpose(-1, -2)
# the - signs on -outer2 and -five_thirds cancel out
K[..., n1:, :n2] = outer2 * (five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
[*([1] * n_batch_dims), d, 1]
)

# 4) Hessian block, cov(omega^m_d, omega^n_d)
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
kp = KroneckerProductLinearOperator(
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / lengthscale**2,
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
)

part1 = -five_thirds * exp_neg_sqrt5r
part2 = 5 * outer3
part3 = 1 + sqrt5 * distance_matrix

K[..., n1:, n2:] = part1.repeat([*([1] * n_batch_dims), d, d]).mul_(
# need to use kp.to_dense().mul instead of kp.to_dense().mul_
# because otherwise a RuntimeError is raised due to how autograd works with
# view + inplace operations in the case of 1-dimensional input
part2.sub_(kp.to_dense().mul(part3.repeat([*([1] * n_batch_dims), d, d])))
)

# Symmetrize for stability
if n1 == n2 and torch.eq(x1, x2).all():
K = 0.5 * (K.transpose(-1, -2) + K)

# Apply a perfect shuffle permutation to match the MutiTask ordering
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
K = K[..., pi1, :][..., :, pi2]

return K
else:
if not (n1 == n2 and torch.eq(x1, x2).all()):
raise RuntimeError("diag=True only works when x1 == x2")

# nu is set to 2.5
kernel_diag = super(Matern52KernelGrad, self).forward(x1, x2, diag=True)
grad_diag = (
five_thirds * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype)
) / lengthscale**2
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
k_diag = torch.cat((kernel_diag, grad_diag), dim=-1)
pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
return k_diag[..., pi]

def num_outputs_per_input(self, x1, x2):
return x1.size(-1) + 1
78 changes: 78 additions & 0 deletions test/kernels/test_matern52_kernel_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3

import unittest

import torch

from gpytorch.kernels import Matern52KernelGrad
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase


class TestMatern52KernelGrad(unittest.TestCase, BaseKernelTestCase):
def create_kernel_no_ard(self, **kwargs):
return Matern52KernelGrad(**kwargs)

def create_kernel_ard(self, num_dims, **kwargs):
return Matern52KernelGrad(ard_num_dims=num_dims, **kwargs)

def test_kernel(self, cuda=False):
a = torch.tensor([[[1, 2], [2, 4]]], dtype=torch.float)
b = torch.tensor([[[1, 3], [0, 4]]], dtype=torch.float)

actual = torch.tensor(
[
[0.3056225, -0.0000000, 0.5822443, 0.0188260, -0.0209871, 0.0419742],
[0.0000000, 0.5822443, 0.0000000, 0.0209871, -0.0056045, 0.0531832],
[-0.5822443, 0.0000000, -0.8515886, -0.0419742, 0.0531832, -0.0853792],
[0.1304891, -0.2014212, -0.2014212, 0.0336440, -0.0815567, -0.0000000],
[0.2014212, -0.1754366, -0.3768578, 0.0815567, -0.1870145, -0.0000000],
[0.2014212, -0.3768578, -0.1754366, 0.0000000, -0.0000000, 0.0407784],
]
)

kernel = Matern52KernelGrad()

if cuda:
a = a.cuda()
b = b.cuda()
actual = actual.cuda()
kernel = kernel.cuda()

res = kernel(a, b).to_dense()

self.assertLess(torch.norm(res - actual), 1e-5)

def test_kernel_cuda(self):
if torch.cuda.is_available():
self.test_kernel(cuda=True)

def test_kernel_batch(self):
a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float)
b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1)

kernel = Matern52KernelGrad()
res = kernel(a, b).to_dense()

# Compute each batch separately
actual = torch.zeros(2, 8, 8)
actual[0, :, :] = kernel(a[0, :, :].squeeze(), b[0, :, :].squeeze()).to_dense()
actual[1, :, :] = kernel(a[1, :, :].squeeze(), b[1, :, :].squeeze()).to_dense()

self.assertLess(torch.norm(res - actual), 1e-5)

def test_initialize_lengthscale(self):
kernel = Matern52KernelGrad()
kernel.initialize(lengthscale=3.14)
actual_value = torch.tensor(3.14).view_as(kernel.lengthscale)
self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)

def test_initialize_lengthscale_batch(self):
kernel = Matern52KernelGrad(batch_shape=torch.Size([2]))
ls_init = torch.tensor([3.14, 4.13])
kernel.initialize(lengthscale=ls_init)
actual_value = ls_init.view_as(kernel.lengthscale)
self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)


if __name__ == "__main__":
unittest.main()
Loading