Skip to content

Commit

Permalink
Bugfixes (#13)
Browse files Browse the repository at this point in the history
* save config

* test damping

* local eval

* zk skeleton w/ tests

* tests
  • Loading branch information
joglekara authored Sep 30, 2023
1 parent 2c7eee5 commit 0744ec5
Show file tree
Hide file tree
Showing 10 changed files with 315 additions and 157 deletions.
55 changes: 51 additions & 4 deletions adept/es1d/pushers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,37 @@ def __call__(self, n, u):

class VelocityStepper(eqx.Module):
kx: jax.Array
kxr: jax.Array
wr_corr: jax.Array
wis: jax.Array
nuee: jnp.float64
vph: jnp.float64
model_kld: jnp.float64
trapping_model: str

def __init__(self, kx, kxr, one_over_kxr, physics):
self.kx = kx
self.kxr = kxr
if physics["gamma"] == "kinetic":
kinetic_real_epw = True
else:
kinetic_real_epw = False

table_wrs, table_wis, table_klds = get_complex_frequency_table(1024, kinetic_real_epw)
wrs, wis, klds = get_complex_frequency_table(1024, True if physics["gamma"] == "kinetic" else False)
wrs = jnp.array(jnp.interp(kxr, klds, wrs, left=1.0, right=wrs[-1]))
wrs = jnp.interp(kxr, table_klds, table_wrs, left=1.0, right=table_wrs[-1])
self.wis = jnp.interp(kxr, table_klds, table_wis, left=0.0, right=0.0)
self.nuee = 0.0
self.model_kld = 0.0
self.vph = 0.0
self.trapping_model = "none"

if physics["trapping"]["is_on"]:
self.nuee = physics["trapping"]["nuee"]
self.trapping_model = physics["trapping"]["model"]
self.model_kld = physics["trapping"]["kld"]
# table_klds = table_klds
self.vph = jnp.interp(self.model_kld, table_klds, table_wrs, left=1.0, right=table_wrs[-1]) / self.model_kld

if physics["gamma"] == "kinetic":
self.wr_corr = (jnp.square(wrs) - 1.0) * one_over_kxr**2.0
Expand All @@ -153,18 +176,42 @@ def __init__(self, kx, kxr, one_over_kxr, physics):
else:
self.wis = jnp.zeros_like(kxr)

def landau_damping_term(self, u):
return 2 * jnp.real(jnp.fft.irfft(self.wis * jnp.fft.rfft(u)))
def landau_damping_term(self, u, e, delta):
baseline = 2 * jnp.real(jnp.fft.irfft(self.wis * jnp.fft.rfft(u)))
if self.trapping_model == "zk":
coeff = self.zk_coeff(e)
elif self.trapping_model == "delta":
coeff = 1.0 / (1.0 + delta**2.0)
elif self.trapping_model == "none":
coeff = 1.0
else:
raise NotImplementedError

return baseline * coeff

def restoring_force_term(self, gradp_over_nm):
return jnp.real(jnp.fft.irfft(self.wr_corr * jnp.fft.rfft(gradp_over_nm)))

def zk_coeff(self, e):
beta = 3.0 * np.sqrt(2.0) * (7.0 * np.pi + 6.0) / (4.0 * np.pi**2.0)
vt = 1.0

ek = jnp.fft.rfft(e, axis=0) * 2.0 / self.kx.size
ek = jnp.interp(self.model_kld, self.kxr, jnp.abs(ek))

vtrap_sq = ek / self.model_kld
tau1 = 1.0 / self.nuee * vtrap_sq / self.vph**2.0
tau2 = 2.0 * np.pi / self.model_kld / jnp.sqrt(vtrap_sq)
coeff = 0.5 #beta * (vt / self.vph) ** 2.0 * tau2 / tau1

return coeff

def __call__(self, n, u, p_over_m, q_over_m_times_e, delta):
return (
-u * gradient(u, self.kx)
- self.restoring_force_term(gradient(p_over_m, self.kx) / n)
- q_over_m_times_e
+ self.landau_damping_term(u) / (1.0 + delta**2)
+ self.landau_damping_term(u, q_over_m_times_e, delta)
)


Expand Down
27 changes: 14 additions & 13 deletions configs/damping.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ mlflow:
run: test

models:
file: False #models/weights.eqx
file: models/weights.eqx
nu_g:
in_size: 3
out_size: 1
width_size: 16
depth: 3
activation: tanh
final_activation: tanh
activation: tanh
depth: 4
final_activation: tanh
in_size: 3
out_size: 1
width_size: 8


grid:
Expand All @@ -35,9 +35,9 @@ save:
xmax: 20.94
nx: 16
kx:
is_on: False
is_on: True
kxmin: 0.0
kxmax: 0.6
kxmax: 0.3
nkx: 2

physics:
Expand All @@ -61,21 +61,22 @@ physics:
charge: -1.0
trapping:
is_on: True
model: zk
kld: 0.3
nuee: 1.0e-9
nuee: 1.0e-7
nn: 8|8


drivers:
"ex":
"0":
"k0": 0.22
"w0": 1.1
"k0": 0.3
"w0": 1.16
"dw0": 0.0
"t_c": 40.0
"t_w": 40.0
"t_r": 5.0
"x_c": 400.0
"x_w": 1000000.0
"x_r": 10.0
"a0": 1.e-5
"a0": 1.e-3
57 changes: 26 additions & 31 deletions configs/epw.yaml
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
drivers:
ex:
'0':
a0: 0.005
a0: 0.01
dw0: 0.0
k0: 0.3
k0: 0.32
t_c: 40.0
t_r: 5.0
t_w: 40.0
w0: 1.1598464805919155
x_c: 300.0
w0: 1.1433284742365162
x_c: 400.0
x_r: 10.0
x_w: 300.0

x_w: 1000000.0
grid:
nx: 2048
nx: 16
tmax: 500.0
tmin: 0.0
xmax: 2000
xmax: 19.634954084936208
xmin: 0.0

mlflow:
experiment: es1d-epw-test
run: wp-l

run: nl-fluid-noml
mode: es-1d

models:
file: false
nu_g:
activation: tanh
depth: 4
final_activation: tanh
in_size: 3
out_size: 1
width_size: 8
physics:
electron:
T0: 1.0
Expand All @@ -35,9 +40,9 @@ physics:
mass: 1.0
trapping:
is_on: false
kld: 0.3
kld: 0.32
nn: 8|8
nuee: 1.0e-05
nuee: 0.0001
ion:
T0: 1.0
charge: 1.0
Expand All @@ -50,30 +55,20 @@ physics:
kld: 0.3
nuee: 1.0e-09
landau_damping: true

save:
func:
is_on: true
kx:
is_on: true
kxmax: 0.9
kxmin: 0.3
nkx: 3
is_on: false
kxmax: 0.32
kxmin: 0.0
nkx: 2
t:
nt: 1000
tmax: 500.0
tmin: 0.5
x:
is_on: true
nx: 256
xmax: 1000
xmin: 0.0
models:
file: False #models/weights.eqx
# nu_g:
# activation: tanh
# depth: 4
# final_activation: tanh
# in_size: 3
# out_size: 1
# width_size: 8
nx: 16
xmax: 19.634954084936208
xmin: 0.0
20 changes: 10 additions & 10 deletions configs/es1d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ mlflow:
run: single-k=0.3-finite-amplitude-fluid-trapping

grid:
nx: 4096
nx: 6144
xmin: 0.0
xmax: 4000
xmax: 6000
tmin: 0.0
tmax: 1000.0

Expand Down Expand Up @@ -35,16 +35,16 @@ models:
in_size: 3
out_size: 1
width_size: 8
depth: 3
activation: tanh
final_activation: tanh
nu_d:
in_size: 3
out_size: 1
width_size: 8
depth: 3
depth: 4
activation: tanh
final_activation: tanh
# nu_d:
# in_size: 3
# out_size: 1
# width_size: 8
# depth: 3
# activation: tanh
# final_activation: tanh


physics:
Expand Down
73 changes: 73 additions & 0 deletions configs/wp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
drivers:
ex:
'0':
a0: 0.005
dw0: 0.0
k0: 0.3
t_c: 40.0
t_r: 5.0
t_w: 40.0
w0: 1.1598464805919155
x_c: 300.0
x_r: 10.0
x_w: 300.0
grid:
nx: 4096
tmax: 500.0
tmin: 0.0
xmax: 4000
xmin: 0.0
mlflow:
experiment: es1d-epw-test
run: wp-nl-local
mode: es-1d
models:
file: models/weights.eqx
nu_g:
activation: tanh
depth: 4
final_activation: tanh
in_size: 3
out_size: 1
width_size: 8
physics:
electron:
T0: 1.0
charge: -1.0
gamma: kinetic
is_on: true
landau_damping: true
mass: 1.0
trapping:
is_on: true
kld: 0.3
nn: 8|8
nuee: 1.0e-05
ion:
T0: 1.0
charge: 1.0
gamma: 3
is_on: false
landau_damping: false
mass: 1836.0
trapping:
is_on: true
kld: 0.3
nuee: 1.0e-09
save:
func:
is_on: true
kx:
is_on: true
kxmax: 0.9
kxmin: 0.3
nkx: 3
t:
nt: 1000
tmax: 500.0
tmin: 0.5
x:
is_on: true
nx: 256
xmax: 1000
xmin: 0.0
4 changes: 2 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)
config.update("jax_disable_jit", True)

import yaml, mlflow
from utils.runner import run

if __name__ == "__main__":
with open("configs/epw.yaml", "r") as fi:
with open("configs/damping.yaml", "r") as fi:
# with open("tests/configs/resonance.yaml", "r") as fi:
cfg = yaml.safe_load(fi)

Expand Down
Loading

0 comments on commit 0744ec5

Please sign in to comment.