Skip to content

Commit

Permalink
Merge pull request #4 from Magritte-code/float32
Browse files Browse the repository at this point in the history
Fixed interpolation algorithm
  • Loading branch information
ThomasCeulemans authored Nov 7, 2024
2 parents 87f3a4a + 364ce5d commit 38d74ab
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
7 changes: 4 additions & 3 deletions magrittetorch/algorithms/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def get_total_optical_depth_single_direction(model: Model, raydir: torch.Tensor,
device (torch.device): Device on which to compute and return the result.
Returns:
torch.Tensor: The computed intensities [W/m**2/Hz/rad**2]. Has dimensions [NPOINTS, NFREQS]
torch.Tensor: The computed optical depths [.]. Has dimensions [NPOINTS, NFREQS]
"""
#no work distribution, as this should happen in a higher function

Expand Down Expand Up @@ -464,8 +464,9 @@ def compute_level_populations_statistical_equilibrium(model: Model, indices: tor
"""

if (ALI_diag is None):
ALI_diag_fraction = torch.zeros((model.parameters.npoints.get(), model.sources.lines.get_total_number_lines()), dtype=Types.FrequencyInfo, device=device) #dims: [NPOINTS, NLINES]
ALI_diag_Jdiff = torch.zeros((model.parameters.npoints.get(), model.sources.lines.get_total_number_lines()), dtype=Types.FrequencyInfo, device=device) #dims: [NPOINTS, NLINES]
npoints = len(indices)
ALI_diag_fraction = torch.zeros((npoints, model.sources.lines.get_total_number_lines()), dtype=Types.FrequencyInfo, device=device) #dims: [NPOINTS, NLINES]
ALI_diag_Jdiff = torch.zeros((npoints, model.sources.lines.get_total_number_lines()), dtype=Types.FrequencyInfo, device=device) #dims: [NPOINTS, NLINES]
else:
ALI_diag_fraction, ALI_diag_Jdiff = ALI_diag #dims: [NPOINTS, NLINES]

Expand Down
4 changes: 2 additions & 2 deletions magrittetorch/algorithms/torch_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def interpolate2D_linear(interpolation_position: torch.Tensor, interpolation_val
eval_size_repeat = list(evaluation_points.size())
eval_size_repeat[0] = 1

minval = torch.min(interpolation_value, dim=0).values
maxval = torch.max(interpolation_value, dim=0).values#dims: [OTHERDIMS]
minval = interpolation_value[0, :]
maxval = interpolation_value[-1, :]#dims: [OTHERDIMS]
minval = minval.repeat(evaldim_size, 1)#dims: [Any1, OTHERDIMS]
maxval = maxval.repeat(evaldim_size, 1)
result = torch.zeros_like(evaluation_points, dtype=interpolation_value.dtype)
Expand Down
8 changes: 7 additions & 1 deletion magrittetorch/model/sources/lineproducingspecies.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,15 @@ def compute_ng_accelerated_level_pops(self, previous_level_pops: torch.Tensor, d
ng_accelerated_pops = ng_accelerated_pops * self.population_tot.get(device)[:, None] / torch.sum(ng_accelerated_pops, dim=1)[:, None]
return ng_accelerated_pops

def compute_line_cooling(self, current_level_pops: torch.Tensor, device: torch.device) -> torch.Tensor:
def compute_line_cooling(self, current_level_pops: torch.Tensor, device: torch.device, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Computes the line cooling rate for each point, based on the given level populations
TODO: test whether memory error might occur for NLTE models with many levels; if so, consider adding memory management
Args:
current_level_pops (torch.Tensor): The current level populations. Has dimensions [parameters.npoints, linedata.nlev].
device (torch.device): The device on which to compute.
indices (Optional[torch.Tensor], optional): The indices of the points to consider. Defaults to None.
Returns:
torch.Tensor: The line cooling rate. Has dimensions [parameters.npoints] and units W/m^3
Expand All @@ -414,6 +415,11 @@ def compute_line_cooling(self, current_level_pops: torch.Tensor, device: torch.d
temperature = self.dataCollection.get_data("gas temperature").get(device)#dims = [parameters.npoints]
abundance = self.dataCollection.get_data("species abundance").get(device)#dims: [parameters.npoints, parameters.nspecs]
cooling_rate = torch.zeros(self.parameters.npoints.get(), device=device, dtype=Types.LevelPopsInfo)
if indices is not None:
temperature = temperature[indices]
abundance = abundance[indices]
cooling_rate = cooling_rate[indices]

for colpar in self.linedata.colpar:
upper_levels = colpar.icol.get(device)
lower_levels = colpar.jcol.get(device)
Expand Down

0 comments on commit 38d74ab

Please sign in to comment.