From 4dad1077ec485410283ccab5e1004ec392dda421 Mon Sep 17 00:00:00 2001 From: Archis Joglekar Date: Sat, 14 Sep 2024 20:33:06 -0700 Subject: [PATCH] Refactoring for package (#72) * more refactoring for packaging * refactor for package * fixing vfp test * vfp test * update badge --- .github/workflows/cpu-tests.yaml | 3 +- README.md | 2 +- adept/__init__.py | 410 +-------------------- adept/_base_.py | 409 ++++++++++++++++++++ adept/_lpse2d/__init__.py | 2 + adept/_lpse2d/core/driver.py | 4 +- adept/_lpse2d/core/integrator.py | 2 +- adept/_lpse2d/core/vector_field.py | 2 +- adept/_lpse2d/helpers.py | 4 +- adept/_lpse2d/modules/__init__.py | 2 + adept/_lpse2d/modules/base.py | 3 +- adept/_vlasov1d/helpers.py | 15 +- adept/_vlasov1d/modules.py | 3 +- adept/_vlasov1d/solvers/pushers/field.py | 2 +- adept/_vlasov1d/solvers/vector_field.py | 2 +- adept/lpse2d.py | 4 +- adept/tf1d/solvers/pushers.py | 2 +- adept/utils.py | 88 ----- adept/vfp1d/base.py | 2 +- adept/vfp1d/helpers.py | 15 +- tests/test_lpse2d/test_tpd_threshold.py | 5 - tests/test_tf1d/test_against_vlasov.py | 4 +- tests/test_tf1d/test_landau_damping.py | 4 +- tests/test_tf1d/test_resonance.py | 4 +- tests/test_vfp1d/epp-short.yaml | 2 +- tests/test_vlasov1d/test_absorbing_wave.py | 2 +- tests/test_vlasov1d/test_landau_damping.py | 4 +- 27 files changed, 457 insertions(+), 544 deletions(-) create mode 100644 adept/_base_.py diff --git a/.github/workflows/cpu-tests.yaml b/.github/workflows/cpu-tests.yaml index 0f155cd..f5d9c95 100644 --- a/.github/workflows/cpu-tests.yaml +++ b/.github/workflows/cpu-tests.yaml @@ -8,6 +8,7 @@ on: push: branches: - main + - feature/package # Allows you to run this workflow manually from the Actions tab workflow_dispatch: @@ -34,4 +35,4 @@ jobs: - name: Test with pytest run: | - CPU_ONLY=True pytest tests/test_lpse2d tests/test_vlasov1d + CPU_ONLY=True pytest tests/test_lpse2d tests/test_vlasov1d tests/test_vfp1d diff --git a/README.md b/README.md index 182fe9b..6650642 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # ADEPT ![Docs](https://readthedocs.org/projects/adept/badge/?version=latest) -![Tests](https://github.com/ergodicio/adept/actions/workflows/test.yaml/badge.svg) +![Tests](https://github.com/ergodicio/adept/actions/workflows/cpu-tests.yaml/badge.svg) ![ADEPT](./docs/source/adept-logo.png) diff --git a/adept/__init__.py b/adept/__init__.py index 655ba0e..f04e818 100644 --- a/adept/__init__.py +++ b/adept/__init__.py @@ -1,409 +1,3 @@ -from typing import Dict, Tuple, Callable -import jax.flatten_util -import os, time, tempfile, yaml, pickle +from ._base_ import ergoExo, ADEPTModule - -from diffrax import Solution, Euler, RESULTS -from equinox import Module, filter_jit -import mlflow, jax, numpy as np -from jax import numpy as jnp - - -def get_envelope(p_wL, p_wR, p_L, p_R, ax): - return 0.5 * (jnp.tanh((ax - p_L) / p_wL) - jnp.tanh((ax - p_R) / p_wR)) - - -class Stepper(Euler): - """ - This is just a dummy stepper - - :param cfg: - """ - - def step(self, terms, t0, t1, y0, args, solver_state, made_jump): - del solver_state, made_jump - y1 = terms.vf(t0, y0, args) - dense_info = dict(y0=y0, y1=y1) - return y1, None, dense_info, None, RESULTS.successful - - -class ADEPTModule: - """ - This class is the base class for all the ADEPT modules. It defines the interface that all the ADEPT modules must implement so that - the `ergoExo` class can call them in the right order. - - Args: - cfg: The configuration dictionary - - """ - - def __init__(self, cfg) -> None: - self.cfg = cfg - self.state = None - self.args = None - self.diffeqsolve_quants = None - self.time_quantities = None - - def post_process(self, run_output: Dict, td: str) -> Dict: - """ - This function is responsible for post-processing the results of the simulation. It is called after the simulation is run and the results are available. - - Args: - run_output (Dict): The output of the simulation - td (str): The temporary directory where the results are stored - - Returns: - A dictionary of the post-processed results. This can include the metrics, the ``xarray`` datasets, and any other information that is relevant to the simulation - - """ - - return {} - - def write_units(self) -> Dict: - """ - This function is responsible for writing the units, normalizing constants, and other important physical quantities to a dictionary. - This dictionary is then dumped to a yaml file and logged to mlflow by the ``ergoExo`` class. - - Returns: - A dictionary of the units - - """ - return {} - - def init_diffeqsolve(self) -> Dict: - """ - This function is responsible for initializing the differential equation solver ``diffrax.diffeqsolve``. It sets up the time quantities, the solver quantities, and the save function. - - Returns: - A dictionary of the differential equation solver quantities - - """ - pass - - def get_derived_quantities(self) -> Dict: - """ - This function is responsible for getting the derived quantities from the configuration dictionary. This is needed for running the simulation. These quantities do get logged to mlflow - by the ``ergoExo`` class. - - Returns: - An updated configuration dictionary - - """ - pass - - def get_solver_quantities(self): - """ - This function is responsible for getting the solver quantities from the configuration dictionary. This is needed for running the simulation. These quantities do NOT get logged - to mlflow because they are often arrays - - Returns: - An updated configuration dictionary - - """ - pass - - def get_save_func(self): - """ - This function is responsible for getting the save function for the differential equation solver. This is needed for running the simulation. - This function lets you subsample your simulation state so as to not save the entire thing at every timestep. - - This dictionary is set as a class attribute for the ``ADEPTModule`` and are used in the ``__call__`` function - - """ - pass - - def init_state_and_args(self): - """ - This function initializes the state and the arguments that are required to run the simulation. The state is the initial conditions of the simulation and - the arguments are often the drivers - - These are set as class attributes for the ``ADEPTModule`` and are used in the ``__call__`` function - - """ - return {} - - def init_modules(self) -> Dict[str, Module]: - """ - This function initializes the necessary (trainable) physics modules that are required to run the simulation. These can be modules that - change the initial conditions, or the driver (boundary conditions), or the metric calculation. These modules are usually `eqx.Module`s - so that you can take derivatives against the (parameters of the) modules. - - Returns: - Dict: A dictionary of the (trainable) modules that are required to run the simulation - - """ - return {} - - def __call__(self, trainable_modules: Dict, args: Dict): - return {} - - def vg(self, trainable_modules: Dict, args: Dict): - raise NotImplementedError( - "This is the base class and does not have a gradient implemented. This is " - + "likely because there is no metric in place. Subclass this class and implement the gradient" - ) - # return eqx.filter_value_and_grad(self.__call__)(trainable_modules) - - -class ergoExo: - """ - This class is the main interface for running a simulation. It is responsible for calling all the ADEPT modules in the right order - and logging parameters and results to mlflow. - - This approach helps decouple the numerical solvers from the experiment management and facilitates the addition of new solvers - - Typical usage is as follows - - .. code-block:: python - - exoskeleton = ergoExo() - modules = exoskeleton.setup(cfg) - run_output, post_processing_output, mlflow_run_id = exoskeleton(modules, args=None) - - - If you are resuming an existing mlflow run, you can do the following - - .. code-block:: python - - exoskeleton = ergoExo(mlflow_run_id=mlflow_run_id) - modules = exoskeleton.setup(cfg) - run_output, post_processing_output, mlflow_run_id = exoskeleton(modules, args=None) - - If you are introducing a custom `ADEPTModule`, you can do the following - - .. code-block:: python - - exoskeleton = ergoExo() - modules = exoskeleton.setup(cfg, exoskeleton_module=custom_module) - run_output, post_processing_output, mlflow_run_id = adept(modules, args=None) - - - """ - - def __init__(self, mlflow_run_id: str = None, mlflow_nested: bool = None) -> None: - - self.mlflow_run_id = mlflow_run_id - # if mlflow_run_id is not None: - # assert self.mlflow_nested is not None - if mlflow_nested is None: - self.mlflow_nested = False - else: - self.mlflow_nested = mlflow_nested - - if "BASE_TEMPDIR" in os.environ: - self.base_tempdir = os.environ["BASE_TEMPDIR"] - else: - self.base_tempdir = None - - self.ran_setup = False - - def setup(self, cfg: Dict, adept_module: ADEPTModule = None) -> Dict[str, Module]: - """ - This function sets up the differentiable simulation by getting the chosen solver and setting it up - At this point in time, the setup includes - - 1. initializing the mlflow run and setting the runid or resuming an existing run - 2. getting the right ``ADEPTModule`` or using the one passed in. This gets assigned to ``self.adept_module``. - 3. updating the config, units, derived quantities, and array config as defined by the ``ADEPTModule``. It also dumps this information to the temporary directory, which will be logged later, and logging the parameters to mlflow - 4. initializing the state and args as defined by the ``ADEPTModule`` - 5. initializing the `diffeqsolve` as defined by the ``ADEPTModule`` - 6. initializing the necessary (trainable) physics modules as defined by the ``ADEPTModule`` - - Args: - cfg: The configuration dictionary - - Returns: - A dictionary of trainable modules (``Dict[str, eqx.Module]``) - - This is a dictionary of the (trainable) modules that are required to run the simulation. These can be modules that - change the initial conditions, or the driver (boundary conditions), or the metric calculation. These modules are - ``equinox`` modules in order to play nice with ``diffrax`` - - """ - - with tempfile.TemporaryDirectory(dir=self.base_tempdir) as td: - if self.mlflow_run_id is None: - mlflow.set_experiment(cfg["mlflow"]["experiment"]) - with mlflow.start_run(run_name=cfg["mlflow"]["run"], nested=self.mlflow_nested) as mlflow_run: - modules = self._setup_(cfg, td, adept_module) - mlflow.log_artifacts(td) # logs the temporary directory to mlflow - self.mlflow_run_id = mlflow_run.info.run_id - - else: - from adept.utils import get_cfg - - with mlflow.start_run(run_id=self.mlflow_run_id, nested=self.mlflow_nested) as mlflow_run: - # with tempfile.TemporaryDirectory(dir=self.base_tempdir) as temp_path: - # cfg = get_cfg(artifact_uri=mlflow_run.info.artifact_uri, temp_path=temp_path) - modules = self._setup_(cfg, td, adept_module) - mlflow.log_artifacts(td) # logs the temporary directory to mlflow - - return modules - - def _get_adept_module_(self, cfg: Dict) -> ADEPTModule: - """ - This function returns the helper functions for the given solver - - Args: - solver: The solver to use - - - """ - - if cfg["solver"] == "tf-1d": - from adept.tf1d.modules import BaseTwoFluid1D as this_module - from adept.tf1d.datamodel import ConfigModel - - # config = ConfigModel(**cfg) - - elif cfg["solver"] == "vlasov-1d": - from adept.vlasov1d import BaseVlasov1D as this_module - from adept._vlasov1d.datamodel import ConfigModel - - # config = ConfigModel(**cfg) - - elif cfg["solver"] == "envelope-2d": - from adept.lpse2d import BaseLPSE2D as this_module - - # from adept.lpse2d.datamodel import ConfigModel - - # config = ConfigModel(**cfg) - - elif cfg["solver"] == "vfp-1d": - from adept.vfp1d.base import BaseVFP1D as this_module - else: - raise NotImplementedError("This solver approach has not been implemented yet") - - return this_module(cfg) - - def _setup_(self, cfg: Dict, td: str, adept_module: ADEPTModule = None, log: bool = True) -> Dict[str, Module]: - from adept.utils import log_params - - if adept_module is None: - self.adept_module = self._get_adept_module_(cfg) - else: - self.adept_module = adept_module(cfg) - - # dump raw config - if log: - with open(os.path.join(td, "config.yaml"), "w") as fi: - yaml.dump(self.adept_module.cfg, fi) - - # dump units - quants_dict = self.adept_module.write_units() # writes the units to the temporary directory - if log: - with open(os.path.join(td, "units.yaml"), "w") as fi: - yaml.dump(quants_dict, fi) - - # dump derived config - self.adept_module.get_derived_quantities() # gets the derived quantities - - if log: - log_params(self.adept_module.cfg) # logs the parameters to mlflow - with open(os.path.join(td, "derived_config.yaml"), "w") as fi: - yaml.dump(self.adept_module.cfg, fi) - - # dump array config - self.adept_module.get_solver_quantities() - if log: - with open(os.path.join(td, "array_config.pkl"), "wb") as fi: - pickle.dump(self.adept_module.cfg, fi) - - self.adept_module.init_state_and_args() - self.adept_module.init_diffeqsolve() - modules = self.adept_module.init_modules() - - self.ran_setup = True - - return modules - - def __call__(self, modules: Dict = None) -> Tuple[Solution, Dict, str]: - """ - This function is the main entry point for running a simulation. It takes a configuration dictionary and returns a - ``diffrax.Solution`` object and a dictionary of datasets. It calls the ``self.adept_module``'s ``__call__`` function. - - It is also responsible for logging the artifacts and metrics to mlflow. - - Args: - modules (Dict(str, eqx.Module)): The trainable modules that are required to run the simulation. All the other parameters are static and initialized during the setup call - - Returns: - a tuple of the run_output (``diffrax.Solution``), post_processing_output (``Dict[str, xarray.dataset]``), and the mlflow_run_id (``str``). - - The run_output comes from the ``__call__`` function of the ``self.adept_module``. The post_processing_output comes from the ``post_process`` method of the ``self.adept_module``. - The mlflow_run_id is the id of the mlflow run that was created during the setup call or passed in during the initialization of the class - - """ - - assert self.ran_setup, "You must run self.setup() before running the simulation" - - with mlflow.start_run( - run_id=self.mlflow_run_id, nested=self.mlflow_nested, log_system_metrics=True - ) as mlflow_run: - t0 = time.time() - run_output = filter_jit(self.adept_module.__call__)(modules, None) - mlflow.log_metrics({"run_time": round(time.time() - t0, 4)}) # logs the run time to mlflow - - t0 = time.time() - with tempfile.TemporaryDirectory(dir=self.base_tempdir) as td: - post_processing_output = self.adept_module.post_process(run_output, td) - mlflow.log_artifacts(td) # logs the temporary directory to mlflow - - if "metrics" in post_processing_output: - mlflow.log_metrics(post_processing_output["metrics"]) - mlflow.log_metrics({"postprocess_time": round(time.time() - t0, 4)}) - - return run_output, post_processing_output, self.mlflow_run_id - - def val_and_grad(self, modules: Dict = None) -> Tuple[float, Dict, Tuple[Solution, Dict, str]]: - """ - This function is the value and gradient of the simulation. This is a very similar looking function to the ``__call__`` function but calls the ``self.adept_module.vg`` rather than the ``self.adept_module.__call__``. - - It is also responsible for logging the artifacts and metrics to mlflow. - - - Args: - modules: The (trainable) modules that are required to run the simulation and take the gradient against. All the other parameters are static and initialized during the setup call - - Returns: - a tuple of the value (``float``), gradient (``Dict``), and a tuple of the run_output (``diffrax.Solution``), post_processing_output (``Dict[str, xarray.dataset]``), and the mlflow_run_id (``str``). - - The value and gradient, and run_output come from the ``adept_module.vg`` function. The run_output is the same as that from ``__call__`` function of the ``self.adept_module``. The post_processing_output comes from the ``post_process`` method of the ``self.adept_module``. - The mlflow_run_id is the id of the mlflow run that was created during the setup call or passed in during the initialization - """ - assert self.ran_setup, "You must run self.setup() before running the simulation" - with mlflow.start_run( - run_id=self.mlflow_run_id, nested=self.mlflow_nested, log_system_metrics=True - ) as mlflow_run: - t0 = time.time() - (val, run_output), grad = filter_jit(self.adept_module.vg)(modules, None) - flattened_grad, _ = jax.flatten_util.ravel_pytree(grad) - mlflow.log_metrics({"run_time": round(time.time() - t0, 4)}) # logs the run time to mlflow - mlflow.log_metrics({"val": float(val), "l2-grad": float(np.linalg.norm(flattened_grad))}) - - t0 = time.time() - with tempfile.TemporaryDirectory(dir=self.base_tempdir) as td: - post_processing_output = self.adept_module.post_process(run_output, td) - mlflow.log_artifacts(td) # logs the temporary directory to mlflow - if "metrics" in post_processing_output: - mlflow.log_metrics(post_processing_output["metrics"]) - mlflow.log_metrics({"postprocess_time": round(time.time() - t0, 4)}) - return val, grad, (run_output, post_processing_output, self.mlflow_run_id) - - def _log_flops_(_run_: Callable, models: Dict, state: Dict, args: Dict, tqs): - """ - Logs the number of flops to mlflow - - Args: - _run_: The function that runs the simulation - models: The models used in the simulation - tqs: The time quantities used in the simulation - - """ - wrapped = jax.xla_computation(_run_) - computation = wrapped(models, state, args, tqs) - module = computation.as_hlo_module() - client = jax.lib.xla_bridge.get_backend() - analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, module) - flops_sum = analysis["flops"] - mlflow.log_metrics({"total GigaFLOP": flops_sum / 1e9}) # logs the flops to mlflow +from . import lpse2d, vlasov1d diff --git a/adept/_base_.py b/adept/_base_.py new file mode 100644 index 0000000..655ba0e --- /dev/null +++ b/adept/_base_.py @@ -0,0 +1,409 @@ +from typing import Dict, Tuple, Callable +import jax.flatten_util +import os, time, tempfile, yaml, pickle + + +from diffrax import Solution, Euler, RESULTS +from equinox import Module, filter_jit +import mlflow, jax, numpy as np +from jax import numpy as jnp + + +def get_envelope(p_wL, p_wR, p_L, p_R, ax): + return 0.5 * (jnp.tanh((ax - p_L) / p_wL) - jnp.tanh((ax - p_R) / p_wR)) + + +class Stepper(Euler): + """ + This is just a dummy stepper + + :param cfg: + """ + + def step(self, terms, t0, t1, y0, args, solver_state, made_jump): + del solver_state, made_jump + y1 = terms.vf(t0, y0, args) + dense_info = dict(y0=y0, y1=y1) + return y1, None, dense_info, None, RESULTS.successful + + +class ADEPTModule: + """ + This class is the base class for all the ADEPT modules. It defines the interface that all the ADEPT modules must implement so that + the `ergoExo` class can call them in the right order. + + Args: + cfg: The configuration dictionary + + """ + + def __init__(self, cfg) -> None: + self.cfg = cfg + self.state = None + self.args = None + self.diffeqsolve_quants = None + self.time_quantities = None + + def post_process(self, run_output: Dict, td: str) -> Dict: + """ + This function is responsible for post-processing the results of the simulation. It is called after the simulation is run and the results are available. + + Args: + run_output (Dict): The output of the simulation + td (str): The temporary directory where the results are stored + + Returns: + A dictionary of the post-processed results. This can include the metrics, the ``xarray`` datasets, and any other information that is relevant to the simulation + + """ + + return {} + + def write_units(self) -> Dict: + """ + This function is responsible for writing the units, normalizing constants, and other important physical quantities to a dictionary. + This dictionary is then dumped to a yaml file and logged to mlflow by the ``ergoExo`` class. + + Returns: + A dictionary of the units + + """ + return {} + + def init_diffeqsolve(self) -> Dict: + """ + This function is responsible for initializing the differential equation solver ``diffrax.diffeqsolve``. It sets up the time quantities, the solver quantities, and the save function. + + Returns: + A dictionary of the differential equation solver quantities + + """ + pass + + def get_derived_quantities(self) -> Dict: + """ + This function is responsible for getting the derived quantities from the configuration dictionary. This is needed for running the simulation. These quantities do get logged to mlflow + by the ``ergoExo`` class. + + Returns: + An updated configuration dictionary + + """ + pass + + def get_solver_quantities(self): + """ + This function is responsible for getting the solver quantities from the configuration dictionary. This is needed for running the simulation. These quantities do NOT get logged + to mlflow because they are often arrays + + Returns: + An updated configuration dictionary + + """ + pass + + def get_save_func(self): + """ + This function is responsible for getting the save function for the differential equation solver. This is needed for running the simulation. + This function lets you subsample your simulation state so as to not save the entire thing at every timestep. + + This dictionary is set as a class attribute for the ``ADEPTModule`` and are used in the ``__call__`` function + + """ + pass + + def init_state_and_args(self): + """ + This function initializes the state and the arguments that are required to run the simulation. The state is the initial conditions of the simulation and + the arguments are often the drivers + + These are set as class attributes for the ``ADEPTModule`` and are used in the ``__call__`` function + + """ + return {} + + def init_modules(self) -> Dict[str, Module]: + """ + This function initializes the necessary (trainable) physics modules that are required to run the simulation. These can be modules that + change the initial conditions, or the driver (boundary conditions), or the metric calculation. These modules are usually `eqx.Module`s + so that you can take derivatives against the (parameters of the) modules. + + Returns: + Dict: A dictionary of the (trainable) modules that are required to run the simulation + + """ + return {} + + def __call__(self, trainable_modules: Dict, args: Dict): + return {} + + def vg(self, trainable_modules: Dict, args: Dict): + raise NotImplementedError( + "This is the base class and does not have a gradient implemented. This is " + + "likely because there is no metric in place. Subclass this class and implement the gradient" + ) + # return eqx.filter_value_and_grad(self.__call__)(trainable_modules) + + +class ergoExo: + """ + This class is the main interface for running a simulation. It is responsible for calling all the ADEPT modules in the right order + and logging parameters and results to mlflow. + + This approach helps decouple the numerical solvers from the experiment management and facilitates the addition of new solvers + + Typical usage is as follows + + .. code-block:: python + + exoskeleton = ergoExo() + modules = exoskeleton.setup(cfg) + run_output, post_processing_output, mlflow_run_id = exoskeleton(modules, args=None) + + + If you are resuming an existing mlflow run, you can do the following + + .. code-block:: python + + exoskeleton = ergoExo(mlflow_run_id=mlflow_run_id) + modules = exoskeleton.setup(cfg) + run_output, post_processing_output, mlflow_run_id = exoskeleton(modules, args=None) + + If you are introducing a custom `ADEPTModule`, you can do the following + + .. code-block:: python + + exoskeleton = ergoExo() + modules = exoskeleton.setup(cfg, exoskeleton_module=custom_module) + run_output, post_processing_output, mlflow_run_id = adept(modules, args=None) + + + """ + + def __init__(self, mlflow_run_id: str = None, mlflow_nested: bool = None) -> None: + + self.mlflow_run_id = mlflow_run_id + # if mlflow_run_id is not None: + # assert self.mlflow_nested is not None + if mlflow_nested is None: + self.mlflow_nested = False + else: + self.mlflow_nested = mlflow_nested + + if "BASE_TEMPDIR" in os.environ: + self.base_tempdir = os.environ["BASE_TEMPDIR"] + else: + self.base_tempdir = None + + self.ran_setup = False + + def setup(self, cfg: Dict, adept_module: ADEPTModule = None) -> Dict[str, Module]: + """ + This function sets up the differentiable simulation by getting the chosen solver and setting it up + At this point in time, the setup includes + + 1. initializing the mlflow run and setting the runid or resuming an existing run + 2. getting the right ``ADEPTModule`` or using the one passed in. This gets assigned to ``self.adept_module``. + 3. updating the config, units, derived quantities, and array config as defined by the ``ADEPTModule``. It also dumps this information to the temporary directory, which will be logged later, and logging the parameters to mlflow + 4. initializing the state and args as defined by the ``ADEPTModule`` + 5. initializing the `diffeqsolve` as defined by the ``ADEPTModule`` + 6. initializing the necessary (trainable) physics modules as defined by the ``ADEPTModule`` + + Args: + cfg: The configuration dictionary + + Returns: + A dictionary of trainable modules (``Dict[str, eqx.Module]``) + + This is a dictionary of the (trainable) modules that are required to run the simulation. These can be modules that + change the initial conditions, or the driver (boundary conditions), or the metric calculation. These modules are + ``equinox`` modules in order to play nice with ``diffrax`` + + """ + + with tempfile.TemporaryDirectory(dir=self.base_tempdir) as td: + if self.mlflow_run_id is None: + mlflow.set_experiment(cfg["mlflow"]["experiment"]) + with mlflow.start_run(run_name=cfg["mlflow"]["run"], nested=self.mlflow_nested) as mlflow_run: + modules = self._setup_(cfg, td, adept_module) + mlflow.log_artifacts(td) # logs the temporary directory to mlflow + self.mlflow_run_id = mlflow_run.info.run_id + + else: + from adept.utils import get_cfg + + with mlflow.start_run(run_id=self.mlflow_run_id, nested=self.mlflow_nested) as mlflow_run: + # with tempfile.TemporaryDirectory(dir=self.base_tempdir) as temp_path: + # cfg = get_cfg(artifact_uri=mlflow_run.info.artifact_uri, temp_path=temp_path) + modules = self._setup_(cfg, td, adept_module) + mlflow.log_artifacts(td) # logs the temporary directory to mlflow + + return modules + + def _get_adept_module_(self, cfg: Dict) -> ADEPTModule: + """ + This function returns the helper functions for the given solver + + Args: + solver: The solver to use + + + """ + + if cfg["solver"] == "tf-1d": + from adept.tf1d.modules import BaseTwoFluid1D as this_module + from adept.tf1d.datamodel import ConfigModel + + # config = ConfigModel(**cfg) + + elif cfg["solver"] == "vlasov-1d": + from adept.vlasov1d import BaseVlasov1D as this_module + from adept._vlasov1d.datamodel import ConfigModel + + # config = ConfigModel(**cfg) + + elif cfg["solver"] == "envelope-2d": + from adept.lpse2d import BaseLPSE2D as this_module + + # from adept.lpse2d.datamodel import ConfigModel + + # config = ConfigModel(**cfg) + + elif cfg["solver"] == "vfp-1d": + from adept.vfp1d.base import BaseVFP1D as this_module + else: + raise NotImplementedError("This solver approach has not been implemented yet") + + return this_module(cfg) + + def _setup_(self, cfg: Dict, td: str, adept_module: ADEPTModule = None, log: bool = True) -> Dict[str, Module]: + from adept.utils import log_params + + if adept_module is None: + self.adept_module = self._get_adept_module_(cfg) + else: + self.adept_module = adept_module(cfg) + + # dump raw config + if log: + with open(os.path.join(td, "config.yaml"), "w") as fi: + yaml.dump(self.adept_module.cfg, fi) + + # dump units + quants_dict = self.adept_module.write_units() # writes the units to the temporary directory + if log: + with open(os.path.join(td, "units.yaml"), "w") as fi: + yaml.dump(quants_dict, fi) + + # dump derived config + self.adept_module.get_derived_quantities() # gets the derived quantities + + if log: + log_params(self.adept_module.cfg) # logs the parameters to mlflow + with open(os.path.join(td, "derived_config.yaml"), "w") as fi: + yaml.dump(self.adept_module.cfg, fi) + + # dump array config + self.adept_module.get_solver_quantities() + if log: + with open(os.path.join(td, "array_config.pkl"), "wb") as fi: + pickle.dump(self.adept_module.cfg, fi) + + self.adept_module.init_state_and_args() + self.adept_module.init_diffeqsolve() + modules = self.adept_module.init_modules() + + self.ran_setup = True + + return modules + + def __call__(self, modules: Dict = None) -> Tuple[Solution, Dict, str]: + """ + This function is the main entry point for running a simulation. It takes a configuration dictionary and returns a + ``diffrax.Solution`` object and a dictionary of datasets. It calls the ``self.adept_module``'s ``__call__`` function. + + It is also responsible for logging the artifacts and metrics to mlflow. + + Args: + modules (Dict(str, eqx.Module)): The trainable modules that are required to run the simulation. All the other parameters are static and initialized during the setup call + + Returns: + a tuple of the run_output (``diffrax.Solution``), post_processing_output (``Dict[str, xarray.dataset]``), and the mlflow_run_id (``str``). + + The run_output comes from the ``__call__`` function of the ``self.adept_module``. The post_processing_output comes from the ``post_process`` method of the ``self.adept_module``. + The mlflow_run_id is the id of the mlflow run that was created during the setup call or passed in during the initialization of the class + + """ + + assert self.ran_setup, "You must run self.setup() before running the simulation" + + with mlflow.start_run( + run_id=self.mlflow_run_id, nested=self.mlflow_nested, log_system_metrics=True + ) as mlflow_run: + t0 = time.time() + run_output = filter_jit(self.adept_module.__call__)(modules, None) + mlflow.log_metrics({"run_time": round(time.time() - t0, 4)}) # logs the run time to mlflow + + t0 = time.time() + with tempfile.TemporaryDirectory(dir=self.base_tempdir) as td: + post_processing_output = self.adept_module.post_process(run_output, td) + mlflow.log_artifacts(td) # logs the temporary directory to mlflow + + if "metrics" in post_processing_output: + mlflow.log_metrics(post_processing_output["metrics"]) + mlflow.log_metrics({"postprocess_time": round(time.time() - t0, 4)}) + + return run_output, post_processing_output, self.mlflow_run_id + + def val_and_grad(self, modules: Dict = None) -> Tuple[float, Dict, Tuple[Solution, Dict, str]]: + """ + This function is the value and gradient of the simulation. This is a very similar looking function to the ``__call__`` function but calls the ``self.adept_module.vg`` rather than the ``self.adept_module.__call__``. + + It is also responsible for logging the artifacts and metrics to mlflow. + + + Args: + modules: The (trainable) modules that are required to run the simulation and take the gradient against. All the other parameters are static and initialized during the setup call + + Returns: + a tuple of the value (``float``), gradient (``Dict``), and a tuple of the run_output (``diffrax.Solution``), post_processing_output (``Dict[str, xarray.dataset]``), and the mlflow_run_id (``str``). + + The value and gradient, and run_output come from the ``adept_module.vg`` function. The run_output is the same as that from ``__call__`` function of the ``self.adept_module``. The post_processing_output comes from the ``post_process`` method of the ``self.adept_module``. + The mlflow_run_id is the id of the mlflow run that was created during the setup call or passed in during the initialization + """ + assert self.ran_setup, "You must run self.setup() before running the simulation" + with mlflow.start_run( + run_id=self.mlflow_run_id, nested=self.mlflow_nested, log_system_metrics=True + ) as mlflow_run: + t0 = time.time() + (val, run_output), grad = filter_jit(self.adept_module.vg)(modules, None) + flattened_grad, _ = jax.flatten_util.ravel_pytree(grad) + mlflow.log_metrics({"run_time": round(time.time() - t0, 4)}) # logs the run time to mlflow + mlflow.log_metrics({"val": float(val), "l2-grad": float(np.linalg.norm(flattened_grad))}) + + t0 = time.time() + with tempfile.TemporaryDirectory(dir=self.base_tempdir) as td: + post_processing_output = self.adept_module.post_process(run_output, td) + mlflow.log_artifacts(td) # logs the temporary directory to mlflow + if "metrics" in post_processing_output: + mlflow.log_metrics(post_processing_output["metrics"]) + mlflow.log_metrics({"postprocess_time": round(time.time() - t0, 4)}) + return val, grad, (run_output, post_processing_output, self.mlflow_run_id) + + def _log_flops_(_run_: Callable, models: Dict, state: Dict, args: Dict, tqs): + """ + Logs the number of flops to mlflow + + Args: + _run_: The function that runs the simulation + models: The models used in the simulation + tqs: The time quantities used in the simulation + + """ + wrapped = jax.xla_computation(_run_) + computation = wrapped(models, state, args, tqs) + module = computation.as_hlo_module() + client = jax.lib.xla_bridge.get_backend() + analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, module) + flops_sum = analysis["flops"] + mlflow.log_metrics({"total GigaFLOP": flops_sum / 1e9}) # logs the flops to mlflow diff --git a/adept/_lpse2d/__init__.py b/adept/_lpse2d/__init__.py index e69de29..64716d6 100644 --- a/adept/_lpse2d/__init__.py +++ b/adept/_lpse2d/__init__.py @@ -0,0 +1,2 @@ +from .modules import BaseLPSE2D as BaseLPSE2D, save_driver as save_driver +from .helpers import calc_threshold_intensity as calc_threshold_intensity diff --git a/adept/_lpse2d/core/driver.py b/adept/_lpse2d/core/driver.py index 9add7ef..bb3e708 100644 --- a/adept/_lpse2d/core/driver.py +++ b/adept/_lpse2d/core/driver.py @@ -1,8 +1,6 @@ from typing import Dict -import jax -import equinox as eqx from jax import numpy as jnp -from adept import get_envelope +from adept._base_ import get_envelope class Driver: diff --git a/adept/_lpse2d/core/integrator.py b/adept/_lpse2d/core/integrator.py index e00a8de..910d8c0 100644 --- a/adept/_lpse2d/core/integrator.py +++ b/adept/_lpse2d/core/integrator.py @@ -3,7 +3,7 @@ from jax import numpy as jnp import numpy as np -from adept import get_envelope +from adept._base_ import get_envelope from adept._lpse2d.core import epw, laser diff --git a/adept/_lpse2d/core/vector_field.py b/adept/_lpse2d/core/vector_field.py index b530f0f..28caf0d 100644 --- a/adept/_lpse2d/core/vector_field.py +++ b/adept/_lpse2d/core/vector_field.py @@ -3,7 +3,7 @@ from jax import numpy as jnp, Array import numpy as np -from adept import get_envelope +from adept._base_ import get_envelope from adept._lpse2d.core import epw, laser diff --git a/adept/_lpse2d/helpers.py b/adept/_lpse2d/helpers.py index f8dc660..6dbc88d 100644 --- a/adept/_lpse2d/helpers.py +++ b/adept/_lpse2d/helpers.py @@ -1,18 +1,16 @@ import os from typing import Dict, Tuple -from collections import defaultdict from functools import partial import matplotlib.pyplot as plt from jax import Array, numpy as jnp import numpy as np -import equinox as eqx import xarray as xr import interpax from astropy.units import Quantity as _Q -from adept import get_envelope +from adept._base_ import get_envelope def write_units(cfg: Dict) -> Dict: diff --git a/adept/_lpse2d/modules/__init__.py b/adept/_lpse2d/modules/__init__.py index e69de29..382e32f 100644 --- a/adept/_lpse2d/modules/__init__.py +++ b/adept/_lpse2d/modules/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseLPSE2D as BaseLPSE2D +from .driver import save as save_driver diff --git a/adept/_lpse2d/modules/base.py b/adept/_lpse2d/modules/base.py index 83da7e2..6c87057 100644 --- a/adept/_lpse2d/modules/base.py +++ b/adept/_lpse2d/modules/base.py @@ -3,7 +3,8 @@ from diffrax import diffeqsolve, SaveAt, ODETerm, SubSaveAt from equinox import filter_jit -from adept import ADEPTModule, Stepper +from adept import ADEPTModule +from adept._base_ import Stepper from adept._lpse2d.helpers import ( write_units, post_process, diff --git a/adept/_vlasov1d/helpers.py b/adept/_vlasov1d/helpers.py index fc9a4a6..21aff09 100644 --- a/adept/_vlasov1d/helpers.py +++ b/adept/_vlasov1d/helpers.py @@ -9,24 +9,25 @@ import numpy as np import xarray, mlflow, pint from jax import numpy as jnp +from scipy.special import gamma from diffrax import Solution from matplotlib import pyplot as plt -from adept import get_envelope +from adept._base_ import get_envelope from adept._vlasov1d.storage import store_f, store_fields -gamma_da = xarray.open_dataarray(os.path.join(os.path.dirname(__file__), "gamma_func_for_sg.nc")) -m_ax = gamma_da.coords["m"].data -g_3_m = np.squeeze(gamma_da.loc[{"gamma": "3/m"}].data) -g_5_m = np.squeeze(gamma_da.loc[{"gamma": "5/m"}].data) +# gamma_da = xarray.open_dataarray(os.path.join(os.path.dirname(__file__), "gamma_func_for_sg.nc")) +# m_ax = gamma_da.coords["m"].data +# g_3_m = np.squeeze(gamma_da.loc[{"gamma": "3/m"}].data) +# g_5_m = np.squeeze(gamma_da.loc[{"gamma": "5/m"}].data) def gamma_3_over_m(m): - return np.interp(m, m_ax, g_3_m) + return gamma(3.0 / m) # np.interp(m, m_ax, g_3_m) def gamma_5_over_m(m): - return np.interp(m, m_ax, g_5_m) + return gamma(5.0 / m) # np.interp(m, m_ax, g_5_m) def _initialize_distribution_( diff --git a/adept/_vlasov1d/modules.py b/adept/_vlasov1d/modules.py index 654533e..04611ff 100644 --- a/adept/_vlasov1d/modules.py +++ b/adept/_vlasov1d/modules.py @@ -8,7 +8,8 @@ from jax import numpy as jnp from diffrax import ODETerm, SubSaveAt, diffeqsolve, SaveAt -from adept import Stepper, ADEPTModule +from adept import ADEPTModule +from adept._base_ import Stepper from adept._vlasov1d.storage import get_save_quantities from adept._vlasov1d.helpers import _initialize_total_distribution_, post_process from adept._vlasov1d.solvers.vector_field import VlasovMaxwell diff --git a/adept/_vlasov1d/solvers/pushers/field.py b/adept/_vlasov1d/solvers/pushers/field.py index 3802f2d..b1e5b9d 100644 --- a/adept/_vlasov1d/solvers/pushers/field.py +++ b/adept/_vlasov1d/solvers/pushers/field.py @@ -3,7 +3,7 @@ from typing import Dict from jax import numpy as jnp -from adept import get_envelope +from adept._base_ import get_envelope class Driver: diff --git a/adept/_vlasov1d/solvers/vector_field.py b/adept/_vlasov1d/solvers/vector_field.py index e1c694d..c3f8501 100644 --- a/adept/_vlasov1d/solvers/vector_field.py +++ b/adept/_vlasov1d/solvers/vector_field.py @@ -3,7 +3,7 @@ from jax import numpy as jnp, Array -from adept import get_envelope +from adept._base_ import get_envelope from adept._vlasov1d.solvers.pushers import field, fokker_planck, vlasov diff --git a/adept/lpse2d.py b/adept/lpse2d.py index aa38b6e..54ada8a 100644 --- a/adept/lpse2d.py +++ b/adept/lpse2d.py @@ -1,3 +1 @@ -from adept._lpse2d.modules.base import BaseLPSE2D -from adept._lpse2d.helpers import calc_threshold_intensity -from adept._lpse2d.modules.driver import save as save_driver +from ._lpse2d import BaseLPSE2D, save_driver, calc_threshold_intensity diff --git a/adept/tf1d/solvers/pushers.py b/adept/tf1d/solvers/pushers.py index 320e4ec..b6ae2dd 100644 --- a/adept/tf1d/solvers/pushers.py +++ b/adept/tf1d/solvers/pushers.py @@ -6,7 +6,7 @@ import equinox as eqx from adept.electrostatic import get_complex_frequency_table -from adept import get_envelope +from adept._base_ import get_envelope class WaveSolver(eqx.Module): diff --git a/adept/utils.py b/adept/utils.py index 5adf658..37705e5 100644 --- a/adept/utils.py +++ b/adept/utils.py @@ -175,91 +175,3 @@ def export_run(run_id, prefix="individual", step=0): t0 = time.time() upload_dir_to_s3(td2, "remote-mlflow-staging", f"artifacts/{run_id}", run_id, prefix, step) # print(f"Uploading took {round(time.time() - t0, 2)} s") - - -def setup_parsl(parsl_provider="local", num_gpus=4, nodes=1): - import parsl - from parsl.config import Config - from parsl.providers import SlurmProvider, LocalProvider - from parsl.launchers import SrunLauncher - from parsl.executors import HighThroughputExecutor - - if parsl_provider == "local": - - if nodes == 1: - this_provider = LocalProvider - provider_args = dict( - worker_init="source /pscratch/sd/a/archis/venvs/adept-gpu/bin/activate; \ - module load cudnn/8.9.3_cuda12.lua; \ - export PYTHONPATH='$PYTHONPATH:/global/homes/a/archis/adept/'; \ - export BASE_TEMPDIR='/pscratch/sd/a/archis/tmp/'; \ - export MLFLOW_TRACKING_URI='/pscratch/sd/a/archis/mlflow'; \ - export MLFLOW_EXPORT=True", - init_blocks=1, - max_blocks=1, - nodes_per_block=1, - ) - htex = HighThroughputExecutor( - available_accelerators=num_gpus, - label="tpd", - provider=this_provider(**provider_args), - cpu_affinity="block", - ) - print(f"{htex.workers_per_node=}") - else: - this_provider = LocalProvider - provider_args = dict( - worker_init="source /pscratch/sd/a/archis/venvs/adept-gpu/bin/activate; \ - module load cudnn/8.9.3_cuda12.lua; \ - export PYTHONPATH='$PYTHONPATH:/global/homes/a/archis/adept/'; \ - export BASE_TEMPDIR='/pscratch/sd/a/archis/tmp/'; \ - export MLFLOW_TRACKING_URI='/pscratch/sd/a/archis/mlflow'; \ - export MLFLOW_EXPORT=True", - nodes_per_block=nodes, - launcher=SrunLauncher(overrides="-c 32 --gpus-per-node 4"), - cmd_timeout=120, - init_blocks=1, - max_blocks=1, - ) - - htex = HighThroughputExecutor( - available_accelerators=num_gpus * nodes, - label="tpd", - provider=this_provider(**provider_args), - max_workers_per_node=4, - cpu_affinity="block", - ) - print(f"{htex.workers_per_node=}") - - elif parsl_provider == "gpu": - - this_provider = SlurmProvider - sched_args = ["#SBATCH -C gpu", "#SBATCH --qos=regular"] - provider_args = dict( - partition=None, - account="m4490_g", - scheduler_options="\n".join(sched_args), - worker_init="export SLURM_CPU_BIND='cores';\ - source /pscratch/sd/a/archis/venvs/adept-gpu/bin/activate; \ - module load cudnn/8.9.3_cuda12.lua; \ - export PYTHONPATH='$PYTHONPATH:/global/homes/a/archis/adept/'; \ - export BASE_TEMPDIR='/pscratch/sd/a/archis/tmp/'; \ - export MLFLOW_TRACKING_URI='/pscratch/sd/a/archis/mlflow';\ - export MLFLOW_EXPORT=True", - launcher=SrunLauncher(overrides="--gpus-per-node 4 -c 128"), - walltime="12:00:00", - cmd_timeout=120, - nodes_per_block=1, - # init_blocks=1, - max_blocks=nodes, - ) - - htex = HighThroughputExecutor( - available_accelerators=4, label="tpd-learn", provider=this_provider(**provider_args), cpu_affinity="block" - ) - print(f"{htex.workers_per_node=}") - - config = Config(executors=[htex], retries=4) - - # load the Parsl config - parsl.load(config) diff --git a/adept/vfp1d/base.py b/adept/vfp1d/base.py index 6eb05de..f65a9ef 100644 --- a/adept/vfp1d/base.py +++ b/adept/vfp1d/base.py @@ -6,7 +6,7 @@ from diffrax import diffeqsolve, SaveAt, ODETerm, SubSaveAt from jax import numpy as jnp, tree_util as jtu -from adept import ADEPTModule, Stepper +from adept._base_ import ADEPTModule, Stepper from adept.vfp1d.vector_field import OSHUN1D from adept.vfp1d.helpers import _initialize_total_distribution_, calc_logLambda from adept.vfp1d.storage import get_save_quantities, post_process diff --git a/adept/vfp1d/helpers.py b/adept/vfp1d/helpers.py index 3ed08c9..75c1682 100644 --- a/adept/vfp1d/helpers.py +++ b/adept/vfp1d/helpers.py @@ -2,20 +2,13 @@ # research@ergodic.io from typing import Dict, Tuple -import os import numpy as np from jax import Array -import xarray, yaml -from astropy import units as u, constants as csts +from scipy.special import gamma from astropy.units import Quantity as _Q from jax import numpy as jnp -from adept import get_envelope - -gamma_da = xarray.open_dataarray(os.path.join(os.path.dirname(__file__), "..", "vlasov1d", "gamma_func_for_sg.nc")) -m_ax = gamma_da.coords["m"].data -g_3_m = np.squeeze(gamma_da.loc[{"gamma": "3/m"}].data) -g_5_m = np.squeeze(gamma_da.loc[{"gamma": "5/m"}].data) +from adept._base_ import get_envelope def gamma_3_over_m(m: float) -> Array: @@ -26,7 +19,7 @@ def gamma_3_over_m(m: float) -> Array: :return: Array """ - return np.interp(m, m_ax, g_3_m) + return gamma(3.0 / m) # np.interp(m, m_ax, g_3_m) def gamma_5_over_m(m: float) -> Array: @@ -36,7 +29,7 @@ def gamma_5_over_m(m: float) -> Array: :param m: float between 2 and 5 :return: Array """ - return np.interp(m, m_ax, g_5_m) + return gamma(5.0 / m) # np.interp(m, m_ax, g_5_m) def calc_logLambda(cfg: Dict, ne: float, Te: float, Z: int, ion_species: str) -> Tuple[float, float]: diff --git a/tests/test_lpse2d/test_tpd_threshold.py b/tests/test_lpse2d/test_tpd_threshold.py index d9e5fb7..8035eae 100644 --- a/tests/test_lpse2d/test_tpd_threshold.py +++ b/tests/test_lpse2d/test_tpd_threshold.py @@ -2,12 +2,8 @@ from adept.lpse2d import calc_threshold_intensity import numpy as np -from adept.utils import setup_parsl -# from parsl.app.app import python_app - -# @python_app def run_once(L, Te, I0): import yaml from adept import ergoExo @@ -36,7 +32,6 @@ def test_threshold(): if "CPU_ONLY" in os.environ: pass else: - setup_parsl() ess = [] c = 3e8 lam0 = 351e-9 diff --git a/tests/test_tf1d/test_against_vlasov.py b/tests/test_tf1d/test_against_vlasov.py index 31009d8..5ee5be0 100644 --- a/tests/test_tf1d/test_against_vlasov.py +++ b/tests/test_tf1d/test_against_vlasov.py @@ -5,11 +5,13 @@ import numpy as np from jax import config +from adept import ergoExo + config.update("jax_enable_x64", True) # config.update("jax_disable_jit", True) import xarray as xr -from adept import ergoExo, electrostatic +from adept import electrostatic def _modify_defaults_(defaults): diff --git a/tests/test_tf1d/test_landau_damping.py b/tests/test_tf1d/test_landau_damping.py index ce1b5c4..9d6780d 100644 --- a/tests/test_tf1d/test_landau_damping.py +++ b/tests/test_tf1d/test_landau_damping.py @@ -5,13 +5,15 @@ import numpy as np from jax import config +from adept import ergoExo + config.update("jax_enable_x64", True) # config.update("jax_disable_jit", True) from jax import numpy as jnp import mlflow -from adept import ergoExo, electrostatic +from adept import electrostatic def _modify_defaults_(defaults, rng): diff --git a/tests/test_tf1d/test_resonance.py b/tests/test_tf1d/test_resonance.py index 6bf673f..93d61ba 100644 --- a/tests/test_tf1d/test_resonance.py +++ b/tests/test_tf1d/test_resonance.py @@ -5,11 +5,13 @@ import numpy as np from jax import config +from adept import ergoExo + config.update("jax_enable_x64", True) # config.update("jax_disable_jit", True) from jax import numpy as jnp -from adept import ergoExo, electrostatic +from adept import electrostatic def _modify_defaults_(defaults, rng, gamma): diff --git a/tests/test_vfp1d/epp-short.yaml b/tests/test_vfp1d/epp-short.yaml index 1d13c2c..80fe198 100644 --- a/tests/test_vfp1d/epp-short.yaml +++ b/tests/test_vfp1d/epp-short.yaml @@ -1,6 +1,6 @@ units: laser_wavelength: 351nm - reference electron temperature: 2000eV + reference electron temperature: 300eV reference ion temperature: 300eV reference electron density: 1.5e21/cm^3 Z: 6 diff --git a/tests/test_vlasov1d/test_absorbing_wave.py b/tests/test_vlasov1d/test_absorbing_wave.py index c5b7209..7062d6c 100644 --- a/tests/test_vlasov1d/test_absorbing_wave.py +++ b/tests/test_vlasov1d/test_absorbing_wave.py @@ -9,7 +9,7 @@ from diffrax import ODETerm, diffeqsolve from adept._vlasov1d.solvers.pushers.field import Driver, WaveSolver -from adept import Stepper +from adept._base_ import Stepper class VectorField(eqx.Module): diff --git a/tests/test_vlasov1d/test_landau_damping.py b/tests/test_vlasov1d/test_landau_damping.py index b47aa9d..70b6d77 100644 --- a/tests/test_vlasov1d/test_landau_damping.py +++ b/tests/test_vlasov1d/test_landau_damping.py @@ -7,10 +7,12 @@ import numpy as np from jax import config +from adept import ergoExo + config.update("jax_enable_x64", True) # config.update("jax_disable_jit", True) -from adept import ergoExo, electrostatic +from adept import electrostatic def _modify_defaults_(defaults, rng, real_or_imag, time, field, edfdv):