Skip to content

Commit

Permalink
gpytorch.module: fix typing annotation (#2611)
Browse files Browse the repository at this point in the history
* Add typing annotations to gpytorch.Module

* Undo removal of Module.register_parameter()

* Fix typing annotations in gpytorch.Module

* Remove return value from SettingClosure

---------

Co-authored-by: Geoff Pleiss <[email protected]>
  • Loading branch information
chrisyeh96 and gpleiss authored Dec 6, 2024
1 parent 1a8aa8f commit bb86550
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 35 deletions.
2 changes: 1 addition & 1 deletion gpytorch/kernels/cosine_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def period_length(self):

@period_length.setter
def period_length(self, value):
return self._set_period_length(value)
self._set_period_length(value)

def _set_period_length(self, value):
if not torch.is_tensor(value):
Expand Down
35 changes: 21 additions & 14 deletions gpytorch/kernels/cylindrical_kernel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python3

from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor

from .. import settings
from ..constraints import Interval, Positive
Expand Down Expand Up @@ -94,39 +95,45 @@ def __init__(
self.register_prior("beta_prior", beta_prior, lambda m: m.beta, lambda m, v: m._set_beta(v))

@property
def angular_weights(self) -> torch.Tensor:
def angular_weights(self) -> Tensor:
return self.raw_angular_weights_constraint.transform(self.raw_angular_weights)

@angular_weights.setter
def angular_weights(self, value: torch.Tensor) -> None:
def angular_weights(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.tensor(value)

self.initialize(raw_angular_weights=self.raw_angular_weights_constraint.inverse_transform(value))

@property
def alpha(self) -> torch.Tensor:
def alpha(self) -> Tensor:
return self.raw_alpha_constraint.transform(self.raw_alpha)

@alpha.setter
def alpha(self, value: torch.Tensor) -> None:
if not torch.is_tensor(value):
value = torch.tensor(value)
def alpha(self, value: Tensor) -> None:
self._set_alpha(value)

def _set_alpha(self, value: Union[Tensor, float]) -> None:
# Used by the alpha_prior
if not isinstance(value, Tensor):
value = torch.as_tensor(value).to(self.raw_alpha)
self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value))

@property
def beta(self) -> torch.Tensor:
def beta(self) -> Tensor:
return self.raw_beta_constraint.transform(self.raw_beta)

@beta.setter
def beta(self, value: torch.Tensor) -> None:
if not torch.is_tensor(value):
value = torch.tensor(value)
def beta(self, value: Tensor) -> None:
self._set_beta(value)

def _set_beta(self, value: Union[Tensor, float]) -> None:
# Used by the beta_prior
if not isinstance(value, Tensor):
value = torch.as_tensor(value).to(self.raw_beta)
self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value))

def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = False, **params) -> torch.Tensor:
def forward(self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, **params) -> Tensor:

x1_, x2_ = x1.clone(), x2.clone()
# Jitter datapoints that are exactly 0
Expand Down Expand Up @@ -156,12 +163,12 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal
radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params)
return radial_kernel.mul(angular_kernel)

def kuma(self, x: torch.Tensor) -> torch.Tensor:
def kuma(self, x: Tensor) -> Tensor:
alpha = self.alpha.view(*self.batch_shape, 1, 1)
beta = self.beta.view(*self.batch_shape, 1, 1)

res = 1 - (1 - x.pow(alpha) + self.eps).pow(beta)
return res

def num_outputs_per_input(self, x1: torch.Tensor, x2: torch.Tensor) -> int:
def num_outputs_per_input(self, x1: Tensor, x2: Tensor) -> int:
return self.radial_base_kernel.num_outputs_per_input(x1, x2)
18 changes: 9 additions & 9 deletions gpytorch/kernels/hamming_kernel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -95,13 +95,13 @@ def _alpha_param(self, m: Kernel) -> Tensor:
# Used by the alpha_prior
return m.alpha

def _alpha_closure(self, m: Kernel, v: Tensor) -> Tensor:
def _alpha_closure(self, m: Kernel, v: Union[Tensor, float]) -> None:
# Used by the alpha_prior
return m._set_alpha(v)
m._set_alpha(v)

