Skip to content

Commit

Permalink
Update lpse (#80)
Browse files Browse the repository at this point in the history
* Arbitrary Driver

* exposing driver formats
  • Loading branch information
joglekara authored Oct 17, 2024
1 parent dc82c25 commit f11e1c3
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 64 deletions.
9 changes: 8 additions & 1 deletion adept/_lpse2d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from .modules import BaseLPSE2D as BaseLPSE2D, save_driver as save_driver
from .modules import (
BaseLPSE2D as BaseLPSE2D,
ArbitraryDriver,
UniformDriver,
GaussianDriver,
LorentzianDriver,
GenerativeDriver,
)
from .helpers import calc_threshold_intensity as calc_threshold_intensity
2 changes: 1 addition & 1 deletion adept/_lpse2d/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .base import BaseLPSE2D as BaseLPSE2D
from .driver import save as save_driver
from .driver import ArbitraryDriver, UniformDriver, GaussianDriver, LorentzianDriver, GenerativeDriver
4 changes: 2 additions & 2 deletions adept/_lpse2d/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def init_modules(self) -> Dict:
if "E0" in self.cfg["drivers"]:
DriverModule = driver.choose_driver(self.cfg["drivers"]["E0"]["shape"])
if "file" in self.cfg["drivers"]["E0"]:
modules["driver"] = driver.load(self.cfg, DriverModule)
modules["laser"] = driver.load(self.cfg, DriverModule)
else:
modules["driver"] = DriverModule(self.cfg)
modules["laser"] = DriverModule(self.cfg)

return modules

Expand Down
129 changes: 70 additions & 59 deletions adept/_lpse2d/modules/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,6 @@
from adept.utils import download_from_s3


def save(filename: str, model_cfg: Dict, model: eqx.Module) -> None:
new_filter_spec = lambda f, x: (
None if isinstance(x, driver_nn.PRNGKeyArray) else eqx.default_serialise_filter_spec(f, x)
)
with open(filename, "wb") as f:
model_cfg_str = json.dumps(model_cfg)
f.write((model_cfg_str + "\n").encode())
eqx.tree_serialise_leaves(f, model, filter_spec=new_filter_spec)


def load(cfg: Dict, DriverModule: eqx.Module) -> eqx.Module:
filename = cfg["drivers"]["E0"]["file"]
with tempfile.TemporaryDirectory() as td:
Expand Down Expand Up @@ -59,8 +49,8 @@ def choose_driver(shape: str) -> eqx.Module:
elif shape == "lorentzian":
return LorentzianDriver

elif shape == "file":
return FileDriver
elif shape == "arbitrary":
return ArbitraryDriver

elif shape == "vae":
return ITLnVAE
Expand All @@ -72,22 +62,28 @@ def choose_driver(shape: str) -> eqx.Module:
raise NotImplementedError(f"Amplitude shape -- {shape} -- not implemented")


class UniformDriver(eqx.Module):
class ArbitraryDriver(eqx.Module):
intensities: Array
delta_omega: Array
initial_phase: Array
envelope: Dict
phase_key: driver_nn.PRNGKeyArray
amp_output: str
phase_output: str
model_cfg: Dict

def __init__(self, cfg: Dict):
super().__init__()
num_colors = cfg["drivers"]["E0"]["num_colors"]
self.intensities = jnp.array(np.ones(num_colors))
delta_omega_max = cfg["drivers"]["E0"]["delta_omega_max"]
self.delta_omega = jnp.linspace(-delta_omega_max, delta_omega_max, num_colors)
self.initial_phase = jnp.array(np.random.uniform(-np.pi, np.pi, num_colors))
self.model_cfg = cfg["drivers"]["E0"]["params"]
self.intensities = np.ones(cfg["drivers"]["E0"]["num_colors"])
self.delta_omega = np.linspace(
-cfg["drivers"]["E0"]["delta_omega_max"],
cfg["drivers"]["E0"]["delta_omega_max"],
cfg["drivers"]["E0"]["num_colors"],
)
self.initial_phase = jnp.array(np.random.uniform(-np.pi, np.pi, cfg["drivers"]["E0"]["num_colors"]))
self.envelope = cfg["drivers"]["E0"]["derived"]
self.phase_key = driver_nn.PRNGKeyArray(PRNGKey(seed=np.random.randint(2**20)))
self.amp_output = "none"
self.phase_output = "none"

def scale_ints_and_phases(self, intensities, phases) -> tuple:
if self.amp_output == "linear":
Expand All @@ -111,6 +107,24 @@ def scale_ints_and_phases(self, intensities, phases) -> tuple:

return ints, phases

def save(self, filename: str) -> None:
"""
Save the model to a file
Parameters
----------
filename : str
The name of the file to save the model to
"""
new_filter_spec = lambda f, x: (
None if isinstance(x, driver_nn.PRNGKeyArray) else eqx.default_serialise_filter_spec(f, x)
)
with open(filename, "wb") as f:
model_cfg_str = json.dumps(self.model_cfg)
f.write((model_cfg_str + "\n").encode())
eqx.tree_serialise_leaves(f, self, filter_spec=new_filter_spec)

def __call__(self, state: Dict, args: Dict) -> tuple:
ints = self.intensities / jnp.sum(self.intensities)
args["drivers"]["E0"] = {
Expand All @@ -121,6 +135,20 @@ def __call__(self, state: Dict, args: Dict) -> tuple:
return state, args


class UniformDriver(ArbitraryDriver):

phase_key: driver_nn.PRNGKeyArray

def __init__(self, cfg: Dict):
super().__init__(cfg)
num_colors = cfg["drivers"]["E0"]["num_colors"]
self.intensities = jnp.array(np.ones(num_colors))
delta_omega_max = cfg["drivers"]["E0"]["delta_omega_max"]
self.delta_omega = jnp.linspace(-delta_omega_max, delta_omega_max, num_colors)
self.initial_phase = jnp.array(np.random.uniform(-np.pi, np.pi, num_colors))
self.phase_key = driver_nn.PRNGKeyArray(PRNGKey(seed=np.random.randint(2**20)))


class GaussianDriver(UniformDriver):

def __init__(self, cfg: Dict):
Expand All @@ -134,15 +162,6 @@ def __init__(self, cfg: Dict):
* np.exp(-4 * np.log(2) * (self.delta_omega / delta_omega_max) ** 2.0)
)

def __call__(self, state: Dict, args: Dict) -> tuple:
ints = self.intensities / jnp.sum(self.intensities)
args["drivers"]["E0"] = {
"delta_omega": stop_gradient(self.delta_omega),
"initial_phase": self.initial_phase,
"intensities": ints,
} | {k: stop_gradient(v) for k, v in self.envelope.items()}
return state, args


class LorentzianDriver(UniformDriver):

Expand All @@ -153,36 +172,6 @@ def __init__(self, cfg: Dict):
1 / np.pi * (delta_omega_max / 2) / (self.delta_omega**2.0 + (delta_omega_max / 2) ** 2.0)
)

