diff --git a/gpytorch/kernels/constant_kernel.py b/gpytorch/kernels/constant_kernel.py index 98a3560e2..ab177519c 100644 --- a/gpytorch/kernels/constant_kernel.py +++ b/gpytorch/kernels/constant_kernel.py @@ -90,7 +90,6 @@ def forward( x1: Tensor, x2: Tensor, diag: Optional[bool] = False, - last_dim_is_batch: Optional[bool] = False, ) -> Tensor: """Evaluates the constant kernel. @@ -98,17 +97,11 @@ def forward( x1: First input tensor of shape (batch_shape x n1 x d). x2: Second input tensor of shape (batch_shape x n2 x d). diag: If True, returns the diagonal of the covariance matrix. - last_dim_is_batch: If True, the last dimension of size `d` of the input - tensors are treated as a batch dimension. Returns: A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of constant covariance values if diag is False, resp. True. """ - if last_dim_is_batch: - x1 = x1.transpose(-1, -2).unsqueeze(-1) - x2 = x2.transpose(-1, -2).unsqueeze(-1) - dtype = torch.promote_types(x1.dtype, x2.dtype) batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],)) @@ -117,7 +110,4 @@ def forward( if not diag: constant = constant.unsqueeze(-1) - if last_dim_is_batch: - constant = constant.unsqueeze(-1) - return constant.expand(shape) diff --git a/gpytorch/kernels/kernel.py b/gpytorch/kernels/kernel.py index 67e576db3..261c4e663 100644 --- a/gpytorch/kernels/kernel.py +++ b/gpytorch/kernels/kernel.py @@ -231,9 +231,7 @@ def _set_lengthscale(self, value: Tensor): self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value)) @abstractmethod - def forward( - self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params - ) -> Union[Tensor, LinearOperator]: + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]: r""" Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`. This method should be implemented by all Kernel subclasses. @@ -242,16 +240,11 @@ def forward( :param x2: Second set of data (... x M x D). :param diag: Should the Kernel compute the whole kernel, or just the diag? If True, it must be the case that `x1 == x2`. (Default: False.) - :param last_dim_is_batch: If True, treat the last dimension - of `x1` and `x2` as another batch dimension. - (Useful for additive structure over the dimensions). (Default: False.) :return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode: * `full_covar`: `... x N x M` - * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M` * `diag`: `... x N` - * `diag` with `last_dim_is_batch=True`: `... x K x N` """ raise NotImplementedError() @@ -314,7 +307,6 @@ def covar_dist( x1: Tensor, x2: Tensor, diag: bool = False, - last_dim_is_batch: bool = False, square_dist: bool = False, **params, ) -> Tensor: @@ -326,22 +318,13 @@ def covar_dist( :param x2: Second set of data (... x M x D). :param diag: Should the Kernel compute the whole kernel, or just the diag? If True, it must be the case that `x1 == x2`. (Default: False.) - :param last_dim_is_batch: If True, treat the last dimension - of `x1` and `x2` as another batch dimension. - (Useful for additive structure over the dimensions). (Default: False.) :param square_dist: If True, returns the squared distance rather than the standard distance. (Default: False.) :return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode: * `full_covar`: `... x N x M` - * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M` * `diag`: `... x N` - * `diag` with `last_dim_is_batch=True`: `... x K x N` """ - if last_dim_is_batch: - x1 = x1.transpose(-1, -2).unsqueeze(-1) - x2 = x2.transpose(-1, -2).unsqueeze(-1) - x1_eq_x2 = torch.equal(x1, x2) res = None @@ -457,7 +440,7 @@ def sub_kernels(self) -> Iterable[Kernel]: yield kernel def __call__( - self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, last_dim_is_batch: bool = False, **params + self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, **params ) -> Union[LazyEvaluatedKernelTensor, LinearOperator, Tensor]: r""" Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`. @@ -473,27 +456,13 @@ def __call__( (If `None`, then `x2` is set to `x1`.) :param diag: Should the Kernel compute the whole kernel, or just the diag? If True, it must be the case that `x1 == x2`. (Default: False.) - :param last_dim_is_batch: If True, treat the last dimension - of `x1` and `x2` as another batch dimension. - (Useful for additive structure over the dimensions). (Default: False.) :return: An object that will lazily evaluate to the kernel matrix or vector. The shape depends on the kernel's evaluation mode: * `full_covar`: `... x N x M` - * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M` * `diag`: `... x N` - * `diag` with `last_dim_is_batch=True`: `... x K x N` """ - if last_dim_is_batch: - warnings.warn( - "The last_dim_is_batch argument is deprecated, and will be removed in GPyTorch 2.0. " - "If you are using it as part of AdditiveStructureKernel or ProductStructureKernel, " - 'please update your code according to the "Kernels with Additive or Product Structure" ' - "tutorial in the GPyTorch docs.", - DeprecationWarning, - ) - x1_, x2_ = x1, x2 # Select the active dimensions @@ -523,7 +492,7 @@ def __call__( ) if diag: - res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, **params) + res = super(Kernel, self).__call__(x1_, x2_, diag=True, **params) # Did this Kernel eat the diag option? # If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output if not isinstance(res, LazyEvaluatedKernelTensor): @@ -533,11 +502,9 @@ def __call__( else: if settings.lazily_evaluate_kernels.on(): - res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params) + res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, **params) else: - res = to_linear_operator( - super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params) - ) + res = to_linear_operator(super(Kernel, self).__call__(x1_, x2_, **params)) return res def __getstate__(self): diff --git a/gpytorch/kernels/linear_kernel.py b/gpytorch/kernels/linear_kernel.py index d7ecd1014..59f1885e8 100644 --- a/gpytorch/kernels/linear_kernel.py +++ b/gpytorch/kernels/linear_kernel.py @@ -85,12 +85,8 @@ def _set_variance(self, value: Union[float, Tensor]): value = torch.as_tensor(value).to(self.raw_variance) self.initialize(raw_variance=self.raw_variance_constraint.inverse_transform(value)) - def forward( - self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, last_dim_is_batch: Optional[bool] = False, **params - ) -> LinearOperator: + def forward(self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, **params) -> LinearOperator: x1_ = x1 * self.variance.sqrt() - if last_dim_is_batch: - x1_ = x1_.transpose(-1, -2).unsqueeze(-1) if x1.size() == x2.size() and torch.equal(x1, x2): # Use RootLinearOperator when x1 == x2 for efficiency when composing @@ -99,9 +95,6 @@ def forward( else: x2_ = x2 * self.variance.sqrt() - if last_dim_is_batch: - x2_ = x2_.transpose(-1, -2).unsqueeze(-1) - prod = MatmulLinearOperator(x1_, x2_.transpose(-2, -1)) if diag: diff --git a/gpytorch/kernels/matern_kernel.py b/gpytorch/kernels/matern_kernel.py index baf145e36..3824ef49b 100644 --- a/gpytorch/kernels/matern_kernel.py +++ b/gpytorch/kernels/matern_kernel.py @@ -89,7 +89,6 @@ def forward(self, x1, x2, diag=False, **params): or x2.requires_grad or (self.ard_num_dims is not None and self.ard_num_dims > 1) or diag - or params.get("last_dim_is_batch", False) or trace_mode.on() ): mean = x1.mean(dim=-2, keepdim=True) diff --git a/gpytorch/kernels/multitask_kernel.py b/gpytorch/kernels/multitask_kernel.py index 79a4f1388..f2e0b16d7 100644 --- a/gpytorch/kernels/multitask_kernel.py +++ b/gpytorch/kernels/multitask_kernel.py @@ -43,9 +43,7 @@ def __init__( self.data_covar_module = data_covar_module self.num_tasks = num_tasks - def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): - if last_dim_is_batch: - raise RuntimeError("MultitaskKernel does not accept the last_dim_is_batch argument.") + def forward(self, x1, x2, diag=False, **params): covar_i = self.task_covar_module.covar_matrix if len(x1.shape[:-2]): covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1) diff --git a/gpytorch/kernels/periodic_kernel.py b/gpytorch/kernels/periodic_kernel.py index 2972b523a..f8c543921 100644 --- a/gpytorch/kernels/periodic_kernel.py +++ b/gpytorch/kernels/periodic_kernel.py @@ -124,23 +124,20 @@ def _set_period_length(self, value): self.initialize(raw_period_length=self.raw_period_length_constraint.inverse_transform(value)) def forward(self, x1, x2, diag=False, **params): - # Pop this argument so that we can manually sum over dimensions - last_dim_is_batch = params.pop("last_dim_is_batch", False) # Get lengthscale lengthscale = self.lengthscale x1_ = x1.div(self.period_length / math.pi) x2_ = x2.div(self.period_length / math.pi) - # We are automatically overriding last_dim_is_batch here so that we can manually sum over dimensions. - diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True, **params) + diff = self.covar_dist( + x1_.transpose(-1, -2).unsqueeze(-1), x2_.transpose(-1, -2).unsqueeze(-1), diag=diag, **params + ) # A ... x D x N x N kernel if diag: lengthscale = lengthscale[..., 0, :, None] else: lengthscale = lengthscale[..., 0, :, None, None] exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0) - - if not last_dim_is_batch: - exp_term = exp_term.sum(dim=(-2 if diag else -3)) + exp_term = exp_term.sum(dim=(-2 if diag else -3)) return exp_term.exp() diff --git a/gpytorch/kernels/piecewise_polynomial_kernel.py b/gpytorch/kernels/piecewise_polynomial_kernel.py index 8135979f0..0bdb358c7 100644 --- a/gpytorch/kernels/piecewise_polynomial_kernel.py +++ b/gpytorch/kernels/piecewise_polynomial_kernel.py @@ -101,20 +101,13 @@ def __init__(self, q: Optional[int] = 2, **kwargs): raise ValueError("q expected to be 0, 1, 2 or 3") self.q = q - def forward(self, x1: Tensor, x2: Tensor, last_dim_is_batch: bool = False, diag: bool = False, **params) -> Tensor: + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Tensor: x1_ = x1.div(self.lengthscale) x2_ = x2.div(self.lengthscale) - if last_dim_is_batch is True: - D = x1.shape[1] - else: - D = x1.shape[-1] + D = x1.shape[-1] j = math.floor(D / 2.0) + self.q + 1 - if last_dim_is_batch and diag: - r = self.covar_dist(x1_, x2_, last_dim_is_batch=True, diag=True) - elif diag: + if diag: r = self.covar_dist(x1_, x2_, diag=True) - elif last_dim_is_batch: - r = self.covar_dist(x1_, x2_, last_dim_is_batch=True) else: r = self.covar_dist(x1_, x2_) cov_matrix = _fmax(r, j, self.q) * _get_cov(r, j, self.q) diff --git a/gpytorch/kernels/polynomial_kernel.py b/gpytorch/kernels/polynomial_kernel.py index 3a98e8d4e..1dd57be88 100644 --- a/gpytorch/kernels/polynomial_kernel.py +++ b/gpytorch/kernels/polynomial_kernel.py @@ -81,15 +81,10 @@ def forward( x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = False, - last_dim_is_batch: Optional[bool] = False, **params, ) -> torch.Tensor: offset = self.offset.view(*self.batch_shape, 1, 1) - if last_dim_is_batch: - x1 = x1.transpose(-1, -2).unsqueeze(-1) - x2 = x2.transpose(-1, -2).unsqueeze(-1) - if diag: return ((x1 * x2).sum(dim=-1) + self.offset).pow(self.power) diff --git a/gpytorch/kernels/polynomial_kernel_grad.py b/gpytorch/kernels/polynomial_kernel_grad.py index f499bc23a..a8a17313d 100644 --- a/gpytorch/kernels/polynomial_kernel_grad.py +++ b/gpytorch/kernels/polynomial_kernel_grad.py @@ -13,7 +13,6 @@ def forward( x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = False, - last_dim_is_batch: Optional[bool] = False, **params, ) -> torch.Tensor: offset = self.offset.view(*self.batch_shape, 1, 1) diff --git a/gpytorch/kernels/rbf_kernel.py b/gpytorch/kernels/rbf_kernel.py index 932e59724..073d30f3e 100644 --- a/gpytorch/kernels/rbf_kernel.py +++ b/gpytorch/kernels/rbf_kernel.py @@ -71,7 +71,6 @@ def forward(self, x1, x2, diag=False, **params): or x2.requires_grad or (self.ard_num_dims is not None and self.ard_num_dims > 1) or diag - or params.get("last_dim_is_batch", False) or trace_mode.on() ): x1_ = x1.div(self.lengthscale) diff --git a/gpytorch/kernels/rff_kernel.py b/gpytorch/kernels/rff_kernel.py index c6b5e4ccd..c11c221ef 100644 --- a/gpytorch/kernels/rff_kernel.py +++ b/gpytorch/kernels/rff_kernel.py @@ -111,10 +111,7 @@ def _init_weights( ) self.register_buffer("randn_weights", randn_weights) - def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **kwargs) -> Tensor: - if last_dim_is_batch: - x1 = x1.transpose(-1, -2).unsqueeze(-1) - x2 = x2.transpose(-1, -2).unsqueeze(-1) + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> Tensor: num_dims = x1.size(-1) if not hasattr(self, "randn_weights"): self._init_weights(num_dims, self.num_samples) diff --git a/gpytorch/kernels/scale_kernel.py b/gpytorch/kernels/scale_kernel.py index 520913265..9a8a8f652 100644 --- a/gpytorch/kernels/scale_kernel.py +++ b/gpytorch/kernels/scale_kernel.py @@ -105,11 +105,9 @@ def _set_outputscale(self, value): value = torch.as_tensor(value).to(self.raw_outputscale) self.initialize(raw_outputscale=self.raw_outputscale_constraint.inverse_transform(value)) - def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params): - orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params) + def forward(self, x1, x2, diag=False, **params): + orig_output = self.base_kernel.forward(x1, x2, diag=diag, **params) outputscales = self.outputscale - if last_dim_is_batch: - outputscales = outputscales.unsqueeze(-1) if diag: outputscales = outputscales.unsqueeze(-1) return to_dense(orig_output) * outputscales diff --git a/gpytorch/kernels/spectral_mixture_kernel.py b/gpytorch/kernels/spectral_mixture_kernel.py index c8de79010..7b4e56ea2 100644 --- a/gpytorch/kernels/spectral_mixture_kernel.py +++ b/gpytorch/kernels/spectral_mixture_kernel.py @@ -268,7 +268,7 @@ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor, **k self.mixture_weights = train_y.std().div(self.num_mixtures) def _create_input_grid( - self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params + self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params ) -> Tuple[torch.Tensor, torch.Tensor]: """ This is a helper method for creating a grid of the kernel's inputs. @@ -280,33 +280,20 @@ def _create_input_grid( :param torch.Tensor x2: ... x m x d (for diag mode, these must be the same inputs) :param diag: Should the Kernel compute the whole kernel, or just the diag? (Default: True.) :type diag: bool, optional - :param last_dim_is_batch: If this is true, it treats the last dimension - of the data as another batch dimension. (Useful for additive - structure over the dimensions). (Default: False.) - :type last_dim_is_batch: bool, optional :rtype: torch.Tensor, torch.Tensor :return: Grid corresponding to x1 and x2. The shape depends on the kernel's mode: * `full_covar`: (`... x n x 1 x d` and `... x 1 x m x d`) - * `full_covar` with `last_dim_is_batch=True`: (`... x k x n x 1 x 1` and `... x k x 1 x m x 1`) * `diag`: (`... x n x d` and `... x n x d`) - * `diag` with `last_dim_is_batch=True`: (`... x k x n x 1` and `... x k x n x 1`) """ x1_, x2_ = x1, x2 - if last_dim_is_batch: - x1_ = x1_.transpose(-1, -2).unsqueeze(-1) - if torch.equal(x1, x2): - x2_ = x1_ - else: - x2_ = x2_.transpose(-1, -2).unsqueeze(-1) - if diag: return x1_, x2_ else: return x1_.unsqueeze(-2), x2_.unsqueeze(-3) def forward( - self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params + self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params ) -> Tuple[torch.Tensor, torch.Tensor]: n, num_dims = x1.shape[-2:] @@ -344,10 +331,5 @@ def forward( res = (res * mixture_weights).sum(-3 if diag else -4) # Product over dimensions - if last_dim_is_batch: - # Put feature-dimension in front of data1/data2 dimensions - res = res.permute(*list(range(0, res.dim() - 3)), -1, -3, -2) - else: - res = res.prod(-1) - + res = res.prod(-1) return res diff --git a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py index 3efe398b4..d2c23e810 100644 --- a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py +++ b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py @@ -31,20 +31,17 @@ def wrapped(self, *args, **kwargs): class LazyEvaluatedKernelTensor(LinearOperator): _check_size = False - def _check_args(self, x1, x2, kernel, last_dim_is_batch=False, **params): + def _check_args(self, x1, x2, kernel, **params): if not torch.is_tensor(x1): return "x1 must be a tensor. Got {}".format(x1.__class__.__name__) if not torch.is_tensor(x2): return "x1 must be a tensor. Got {}".format(x1.__class__.__name__) - def __init__(self, x1, x2, kernel, last_dim_is_batch=False, **params): - super(LazyEvaluatedKernelTensor, self).__init__( - x1, x2, kernel=kernel, last_dim_is_batch=last_dim_is_batch, **params - ) + def __init__(self, x1, x2, kernel, **params): + super(LazyEvaluatedKernelTensor, self).__init__(x1, x2, kernel=kernel, **params) self.kernel = kernel self.x1 = x1 self.x2 = x2 - self.last_dim_is_batch = last_dim_is_batch self.params = params self._is_grad_enabled = torch.is_grad_enabled() # records grad state at instantiation @@ -92,7 +89,6 @@ def _bilinear_derivative(self, left_vecs, right_vecs): sub_x1, x2, diag=False, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) ) @@ -115,9 +111,7 @@ def _diagonal(self) -> torch.Tensor: x1 = self.x1 x2 = self.x2 - res = super(Kernel, self.kernel).__call__( - x1, x2, diag=True, last_dim_is_batch=self.last_dim_is_batch, **self.params - ) + res = super(Kernel, self.kernel).__call__(x1, x2, diag=True, **self.params) # Now we'll make sure that the shape we're getting from diag makes sense if settings.debug.on(): @@ -193,12 +187,7 @@ def _getitem(self, row_index, col_index, *batch_indices): col_index = slice(col_start // num_outs_per_in_cols, col_end // num_outs_per_in_cols, None) # Define the index we're using for the last index - # If the last index corresponds to a batch, then we'll use the appropriate batch_index - # Otherwise, we'll use the _noop_index - if self.last_dim_is_batch: - *batch_indices, dim_index = batch_indices - else: - dim_index = _noop_index + dim_index = _noop_index # Get the indices of x1 and x2 that matter for the kernel # Call x1[*batch_indices, row_index, :] @@ -238,7 +227,6 @@ def _getitem(self, row_index, col_index, *batch_indices): x1, x2, kernel=new_kernel, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) @@ -265,7 +253,6 @@ def _matmul(self, rhs): sub_x1, x2, diag=False, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) ) @@ -312,9 +299,6 @@ def _size(self): f"Got x1.shape = {x1.shape} and x2.shape = {x2.shape}" ) - # Handle when the last dim is batch - if self.last_dim_is_batch: - expected_size = expected_size[:-2] + x1.shape[-1:] + expected_size[-2:] return expected_size @recall_grad_state @@ -323,7 +307,6 @@ def _transpose_nonbatch(self): self.x2, self.x1, kernel=self.kernel, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) @@ -335,7 +318,6 @@ def _unsqueeze_batch(self, dim): x1, x2, kernel=self.kernel, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) @@ -356,7 +338,6 @@ def evaluate_kernel(self): x1, x2, diag=False, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) self.kernel.active_dims = temp_active_dims @@ -383,7 +364,6 @@ def repeat(self, *repeats): x1, x2, kernel=self.kernel, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 2e95e2162..ad4760e86 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -835,7 +835,6 @@ def exact_prediction(self, joint_mean, joint_covar): test_test_covar.x1, test_test_covar.x2, test_test_covar.kernel.base_kernel, - test_test_covar.last_dim_is_batch, **test_test_covar.params, ) diff --git a/test/kernels/test_constant_kernel.py b/test/kernels/test_constant_kernel.py index 849ec3996..af46029fe 100644 --- a/test/kernels/test_constant_kernel.py +++ b/test/kernels/test_constant_kernel.py @@ -46,17 +46,6 @@ def _test_constant_kernel(self, device: torch.device): # standard deviation is zero iff KM is constant self.assertAlmostEqual(KM.std().item(), 0, places=places) - # testing last_dim_is_batch - with self.subTest(last_dim_is_batch=True): - KD = constant_kernel(X, last_dim_is_batch=True).to(device=device) - self.assertIsInstance(KD, LazyEvaluatedKernelTensor) - KM = KD.to_dense() - self.assertIsInstance(KM, Tensor) - self.assertEqual(KM.shape, (*batch_shape, d, n, n)) - self.assertAlmostEqual(KM.std().item(), 0, places=places) - self.assertEqual(KM.dtype, dtype) - self.assertEqual(KM.device.type, device.type) - # testing diag with self.subTest(diag=True): KD = constant_kernel(X, diag=True) @@ -66,15 +55,6 @@ def _test_constant_kernel(self, device: torch.device): self.assertEqual(KD.dtype, dtype) self.assertEqual(KD.device.type, device.type) - # testing diag and last_dim_is_batch - with self.subTest(diag=True, last_dim_is_batch=True): - KD = constant_kernel(X, diag=True, last_dim_is_batch=True) - self.assertIsInstance(KD, Tensor) - self.assertEqual(KD.shape, (*batch_shape, d, n)) - self.assertAlmostEqual(KD.std().item(), 0, places=places) - self.assertEqual(KD.dtype, dtype) - self.assertEqual(KD.device.type, device.type) - # testing AD with self.subTest(requires_grad=True): X.requires_grad = True diff --git a/test/kernels/test_cosine_kernel.py b/test/kernels/test_cosine_kernel.py index e6d903bd5..6cd3196c9 100644 --- a/test/kernels/test_cosine_kernel.py +++ b/test/kernels/test_cosine_kernel.py @@ -30,20 +30,6 @@ def test_computes_periodic_function(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 3, 3) - for i in range(3): - for j in range(3): - for l in range(2): - actual[l, i, j] = torch.cos(math.pi * ((a[i, l] - b[j, l]) / period)) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_batch(self): a = torch.tensor([[4, 2, 8], [1, 2, 3]], dtype=torch.float).view(2, 3, 1) b = torch.tensor([[0, 2, 1], [-1, 2, 0]], dtype=torch.float).view(2, 3, 1) @@ -81,21 +67,6 @@ def test_batch_separate(self): actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 2, 3, 3) - for k in range(2): - for i in range(3): - for j in range(3): - for l in range(2): - actual[k, l, i, j] = torch.cos(math.pi * ((a[k, i, l] - b[k, j, l]) / period[k])) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def create_kernel_with_prior(self, period_length_prior): return CosineKernel(period_length_prior=period_length_prior) diff --git a/test/kernels/test_linear_kernel.py b/test/kernels/test_linear_kernel.py index b520842fd..708fc1253 100644 --- a/test/kernels/test_linear_kernel.py +++ b/test/kernels/test_linear_kernel.py @@ -42,18 +42,6 @@ def test_computes_linear_function_square(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-4) - # batch_dims - dim_group_a = a - dim_group_a = dim_group_a.permute(1, 0).reshape(-1, 3) - actual = 3.14 * torch.mul(dim_group_a.unsqueeze(-1), dim_group_a.unsqueeze(-2)) - res = kernel(a, a, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-4) - - # batch_dims + diag - res = kernel(a, a, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-4) - def test_computes_linear_function_square_batch(self): a = torch.tensor([[[4, 1], [2, 0], [8, 3]], [[1, 1], [2, 1], [1, 3]]], dtype=torch.float) @@ -68,18 +56,6 @@ def test_computes_linear_function_square_batch(self): actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-4) - # batch_dims - dim_group_a = a - dim_group_a = dim_group_a.transpose(-1, -2).unsqueeze(-1) - actual = dim_group_a.matmul(dim_group_a.transpose(-2, -1)) - res = kernel(a, a, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-4) - - # batch_dims + diag - res = kernel(a, a, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-4) - def create_kernel_with_prior(self, variance_prior): return self.create_kernel_no_ard(variance_prior=variance_prior) diff --git a/test/kernels/test_matern_kernel.py b/test/kernels/test_matern_kernel.py index a544947e8..20b3a18f9 100644 --- a/test/kernels/test_matern_kernel.py +++ b/test/kernels/test_matern_kernel.py @@ -96,18 +96,6 @@ def test_ard(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - dist = torch.tensor([[[0, 0], [2, 2]], [[1, 1], [0, 0]]], dtype=torch.float) - dist.mul_(math.sqrt(5)) - actual = (dist**2 / 3 + dist + 1).mul(torch.exp(-dist)) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 3]], [[2, -1, 2], [2, -1, 0]]], dtype=torch.float) b = torch.tensor([[[1, 4, 3]], [[2, -1, 0]]], dtype=torch.float) @@ -141,26 +129,6 @@ def test_ard_separate_batch(self): actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - dist = torch.tensor( - [ - [[[0.0, 0.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[0.0, 0.0], [0.0, 0.0]], [[4.0, 4.0], [0.0, 0.0]]], - ] - ) - - dist.mul_(math.sqrt(5)) - dist = dist.view(3, 2, 2, 2).transpose(0, 1) - actual = (dist**2 / 3 + dist + 1).mul(torch.exp(-dist)) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def create_kernel_with_prior(self, lengthscale_prior): return self.create_kernel_no_ard(lengthscale_prior=lengthscale_prior) diff --git a/test/kernels/test_piecewise_polynomial_kernel.py b/test/kernels/test_piecewise_polynomial_kernel.py index 3b5f7e766..f09181070 100644 --- a/test/kernels/test_piecewise_polynomial_kernel.py +++ b/test/kernels/test_piecewise_polynomial_kernel.py @@ -51,19 +51,6 @@ def test_fmax(r, j, q): res = kernel(a, b).diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 3, 3) - for i in range(2): - actual[i] = kernel(a[:, i].unsqueeze(-1), b[:, i].unsqueeze(-1)).to_dense() - - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_piecewise_polynomial_kernel_batch(self): a = torch.tensor([[4, 2, 8], [1, 2, 3]], dtype=torch.float).view(2, 3, 1) b = torch.tensor([[0, 2, 1], [-1, 2, 0]], dtype=torch.float).view(2, 3, 1) diff --git a/test/kernels/test_polynomial_kernel.py b/test/kernels/test_polynomial_kernel.py index ad57536a0..a82bfc12a 100644 --- a/test/kernels/test_polynomial_kernel.py +++ b/test/kernels/test_polynomial_kernel.py @@ -31,19 +31,6 @@ def test_computes_quadratic_kernel(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 3, 3) - for i in range(2): - actual[i] = kernel(a[:, i].unsqueeze(-1), b[:, i].unsqueeze(-1)).to_dense() - - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_computes_cubic_kernel(self): a = torch.tensor([[4, 1], [2, 2], [8, 0]], dtype=torch.float) b = torch.tensor([[0, 0], [2, 1], [1, 0]], dtype=torch.float) @@ -63,19 +50,6 @@ def test_computes_cubic_kernel(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 3, 3) - for i in range(2): - actual[i] = kernel(a[:, i].unsqueeze(-1), b[:, i].unsqueeze(-1)).to_dense() - - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_quadratic_kernel_batch(self): a = torch.tensor([[4, 2, 8], [1, 2, 3]], dtype=torch.float).view(2, 3, 1) b = torch.tensor([[0, 2, 1], [-1, 2, 0]], dtype=torch.float).view(2, 3, 1) diff --git a/test/kernels/test_rbf_kernel.py b/test/kernels/test_rbf_kernel.py index 718fb4e26..729cb0b95 100644 --- a/test/kernels/test_rbf_kernel.py +++ b/test/kernels/test_rbf_kernel.py @@ -38,17 +38,6 @@ def test_ard(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = scaled_a.transpose(-1, -2).unsqueeze(-1) - scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = actual.pow(2).mul_(-0.5).exp() - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-1, dim2=-2) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) @@ -69,19 +58,6 @@ def test_ard_batch(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - double_batch_a = scaled_a.transpose(-1, -2).unsqueeze(-1) - double_batch_b = scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = double_batch_a - double_batch_b - actual = actual.pow(2).mul_(-0.5).exp() - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_separate_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) diff --git a/test/kernels/test_rq_kernel.py b/test/kernels/test_rq_kernel.py index b7a92e726..05eb806f7 100644 --- a/test/kernels/test_rq_kernel.py +++ b/test/kernels/test_rq_kernel.py @@ -38,17 +38,6 @@ def test_ard(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - diff = scaled_a.transpose(-1, -2).unsqueeze(-1) - scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = diff.pow(2).div_(2 * kernel.alpha).add_(1.0).pow(-kernel.alpha) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-1, dim2=-2) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) @@ -71,20 +60,6 @@ def test_ard_batch(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # # batch_dims - double_batch_a = scaled_a.transpose(-1, -2).unsqueeze(-1) - double_batch_b = scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = double_batch_a - double_batch_b - alpha = kernel.alpha.view(2, 1, 1, 1) - actual = actual.pow_(2).div_(2 * alpha).add_(1.0).pow(-alpha) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_separate_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) diff --git a/test/kernels/test_scale_kernel.py b/test/kernels/test_scale_kernel.py index 57ad0bdf8..3959064bb 100644 --- a/test/kernels/test_scale_kernel.py +++ b/test/kernels/test_scale_kernel.py @@ -44,18 +44,6 @@ def test_ard(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = scaled_a.transpose(-1, -2).unsqueeze(-1) - scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = actual.pow(2).mul_(-0.5).exp().view(2, 2, 2) - actual.mul_(3) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) @@ -79,20 +67,6 @@ def test_ard_batch(self): actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - double_batch_a = scaled_a.transpose(-1, -2) - double_batch_b = scaled_b.transpose(-1, -2) - actual = double_batch_a.unsqueeze(-1) - double_batch_b.unsqueeze(-2) - actual = actual.pow(2).mul_(-0.5).exp() - actual[1, :, :, :].mul_(2) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_initialize_outputscale(self): kernel = ScaleKernel(RBFKernel()) kernel.initialize(outputscale=3.14) diff --git a/test/lazy/test_lazy_evaluated_kernel_tensor.py b/test/lazy/test_lazy_evaluated_kernel_tensor.py index 5a3528704..2041f8ca4 100644 --- a/test/lazy/test_lazy_evaluated_kernel_tensor.py +++ b/test/lazy/test_lazy_evaluated_kernel_tensor.py @@ -181,33 +181,3 @@ def test_half(self): lazy_tensor = self.create_linear_op() lazy_tensor.kernel.data_covar_module.raw_lengthscale_constraint.transform = lambda x: x + 0.1 self._test_half(lazy_tensor) - - -class TestLazyEvaluatedKernelTensorAdditive(TestLazyEvaluatedKernelTensorBatch): - seed = 0 - - def create_linear_op(self): - kern = gpytorch.kernels.AdditiveStructureKernel(gpytorch.kernels.RBFKernel(), num_dims=6) - mat1 = torch.randn(5, 6) - mat2 = mat1.detach().clone() - return kern(mat1, mat2) - - def evaluate_linear_op(self, lazy_tensor): - res = to_dense( - gpytorch.Module.__call__( - lazy_tensor.kernel.base_kernel, - lazy_tensor.x1.transpose(-1, -2).unsqueeze(-1), - lazy_tensor.x2.transpose(-1, -2).unsqueeze(-1), - ) - ).sum(0) - return res - - def test_inv_matmul_matrix_with_checkpointing(self): - pass - - def test_half(self): - # many transform operations aren't supported in half so we overwrite - # this test - lazy_tensor = self.create_linear_op() - lazy_tensor.kernel.base_kernel.raw_lengthscale_constraint.transform = lambda x: x + 0.1 - self._test_half(lazy_tensor)