diff --git a/magrittetorch/algorithms/solvers.py b/magrittetorch/algorithms/solvers.py index 3c7115d..bb87e4d 100644 --- a/magrittetorch/algorithms/solvers.py +++ b/magrittetorch/algorithms/solvers.py @@ -302,7 +302,7 @@ def get_total_optical_depth_single_direction(model: Model, raydir: torch.Tensor, device (torch.device): Device on which to compute and return the result. Returns: - torch.Tensor: The computed intensities [W/m**2/Hz/rad**2]. Has dimensions [NPOINTS, NFREQS] + torch.Tensor: The computed optical depths [.]. Has dimensions [NPOINTS, NFREQS] """ #no work distribution, as this should happen in a higher function @@ -464,8 +464,9 @@ def compute_level_populations_statistical_equilibrium(model: Model, indices: tor """ if (ALI_diag is None): - ALI_diag_fraction = torch.zeros((model.parameters.npoints.get(), model.sources.lines.get_total_number_lines()), dtype=Types.FrequencyInfo, device=device) #dims: [NPOINTS, NLINES] - ALI_diag_Jdiff = torch.zeros((model.parameters.npoints.get(), model.sources.lines.get_total_number_lines()), dtype=Types.FrequencyInfo, device=device) #dims: [NPOINTS, NLINES] + npoints = len(indices) + ALI_diag_fraction = torch.zeros((npoints, model.sources.lines.get_total_number_lines()), dtype=Types.FrequencyInfo, device=device) #dims: [NPOINTS, NLINES] + ALI_diag_Jdiff = torch.zeros((npoints, model.sources.lines.get_total_number_lines()), dtype=Types.FrequencyInfo, device=device) #dims: [NPOINTS, NLINES] else: ALI_diag_fraction, ALI_diag_Jdiff = ALI_diag #dims: [NPOINTS, NLINES] diff --git a/magrittetorch/algorithms/torch_algorithms.py b/magrittetorch/algorithms/torch_algorithms.py index b28fd62..9a83d39 100644 --- a/magrittetorch/algorithms/torch_algorithms.py +++ b/magrittetorch/algorithms/torch_algorithms.py @@ -131,8 +131,8 @@ def interpolate2D_linear(interpolation_position: torch.Tensor, interpolation_val eval_size_repeat = list(evaluation_points.size()) eval_size_repeat[0] = 1 - minval = torch.min(interpolation_value, dim=0).values - maxval = torch.max(interpolation_value, dim=0).values#dims: [OTHERDIMS] + minval = interpolation_value[0, :] + maxval = interpolation_value[-1, :]#dims: [OTHERDIMS] minval = minval.repeat(evaldim_size, 1)#dims: [Any1, OTHERDIMS] maxval = maxval.repeat(evaldim_size, 1) result = torch.zeros_like(evaluation_points, dtype=interpolation_value.dtype) diff --git a/magrittetorch/model/sources/lineproducingspecies.py b/magrittetorch/model/sources/lineproducingspecies.py index 61f4604..390bb17 100644 --- a/magrittetorch/model/sources/lineproducingspecies.py +++ b/magrittetorch/model/sources/lineproducingspecies.py @@ -397,7 +397,7 @@ def compute_ng_accelerated_level_pops(self, previous_level_pops: torch.Tensor, d ng_accelerated_pops = ng_accelerated_pops * self.population_tot.get(device)[:, None] / torch.sum(ng_accelerated_pops, dim=1)[:, None] return ng_accelerated_pops - def compute_line_cooling(self, current_level_pops: torch.Tensor, device: torch.device) -> torch.Tensor: + def compute_line_cooling(self, current_level_pops: torch.Tensor, device: torch.device, indices: Optional[torch.Tensor] = None) -> torch.Tensor: """ Computes the line cooling rate for each point, based on the given level populations TODO: test whether memory error might occur for NLTE models with many levels; if so, consider adding memory management @@ -405,6 +405,7 @@ def compute_line_cooling(self, current_level_pops: torch.Tensor, device: torch.d Args: current_level_pops (torch.Tensor): The current level populations. Has dimensions [parameters.npoints, linedata.nlev]. device (torch.device): The device on which to compute. + indices (Optional[torch.Tensor], optional): The indices of the points to consider. Defaults to None. Returns: torch.Tensor: The line cooling rate. Has dimensions [parameters.npoints] and units W/m^3 @@ -414,6 +415,11 @@ def compute_line_cooling(self, current_level_pops: torch.Tensor, device: torch.d temperature = self.dataCollection.get_data("gas temperature").get(device)#dims = [parameters.npoints] abundance = self.dataCollection.get_data("species abundance").get(device)#dims: [parameters.npoints, parameters.nspecs] cooling_rate = torch.zeros(self.parameters.npoints.get(), device=device, dtype=Types.LevelPopsInfo) + if indices is not None: + temperature = temperature[indices] + abundance = abundance[indices] + cooling_rate = cooling_rate[indices] + for colpar in self.linedata.colpar: upper_levels = colpar.icol.get(device) lower_levels = colpar.jcol.get(device)