From b0e3c0d05d8b773e0dfd518882256808623775c9 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 21 Nov 2024 09:42:44 -0500 Subject: [PATCH] Guassian to torch --- src/caustics/utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/caustics/utils.py b/src/caustics/utils.py index f6399979..236b7abf 100644 --- a/src/caustics/utils.py +++ b/src/caustics/utils.py @@ -7,7 +7,6 @@ import torch from torch import Tensor from torch.func import jacfwd -import numpy as np from scipy.special import roots_legendre @@ -1104,22 +1103,26 @@ def batch_lm( def gaussian(pixelscale, nx, ny, sigma, upsample=1, dtype=torch.float32, device=None): - X, Y = np.meshgrid( - np.linspace( + X, Y = torch.meshgrid( + torch.linspace( -(nx * upsample - 1) * pixelscale / 2, (nx * upsample - 1) * pixelscale / 2, nx * upsample, + dtype=dtype, + device=device, ), - np.linspace( + torch.linspace( -(ny * upsample - 1) * pixelscale / 2, (ny * upsample - 1) * pixelscale / 2, ny * upsample, + dtype=dtype, + device=device, ), indexing="xy", ) - Z = np.exp(-0.5 * (X**2 + Y**2) / sigma**2) + Z = torch.exp(-0.5 * (X**2 + Y**2) / sigma**2) - Z = Z.reshape(ny, upsample, nx, upsample).sum(axis=(1, 3)) + Z = Z.reshape(ny, upsample, nx, upsample).sum(dim=(1, 3)) - return torch.tensor(Z / np.sum(Z), dtype=dtype, device=device) + return Z / Z.sum()