diff --git a/gpytorch/kernels/cosine_kernel.py b/gpytorch/kernels/cosine_kernel.py index 11add6f2f..a688eba67 100644 --- a/gpytorch/kernels/cosine_kernel.py +++ b/gpytorch/kernels/cosine_kernel.py @@ -91,7 +91,7 @@ def period_length(self): @period_length.setter def period_length(self, value): - return self._set_period_length(value) + self._set_period_length(value) def _set_period_length(self, value): if not torch.is_tensor(value): diff --git a/gpytorch/kernels/cylindrical_kernel.py b/gpytorch/kernels/cylindrical_kernel.py index 48f24958c..ea66c956a 100644 --- a/gpytorch/kernels/cylindrical_kernel.py +++ b/gpytorch/kernels/cylindrical_kernel.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -from typing import Optional +from typing import Optional, Union import torch +from torch import Tensor from .. import settings from ..constraints import Interval, Positive @@ -94,39 +95,45 @@ def __init__( self.register_prior("beta_prior", beta_prior, lambda m: m.beta, lambda m, v: m._set_beta(v)) @property - def angular_weights(self) -> torch.Tensor: + def angular_weights(self) -> Tensor: return self.raw_angular_weights_constraint.transform(self.raw_angular_weights) @angular_weights.setter - def angular_weights(self, value: torch.Tensor) -> None: + def angular_weights(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.tensor(value) self.initialize(raw_angular_weights=self.raw_angular_weights_constraint.inverse_transform(value)) @property - def alpha(self) -> torch.Tensor: + def alpha(self) -> Tensor: return self.raw_alpha_constraint.transform(self.raw_alpha) @alpha.setter - def alpha(self, value: torch.Tensor) -> None: - if not torch.is_tensor(value): - value = torch.tensor(value) + def alpha(self, value: Tensor) -> None: + self._set_alpha(value) + def _set_alpha(self, value: Union[Tensor, float]) -> None: + # Used by the alpha_prior + if not isinstance(value, Tensor): + value = torch.as_tensor(value).to(self.raw_alpha) self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value)) @property - def beta(self) -> torch.Tensor: + def beta(self) -> Tensor: return self.raw_beta_constraint.transform(self.raw_beta) @beta.setter - def beta(self, value: torch.Tensor) -> None: - if not torch.is_tensor(value): - value = torch.tensor(value) + def beta(self, value: Tensor) -> None: + self._set_beta(value) + def _set_beta(self, value: Union[Tensor, float]) -> None: + # Used by the beta_prior + if not isinstance(value, Tensor): + value = torch.as_tensor(value).to(self.raw_beta) self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value)) - def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = False, **params) -> torch.Tensor: + def forward(self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, **params) -> Tensor: x1_, x2_ = x1.clone(), x2.clone() # Jitter datapoints that are exactly 0 @@ -156,12 +163,12 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params) return radial_kernel.mul(angular_kernel) - def kuma(self, x: torch.Tensor) -> torch.Tensor: + def kuma(self, x: Tensor) -> Tensor: alpha = self.alpha.view(*self.batch_shape, 1, 1) beta = self.beta.view(*self.batch_shape, 1, 1) res = 1 - (1 - x.pow(alpha) + self.eps).pow(beta) return res - def num_outputs_per_input(self, x1: torch.Tensor, x2: torch.Tensor) -> int: + def num_outputs_per_input(self, x1: Tensor, x2: Tensor) -> int: return self.radial_base_kernel.num_outputs_per_input(x1, x2) diff --git a/gpytorch/kernels/hamming_kernel.py b/gpytorch/kernels/hamming_kernel.py index 6a28a2aa9..d942872b8 100644 --- a/gpytorch/kernels/hamming_kernel.py +++ b/gpytorch/kernels/hamming_kernel.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from torch import nn, Tensor @@ -95,13 +95,13 @@ def _alpha_param(self, m: Kernel) -> Tensor: # Used by the alpha_prior return m.alpha - def _alpha_closure(self, m: Kernel, v: Tensor) -> Tensor: + def _alpha_closure(self, m: Kernel, v: Union[Tensor, float]) -> None: # Used by the alpha_prior - return m._set_alpha(v) + m._set_alpha(v) - def _set_alpha(self, value: Tensor): + def _set_alpha(self, value: Union[Tensor, float]) -> None: # Used by the alpha_prior - if not torch.is_tensor(value): + if not isinstance(value, Tensor): value = torch.as_tensor(value).to(self.raw_alpha) self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value)) @@ -117,13 +117,13 @@ def _beta_param(self, m: Kernel) -> Tensor: # Used by the beta_prior return m.beta - def _beta_closure(self, m: Kernel, v: Tensor) -> Tensor: + def _beta_closure(self, m: Kernel, v: Union[Tensor, float]) -> None: # Used by the beta_prior - return m._set_beta(v) + m._set_beta(v) - def _set_beta(self, value: Tensor): + def _set_beta(self, value: Union[Tensor, float]) -> None: # Used by the beta_prior - if not torch.is_tensor(value): + if not isinstance(value, Tensor): value = torch.as_tensor(value).to(self.raw_beta) self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value)) diff --git a/gpytorch/kernels/kernel.py b/gpytorch/kernels/kernel.py index 67e576db3..0a4c49efa 100644 --- a/gpytorch/kernels/kernel.py +++ b/gpytorch/kernels/kernel.py @@ -216,9 +216,9 @@ def _lengthscale_param(self, m: Kernel) -> Tensor: # Used by the lengthscale_prior return m.lengthscale - def _lengthscale_closure(self, m: Kernel, v: Tensor) -> Tensor: + def _lengthscale_closure(self, m: Kernel, v: Tensor) -> None: # Used by the lengthscale_prior - return m._set_lengthscale(v) + m._set_lengthscale(v) def _set_lengthscale(self, value: Tensor): # Used by the lengthscale_prior diff --git a/gpytorch/kernels/scale_kernel.py b/gpytorch/kernels/scale_kernel.py index 520913265..fdfadb0af 100644 --- a/gpytorch/kernels/scale_kernel.py +++ b/gpytorch/kernels/scale_kernel.py @@ -90,7 +90,7 @@ def _outputscale_param(self, m): return m.outputscale def _outputscale_closure(self, m, v): - return m._set_outputscale(v) + m._set_outputscale(v) @property def outputscale(self): diff --git a/gpytorch/module.py b/gpytorch/module.py index 57550755a..e5081d878 100644 --- a/gpytorch/module.py +++ b/gpytorch/module.py @@ -21,8 +21,8 @@ ModuleSelf = TypeVar("ModuleSelf", bound="Module") # TODO: replace w/ typing.Self in Python 3.11 RandomModuleSelf = TypeVar("RandomModuleSelf", bound="RandomModuleMixin") # TODO: replace w/ typing.Self in Python 3.11 -Closure = Callable[[nn.Module], Tensor] -SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], ModuleSelf] +Closure = Callable[[NnModuleSelf], Tensor] +SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], None] SamplesDict = Mapping[str, Union[Tensor, float]] @@ -111,7 +111,7 @@ def added_loss_terms(self): def forward(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]: raise NotImplementedError - def constraints(self): + def constraints(self) -> Iterator[Interval]: for _, constraint in self.named_constraints(): yield constraint @@ -294,8 +294,8 @@ def closure_new(module: nn.Module) -> Tensor: if setting_closure is not None: raise RuntimeError("Must specify a closure instead of a parameter name when providing setting_closure") - def setting_closure_new(module: ModuleSelf, val: Union[Tensor, float]) -> ModuleSelf: - return module.initialize(**{param: val}) + def setting_closure_new(module: Module, val: Union[Tensor, float]) -> None: + module.initialize(**{param: val}) setting_closure = setting_closure_new @@ -443,7 +443,7 @@ def pyro_sample_from_prior(self) -> Module: new_module = self.to_pyro_random_module() return _pyro_sample_from_prior(module=new_module, memo=None, prefix="") - def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[str], prefix: str) -> None: + def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[Prior], prefix: str) -> None: """ Defines local behavior of this Module when loading parameters from a samples_dict generated by a Pyro sampling mechanism. @@ -516,7 +516,7 @@ def _set_strict(module: nn.Module, value: bool) -> None: def _pyro_sample_from_prior( - module: NnModuleSelf, memo: Optional[MutableSet[str]] = None, prefix: str = "" + module: NnModuleSelf, memo: Optional[MutableSet[Prior]] = None, prefix: str = "" ) -> NnModuleSelf: try: import pyro @@ -546,7 +546,7 @@ def _pyro_sample_from_prior( def _pyro_load_from_samples( - module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[str]] = None, prefix: str = "" + module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[Prior]] = None, prefix: str = "" ) -> None: if memo is None: memo = set()