Skip to content

Commit

Permalink
[Breaking Change] remove last_dim_is_batch from remaining kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
gpleiss committed Jul 17, 2024
1 parent 650368e commit 8d859b9
Show file tree
Hide file tree
Showing 25 changed files with 25 additions and 388 deletions.
10 changes: 0 additions & 10 deletions gpytorch/kernels/constant_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,18 @@ def forward(
x1: Tensor,
x2: Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
) -> Tensor:
"""Evaluates the constant kernel.
Args:
x1: First input tensor of shape (batch_shape x n1 x d).
x2: Second input tensor of shape (batch_shape x n2 x d).
diag: If True, returns the diagonal of the covariance matrix.
last_dim_is_batch: If True, the last dimension of size `d` of the input
tensors are treated as a batch dimension.
Returns:
A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of
constant covariance values if diag is False, resp. True.
"""
if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)

dtype = torch.promote_types(x1.dtype, x2.dtype)
batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],))
Expand All @@ -117,7 +110,4 @@ def forward(
if not diag:
constant = constant.unsqueeze(-1)

if last_dim_is_batch:
constant = constant.unsqueeze(-1)

return constant.expand(shape)
43 changes: 5 additions & 38 deletions gpytorch/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,7 @@ def _set_lengthscale(self, value: Tensor):
self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))

@abstractmethod
def forward(
self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params
) -> Union[Tensor, LinearOperator]:
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
r"""
Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
This method should be implemented by all Kernel subclasses.
Expand All @@ -242,16 +240,11 @@ def forward(
:param x2: Second set of data (... x M x D).
:param diag: Should the Kernel compute the whole kernel, or just the diag?
If True, it must be the case that `x1 == x2`. (Default: False.)
:param last_dim_is_batch: If True, treat the last dimension
of `x1` and `x2` as another batch dimension.
(Useful for additive structure over the dimensions). (Default: False.)
:return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode:
* `full_covar`: `... x N x M`
* `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
* `diag`: `... x N`
* `diag` with `last_dim_is_batch=True`: `... x K x N`
"""
raise NotImplementedError()

