diff --git a/python/adjoint/utils.py b/python/adjoint/utils.py index fb37d2b0b..84f8f9866 100644 --- a/python/adjoint/utils.py +++ b/python/adjoint/utils.py @@ -24,22 +24,33 @@ class DesignRegion: - def __init__(self, design_parameters, volume=None, size=None, center=mp.Vector3()): + def __init__( + self, + design_parameters: Iterable[onp.ndarray], + volume: mp.Volume = None, + size: mp.Vector3 = None, + center: mp.Vector3 = mp.Vector3(), + ): self.volume = volume or mp.Volume(center=center, size=size) self.size = self.volume.size self.center = self.volume.center self.design_parameters = design_parameters self.num_design_params = design_parameters.num_params - def update_design_parameters(self, design_parameters): + def update_design_parameters(self, design_parameters) -> None: self.design_parameters.update_weights(design_parameters) - def update_beta(self, beta): + def update_beta(self, beta: float) -> None: self.design_parameters.beta = beta def get_gradient( - self, sim, fields_a, fields_f, frequencies, finite_difference_step - ): + self, + sim: mp.Simulation, + fields_a: List[mp.DftFields], + fields_f: List[mp.DftFields], + frequencies: List[float], + finite_difference_step: float, + ) -> onp.ndarray: num_freqs = onp.array(frequencies).size """We have the option to linearly scale the gradients up front using the scalegrad parameter (leftover from MPB API). Not @@ -67,11 +78,11 @@ def get_gradient( return onp.squeeze(grad).T -def _check_if_cylindrical(sim): +def _check_if_cylindrical(sim: mp.Simulation) -> bool: return sim.is_cylindrical or (sim.dimensions == mp.CYLINDRICAL) -def _compute_components(sim): +def _compute_components(sim: mp.Simulation) -> List[int]: return ( _ADJOINT_FIELD_COMPONENTS_CYL if _check_if_cylindrical(sim) @@ -88,8 +99,8 @@ def calculate_vjps( simulation: mp.Simulation, design_regions: List[DesignRegion], frequencies: List[float], - fwd_fields: List[List[onp.ndarray]], - adj_fields: List[List[onp.ndarray]], + fwd_fields: List[List[mp.DftFields]], + adj_fields: List[List[mp.DftFields]], design_variable_shapes: List[Tuple[int, ...]], sum_freq_partials: bool = True, finite_difference_step: float = FD_DEFAULT, @@ -132,7 +143,7 @@ def install_design_region_monitors( design_regions: List[DesignRegion], frequencies: List[float], decimation_factor: int = 0, -) -> List[mp.DftFields]: +) -> List[List[mp.DftFields]]: """Installs DFT field monitors at the design regions of the simulation.""" return [ [ @@ -168,41 +179,6 @@ def gather_monitor_values(monitors: List[ObjectiveQuantity]) -> onp.ndarray: return monitor_values -def gather_design_region_fields( - simulation: mp.Simulation, - design_region_monitors: List[mp.DftFields], - frequencies: List[float], -) -> List[List[onp.ndarray]]: - """Collects the design region DFT fields from the simulation. - - Args: - simulation: the simulation object. - design_region_monitors: the installed design region monitors. - frequencies: the frequencies to monitor. - - Returns: - A list of lists. Each entry (list) in the overall list corresponds one-to- - one with a declared design region. For each such contained list, the - entries correspond to the field components that are monitored. The entries - are ndarrays of rank 4 with dimensions (freq, x, y, (z-or-pad)). - - The design region fields are sampled on the *Yee grid*. This makes them - fairly awkward to inspect directly. Their primary use case is supporting - gradient calculations. - """ - design_region_fields = [] - for monitor in design_region_monitors: - fields_by_component = [] - for component in _compute_components(simulation): - fields_by_freq = [] - for freq_idx, _ in enumerate(frequencies): - fields = simulation.get_dft_array(monitor, component, freq_idx) - fields_by_freq.append(_make_at_least_nd(fields)) - fields_by_component.append(onp.stack(fields_by_freq)) - design_region_fields.append(fields_by_component) - return design_region_fields - - def validate_and_update_design( design_regions: List[DesignRegion], design_variables: Iterable[onp.ndarray] ) -> None: diff --git a/python/adjoint/wrapper.py b/python/adjoint/wrapper.py index 06d9ad96e..6e037e36f 100644 --- a/python/adjoint/wrapper.py +++ b/python/adjoint/wrapper.py @@ -46,7 +46,7 @@ def loss(x): value, grad = jax.value_and_grad(loss)(x) ``` """ -from typing import Callable, List, Tuple +from typing import Callable, Iterable, List, Tuple import jax import jax.numpy as jnp @@ -137,7 +137,9 @@ def __call__(self, designs: List[jnp.ndarray]) -> jnp.ndarray: """ return self._simulate_fn(designs) - def _run_fwd_simulation(self, design_variables): + def _run_fwd_simulation( + self, design_variables: Iterable[onp.ndarray] + ) -> (jnp.ndarray, List[List[mp.DftFields]]): """Runs forward simulation, returning monitor values and design region fields.""" utils.validate_and_update_design(self.design_regions, design_variables) self.simulation.reset_meep() @@ -161,7 +163,9 @@ def _run_fwd_simulation(self, design_variables): monitor_values = utils.gather_monitor_values(self.monitors) return (jnp.asarray(monitor_values), fwd_design_region_monitors) - def _run_adjoint_simulation(self, monitor_values_grad): + def _run_adjoint_simulation( + self, monitor_values_grad: onp.ndarray + ) -> List[List[mp.DftFields]]: """Runs adjoint simulation, returning design region fields.""" if not self.design_regions: raise RuntimeError( @@ -195,11 +199,11 @@ def _run_adjoint_simulation(self, monitor_values_grad): def _calculate_vjps( self, - fwd_fields, - adj_fields, - design_variable_shapes, - sum_freq_partials=True, - ): + fwd_fields: List[List[mp.DftFields]], + adj_fields: List[List[mp.DftFields]], + design_variable_shapes: List[Tuple[int, ...]], + sum_freq_partials: bool = True, + ) -> List[onp.ndarray]: """Calculates the VJP for a given set of forward and adjoint fields.""" return utils.calculate_vjps( self.simulation, diff --git a/python/tests/test_adjoint_jax.py b/python/tests/test_adjoint_jax.py index ea007ddc5..eb6c5f506 100644 --- a/python/tests/test_adjoint_jax.py +++ b/python/tests/test_adjoint_jax.py @@ -9,16 +9,20 @@ import meep as mp -# The calculation of finite difference gradients requires that JAX be operated with double precision +# The calculation of finite-difference gradients +# requires that JAX be operated with double precision jax.config.update("jax_enable_x64", True) -# The step size for the finite difference gradient calculation +# The step size for the finite-difference +# gradient calculation _FD_STEP = 1e-4 -# The tolerance for the adjoint and finite difference gradient comparison +# The tolerance for the adjoint and finite-difference +# gradient comparison _TOL = 0.1 if mp.is_single_precision() else 0.025 -# We expect 3 design region monitor pointers (one for each field component) +# We expect 3 design region monitor pointers +# (one for each field component) _NUM_DES_REG_MON = 3 mp.verbosity(0) @@ -257,8 +261,8 @@ def loss_fn(x, excite_port_idx=0): frequencies, ) monitor_values = wrapped_meep([x]) - s1p, s1m, s2m, s2p = monitor_values - t = s2m / s1p if excite_port_idx == 0 else s1m / s2p + s1p, s1m, s2p, s2m = monitor_values + t = s2p / s1p if excite_port_idx == 0 else s1m / s2m return jnp.mean(jnp.square(jnp.abs(t))) value, adjoint_grad = jax.value_and_grad(loss_fn)(