Skip to content

Commit

Permalink
Vlasov 2D2V (#14)
Browse files Browse the repository at this point in the history
* passing tests

* passing tests

* moved vlasov files over

* refactor save funcs
passing tests

* units

* working vlasov 2d2v

* passing all tests!

* ampere solver wrong

but probably converges with timesteps or dv etc

probably best to have a higher order timestepper

* half timesteps

still doesnt quite fix ampere solver
  • Loading branch information
joglekara authored Nov 8, 2023
1 parent 0744ec5 commit 73170e0
Show file tree
Hide file tree
Showing 43 changed files with 1,800 additions and 483 deletions.
File renamed without changes.
File renamed without changes.
63 changes: 38 additions & 25 deletions adept/es1d/helpers.py → adept/tf1d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@
from jax import tree_util as jtu
from flatdict import FlatDict
import equinox as eqx
from diffrax import ODETerm, Tsit5

from jax import numpy as jnp
from adept.es1d import pushers
from utils import nn
from adept.tf1d import pushers
from equinox import nn


def save_arrays(result, td, cfg, label):
if label is None:
label = "x"
flattened_dict = dict(FlatDict(result.ys, delimiter="-"))
save_ax = cfg["grid"]["x"]
else:
flattened_dict = dict(FlatDict(result.ys[label], delimiter="-"))
save_ax = cfg["save"][label]["ax"]
data_vars = {
k: xr.DataArray(v, coords=(("t", cfg["save"]["t"]["ax"]), (label, cfg["save"][label]["ax"])))
for k, v in flattened_dict.items()
k: xr.DataArray(v, coords=(("t", cfg["save"]["t"]["ax"]), (label, save_ax))) for k, v in flattened_dict.items()
}

saved_arrays_xr = xr.Dataset(data_vars)
Expand Down Expand Up @@ -69,21 +71,23 @@ def plot_xrs(which, td, xrs):
plt.close(fig)


def post_process(result, cfg: Dict, td: str) -> None:
def post_process(result, cfg: Dict, td: str) -> Dict:
os.makedirs(os.path.join(td, "binary"))
os.makedirs(os.path.join(td, "plots"))

if cfg["save"]["func"]["is_on"]:
if cfg["save"]["x"]["is_on"]:
xrs = save_arrays(result, td, cfg, label="x")
plot_xrs("x", td, xrs)

if cfg["save"]["kx"]["is_on"]:
xrs = save_arrays(result, td, cfg, label="kx")
plot_xrs("kx", td, xrs)
datasets = {}
if any(x in ["x", "kx"] for x in cfg["save"]):
if "x" in cfg["save"].keys():
datasets["x"] = save_arrays(result, td, cfg, label="x")
plot_xrs("x", td, datasets["x"])
if "kx" in cfg["save"].keys():
datasets["kx"] = save_arrays(result, td, cfg, label="kx")
plot_xrs("kx", td, datasets["kx"])
else:
xrs = save_arrays(result, td, cfg, label=None)
plot_xrs("x", td, xrs)
datasets["full"] = save_arrays(result, td, cfg, label=None)
plot_xrs("x", td, datasets["full"])

return datasets


def get_derived_quantities(cfg_grid: Dict) -> Dict:
Expand All @@ -109,7 +113,7 @@ def get_derived_quantities(cfg_grid: Dict) -> Dict:
return cfg_grid


def get_solver_quantities(cfg_grid: Dict) -> Dict:
def get_solver_quantities(cfg: Dict) -> Dict:
"""
This function just updates the config with the derived quantities that are arrays
Expand All @@ -118,6 +122,8 @@ def get_solver_quantities(cfg_grid: Dict) -> Dict:
:param cfg_grid:
:return:
"""
cfg_grid = cfg["grid"]

cfg_grid = {
**cfg_grid,
**{
Expand Down Expand Up @@ -148,12 +154,20 @@ def get_save_quantities(cfg: Dict) -> Dict:
:param cfg:
:return:
"""
cfg["save"]["func"] = {**cfg["save"]["func"], **{"callable": get_save_func(cfg)}}
cfg["save"]["func"] = {"callable": get_save_func(cfg)}
cfg["save"]["t"]["ax"] = jnp.linspace(cfg["save"]["t"]["tmin"], cfg["save"]["t"]["tmax"], cfg["save"]["t"]["nt"])

return cfg


def get_diffeqsolve_quants(cfg):
return dict(
terms=ODETerm(VectorField(cfg)),
solver=Tsit5(),
saveat=dict(ts=cfg["save"]["t"]["ax"], fn=cfg["save"]["func"]["callable"]),
)


def init_state(cfg: Dict) -> Dict:
"""
This function initializes the state
Expand Down Expand Up @@ -190,7 +204,7 @@ class VectorField(eqx.Module):
push_driver: Callable
poisson_solver: Callable

def __init__(self, cfg, models):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.pusher_dict = {"ion": {}, "electron": {}}
Expand All @@ -203,7 +217,7 @@ def __init__(self, cfg, models):
cfg["grid"]["kx"], cfg["physics"][species_name]
)
if cfg["physics"][species_name]["trapping"]["is_on"]:
self.pusher_dict[species_name]["particle_trapper"] = pushers.ParticleTrapper(cfg, species_name, models)
self.pusher_dict[species_name]["particle_trapper"] = pushers.ParticleTrapper(cfg, species_name)

self.push_driver = pushers.Driver(cfg["grid"]["x"])
# if "ey" in self.cfg["drivers"]:
Expand Down Expand Up @@ -267,16 +281,16 @@ def __call__(self, t: float, y: Dict, args: Dict):


def get_save_func(cfg):
if cfg["save"]["func"]["is_on"]:
if cfg["save"]["x"]["is_on"]:
if any(x in ["x", "kx"] for x in cfg["save"]):
if "x" in cfg["save"].keys():
dx = (cfg["save"]["x"]["xmax"] - cfg["save"]["x"]["xmin"]) / cfg["save"]["x"]["nx"]
cfg["save"]["x"]["ax"] = jnp.linspace(
cfg["save"]["x"]["xmin"] + dx / 2.0, cfg["save"]["x"]["xmax"] - dx / 2.0, cfg["save"]["x"]["nx"]
)

save_x = partial(jnp.interp, cfg["save"]["x"]["ax"], cfg["grid"]["x"])

if cfg["save"]["kx"]["is_on"]:
if "kx" in cfg["save"].keys():
cfg["save"]["kx"]["ax"] = jnp.linspace(
cfg["save"]["kx"]["kxmin"], cfg["save"]["kx"]["kxmax"], cfg["save"]["kx"]["nkx"]
)
Expand All @@ -288,15 +302,14 @@ def save_kx(field):

def save_func(t, y, args):
save_dict = {}
if cfg["save"]["x"]["is_on"]:
if "x" in cfg["save"].keys():
save_dict["x"] = jtu.tree_map(save_x, y)
if cfg["save"]["kx"]["is_on"]:
if "kx" in cfg["save"].keys():
save_dict["kx"] = jtu.tree_map(save_kx, y)

return save_dict

else:
cfg["save"]["x"]["ax"] = cfg["grid"]["x"]
save_func = None

return save_func
Expand Down
14 changes: 7 additions & 7 deletions adept/es1d/pushers.py → adept/tf1d/pushers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def zk_coeff(self, e):
vtrap_sq = ek / self.model_kld
tau1 = 1.0 / self.nuee * vtrap_sq / self.vph**2.0
tau2 = 2.0 * np.pi / self.model_kld / jnp.sqrt(vtrap_sq)
coeff = 0.5 #beta * (vt / self.vph) ** 2.0 * tau2 / tau1
coeff = 0.5 # beta * (vt / self.vph) ** 2.0 * tau2 / tau1

return coeff

Expand Down Expand Up @@ -247,7 +247,7 @@ class ParticleTrapper(eqx.Module):
nu_g_model: eqx.Module
# nu_d_model: eqx.Module

def __init__(self, cfg, species="electron", models=None):
def __init__(self, cfg, species="electron"):
nuee = cfg["physics"][species]["trapping"]["nuee"]
if cfg["physics"][species]["gamma"] == "kinetic":
kinetic_real_epw = True
Expand All @@ -266,17 +266,17 @@ def __init__(self, cfg, species="electron", models=None):
self.vph = jnp.interp(self.model_kld, table_klds, table_wrs, left=1.0, right=table_wrs[-1]) / self.model_kld

# Make models
if models:
self.nu_g_model = models["nu_g"]
else:
self.nu_g_model = lambda x: 1e-3
# if models:
# self.nu_g_model = models["nu_g"]
# else:
# self.nu_g_model = lambda x: 1e-3

def __call__(self, e, delta, args):
ek = jnp.fft.rfft(e, axis=0) * 2.0 / self.kx.size
norm_e = (jnp.log10(jnp.interp(self.model_kld, self.kxr, jnp.abs(ek)) + 1e-10) + 10.0) / -10.0
func_inputs = jnp.stack([norm_e, self.norm_kld, self.norm_nuee], axis=-1)
# jax.debug.print("{x}", x=func_inputs)
growth_rates = 10 ** (3 * jnp.squeeze(self.nu_g_model(func_inputs)))
growth_rates = 10 ** (3 * jnp.squeeze(args["nu_g"](func_inputs)))

return -self.vph * gradient(delta, self.kx) + growth_rates * jnp.abs(
jnp.fft.irfft(ek * self.kx.size / 2.0 * self.wis)
Expand Down
15 changes: 7 additions & 8 deletions train-damping.py → adept/tf1d/train_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
# config.update("jax_debug_nans", True)
# config.update("jax_disable_jit", True)

import jax
from jax import numpy as jnp
import xarray as xr
import tempfile, time
import mlflow, optax, pickle
import mlflow, optax
import equinox as eqx
from tqdm import tqdm

from adept.es1d import helpers
from adept.tf1d import helpers
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5
from utils import misc, plotters

Expand All @@ -43,7 +42,7 @@ def _modify_defaults_(defaults, k0, a0, nuee):

def train_loop():
# modify config
fks = xr.open_dataset("./epws.nc")
fks = xr.open_dataset("../../epws.nc")

nus = np.copy(fks.coords[r"$\nu_{ee}$"].data) # [::4]
k0s = np.copy(fks.coords["$k_0$"].data) # [::4]
Expand Down Expand Up @@ -138,7 +137,7 @@ def remote_train_loop():
batch_size = 16

# modify config
fks = xr.open_dataset("./epws.nc")
fks = xr.open_dataset("../../epws.nc")

nus = np.copy(fks.coords[r"$\nu_{ee}$"].data[::3])
k0s = np.copy(fks.coords["$k_0$"].data[::2])
Expand Down Expand Up @@ -216,7 +215,7 @@ def update_w_and_b(job_done, run_ids, optimizer, opt_state, w_and_b):


def queue_sim(fks, nuee, k0, a0, run_ids, job_done, w_and_b, epoch, i_batch, sim, t_or_v="grad"):
with open("./configs/damping.yaml", "r") as file:
with open("../../configs/tf-1d/damping.yaml", "r") as file:
defaults = yaml.safe_load(file)

mod_defaults = _modify_defaults_(defaults, float(k0), float(a0), float(nuee))
Expand Down Expand Up @@ -249,14 +248,14 @@ def queue_sim(fks, nuee, k0, a0, run_ids, job_done, w_and_b, epoch, i_batch, sim


def eval_over_all():
with open("./configs/damping.yaml", "r") as file:
with open("../../configs/tf-1d/damping.yaml", "r") as file:
defaults = yaml.safe_load(file)
trapping_models = helpers.get_models(defaults["models"])

# batch_size = 16

# modify config
fks = xr.open_dataset("./epws.nc")
fks = xr.open_dataset("../../epws.nc")

nus = np.copy(fks.coords[r"$\nu_{ee}$"].data[::3])
k0s = np.copy(fks.coords["$k_0$"].data[::2])
Expand Down
Empty file added adept/vlasov2d/__init__.py
Empty file.
Binary file added adept/vlasov2d/gamma_func_for_sg.nc
Binary file not shown.
Loading

0 comments on commit 73170e0

Please sign in to comment.