Expand Down Expand Up @@ -314,7 +307,6 @@ def covar_dist(
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
square_dist: bool = False,
**params,
) -> Tensor:
Expand All @@ -326,22 +318,13 @@ def covar_dist(
:param x2: Second set of data (... x M x D).
:param diag: Should the Kernel compute the whole kernel, or just the diag?
If True, it must be the case that `x1 == x2`. (Default: False.)
:param last_dim_is_batch: If True, treat the last dimension
of `x1` and `x2` as another batch dimension.
(Useful for additive structure over the dimensions). (Default: False.)
:param square_dist:
If True, returns the squared distance rather than the standard distance. (Default: False.)
:return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode:
* `full_covar`: `... x N x M`
* `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
* `diag`: `... x N`
* `diag` with `last_dim_is_batch=True`: `... x K x N`
"""
if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)

x1_eq_x2 = torch.equal(x1, x2)
res = None

Expand Down Expand Up @@ -457,7 +440,7 @@ def sub_kernels(self) -> Iterable[Kernel]:
yield kernel

def __call__(
self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, last_dim_is_batch: bool = False, **params
self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, **params
) -> Union[LazyEvaluatedKernelTensor, LinearOperator, Tensor]:
r"""
Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
Expand All @@ -473,27 +456,13 @@ def __call__(
(If `None`, then `x2` is set to `x1`.)
:param diag: Should the Kernel compute the whole kernel, or just the diag?
If True, it must be the case that `x1 == x2`. (Default: False.)
:param last_dim_is_batch: If True, treat the last dimension
of `x1` and `x2` as another batch dimension.
(Useful for additive structure over the dimensions). (Default: False.)
:return: An object that will lazily evaluate to the kernel matrix or vector.
The shape depends on the kernel's evaluation mode:
* `full_covar`: `... x N x M`
* `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
* `diag`: `... x N`
* `diag` with `last_dim_is_batch=True`: `... x K x N`
"""
if last_dim_is_batch:
warnings.warn(
"The last_dim_is_batch argument is deprecated, and will be removed in GPyTorch 2.0. "
"If you are using it as part of AdditiveStructureKernel or ProductStructureKernel, "
'please update your code according to the "Kernels with Additive or Product Structure" '
"tutorial in the GPyTorch docs.",
DeprecationWarning,
)

x1_, x2_ = x1, x2

# Select the active dimensions
Expand Down Expand Up @@ -523,7 +492,7 @@ def __call__(
)

if diag:
res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, **params)
res = super(Kernel, self).__call__(x1_, x2_, diag=True, **params)
# Did this Kernel eat the diag option?
# If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output
if not isinstance(res, LazyEvaluatedKernelTensor):
Expand All @@ -533,11 +502,9 @@ def __call__(

else:
if settings.lazily_evaluate_kernels.on():
res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params)
res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, **params)
else:
res = to_linear_operator(
super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
)
res = to_linear_operator(super(Kernel, self).__call__(x1_, x2_, **params))
return res

def __getstate__(self):
Expand Down
9 changes: 1 addition & 8 deletions gpytorch/kernels/linear_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,8 @@ def _set_variance(self, value: Union[float, Tensor]):
value = torch.as_tensor(value).to(self.raw_variance)
self.initialize(raw_variance=self.raw_variance_constraint.inverse_transform(value))

def forward(
self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, last_dim_is_batch: Optional[bool] = False, **params
) -> LinearOperator:
def forward(self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, **params) -> LinearOperator:
x1_ = x1 * self.variance.sqrt()
if last_dim_is_batch:
x1_ = x1_.transpose(-1, -2).unsqueeze(-1)

if x1.size() == x2.size() and torch.equal(x1, x2):
# Use RootLinearOperator when x1 == x2 for efficiency when composing
Expand All @@ -99,9 +95,6 @@ def forward(

else:
x2_ = x2 * self.variance.sqrt()
if last_dim_is_batch:
x2_ = x2_.transpose(-1, -2).unsqueeze(-1)

prod = MatmulLinearOperator(x1_, x2_.transpose(-2, -1))

if diag:
Expand Down
1 change: 0 additions & 1 deletion gpytorch/kernels/matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def forward(self, x1, x2, diag=False, **params):
or x2.requires_grad
or (self.ard_num_dims is not None and self.ard_num_dims > 1)
or diag
or params.get("last_dim_is_batch", False)
or trace_mode.on()
):
mean = x1.mean(dim=-2, keepdim=True)
Expand Down
4 changes: 1 addition & 3 deletions gpytorch/kernels/multitask_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def __init__(
self.data_covar_module = data_covar_module
self.num_tasks = num_tasks

def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
if last_dim_is_batch:
raise RuntimeError("MultitaskKernel does not accept the last_dim_is_batch argument.")
def forward(self, x1, x2, diag=False, **params):
covar_i = self.task_covar_module.covar_matrix
if len(x1.shape[:-2]):
covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)
Expand Down
11 changes: 4 additions & 7 deletions gpytorch/kernels/periodic_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,23 +124,20 @@ def _set_period_length(self, value):
self.initialize(raw_period_length=self.raw_period_length_constraint.inverse_transform(value))

def forward(self, x1, x2, diag=False, **params):
# Pop this argument so that we can manually sum over dimensions
last_dim_is_batch = params.pop("last_dim_is_batch", False)
# Get lengthscale
lengthscale = self.lengthscale

x1_ = x1.div(self.period_length / math.pi)
x2_ = x2.div(self.period_length / math.pi)
# We are automatically overriding last_dim_is_batch here so that we can manually sum over dimensions.
diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True, **params)
diff = self.covar_dist(
x1_.transpose(-1, -2).unsqueeze(-1), x2_.transpose(-1, -2).unsqueeze(-1), diag=diag, **params
) # A ... x D x N x N kernel

if diag:
lengthscale = lengthscale[..., 0, :, None]
else:
lengthscale = lengthscale[..., 0, :, None, None]
exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0)

if not last_dim_is_batch:
exp_term = exp_term.sum(dim=(-2 if diag else -3))
exp_term = exp_term.sum(dim=(-2 if diag else -3))

return exp_term.exp()
13 changes: 3 additions & 10 deletions gpytorch/kernels/piecewise_polynomial_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,13 @@ def __init__(self, q: Optional[int] = 2, **kwargs):
raise ValueError("q expected to be 0, 1, 2 or 3")
self.q = q

def forward(self, x1: Tensor, x2: Tensor, last_dim_is_batch: bool = False, diag: bool = False, **params) -> Tensor:
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Tensor:
x1_ = x1.div(self.lengthscale)
x2_ = x2.div(self.lengthscale)
if last_dim_is_batch is True:
D = x1.shape[1]
else:
D = x1.shape[-1]
D = x1.shape[-1]
j = math.floor(D / 2.0) + self.q + 1
if last_dim_is_batch and diag:
r = self.covar_dist(x1_, x2_, last_dim_is_batch=True, diag=True)
elif diag:
if diag:
r = self.covar_dist(x1_, x2_, diag=True)
elif last_dim_is_batch:
r = self.covar_dist(x1_, x2_, last_dim_is_batch=True)
else:
r = self.covar_dist(x1_, x2_)
cov_matrix = _fmax(r, j, self.q) * _get_cov(r, j, self.q)
Expand Down
5 changes: 0 additions & 5 deletions gpytorch/kernels/polynomial_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,10 @@ def forward(
x1: torch.Tensor,
x2: torch.Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
**params,
) -> torch.Tensor:
offset = self.offset.view(*self.batch_shape, 1, 1)

if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)

if diag:
return ((x1 * x2).sum(dim=-1) + self.offset).pow(self.power)

Expand Down
1 change: 0 additions & 1 deletion gpytorch/kernels/polynomial_kernel_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def forward(
x1: torch.Tensor,
x2: torch.Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
**params,
) -> torch.Tensor:
offset = self.offset.view(*self.batch_shape, 1, 1)
Expand Down
1 change: 0 additions & 1 deletion gpytorch/kernels/rbf_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def forward(self, x1, x2, diag=False, **params):
or x2.requires_grad
or (self.ard_num_dims is not None and self.ard_num_dims > 1)
or diag
or params.get("last_dim_is_batch", False)
or trace_mode.on()
):
x1_ = x1.div(self.lengthscale)
Expand Down
5 changes: 1 addition & 4 deletions gpytorch/kernels/rff_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,7 @@ def _init_weights(
)
self.register_buffer("randn_weights", randn_weights)

def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **kwargs) -> Tensor:
if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> Tensor:
num_dims = x1.size(-1)
if not hasattr(self, "randn_weights"):
self._init_weights(num_dims, self.num_samples)
Expand Down
6 changes: 2 additions & 4 deletions gpytorch/kernels/scale_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,9 @@ def _set_outputscale(self, value):
value = torch.as_tensor(value).to(self.raw_outputscale)
self.initialize(raw_outputscale=self.raw_outputscale_constraint.inverse_transform(value))

def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params):
orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
def forward(self, x1, x2, diag=False, **params):
orig_output = self.base_kernel.forward(x1, x2, diag=diag, **params)
outputscales = self.outputscale
if last_dim_is_batch:
outputscales = outputscales.unsqueeze(-1)
if diag:
outputscales = outputscales.unsqueeze(-1)
return to_dense(orig_output) * outputscales
Expand Down
24 changes: 3 additions & 21 deletions gpytorch/kernels/spectral_mixture_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor, **k
self.mixture_weights = train_y.std().div(self.num_mixtures)

def _create_input_grid(
self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params
self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This is a helper method for creating a grid of the kernel's inputs.
Expand All @@ -280,33 +280,20 @@ def _create_input_grid(
:param torch.Tensor x2: ... x m x d (for diag mode, these must be the same inputs)
:param diag: Should the Kernel compute the whole kernel, or just the diag? (Default: True.)
:type diag: bool, optional
:param last_dim_is_batch: If this is true, it treats the last dimension
of the data as another batch dimension. (Useful for additive
structure over the dimensions). (Default: False.)
:type last_dim_is_batch: bool, optional
:rtype: torch.Tensor, torch.Tensor
:return: Grid corresponding to x1 and x2. The shape depends on the kernel's mode:
* `full_covar`: (`... x n x 1 x d` and `... x 1 x m x d`)
* `full_covar` with `last_dim_is_batch=True`: (`... x k x n x 1 x 1` and `... x k x 1 x m x 1`)
* `diag`: (`... x n x d` and `... x n x d`)
* `diag` with `last_dim_is_batch=True`: (`... x k x n x 1` and `... x k x n x 1`)
"""
x1_, x2_ = x1, x2
if last_dim_is_batch:
x1_ = x1_.transpose(-1, -2).unsqueeze(-1)
if torch.equal(x1, x2):
x2_ = x1_
else:
x2_ = x2_.transpose(-1, -2).unsqueeze(-1)

if diag:
return x1_, x2_
else:
return x1_.unsqueeze(-2), x2_.unsqueeze(-3)

def forward(
self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params
self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params
) -> Tuple[torch.Tensor, torch.Tensor]:
n, num_dims = x1.shape[-2:]

Expand Down Expand Up @@ -344,10 +331,5 @@ def forward(
res = (res * mixture_weights).sum(-3 if diag else -4)

# Product over dimensions
if last_dim_is_batch:
# Put feature-dimension in front of data1/data2 dimensions
res = res.permute(*list(range(0, res.dim() - 3)), -1, -3, -2)
else:
res = res.prod(-1)

res = res.prod(-1)
return res
Loading

0 comments on commit 8d859b9

Please sign in to comment.