def _set_alpha(self, value: Tensor):
def _set_alpha(self, value: Union[Tensor, float]) -> None:
# Used by the alpha_prior
if not torch.is_tensor(value):
if not isinstance(value, Tensor):
value = torch.as_tensor(value).to(self.raw_alpha)
self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value))

Expand All @@ -117,13 +117,13 @@ def _beta_param(self, m: Kernel) -> Tensor:
# Used by the beta_prior
return m.beta

def _beta_closure(self, m: Kernel, v: Tensor) -> Tensor:
def _beta_closure(self, m: Kernel, v: Union[Tensor, float]) -> None:
# Used by the beta_prior
return m._set_beta(v)
m._set_beta(v)

def _set_beta(self, value: Tensor):
def _set_beta(self, value: Union[Tensor, float]) -> None:
# Used by the beta_prior
if not torch.is_tensor(value):
if not isinstance(value, Tensor):
value = torch.as_tensor(value).to(self.raw_beta)
self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value))

Expand Down
4 changes: 2 additions & 2 deletions gpytorch/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def _lengthscale_param(self, m: Kernel) -> Tensor:
# Used by the lengthscale_prior
return m.lengthscale

def _lengthscale_closure(self, m: Kernel, v: Tensor) -> Tensor:
def _lengthscale_closure(self, m: Kernel, v: Tensor) -> None:
# Used by the lengthscale_prior
return m._set_lengthscale(v)
m._set_lengthscale(v)

def _set_lengthscale(self, value: Tensor):
# Used by the lengthscale_prior
Expand Down
2 changes: 1 addition & 1 deletion gpytorch/kernels/scale_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _outputscale_param(self, m):
return m.outputscale

def _outputscale_closure(self, m, v):
return m._set_outputscale(v)
m._set_outputscale(v)

@property
def outputscale(self):
Expand Down
16 changes: 8 additions & 8 deletions gpytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
ModuleSelf = TypeVar("ModuleSelf", bound="Module") # TODO: replace w/ typing.Self in Python 3.11
RandomModuleSelf = TypeVar("RandomModuleSelf", bound="RandomModuleMixin") # TODO: replace w/ typing.Self in Python 3.11

Closure = Callable[[nn.Module], Tensor]
SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], ModuleSelf]
Closure = Callable[[NnModuleSelf], Tensor]
SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], None]
SamplesDict = Mapping[str, Union[Tensor, float]]


Expand Down Expand Up @@ -111,7 +111,7 @@ def added_loss_terms(self):
def forward(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
raise NotImplementedError

def constraints(self):
def constraints(self) -> Iterator[Interval]:
for _, constraint in self.named_constraints():
yield constraint

Expand Down Expand Up @@ -294,8 +294,8 @@ def closure_new(module: nn.Module) -> Tensor:
if setting_closure is not None:
raise RuntimeError("Must specify a closure instead of a parameter name when providing setting_closure")

def setting_closure_new(module: ModuleSelf, val: Union[Tensor, float]) -> ModuleSelf:
return module.initialize(**{param: val})
def setting_closure_new(module: Module, val: Union[Tensor, float]) -> None:
module.initialize(**{param: val})

setting_closure = setting_closure_new

Expand Down Expand Up @@ -443,7 +443,7 @@ def pyro_sample_from_prior(self) -> Module:
new_module = self.to_pyro_random_module()
return _pyro_sample_from_prior(module=new_module, memo=None, prefix="")

def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[str], prefix: str) -> None:
def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[Prior], prefix: str) -> None:
"""
Defines local behavior of this Module when loading parameters from a samples_dict generated by a Pyro
sampling mechanism.
Expand Down Expand Up @@ -516,7 +516,7 @@ def _set_strict(module: nn.Module, value: bool) -> None:


def _pyro_sample_from_prior(
module: NnModuleSelf, memo: Optional[MutableSet[str]] = None, prefix: str = ""
module: NnModuleSelf, memo: Optional[MutableSet[Prior]] = None, prefix: str = ""
) -> NnModuleSelf:
try:
import pyro
Expand Down Expand Up @@ -546,7 +546,7 @@ def _pyro_sample_from_prior(


def _pyro_load_from_samples(
module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[str]] = None, prefix: str = ""
module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[Prior]] = None, prefix: str = ""
) -> None:
if memo is None:
memo = set()
Expand Down

0 comments on commit bb86550

Please sign in to comment.