def __call__(self, state: Dict, args: Dict) -> tuple:
ints = self.intensities / jnp.sum(self.intensities)
args["drivers"]["E0"] = {
"delta_omega": stop_gradient(self.delta_omega),
"initial_phase": self.initial_phase,
"intensities": ints,
} | {k: stop_gradient(v) for k, v in self.envelope.items()}
return state, args


class FileDriver(UniformDriver):
amp_output: str
phase_output: str

def __init__(self, cfg: Dict):
super().__init__(cfg)

self.amp_output = cfg["drivers"]["E0"]["output"]["amp"]
self.phase_output = cfg["drivers"]["E0"]["output"]["phase"]

def __call__(self, state: Dict, args: Dict) -> tuple:
ints, phases = self.scale_ints_and_phases(self.intensities, self.initial_phase)

args["drivers"]["E0"] = {
"delta_omega": stop_gradient(self.delta_omega),
"initial_phase": phases,
"intensities": ints,
} | {k: stop_gradient(v) for k, v in self.envelope.items()}
return state, args


class GenerativeDriver(UniformDriver):
input_width: int
Expand Down Expand Up @@ -241,6 +230,28 @@ def __init__(self, cfg: Dict):

self.inputs = jnp.array((rescaled_I0, rescaled_Te, rescaled_Ln))

def scale_ints_and_phases(self, intensities, phases) -> tuple:
if self.amp_output == "linear":
ints = 0.5 * (jnp.tanh(ints) + 1.0)
elif self.amp_output == "log":
ints = 3 * (jnp.tanh(intensities) + 1.0) - 6
ints = 10**ints
else:
raise NotImplementedError(f"Amplitude Output type -- {self.amp_output} -- not implemented")

if self.phase_output == "learned":
phases = jnp.tanh(phases) * jnp.pi * 4
elif self.phase_output == "random":
phases = stop_gradient(
uniform(self.phase_key.key, (self.initial_phase.size,), minval=-jnp.pi, maxval=jnp.pi)
)
else:
raise NotImplementedError(f"Phase Output type -- {self.phase_output} -- not implemented")

ints /= jnp.sum(ints)

return ints, phases

def __call__(self, state: Dict, args: Dict) -> tuple:
ints_and_phases = self.model(self.inputs)
ints, phases = self.scale_ints_and_phases(ints_and_phases["amps"], ints_and_phases["phases"])
Expand Down
10 changes: 9 additions & 1 deletion adept/lpse2d.py
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
from ._lpse2d import BaseLPSE2D, save_driver, calc_threshold_intensity
from ._lpse2d import (
BaseLPSE2D,
calc_threshold_intensity,
ArbitraryDriver,
UniformDriver,
GaussianDriver,
LorentzianDriver,
GenerativeDriver,
)

0 comments on commit f11e1c3

Please sign in to comment.