Skip to content

Commit

Permalink
Clamp noise in FixedGaussianNoise via min_fixed_noise context man…
Browse files Browse the repository at this point in the history
…ager (#2009)

* add min_fixed_noise context manager

* modify defaults

* Change defaults

Co-authored-by: Geoff Pleiss <[email protected]>
  • Loading branch information
saitcakmak and gpleiss authored Jun 27, 2022
1 parent 538648b commit 2533800
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
11 changes: 11 additions & 0 deletions gpytorch/likelihoods/noise_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import warnings
from typing import Any, Optional

import torch
Expand All @@ -12,6 +13,7 @@
from ..lazy import ConstantDiagLazyTensor, DiagLazyTensor, ZeroLazyTensor
from ..module import Module
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.warnings import NumericalWarning


class Noise(Module):
Expand Down Expand Up @@ -138,6 +140,15 @@ def forward(
class FixedGaussianNoise(Module):
def __init__(self, noise: Tensor) -> None:
super().__init__()
min_noise = settings.min_fixed_noise.value(noise.dtype)
if noise.lt(min_noise).any():
warnings.warn(
"Very small noise values detected. This will likely "
"lead to numerical instabilities. Rounding small noise "
f"values up to {min_noise}.",
NumericalWarning,
)
noise = noise.clamp_min(min_noise)
self.noise = noise

def forward(
Expand Down
15 changes: 15 additions & 0 deletions gpytorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,21 @@ class max_cg_iterations(_value_context):
_global_value = 1000


class min_fixed_noise(_dtype_value_context):
"""
The minimum noise value that can be used in :obj:`~gpytorch.likelihoods.FixedNoiseGaussianLikelihood`.
If the supplied noise values are smaller than this, they are rounded up and a warning is raised.
- Default for `float`: 1e-4
- Default for `double`: 1e-6
- Default for `half`: 1e-3
"""

_global_float_value = 1e-4
_global_double_value = 1e-6
_global_half_value = 1e-3


class min_variance(_dtype_value_context):
"""
The minimum variance that can be returned from :obj:`~gpytorch.distributions.MultivariateNormal#variance`.
Expand Down
8 changes: 8 additions & 0 deletions test/likelihoods/test_gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.lazy import DiagLazyTensor
from gpytorch.likelihoods import (
Expand Down Expand Up @@ -84,6 +85,13 @@ def test_fixed_noise_gaussian_likelihood(self, cuda=False):
obs_noise = 0.1 + torch.rand(5, device=device, dtype=dtype)
out = lkhd(mvn, noise=obs_noise)
self.assertTrue(torch.allclose(out.variance, 1 + obs_noise))
# test noise smaller than min_fixed_noise
expected_min_noise = settings.min_fixed_noise.value(dtype)
noise[:2] = 0
lkhd = FixedNoiseGaussianLikelihood(noise=noise)
expected_noise = noise.clone()
expected_noise[:2] = expected_min_noise
self.assertTrue(torch.allclose(lkhd.noise, expected_noise))


class TestFixedNoiseGaussianLikelihoodBatch(BaseLikelihoodTestCase, unittest.TestCase):
Expand Down

0 comments on commit 2533800

Please sign in to comment.