diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 1cd9882d9..d079c09c2 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -11,7 +11,7 @@ We are happy to get you started in using our software. ## Issues If you think you have encountered a software issue, please raise this on the "Issues" tab in Github. -In general the more details you can provide the better, +In general the more details you can provide the better, we recommend reading section 3.3 of [this article](https://livecomsjournal.org/index.php/livecoms/article/view/v3i1e1473) to understand the problem solving process. diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b89275985..c506d6302 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -123,9 +123,9 @@ jobs: # we only want to upload a slow report if # 1) it isn't a schedule run # 2) it wasn't from a PR (we don't run slow tests on PRs) - if: ${{ github.repository == 'OpenFreeEnergy/openfe' - && github.event_name != 'schedule' - && github.event_name != 'pull_request' }} + if: ${{ github.repository == 'OpenFreeEnergy/openfe' + && github.event_name != 'schedule' + && github.event_name != 'pull_request' }} uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/clean_cache.yaml b/.github/workflows/clean_cache.yaml index c3db0a6f0..e9a6c50d9 100644 --- a/.github/workflows/clean_cache.yaml +++ b/.github/workflows/clean_cache.yaml @@ -11,18 +11,18 @@ jobs: steps: - name: Check out code uses: actions/checkout@v3 - + - name: Cleanup run: | gh extension install actions/gh-actions-cache - + REPO=${{ github.repository }} BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge" echo "Fetching list of cache key" cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 ) - ## Setting this to not fail the workflow while deleting cache keys. + ## Setting this to not fail the workflow while deleting cache keys. set +e echo "Deleting caches..." for cacheKey in $cacheKeysForPR diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 8eb165688..29f5aa547 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -50,4 +50,4 @@ jobs: micromamba list - name: "Lint with mypy" - run: mypy + run: mypy diff --git a/Code_of_Conduct.md b/Code_of_Conduct.md index 724db5122..4cdaf25a9 100644 --- a/Code_of_Conduct.md +++ b/Code_of_Conduct.md @@ -1,6 +1,6 @@ ## Code of Conduct ## -This project is dedicated to providing a welcoming and supportive environment for all people, regardless of background or identity. Members do not tolerate harassment for any reason, but especially harassment based on gender, sexual orientation, disability, physical appearance, body size, race, nationality, sex, color, ethnic or social origin, pregnancy, citizenship, familial status, veteran status, genetic information, religion or belief, political or any other opinion, membership of a national minority, property, age, or preference of text editor. +This project is dedicated to providing a welcoming and supportive environment for all people, regardless of background or identity. Members do not tolerate harassment for any reason, but especially harassment based on gender, sexual orientation, disability, physical appearance, body size, race, nationality, sex, color, ethnic or social origin, pregnancy, citizenship, familial status, veteran status, genetic information, religion or belief, political or any other opinion, membership of a national minority, property, age, or preference of text editor. ### Expected Behavior ### diff --git a/devtools/data/gen-serialized-results.py b/devtools/data/gen-serialized-results.py index bd74a01f5..39a7db46d 100644 --- a/devtools/data/gen-serialized-results.py +++ b/devtools/data/gen-serialized-results.py @@ -9,36 +9,32 @@ - MDProtocol_json_results.gz - used in md_json fixture """ + import gzip import json import logging import pathlib import tempfile -from openff.toolkit import ( - Molecule, RDKitToolkitWrapper, AmberToolsToolkitWrapper -) -from openff.toolkit.utils.toolkit_registry import ( - toolkit_registry_manager, ToolkitRegistry -) -from openff.units import unit -from kartograf.atom_aligner import align_mol_shape -from kartograf import KartografAtomMapper + import gufe from gufe.tokenization import JSON_HANDLER +from kartograf import KartografAtomMapper +from kartograf.atom_aligner import align_mol_shape +from openff.toolkit import AmberToolsToolkitWrapper, Molecule, RDKitToolkitWrapper +from openff.toolkit.utils.toolkit_registry import ToolkitRegistry, toolkit_registry_manager +from openff.units import unit + import openfe -from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocol from openfe.protocols.openmm_afe import AbsoluteSolvationProtocol +from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocol from openfe.protocols.openmm_rfe import RelativeHybridTopologyProtocol - logger = logging.getLogger(__name__) LIGA = "[H]C([H])([H])C([H])([H])C(=O)C([H])([H])C([H])([H])[H]" LIGB = "[H]C([H])([H])C(=O)C([H])([H])C([H])([H])C([H])([H])[H]" -amber_rdkit = ToolkitRegistry( - [RDKitToolkitWrapper(), AmberToolsToolkitWrapper()] -) +amber_rdkit = ToolkitRegistry([RDKitToolkitWrapper(), AmberToolsToolkitWrapper()]) def get_molecule(smi, name): @@ -58,7 +54,7 @@ def execute_and_serialize(dag, protocol, simname): shared_basedir=workdir, scratch_basedir=workdir, keep_shared=False, - n_retries=3 + n_retries=3, ) protres = protocol.gather([dagres]) @@ -66,13 +62,10 @@ def execute_and_serialize(dag, protocol, simname): "estimate": protres.get_estimate(), "uncertainty": protres.get_uncertainty(), "protocol_result": protres.to_dict(), - "unit_results": { - unit.key: unit.to_keyed_dict() - for unit in dagres.protocol_unit_results - } + "unit_results": {unit.key: unit.to_keyed_dict() for unit in dagres.protocol_unit_results}, } - with gzip.open(f"{simname}_json_results.gz", 'wt') as zipfile: + with gzip.open(f"{simname}_json_results.gz", "wt") as zipfile: json.dump(outdict, zipfile, cls=JSON_HANDLER.encoder) @@ -95,27 +88,19 @@ def generate_ahfe_json(smc): settings.solvent_simulation_settings.production_length = 500 * unit.picosecond settings.vacuum_simulation_settings.equilibration_length = 10 * unit.picosecond settings.vacuum_simulation_settings.production_length = 1000 * unit.picosecond - settings.lambda_settings.lambda_elec = [0.0, 0.25, 0.5, 0.75, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0] - settings.lambda_settings.lambda_vdw = [0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24, - 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, - 1.0] + settings.lambda_settings.lambda_elec = [0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + settings.lambda_settings.lambda_vdw = [0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24, 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0] settings.protocol_repeats = 3 settings.solvent_simulation_settings.n_replicas = 14 settings.vacuum_simulation_settings.n_replicas = 14 settings.solvent_simulation_settings.early_termination_target_error = 0.12 * unit.kilocalorie_per_mole settings.vacuum_simulation_settings.early_termination_target_error = 0.12 * unit.kilocalorie_per_mole - settings.vacuum_engine_settings.compute_platform = 'CPU' - settings.solvent_engine_settings.compute_platform = 'CUDA' + settings.vacuum_engine_settings.compute_platform = "CPU" + settings.solvent_engine_settings.compute_platform = "CUDA" protocol = AbsoluteSolvationProtocol(settings=settings) - sysA = openfe.ChemicalSystem( - {"ligand": smc, "solvent": openfe.SolventComponent()} - ) - sysB = openfe.ChemicalSystem( - {"solvent": openfe.SolventComponent()} - ) + sysA = openfe.ChemicalSystem({"ligand": smc, "solvent": openfe.SolventComponent()}) + sysB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) dag = protocol.create(stateA=sysA, stateB=sysB, mapping=None) @@ -133,15 +118,13 @@ def generate_rfe_json(smcA, smcB): mapper = KartografAtomMapper(atom_map_hydrogens=True) mapping = next(mapper.suggest_mappings(smcA, a_smcB)) - systemA = openfe.ChemicalSystem({'ligand': smcA}) - systemB = openfe.ChemicalSystem({'ligand': a_smcB}) + systemA = openfe.ChemicalSystem({"ligand": smcA}) + systemB = openfe.ChemicalSystem({"ligand": a_smcB}) - dag = protocol.create( - stateA=systemA, stateB=systemB, mapping=mapping - ) + dag = protocol.create(stateA=systemA, stateB=systemB, mapping=mapping) execute_and_serialize(dag, protocol, "RHFEProtocol") - + if __name__ == "__main__": molA = get_molecule(LIGA, "ligandA") diff --git a/docs/_ext/sass.py b/docs/_ext/sass.py index 43888aa06..8fcc4cb78 100644 --- a/docs/_ext/sass.py +++ b/docs/_ext/sass.py @@ -17,13 +17,11 @@ from pathlib import Path from typing import Optional, Union - import sass from sphinx.application import Sphinx from sphinx.environment import BuildEnvironment from sphinx.util import logging - logger = logging.getLogger(__name__) @@ -44,10 +42,7 @@ def get_targets(app: Sphinx) -> dict[Path, Path]: if isinstance(app.config.sass_targets, dict): targets = app.config.sass_targets else: - targets = { - path: path.relative_to(src_dir).with_suffix(".css") - for path in src_dir.glob("**/[!_]*.s[ca]ss") - } + targets = {path: path.relative_to(src_dir).with_suffix(".css") for path in src_dir.glob("**/[!_]*.s[ca]ss")} return {src_dir / src: dst_dir / dst for src, dst in targets.items()} diff --git a/docs/_sass/deflist-flowchart.scss b/docs/_sass/deflist-flowchart.scss index 199bf5e9f..b65f54bb8 100644 --- a/docs/_sass/deflist-flowchart.scss +++ b/docs/_sass/deflist-flowchart.scss @@ -263,7 +263,7 @@ ul.deflist-flowchart { margin-left: 0; } } - + dl { display: flex; flex-direction: row-reverse; @@ -394,4 +394,4 @@ ul.deflist-flowchart { } } } -} \ No newline at end of file +} diff --git a/docs/_static/API.svg b/docs/_static/API.svg index 34e5d0cc9..9c2cb314e 100644 --- a/docs/_static/API.svg +++ b/docs/_static/API.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/_static/CLI.svg b/docs/_static/CLI.svg index eb94cf249..3d170a96c 100644 --- a/docs/_static/CLI.svg +++ b/docs/_static/CLI.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/_static/Cookbook.svg b/docs/_static/Cookbook.svg index 50e97f434..d2f72b489 100644 --- a/docs/_static/Cookbook.svg +++ b/docs/_static/Cookbook.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/_static/Download.svg b/docs/_static/Download.svg index 101317541..3e425ca88 100644 --- a/docs/_static/Download.svg +++ b/docs/_static/Download.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/_static/Rocket.svg b/docs/_static/Rocket.svg index 292fe01ba..ae1ed81f9 100644 --- a/docs/_static/Rocket.svg +++ b/docs/_static/Rocket.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/_static/Squaredcircle.svg b/docs/_static/Squaredcircle.svg index a8fea245e..0c622a519 100644 --- a/docs/_static/Squaredcircle.svg +++ b/docs/_static/Squaredcircle.svg @@ -18,4 +18,4 @@ - \ No newline at end of file + diff --git a/docs/_static/Tutorial.svg b/docs/_static/Tutorial.svg index 157d3a8d3..a42592d7b 100644 --- a/docs/_static/Tutorial.svg +++ b/docs/_static/Tutorial.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/_static/UserGuide.svg b/docs/_static/UserGuide.svg index 6657e0a85..e8cf53aa7 100644 --- a/docs/_static/UserGuide.svg +++ b/docs/_static/UserGuide.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/docs/conf.py b/docs/conf.py index c107a40f5..cedd37dc9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,15 +13,15 @@ import os import sys from importlib.metadata import version -from packaging.version import parse -from pathlib import Path from inspect import cleandoc +from pathlib import Path -from git import Repo -import nbsphinx import nbformat +import nbsphinx +from git import Repo +from packaging.version import parse -sys.path.insert(0, os.path.abspath('../')) +sys.path.insert(0, os.path.abspath("../")) os.environ["SPHINX"] = "True" @@ -133,7 +133,7 @@ "url": "https://github.com/OpenFreeEnergy/openfe", "icon": "fa-brands fa-square-github", "type": "fontawesome", - } + }, ], "accent_color": "DarkGoldenYellow", "navigation_with_keys": False, @@ -170,7 +170,7 @@ try: if example_notebooks_path.exists(): repo = Repo(example_notebooks_path) - repo.remote('origin').pull() + repo.remote("origin").pull() else: repo = Repo.clone_from( "https://github.com/OpenFreeEnergy/ExampleNotebooks.git", @@ -181,11 +181,10 @@ filename = e.__traceback__.tb_frame.f_code.co_filename lineno = e.__traceback__.tb_lineno - getLogger('sphinx.ext.openfe_git').warning( - f"Getting ExampleNotebooks failed in {filename} line {lineno}: {e}" - ) + getLogger("sphinx.ext.openfe_git").warning(f"Getting ExampleNotebooks failed in {filename} line {lineno}: {e}") -nbsphinx_prolog = cleandoc(r""" +nbsphinx_prolog = cleandoc( + r""" {%- set path = env.doc2path(env.docname, base="ExampleNotebooks") -%} {%- set gh_repo = "OpenFreeEnergy/openfe" -%} {%- set gh_branch = "main" -%} @@ -245,4 +244,5 @@ :octicon:`rocket` Run in Colab -""") +""", +) diff --git a/docs/cookbook/creating_atom_mappings.rst b/docs/cookbook/creating_atom_mappings.rst index 39286ae93..29e9e3dd9 100644 --- a/docs/cookbook/creating_atom_mappings.rst +++ b/docs/cookbook/creating_atom_mappings.rst @@ -22,7 +22,7 @@ which uses an MCS approach based on the RDKit. This takes various parameters which control how it produces mappings, these are viewable through ``help(LomapAtomMapper)``. -This is how we can create a mapping between two ligands: +This is how we can create a mapping between two ligands: .. code:: diff --git a/docs/cookbook/index.rst b/docs/cookbook/index.rst index 7e3fac9c2..69cb811a7 100644 --- a/docs/cookbook/index.rst +++ b/docs/cookbook/index.rst @@ -160,5 +160,3 @@ List of Cookbooks create_alchemical_network under_the_hood user_charges - - diff --git a/docs/cookbook/under_the_hood.rst b/docs/cookbook/under_the_hood.rst index ddd821adb..d47da8b40 100644 --- a/docs/cookbook/under_the_hood.rst +++ b/docs/cookbook/under_the_hood.rst @@ -142,4 +142,3 @@ If you want to implement your own atom mapper or free energy procedure, or you w - :class:`ProtocolDAGResult` A completed transformation with multiple user-defined replicas. - diff --git a/docs/environment.yaml b/docs/environment.yaml index 0933cca80..25ad5aacd 100644 --- a/docs/environment.yaml +++ b/docs/environment.yaml @@ -28,7 +28,7 @@ dependencies: - git+https://github.com/OpenFreeEnergy/gufe@main - git+https://github.com/OpenFreeEnergy/ofe-sphinx-theme@main -# These are added automatically by RTD, so we include them here +# These are added automatically by RTD, so we include them here # for a consistent environment. - mock - pillow diff --git a/docs/guide/execution/index.rst b/docs/guide/execution/index.rst index 741ece248..54414c131 100644 --- a/docs/guide/execution/index.rst +++ b/docs/guide/execution/index.rst @@ -11,4 +11,3 @@ create a :class:`.ProtocolResult`. .. TODO: add information about failures etc... - diff --git a/docs/guide/introduction.rst b/docs/guide/introduction.rst index ba8a0a19a..158a5da10 100644 --- a/docs/guide/introduction.rst +++ b/docs/guide/introduction.rst @@ -1,6 +1,6 @@ .. _guide-introduction: -Introduction +Introduction ============ Here we present an overview of the workflow for calculating free energies in diff --git a/docs/guide/models/execution.rst b/docs/guide/models/execution.rst index 3662b5ac4..41d8ee2c2 100644 --- a/docs/guide/models/execution.rst +++ b/docs/guide/models/execution.rst @@ -1,7 +1,7 @@ Protocols and the Execution Model ================================= -Protocols in OpenFE are built on a flexible execution model. +Protocols in OpenFE are built on a flexible execution model. Result objects are shaped by this model, and therefore some basic background on it can be useful when looking into the details of simulation results. In general, most users don't need to work with the details of the diff --git a/docs/guide/setup/define_ligand_network.rst b/docs/guide/setup/define_ligand_network.rst index e02850838..ed0218499 100644 --- a/docs/guide/setup/define_ligand_network.rst +++ b/docs/guide/setup/define_ligand_network.rst @@ -2,5 +2,3 @@ Defining the Ligand Network =========================== - - diff --git a/docs/guide/setup/index.rst b/docs/guide/setup/index.rst index 1e829acbf..492ca8638 100644 --- a/docs/guide/setup/index.rst +++ b/docs/guide/setup/index.rst @@ -33,4 +33,3 @@ for more details. :hidden: define_ligand_network - diff --git a/docs/index.rst b/docs/index.rst index 1e465dbf6..f2dea9e33 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,7 +24,7 @@ The **OpenFE** toolkit provides open-source frameworks for calculating alchemica :link-type: doc New to *OpenFE*? Check out our installation guide to get it working on your machine! - + .. grid-item-card:: Tutorials :img-top: _static/Tutorial.svg :text-align: center @@ -32,7 +32,7 @@ The **OpenFE** toolkit provides open-source frameworks for calculating alchemica :link-type: doc Worked through examples of how to use the OpenFE toolkit. - + .. grid-item-card:: User Guide :img-top: _static/UserGuide.svg :text-align: center @@ -40,7 +40,7 @@ The **OpenFE** toolkit provides open-source frameworks for calculating alchemica :link-type: doc Learn about the underlying concepts of the OpenFE toolkit. - + .. grid-item-card:: API Reference :img-top: _static/API.svg :text-align: center @@ -56,7 +56,7 @@ The **OpenFE** toolkit provides open-source frameworks for calculating alchemica :link-type: doc How-to guides on how to utilise the toolkit components. - + .. grid-item-card:: Using the CLI :img-top: _static/CLI.svg :text-align: center @@ -72,7 +72,7 @@ The **OpenFE** toolkit provides open-source frameworks for calculating alchemica :link-type: doc Tutorial notebook showing the sorts of things OpenFE can do. - + .. grid-item-card:: Relative Free Energy Protocol :img-top: _static/Rocket.svg :text-align: center diff --git a/docs/installation.rst b/docs/installation.rst index 5297cafd4..b2be6f86f 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -9,7 +9,7 @@ installation is working. is tested against Python 3.9, 3.10, and 3.11. When you install ``openfe`` through any of the methods described below, you -will install both the core library and the command line interface (CLI). +will install both the core library and the command line interface (CLI). If you already have a Mamba installation, you can install ``openfe`` with: @@ -120,7 +120,7 @@ Now we need to activate our new environment :: environment, as one of our requirements is not yet available for Apple Silicon. Run the following modified commands - .. parsed-literal:: + .. parsed-literal:: CONDA_SUBDIR=osx-64 mamba create -n openfe_env openfe=\ |version| mamba activate openfe_env @@ -136,7 +136,7 @@ skipped, or xfailed (expected fail). The very first time you run this, the initial check that you can import ``openfe`` will take a while, because some code is compiled the first time it is encountered. That compilation only happens once per installation. - + With that, you should be ready to use ``openfe``! Single file installer @@ -144,7 +144,7 @@ Single file installer .. _releases on GitHub: https://github.com/OpenFreeEnergy/openfe/releases -Single file installers are available for x86_64 Linux and MacOS. +Single file installers are available for x86_64 Linux and MacOS. They are attached to our `releases on GitHub`_ and can be downloaded with a browser or ``curl`` (or similar tool). For example, the Linux installer can be downloaded with :: @@ -158,7 +158,7 @@ The single file installer contains all of the dependencies required for ``openfe Both ``conda`` and ``mamba`` are also available in the environment created by the single file installer and can be used to install additional packages. The installer can be installed in batch mode or interactively :: - + $ chmod +x ./OpenFEforge-Linux-x86_64.sh # Make installer executable $ ./OpenFEforge-Linux-x86_64.sh # Run the installer @@ -167,27 +167,27 @@ Example installer output is shown below (click to expand "Installer Output") .. collapse:: Installer Output .. code-block:: - + Welcome to OpenFEforge 0.7.4 - + In order to continue the installation process, please review the license agreement. Please, press ENTER to continue >>> MIT License - + Copyright (c) 2022 OpenFreeEnergy - + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - + The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE @@ -195,41 +195,41 @@ Example installer output is shown below (click to expand "Installer Output") LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - + + Do you accept the license terms? [yes|no] [no] >>> yes - - .. note:: + + .. note:: The install location will be different when you run the installer. - + .. code-block:: - + OpenFEforge will now be installed into this location: /home/mmh/openfeforge - + - Press ENTER to confirm the location - Press CTRL-C to abort the installation - Or specify a different location below - + [/home/mmh/openfeforge] >>> PREFIX=/home/mmh/openfeforge Unpacking payload ... - + Installing base environment... - - + + Downloading and Extracting Packages - - + + Downloading and Extracting Packages - + Preparing transaction: done Executing transaction: \ By downloading and using the CUDA Toolkit conda packages, you accept the terms and conditions of the CUDA End User License Agreement (EULA): https://docs.nvidia.com/cuda/eula/index.html - + | Enabling notebook extension jupyter-js-widgets/extension... - Validating: OK - + done installation finished. Do you wish the installer to initialize OpenFEforge @@ -247,10 +247,10 @@ Example installer output is shown below (click to expand "Installer Output") no change /home/mmh/openfeforge/lib/python3.9/site-packages/xontrib/conda.xsh no change /home/mmh/openfeforge/etc/profile.d/conda.csh modified /home/mmh/.bashrc - + ==> For changes to take effect, close and re-open your current shell. <== - - + + __ __ __ __ / \ / \ / \ / \ / \/ \/ \/ \ @@ -265,14 +265,14 @@ Example installer output is shown below (click to expand "Installer Output") ██║╚██╔╝██║██╔══██║██║╚██╔╝██║██╔══██╗██╔══██║ ██║ ╚═╝ ██║██║ ██║██║ ╚═╝ ██║██████╔╝██║ ██║ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝╚═════╝ ╚═╝ ╚═╝ - + mamba (1.4.2) supported by @QuantStack - + GitHub: https://github.com/mamba-org/mamba Twitter: https://twitter.com/QuantStack - + █████████████████████████████████████████████████████████████ - + no change /home/mmh/openfeforge/condabin/conda no change /home/mmh/openfeforge/bin/conda no change /home/mmh/openfeforge/bin/conda-env @@ -287,17 +287,17 @@ Example installer output is shown below (click to expand "Installer Output") no change /home/mmh/.bashrc No action taken. Added mamba to /home/mmh/.bashrc - + ==> For changes to take effect, close and re-open your current shell. <== - + If you'd prefer that conda's base environment not be activated on startup, set the auto_activate_base parameter to false: - + conda config --set auto_activate_base false - + Thank you for installing OpenFEforge! -After the installer completes, close and reopen your shell. +After the installer completes, close and reopen your shell. To check if your path is setup correctly, run ``which python`` your output should look something like this :: (base) $ which python @@ -310,22 +310,22 @@ Now the CLI tool should work as well :: (base) $ openfe --help Usage: openfe [OPTIONS] COMMAND [ARGS]... - + This is the command line tool to provide easy access to functionality from the OpenFE Python library. - + Options: --version Show the version and exit. --log PATH logging configuration file -h, --help Show this message and exit. - + Setup Commands: atommapping Check the atom mapping of a given pair of ligands plan-rhfe-network Plan a relative hydration free energy network, saved in a dir with multiple JSON files plan-rbfe-network Plan a relative binding free energy network, saved in a dir with multiple JSON files. - + Simulation Commands: gather Gather DAG result jsons for network of RFE results into single TSV file @@ -338,7 +338,7 @@ To make sure everything is working, run the tests :: The test suite contains several hundred individual tests. This will take a few minutes, and all tests should complete with status either passed, skipped, or xfailed (expected fail). - + With that, you should be ready to use ``openfe``! Containers @@ -356,7 +356,7 @@ The Apptainer image is pre-built and can be pulled with :: .. warning:: For production use, we recommend using version tags to prevent disruptions in workflows e.g. - + .. parsed-literal:: $ docker pull ghcr.io/openfreeenergy/openfe:\ |version| @@ -366,7 +366,7 @@ We recommend testing the container to ensure that it can access a GPU (if desire This can be done with the following command :: $ singularity run --nv openfe_latest-apptainer.sif python -m openmm.testInstallation - + OpenMM Version: 8.0 Git Revision: a7800059645f4471f4b91c21e742fe5aa4513cda @@ -385,7 +385,7 @@ This can be done with the following command :: All differences are within tolerance. The ``--nv`` flag is required for the Apptainer image to access the GPU on the host. -Your output may produce different values for the forces, but should list the CUDA platform if everything is working properly. +Your output may produce different values for the forces, but should list the CUDA platform if everything is working properly. You can access the ``openfe`` CLI from the Singularity image with :: @@ -398,7 +398,7 @@ To make sure everything is working, run the tests :: The test suite contains several hundred individual tests. This will take a few minutes, and all tests should complete with status either passed, skipped, or xfailed (expected fail). - + With that, you should be ready to use ``openfe``! Developer install @@ -500,7 +500,7 @@ For example, on a login node where there likely is not a GPU or a CUDA environme Now if we run the same command on a HPC node that has a GPU :: $ mamba info - + mamba version : 1.5.1 active environment : base active env location : /lila/home/henrym3/mamba/envs/QA-openfe-0.14.0 diff --git a/docs/tutorials/.gitignore b/docs/tutorials/.gitignore index 0d7d434ae..14a319608 100644 --- a/docs/tutorials/.gitignore +++ b/docs/tutorials/.gitignore @@ -1,2 +1,2 @@ assets/ -inputs/ \ No newline at end of file +inputs/ diff --git a/openfe/__init__.py b/openfe/__init__.py index 0b11ed153..69411add8 100644 --- a/openfe/__init__.py +++ b/openfe/__init__.py @@ -1,36 +1,35 @@ +from importlib.metadata import version + from gufe import ( + AlchemicalNetwork, ChemicalSystem, Component, + LigandAtomMapping, ProteinComponent, SmallMoleculeComponent, SolventComponent, Transformation, - AlchemicalNetwork, - LigandAtomMapping, ) from gufe.protocols import ( Protocol, ProtocolDAG, - ProtocolUnit, - ProtocolUnitResult, ProtocolUnitFailure, ProtocolDAGResult, ProtocolResult, + ProtocolUnit, + ProtocolUnitFailure, + ProtocolUnitResult, execute_DAG, ) -from . import utils -from . import setup +from . import analysis, orchestration, setup, utils from .setup import ( + LigandAtomMapper, + LigandNetwork, LomapAtomMapper, - lomap_scorers, PersesAtomMapper, - perses_scorers, ligand_network_planning, - LigandNetwork, - LigandAtomMapper, + lomap_scorers, + perses_scorers, ) -from . import orchestration -from . import analysis -from importlib.metadata import version __version__ = version("openfe") diff --git a/openfe/analysis/plotting.py b/openfe/analysis/plotting.py index cba66aaef..fffb225f6 100644 --- a/openfe/analysis/plotting.py +++ b/openfe/analysis/plotting.py @@ -1,12 +1,13 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe from itertools import chain +from typing import Optional, Union + import matplotlib.pyplot as plt -from matplotlib.axes import Axes import numpy as np import numpy.typing as npt +from matplotlib.axes import Axes from openff.units import unit -from typing import Optional, Union def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes: @@ -25,47 +26,54 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes: """ num_states = len(matrix) fig, ax = plt.subplots(figsize=(num_states / 2, num_states / 2)) - ax.axis('off') + ax.axis("off") for i in range(num_states): if i != 0: ax.axvline(x=i, ls="-", lw=0.5, color="k", alpha=0.25) ax.axhline(y=i, ls="-", lw=0.5, color="k", alpha=0.25) for j in range(num_states): val = matrix[i, j] - val_str = "{:.2f}".format(val)[1:] + val_str = f"{val:.2f}"[1:] rel_prob = val / matrix.max() # shade box ax.fill_between( - [i, i+1], [num_states - j, num_states - j], + [i, i + 1], + [num_states - j, num_states - j], [num_states - (j + 1), num_states - (j + 1)], - color='k', alpha=rel_prob + color="k", + alpha=rel_prob, ) # annotate box ax.annotate( - val_str, xy=(i, j), xytext=(i+0.5, num_states - (j + 0.5)), - size=8, va="center", ha="center", + val_str, + xy=(i, j), + xytext=(i + 0.5, num_states - (j + 0.5)), + size=8, + va="center", + ha="center", color=("k" if rel_prob < 0.5 else "w"), ) # anotate axes - base_settings = { - 'size': 10, 'va': 'center', 'ha': 'center', 'color': 'k', - 'family': 'sans-serif' - } + base_settings = {"size": 10, "va": "center", "ha": "center", "color": "k", "family": "sans-serif"} for i in range(num_states): ax.annotate( - i, xy=(i + 0.5, 1), xytext=(i + 0.5, num_states + 0.5), + i, + xy=(i + 0.5, 1), + xytext=(i + 0.5, num_states + 0.5), **base_settings, ) ax.annotate( - i, xy=(-0.5, num_states - (num_states - 0.5)), + i, + xy=(-0.5, num_states - (num_states - 0.5)), xytext=(-0.5, num_states - (i + 0.5)), **base_settings, ) ax.annotate( - r"$\lambda$", xy=(-0.5, num_states - (num_states - 0.5)), + r"$\lambda$", + xy=(-0.5, num_states - (num_states - 0.5)), xytext=(-0.5, num_states + 0.5), **base_settings, ) @@ -79,10 +87,7 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes: return ax -def plot_convergence( - forward_and_reverse: dict[str, Union[npt.NDArray, unit.Quantity]], - units: unit.Quantity -) -> Axes: +def plot_convergence(forward_and_reverse: dict[str, Union[npt.NDArray, unit.Quantity]], units: unit.Quantity) -> Axes: """ Plot a Reverse and Forward convergence analysis of the free energies. @@ -102,18 +107,20 @@ def plot_convergence( An Axes object to plot. """ known_units = { - 'kilojoule_per_mole': 'kJ/mol', - 'kilojoules_per_mole': 'kJ/mol', - 'kilocalorie_per_mole': 'kcal/mol', - 'kilocalories_per_mole': 'kcal/mol', + "kilojoule_per_mole": "kJ/mol", + "kilojoules_per_mole": "kJ/mol", + "kilocalorie_per_mole": "kcal/mol", + "kilocalories_per_mole": "kcal/mol", } try: plt_units = known_units[str(units)] except KeyError: - errmsg = (f"Unknown plotting units {units} passed, acceptable " - "values are kilojoule(s)_per_mole and " - "kilocalorie(s)_per_mole") + errmsg = ( + f"Unknown plotting units {units} passed, acceptable " + "values are kilojoule(s)_per_mole and " + "kilocalorie(s)_per_mole" + ) raise ValueError(errmsg) fig, ax = plt.subplots(figsize=(8, 6)) @@ -129,38 +136,43 @@ def plot_convergence( ax.yaxis.set_ticks_position("left") # Set the overall error bar to the final error for the reverse results - overall_error = forward_and_reverse['reverse_dDGs'][-1].m - final_value = forward_and_reverse['reverse_DGs'][-1].m - ax.fill_between([0, 1], - final_value - overall_error, - final_value + overall_error, - color='#D2B9D3', zorder=1) + overall_error = forward_and_reverse["reverse_dDGs"][-1].m + final_value = forward_and_reverse["reverse_DGs"][-1].m + ax.fill_between([0, 1], final_value - overall_error, final_value + overall_error, color="#D2B9D3", zorder=1) ax.errorbar( - forward_and_reverse['fractions'], - [val.m - for val in forward_and_reverse['forward_DGs']], - yerr=[err.m - for err in forward_and_reverse['forward_dDGs']], - color="#736AFF", lw=3, zorder=2, - marker="o", mfc="w", mew=2.5, - mec="#736AFF", ms=8, label='Forward' + forward_and_reverse["fractions"], + [val.m for val in forward_and_reverse["forward_DGs"]], + yerr=[err.m for err in forward_and_reverse["forward_dDGs"]], + color="#736AFF", + lw=3, + zorder=2, + marker="o", + mfc="w", + mew=2.5, + mec="#736AFF", + ms=8, + label="Forward", ) ax.errorbar( - forward_and_reverse['fractions'], - [val.m - for val in forward_and_reverse['reverse_DGs']], - yerr=[err.m - for err in forward_and_reverse['reverse_dDGs']], - color="#C11B17", lw=3, zorder=2, - marker="o", mfc="w", mew=2.5, - mec="#C11B17", ms=8, label='Reverse', + forward_and_reverse["fractions"], + [val.m for val in forward_and_reverse["reverse_DGs"]], + yerr=[err.m for err in forward_and_reverse["reverse_dDGs"]], + color="#C11B17", + lw=3, + zorder=2, + marker="o", + mfc="w", + mew=2.5, + mec="#C11B17", + ms=8, + label="Reverse", ) ax.legend(frameon=False) - ax.set_ylabel(r'$\Delta G$' + f' ({plt_units})') - ax.set_xlabel('Fraction of uncorrelated samples') + ax.set_ylabel(r"$\Delta G$" + f" ({plt_units})") + ax.set_xlabel("Fraction of uncorrelated samples") return ax @@ -190,24 +202,20 @@ def plot_replica_timeseries( iterations = [i for i in range(len(state_timeseries))] for i in range(num_states): - ax.scatter(iterations, state_timeseries.T[i], label=f'replica {i}', s=8) + ax.scatter(iterations, state_timeseries.T[i], label=f"replica {i}", s=8) ax.set_xlabel("Number of simulation iterations") ax.set_ylabel("Lambda state") ax.set_title("Change in replica lambda state over time") if equilibration_iterations is not None: - ax.axvline( - x=equilibration_iterations, color='grey', - linestyle='--', label='equilibration limit' - ) + ax.axvline(x=equilibration_iterations, color="grey", linestyle="--", label="equilibration limit") - ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) return ax -def plot_2D_rmsd(data: list[list[float]], - vmax=5.0) -> plt.Figure: +def plot_2D_rmsd(data: list[list[float]], vmax=5.0) -> plt.Figure: """Plots 2D RMSD for many states Parameters @@ -240,20 +248,14 @@ def plot_2D_rmsd(data: list[list[float]], fig, axes = plt.subplots(nrows, 4) - for i, (arr, ax) in enumerate( - zip(twod_rmsd_arrs, chain.from_iterable(axes))): - ax.imshow(arr, - vmin=0, vmax=vmax, - cmap=plt.get_cmap('cividis')) - ax.axis('off') # turn off ticks/labels - ax.set_title(f'State {i}') + for i, (arr, ax) in enumerate(zip(twod_rmsd_arrs, chain.from_iterable(axes))): + ax.imshow(arr, vmin=0, vmax=vmax, cmap=plt.get_cmap("cividis")) + ax.axis("off") # turn off ticks/labels + ax.set_title(f"State {i}") - plt.colorbar(axes[0][0].images[0], - cax=axes[-1][-1], - label="RMSD scale (A)", - orientation="horizontal") + plt.colorbar(axes[0][0].images[0], cax=axes[-1][-1], label="RMSD scale (A)", orientation="horizontal") - fig.suptitle('Protein 2D RMSD') + fig.suptitle("Protein 2D RMSD") fig.tight_layout() return fig @@ -263,12 +265,12 @@ def plot_ligand_COM_drift(time: list[float], data: list[list[float]]): fig, ax = plt.subplots() for i, s in enumerate(data): - ax.plot(time, s, label=f'State {i}') + ax.plot(time, s, label=f"State {i}") - ax.legend(loc='upper left') - ax.set_xlabel('Time (ps)') - ax.set_ylabel('Distance (A)') - ax.set_title('Ligand COM drift') + ax.legend(loc="upper left") + ax.set_xlabel("Time (ps)") + ax.set_ylabel("Distance (A)") + ax.set_title("Ligand COM drift") return fig @@ -277,11 +279,11 @@ def plot_ligand_RMSD(time: list[float], data: list[list[float]]): fig, ax = plt.subplots() for i, s in enumerate(data): - ax.plot(time, s, label=f'State {i}') + ax.plot(time, s, label=f"State {i}") - ax.legend(loc='upper left') - ax.set_xlabel('Time (ps)') - ax.set_ylabel('RMSD (A)') - ax.set_title('Ligand RMSD') + ax.legend(loc="upper left") + ax.set_xlabel("Time (ps)") + ax.set_ylabel("RMSD (A)") + ax.set_title("Ligand RMSD") return fig diff --git a/openfe/due.py b/openfe/due.py index f729f8438..cdc7d8a7f 100644 --- a/openfe/due.py +++ b/openfe/due.py @@ -24,26 +24,29 @@ License: BSD-2 """ -__version__ = '0.0.9' +__version__ = "0.0.9" -class InactiveDueCreditCollector(object): +class InactiveDueCreditCollector: """Just a stub at the Collector which would not do anything""" + def _donothing(self, *args, **kwargs): """Perform no good and no bad""" pass def dcite(self, *args, **kwargs): """If I could cite I would""" + def nondecorating_decorator(func): return func + return nondecorating_decorator active = False activate = add = cite = dump = load = _donothing def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" def _donothing_func(*args, **kwargs): @@ -52,15 +55,15 @@ def _donothing_func(*args, **kwargs): try: - from duecredit import due, BibTeX, Doi, Url, Text # lgtm [py/unused-import] - if 'due' in locals() and not hasattr(due, 'cite'): - raise RuntimeError( - "Imported due lacks .cite. DueCredit is now disabled") + from duecredit import BibTeX, Doi, Text, Url, due # lgtm [py/unused-import] + + if "due" in locals() and not hasattr(due, "cite"): + raise RuntimeError("Imported due lacks .cite. DueCredit is now disabled") except Exception as e: if not isinstance(e, ImportError): import logging - logging.getLogger("duecredit").error( - "Failed to import duecredit due to %s" % str(e)) + + logging.getLogger("duecredit").error("Failed to import duecredit due to %s" % str(e)) # Initiate due stub due = InactiveDueCreditCollector() BibTeX = Doi = Url = Text = _donothing_func diff --git a/openfe/protocols/openmm_afe/__init__.py b/openfe/protocols/openmm_afe/__init__.py index 40c10e8aa..0ac89e754 100644 --- a/openfe/protocols/openmm_afe/__init__.py +++ b/openfe/protocols/openmm_afe/__init__.py @@ -7,10 +7,10 @@ from .equil_solvation_afe_method import ( AbsoluteSolvationProtocol, - AbsoluteSolvationSettings, AbsoluteSolvationProtocolResult, - AbsoluteSolvationVacuumUnit, + AbsoluteSolvationSettings, AbsoluteSolvationSolventUnit, + AbsoluteSolvationVacuumUnit, ) __all__ = [ diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py index 1761a4dee..446816b39 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/openfe/protocols/openmm_afe/base.py @@ -17,55 +17,44 @@ from __future__ import annotations import abc -import os import logging +import os +import pathlib +from typing import Any, Optional import gufe -from gufe.components import Component +import mdtraj as mdt import numpy as np import numpy.typing as npt import openmm -from openff.units import unit -from openff.units.openmm import from_openmm, to_openmm, ensure_quantity +import openmmtools +from gufe import ChemicalSystem, ProteinComponent, SmallMoleculeComponent, SolventComponent, settings +from gufe.components import Component from openff.toolkit.topology import Molecule as OFFMolecule -from openmmtools import multistate -from openmmtools.states import (SamplerState, - ThermodynamicState, - create_thermodynamic_state_protocol,) -from openmmtools.alchemy import (AlchemicalRegion, AbsoluteAlchemicalFactory, - AlchemicalState,) -from typing import Optional +from openff.units import unit +from openff.units.openmm import ensure_quantity, from_openmm, to_openmm from openmm import app from openmm import unit as omm_unit from openmmforcefields.generators import SystemGenerator -import pathlib -from typing import Any -import openmmtools -import mdtraj as mdt +from openmmtools import multistate +from openmmtools.alchemy import AbsoluteAlchemicalFactory, AlchemicalRegion, AlchemicalState +from openmmtools.states import SamplerState, ThermodynamicState, create_thermodynamic_state_protocol -from gufe import ( - settings, ChemicalSystem, SmallMoleculeComponent, - ProteinComponent, SolventComponent -) -from openfe.protocols.openmm_utils.omm_settings import ( - SettingsBaseModel, -) -from openfe.protocols.openmm_utils.omm_settings import ( - BasePartialChargeSettings, -) from openfe.protocols.openmm_afe.equil_afe_settings import ( BaseSolvationSettings, - MultiStateSimulationSettings, OpenMMEngineSettings, - IntegratorSettings, LambdaSettings, OutputSettings, - ThermoSettings, OpenFFPartialChargeSettings, + IntegratorSettings, + LambdaSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OutputSettings, + ThermoSettings, ) from openfe.protocols.openmm_rfe._rfe_utils import compute -from ..openmm_utils import ( - settings_validation, system_creation, - multistate_analysis, charge_generation -) +from openfe.protocols.openmm_utils.omm_settings import BasePartialChargeSettings, SettingsBaseModel from openfe.utils import without_oechem_backend +from ..openmm_utils import charge_generation, multistate_analysis, settings_validation, system_creation logger = logging.getLogger(__name__) @@ -74,14 +63,18 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): """ Base class for ligand absolute free energy transformations. """ - def __init__(self, *, - protocol: gufe.Protocol, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - alchemical_components: dict[str, list[Component]], - generation: int = 0, - repeat_id: int = 0, - name: Optional[str] = None,): + + def __init__( + self, + *, + protocol: gufe.Protocol, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + alchemical_components: dict[str, list[Component]], + generation: int = 0, + repeat_id: int = 0, + name: str | None = None, + ): """ Parameters ---------- @@ -116,10 +109,11 @@ def __init__(self, *, ) @staticmethod - def _get_alchemical_indices(omm_top: openmm.Topology, - comp_resids: dict[Component, npt.NDArray], - alchem_comps: dict[str, list[Component]] - ) -> list[int]: + def _get_alchemical_indices( + omm_top: openmm.Topology, + comp_resids: dict[Component, npt.NDArray], + alchem_comps: dict[str, list[Component]], + ) -> list[int]: """ Get a list of atom indices for all the alchemical species @@ -139,9 +133,7 @@ def _get_alchemical_indices(omm_top: openmm.Topology, """ # concatenate a list of residue indexes for all alchemical components - residxs = np.concatenate( - [comp_resids[key] for key in alchem_comps['stateA']] - ) + residxs = np.concatenate([comp_resids[key] for key in alchem_comps["stateA"]]) # get the alchemicical atom ids atom_ids = [] @@ -153,8 +145,7 @@ def _get_alchemical_indices(omm_top: openmm.Topology, return atom_ids @staticmethod - def _pre_minimize(system: openmm.System, - positions: omm_unit.Quantity) -> npt.NDArray: + def _pre_minimize(system: openmm.System, positions: omm_unit.Quantity) -> npt.NDArray: """ Short CPU minization of System to avoid GPU NaNs @@ -172,22 +163,22 @@ def _pre_minimize(system: openmm.System, """ integrator = openmm.VerletIntegrator(0.001) context = openmm.Context( - system, integrator, - openmm.Platform.getPlatformByName('CPU'), + system, + integrator, + openmm.Platform.getPlatformByName("CPU"), ) context.setPositions(positions) # Do a quick 100 steps minimization, usually avoids NaNs - openmm.LocalEnergyMinimizer.minimize( - context, maxIterations=100 - ) + openmm.LocalEnergyMinimizer.minimize(context, maxIterations=100) state = context.getState(getPositions=True) minimized_positions = state.getPositions(asNumpy=True) return minimized_positions def _prepare( - self, verbose: bool, - scratch_basepath: Optional[pathlib.Path], - shared_basepath: Optional[pathlib.Path], + self, + verbose: bool, + scratch_basepath: pathlib.Path | None, + shared_basepath: pathlib.Path | None, ): """ Set basepaths and do some initial logging. @@ -208,17 +199,21 @@ def _prepare( # set basepaths def _set_optional_path(basepath): if basepath is None: - return pathlib.Path('.') + return pathlib.Path(".") return basepath self.scratch_basepath = _set_optional_path(scratch_basepath) self.shared_basepath = _set_optional_path(shared_basepath) @abc.abstractmethod - def _get_components(self) -> tuple[dict[str, list[Component]], - Optional[gufe.SolventComponent], - Optional[gufe.ProteinComponent], - dict[SmallMoleculeComponent, OFFMolecule]]: + def _get_components( + self, + ) -> tuple[ + dict[str, list[Component]], + gufe.SolventComponent | None, + gufe.ProteinComponent | None, + dict[SmallMoleculeComponent, OFFMolecule], + ]: """ Get the relevant components to create the alchemical system with. @@ -254,8 +249,9 @@ def _handle_settings(self): ... def _get_system_generator( - self, settings: dict[str, SettingsBaseModel], - solvent_comp: Optional[SolventComponent] + self, + settings: dict[str, SettingsBaseModel], + solvent_comp: SolventComponent | None, ) -> SystemGenerator: """ Get a system generator through the system creation @@ -273,7 +269,7 @@ def _get_system_generator( system_generator : openmmforcefields.generator.SystemGenerator System Generator to parameterise this unit. """ - ffcache = settings['output_settings'].forcefield_cache + ffcache = settings["output_settings"].forcefield_cache if ffcache is not None: ffcache = self.shared_basepath / ffcache @@ -281,9 +277,9 @@ def _get_system_generator( # smiles roundtripping between rdkit and oechem with without_oechem_backend(): system_generator = system_creation.get_system_generator( - forcefield_settings=settings['forcefield_settings'], - integrator_settings=settings['integrator_settings'], - thermo_settings=settings['thermo_settings'], + forcefield_settings=settings["forcefield_settings"], + integrator_settings=settings["integrator_settings"], + thermo_settings=settings["thermo_settings"], cache=ffcache, has_solvent=solvent_comp is not None, ) @@ -317,12 +313,12 @@ def _assign_partial_charges( def _get_modeller( self, - protein_component: Optional[ProteinComponent], - solvent_component: Optional[SolventComponent], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, smc_components: dict[SmallMoleculeComponent, OFFMolecule], system_generator: SystemGenerator, partial_charge_settings: BasePartialChargeSettings, - solvation_settings: BaseSolvationSettings + solvation_settings: BaseSolvationSettings, ) -> tuple[app.Modeller, dict[Component, npt.NDArray]]: """ Get an OpenMM Modeller object and a list of residue indices @@ -367,9 +363,7 @@ def _get_modeller( # smiles roundtripping between rdkit and oechem with without_oechem_backend(): for mol in smc_components.values(): - system_generator.create_system( - mol.to_topology().to_openmm(), molecules=[mol] - ) + system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) # get OpenMM modeller + dictionary of resids for each component system_modeller, comp_resids = system_creation.get_omm_modeller( @@ -424,9 +418,7 @@ def _get_omm_objects( ) return topology, system, positions - def _get_lambda_schedule( - self, settings: dict[str, SettingsBaseModel] - ) -> dict[str, npt.NDArray]: + def _get_lambda_schedule(self, settings: dict[str, SettingsBaseModel]) -> dict[str, npt.NDArray]: """ Create the lambda schedule @@ -446,15 +438,15 @@ def _get_lambda_schedule( """ lambdas = dict() - lambda_elec = settings['lambda_settings'].lambda_elec - lambda_vdw = settings['lambda_settings'].lambda_vdw + lambda_elec = settings["lambda_settings"].lambda_elec + lambda_vdw = settings["lambda_settings"].lambda_vdw # Reverse lambda schedule since in AbsoluteAlchemicalFactory 1 # means fully interacting, not stateB - lambda_elec = [1-x for x in lambda_elec] - lambda_vdw = [1-x for x in lambda_vdw] - lambdas['lambda_electrostatics'] = lambda_elec - lambdas['lambda_sterics'] = lambda_vdw + lambda_elec = [1 - x for x in lambda_elec] + lambda_vdw = [1 - x for x in lambda_vdw] + lambdas["lambda_electrostatics"] = lambda_elec + lambdas["lambda_sterics"] = lambda_vdw return lambdas @@ -469,7 +461,7 @@ def _get_alchemical_system( topology: app.Topology, system: openmm.System, comp_resids: dict[Component, npt.NDArray], - alchem_comps: dict[str, list[Component]] + alchem_comps: dict[str, list[Component]], ) -> tuple[AbsoluteAlchemicalFactory, openmm.System, list[int]]: """ Get an alchemically modified system and its associated factory @@ -500,18 +492,14 @@ def _get_alchemical_system( ---- * Add support for all alchemical factory options """ - alchemical_indices = self._get_alchemical_indices( - topology, comp_resids, alchem_comps - ) + alchemical_indices = self._get_alchemical_indices(topology, comp_resids, alchem_comps) alchemical_region = AlchemicalRegion( alchemical_atoms=alchemical_indices, ) alchemical_factory = AbsoluteAlchemicalFactory() - alchemical_system = alchemical_factory.create_alchemical_system( - system, alchemical_region - ) + alchemical_system = alchemical_factory.create_alchemical_system(system, alchemical_region) return alchemical_factory, alchemical_system, alchemical_indices @@ -521,7 +509,7 @@ def _get_states( positions: openmm.unit.Quantity, settings: dict[str, SettingsBaseModel], lambdas: dict[str, npt.NDArray], - solvent_comp: Optional[SolventComponent], + solvent_comp: SolventComponent | None, ) -> tuple[list[SamplerState], list[ThermodynamicState]]: """ Get a list of sampler and thermodynmic states from an @@ -549,16 +537,18 @@ def _get_states( """ alchemical_state = AlchemicalState.from_system(alchemical_system) # Set up the system constants - temperature = settings['thermo_settings'].temperature - pressure = settings['thermo_settings'].pressure + temperature = settings["thermo_settings"].temperature + pressure = settings["thermo_settings"].pressure constants = dict() - constants['temperature'] = ensure_quantity(temperature, 'openmm') + constants["temperature"] = ensure_quantity(temperature, "openmm") if solvent_comp is not None: - constants['pressure'] = ensure_quantity(pressure, 'openmm') + constants["pressure"] = ensure_quantity(pressure, "openmm") cmp_states = create_thermodynamic_state_protocol( - alchemical_system, protocol=lambdas, - constants=constants, composable_states=[alchemical_state], + alchemical_system, + protocol=lambdas, + constants=constants, + composable_states=[alchemical_state], ) sampler_state = SamplerState(positions=positions) @@ -599,9 +589,7 @@ def _get_reporter( """ mdt_top = mdt.Topology.from_openmm(topology) - selection_indices = mdt_top.select( - output_settings.output_indices - ) + selection_indices = mdt_top.select(output_settings.output_indices) nc = self.shared_basepath / output_settings.output_filename chk = output_settings.checkpoint_storage_filename @@ -623,15 +611,13 @@ def _get_reporter( positions[selection_indices, :], mdt_top.subset(selection_indices), ) - traj.save_pdb( - self.shared_basepath / output_settings.output_structure - ) + traj.save_pdb(self.shared_basepath / output_settings.output_structure) return reporter def _get_ctx_caches( self, - engine_settings: OpenMMEngineSettings + engine_settings: OpenMMEngineSettings, ) -> tuple[openmmtools.cache.ContextCache, openmmtools.cache.ContextCache]: """ Set the context caches based on the chosen platform @@ -652,11 +638,15 @@ def _get_ctx_caches( ) energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, + capacity=None, + time_to_live=None, + platform=platform, ) sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, + capacity=None, + time_to_live=None, + platform=platform, ) return energy_context_cache, sampler_context_cache @@ -664,7 +654,7 @@ def _get_ctx_caches( @staticmethod def _get_integrator( integrator_settings: IntegratorSettings, - simulation_settings: MultiStateSimulationSettings + simulation_settings: MultiStateSimulationSettings, ) -> openmmtools.mcmc.LangevinDynamicsMove: """ Return a LangevinDynamicsMove integrator @@ -679,9 +669,7 @@ def _get_integrator( integrator : openmmtools.mcmc.LangevinDynamicsMove A configured integrator object. """ - steps_per_iteration = settings_validation.convert_steps_per_iteration( - simulation_settings, integrator_settings - ) + steps_per_iteration = settings_validation.convert_steps_per_iteration(simulation_settings, integrator_settings) integrator = openmmtools.mcmc.LangevinDynamicsMove( timestep=to_openmm(integrator_settings.timestep), @@ -703,7 +691,7 @@ def _get_sampler( cmp_states: list[ThermodynamicState], sampler_states: list[SamplerState], energy_context_cache: openmmtools.cache.ContextCache, - sampler_context_cache: openmmtools.cache.ContextCache + sampler_context_cache: openmmtools.cache.ContextCache, ) -> multistate.MultiStateSampler: """ Get a sampler based on the equilibrium sampling method requested. @@ -747,7 +735,7 @@ def _get_sampler( mcmc_moves=integrator, online_analysis_interval=rta_its, online_analysis_target_error=et_target_err, - online_analysis_minimum_iterations=rta_min_its + online_analysis_minimum_iterations=rta_min_its, ) elif simulation_settings.sampler_method.lower() == "sams": sampler = multistate.SAMSSampler( @@ -757,7 +745,7 @@ def _get_sampler( flatness_criteria=simulation_settings.sams_flatness_criteria, gamma0=simulation_settings.sams_gamma0, ) - elif simulation_settings.sampler_method.lower() == 'independent': + elif simulation_settings.sampler_method.lower() == "independent": sampler = multistate.MultiStateSampler( mcmc_moves=integrator, online_analysis_interval=rta_its, @@ -765,11 +753,7 @@ def _get_sampler( online_analysis_minimum_iterations=rta_min_its, ) - sampler.create( - thermodynamic_states=cmp_states, - sampler_states=sampler_states, - storage=reporter - ) + sampler.create(thermodynamic_states=cmp_states, sampler_states=sampler_states, storage=reporter) sampler.energy_context_cache = energy_context_cache sampler.sampler_context_cache = sampler_context_cache @@ -781,7 +765,7 @@ def _run_simulation( sampler: multistate.MultiStateSampler, reporter: multistate.MultiStateReporter, settings: dict[str, SettingsBaseModel], - dry: bool + dry: bool, ): """ Run the simulation. @@ -805,18 +789,18 @@ def _run_simulation( """ # Get the relevant simulation steps mc_steps = settings_validation.convert_steps_per_iteration( - simulation_settings=settings['simulation_settings'], - integrator_settings=settings['integrator_settings'], + simulation_settings=settings["simulation_settings"], + integrator_settings=settings["integrator_settings"], ) equil_steps = settings_validation.get_simsteps( - sim_length=settings['simulation_settings'].equilibration_length, - timestep=settings['integrator_settings'].timestep, + sim_length=settings["simulation_settings"].equilibration_length, + timestep=settings["integrator_settings"].timestep, mc_steps=mc_steps, ) prod_steps = settings_validation.get_simsteps( - sim_length=settings['simulation_settings'].production_length, - timestep=settings['integrator_settings'].timestep, + sim_length=settings["simulation_settings"].production_length, + timestep=settings["integrator_settings"].timestep, mc_steps=mc_steps, ) @@ -824,9 +808,7 @@ def _run_simulation( # minimize if self.verbose: self.logger.info("minimizing systems") - sampler.minimize( - max_iterations=settings['simulation_settings'].minimization_steps - ) + sampler.minimize(max_iterations=settings["simulation_settings"].minimization_steps) # equilibrate if self.verbose: self.logger.info("equilibrating systems") @@ -846,8 +828,8 @@ def _run_simulation( analyzer = multistate_analysis.MultistateEquilFEAnalysis( reporter, - sampling_method=settings['simulation_settings'].sampler_method.lower(), - result_units=unit.kilocalorie_per_mole + sampling_method=settings["simulation_settings"].sampler_method.lower(), + result_units=unit.kilocalorie_per_mole, ) analyzer.plot(filepath=self.shared_basepath, filename_prefix="") analyzer.close() @@ -859,15 +841,16 @@ def _run_simulation( reporter.close() # clean up the reporter file - fns = [self.shared_basepath / settings['output_settings'].output_filename, - self.shared_basepath / settings['output_settings'].checkpoint_storage_filename] + fns = [ + self.shared_basepath / settings["output_settings"].output_filename, + self.shared_basepath / settings["output_settings"].checkpoint_storage_filename, + ] for fn in fns: os.remove(fn) return None - def run(self, dry=False, verbose=True, - scratch_basepath=None, shared_basepath=None) -> dict[str, Any]: + def run(self, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None) -> dict[str, Any]: """Run the absolute free energy calculation. Parameters @@ -904,14 +887,19 @@ def run(self, dry=False, verbose=True, # 4. Get modeller system_modeller, comp_resids = self._get_modeller( - prot_comp, solv_comp, smc_comps, system_generator, - settings['charge_settings'], - settings['solvation_settings'], + prot_comp, + solv_comp, + smc_comps, + system_generator, + settings["charge_settings"], + settings["solvation_settings"], ) # 5. Get OpenMM topology, positions and system omm_topology, omm_system, positions = self._get_omm_objects( - system_modeller, system_generator, list(smc_comps.values()) + system_modeller, + system_generator, + list(smc_comps.values()), ) # 6. Pre-minimize System (Test + Avoid NaNs) @@ -925,47 +913,48 @@ def run(self, dry=False, verbose=True, # 9. Get alchemical system alchem_factory, alchem_system, alchem_indices = self._get_alchemical_system( - omm_topology, omm_system, comp_resids, alchem_comps + omm_topology, + omm_system, + comp_resids, + alchem_comps, ) # 10. Get compound and sampler states - sampler_states, cmp_states = self._get_states( - alchem_system, positions, settings, - lambdas, solv_comp - ) + sampler_states, cmp_states = self._get_states(alchem_system, positions, settings, lambdas, solv_comp) # 11. Create the multistate reporter & create PDB reporter = self._get_reporter( - omm_topology, positions, - settings['simulation_settings'], - settings['output_settings'], + omm_topology, + positions, + settings["simulation_settings"], + settings["output_settings"], ) # Wrap in try/finally to avoid memory leak issues try: # 12. Get context caches - energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches( - settings['engine_settings'] - ) + energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches(settings["engine_settings"]) # 13. Get integrator integrator = self._get_integrator( - settings['integrator_settings'], - settings['simulation_settings'], + settings["integrator_settings"], + settings["simulation_settings"], ) # 14. Get sampler sampler = self._get_sampler( - integrator, reporter, settings['simulation_settings'], - settings['thermo_settings'], - cmp_states, sampler_states, - energy_ctx_cache, sampler_ctx_cache + integrator, + reporter, + settings["simulation_settings"], + settings["thermo_settings"], + cmp_states, + sampler_states, + energy_ctx_cache, + sampler_ctx_cache, ) # 15. Run simulation - unit_result_dict = self._run_simulation( - sampler, reporter, settings, dry - ) + unit_result_dict = self._run_simulation(sampler, reporter, settings, dry) finally: # close reporter when you're done to prevent file handle clashes @@ -978,8 +967,7 @@ def run(self, dry=False, verbose=True, for context in list(sampler_ctx_cache._lru._data.keys()): del sampler_ctx_cache._lru._data[context] # cautiously clear out the global context cache too - for context in list( - openmmtools.cache.global_context_cache._lru._data.keys()): + for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] del sampler_ctx_cache, energy_ctx_cache @@ -989,12 +977,12 @@ def run(self, dry=False, verbose=True, del integrator, sampler if not dry: - nc = self.shared_basepath / settings['output_settings'].output_filename - chk = settings['output_settings'].checkpoint_storage_filename + nc = self.shared_basepath / settings["output_settings"].output_filename + chk = settings["output_settings"].checkpoint_storage_filename return { - 'nc': nc, - 'last_checkpoint': chk, + "nc": nc, + "last_checkpoint": chk, **unit_result_dict, } else: - return {'debug': {'sampler': sampler}} + return {"debug": {"sampler": sampler}} diff --git a/openfe/protocols/openmm_afe/equil_afe_settings.py b/openfe/protocols/openmm_afe/equil_afe_settings.py index 9f1df2435..d9e495d4c 100644 --- a/openfe/protocols/openmm_afe/equil_afe_settings.py +++ b/openfe/protocols/openmm_afe/equil_afe_settings.py @@ -15,21 +15,18 @@ * Add support for restraints """ -from gufe.settings import ( - SettingsBaseModel, - OpenMMSystemGeneratorFFSettings, - ThermoSettings, -) +import numpy as np +from gufe.settings import OpenMMSystemGeneratorFFSettings, SettingsBaseModel, ThermoSettings + from openfe.protocols.openmm_utils.omm_settings import ( - MultiStateSimulationSettings, BaseSolvationSettings, - OpenMMSolvationSettings, - OpenMMEngineSettings, IntegratorSettings, + MultiStateSimulationSettings, OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, OutputSettings, ) -import numpy as np try: from pydantic.v1 import validator @@ -59,18 +56,55 @@ class LambdaSettings(SettingsBaseModel): the same length, defining all the windows of the transformation. """ + lambda_elec: list[float] = [ - 0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, + 0.25, + 0.5, + 0.75, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, ] """ - List of floats of lambda values for the electrostatics. + List of floats of lambda values for the electrostatics. Zero means state A and 1 means state B. Length of this list needs to match length of lambda_vdw and lambda_restraints. """ lambda_vdw: list[float] = [ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.05, 0.1, 0.2, 0.3, 0.4, - 0.5, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.05, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.65, + 0.7, + 0.75, + 0.8, + 0.85, + 0.9, + 0.95, + 1.0, ] """ List of floats of lambda values for the van der Waals. @@ -78,8 +112,26 @@ class LambdaSettings(SettingsBaseModel): Length of this list needs to match length of lambda_elec and lambda_restraints. """ lambda_restraints: list[float] = [ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, ] """ List of floats of lambda values for the restraints. @@ -87,21 +139,20 @@ class LambdaSettings(SettingsBaseModel): Length of this list needs to match length of lambda_vdw and lambda_elec. """ - @validator('lambda_elec', 'lambda_vdw', 'lambda_restraints') + @validator("lambda_elec", "lambda_vdw", "lambda_restraints") def must_be_between_0_and_1(cls, v): for window in v: if not 0 <= window <= 1: - errmsg = ("Lambda windows must be between 0 and 1, got a" - f" window with value {window}.") + errmsg = "Lambda windows must be between 0 and 1, got a" f" window with value {window}." raise ValueError(errmsg) return v - @validator('lambda_elec', 'lambda_vdw', 'lambda_restraints') + @validator("lambda_elec", "lambda_vdw", "lambda_restraints") def must_be_monotonic(cls, v): difference = np.diff(v) - if not all(i >= 0. for i in difference): + if not all(i >= 0.0 for i in difference): errmsg = f"The lambda schedule is not monotonic, got schedule {v}." raise ValueError(errmsg) @@ -118,14 +169,15 @@ class AbsoluteSolvationSettings(SettingsBaseModel): -------- openfe.protocols.openmm_afe.AbsoluteSolvationProtocol """ + protocol_repeats: int """ - The number of completely independent repeats of the entire sampling - process. The mean of the repeats defines the final estimate of FE - difference, while the variance between repeats is used as the uncertainty. + The number of completely independent repeats of the entire sampling + process. The mean of the repeats defines the final estimate of FE + difference, while the variance between repeats is used as the uncertainty. """ - @validator('protocol_repeats') + @validator("protocol_repeats") def must_be_positive(cls, v): if v <= 0: errmsg = f"protocol_repeats must be a positive value, got {v}." @@ -149,7 +201,7 @@ def must_be_positive(cls, v): """ lambda_settings: LambdaSettings """ - Settings for controlling the lambda schedule for the different components + Settings for controlling the lambda schedule for the different components (vdw, elec, restraints). """ diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index e13fb68e4..34e9416c9 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -28,72 +28,83 @@ """ from __future__ import annotations -import pathlib +import itertools import logging +import pathlib +import uuid import warnings from collections import defaultdict +from collections.abc import Iterable +from typing import Any, Optional, Union + import gufe -from gufe.components import Component -import itertools import numpy as np import numpy.typing as npt +from gufe import ChemicalSystem, ProteinComponent, SmallMoleculeComponent, SolventComponent, settings +from gufe.components import Component from openff.units import unit from openmmtools import multistate -from typing import Optional, Union -from typing import Any, Iterable -import uuid -from gufe import ( - settings, - ChemicalSystem, SmallMoleculeComponent, - ProteinComponent, SolventComponent -) +from openfe.due import Doi, due from openfe.protocols.openmm_afe.equil_afe_settings import ( AbsoluteSolvationSettings, - OpenMMSolvationSettings, AlchemicalSettings, LambdaSettings, - MultiStateSimulationSettings, OpenMMEngineSettings, - IntegratorSettings, OutputSettings, + AlchemicalSettings, + IntegratorSettings, + LambdaSettings, + MultiStateSimulationSettings, OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, + OutputSettings, SettingsBaseModel, ) -from ..openmm_utils import system_validation, settings_validation -from .base import BaseAbsoluteUnit from openfe.utils import log_system_probe -from openfe.due import due, Doi +from ..openmm_utils import settings_validation, system_validation +from .base import BaseAbsoluteUnit -due.cite(Doi("10.5281/zenodo.596504"), - description="Yank", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.596504"), + description="Yank", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) -due.cite(Doi("10.48550/ARXIV.2302.06758"), - description="EspalomaCharge", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True) +due.cite( + Doi("10.48550/ARXIV.2302.06758"), + description="EspalomaCharge", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) -due.cite(Doi("10.5281/zenodo.596622"), - description="OpenMMTools", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) -due.cite(Doi("10.1371/journal.pcbi.1005659"), - description="OpenMM", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True) +due.cite( + Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) logger = logging.getLogger(__name__) class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): - """Dict-like container for the output of a AbsoluteSolvationProtocol - """ + """Dict-like container for the output of a AbsoluteSolvationProtocol""" + def __init__(self, **data): super().__init__(**data) # TODO: Detect when we have extensions and stitch these together? - if any(len(pur_list) > 2 for pur_list - in itertools.chain(self.data['solvent'].values(), self.data['vacuum'].values())): + if any( + len(pur_list) > 2 + for pur_list in itertools.chain(self.data["solvent"].values(), self.data["vacuum"].values()) + ): raise NotImplementedError("Can't stitch together results yet") def get_individual_estimates(self) -> dict[str, list[tuple[unit.Quantity, unit.Quantity]]]: @@ -111,19 +122,13 @@ def get_individual_estimates(self) -> dict[str, list[tuple[unit.Quantity, unit.Q vac_dGs = [] solv_dGs = [] - for pus in self.data['vacuum'].values(): - vac_dGs.append(( - pus[0].outputs['unit_estimate'], - pus[0].outputs['unit_estimate_error'] - )) + for pus in self.data["vacuum"].values(): + vac_dGs.append((pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])) - for pus in self.data['solvent'].values(): - solv_dGs.append(( - pus[0].outputs['unit_estimate'], - pus[0].outputs['unit_estimate_error'] - )) + for pus in self.data["solvent"].values(): + solv_dGs.append((pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])) - return {'solvent': solv_dGs, 'vacuum': vac_dGs} + return {"solvent": solv_dGs, "vacuum": vac_dGs} def get_estimate(self): """Get the solvation free energy estimate for this calculation. @@ -133,6 +138,7 @@ def get_estimate(self): dG : unit.Quantity The solvation free energy. This is a Quantity defined with units. """ + def _get_average(estimates): # Get the unit value of the first value in the estimates u = estimates[0][0].u @@ -143,8 +149,8 @@ def _get_average(estimates): return np.average(dGs) * u individual_estimates = self.get_individual_estimates() - vac_dG = _get_average(individual_estimates['vacuum']) - solv_dG = _get_average(individual_estimates['solvent']) + vac_dG = _get_average(individual_estimates["vacuum"]) + solv_dG = _get_average(individual_estimates["solvent"]) return vac_dG - solv_dG @@ -157,6 +163,7 @@ def get_uncertainty(self): The standard deviation between estimates of the solvation free energy. This is a Quantity defined with units. """ + def _get_stdev(estimates): # Get the unit value of the first value in the estimates u = estimates[0][0].u @@ -167,13 +174,13 @@ def _get_stdev(estimates): return np.std(dGs) * u individual_estimates = self.get_individual_estimates() - vac_err = _get_stdev(individual_estimates['vacuum']) - solv_err = _get_stdev(individual_estimates['solvent']) + vac_err = _get_stdev(individual_estimates["vacuum"]) + solv_err = _get_stdev(individual_estimates["solvent"]) # return the combined error return np.sqrt(vac_err**2 + solv_err**2) - def get_forward_and_reverse_energy_analysis(self) -> dict[str, list[dict[str, Union[npt.NDArray, unit.Quantity]]]]: + def get_forward_and_reverse_energy_analysis(self) -> dict[str, list[dict[str, npt.NDArray | unit.Quantity]]]: """ Get the reverse and forward analysis of the free energies. @@ -195,13 +202,10 @@ def get_forward_and_reverse_energy_analysis(self) -> dict[str, list[dict[str, Un fraction of data. """ - forward_reverse: dict[str, list[dict[str, Union[npt.NDArray, unit.Quantity]]]] = {} + forward_reverse: dict[str, list[dict[str, npt.NDArray | unit.Quantity]]] = {} - for key in ['solvent', 'vacuum']: - forward_reverse[key] = [ - pus[0].outputs['forward_and_reverse_energies'] - for pus in self.data[key].values() - ] + for key in ["solvent", "vacuum"]: + forward_reverse[key] = [pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values()] return forward_reverse @@ -227,11 +231,8 @@ def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: # Loop through and get the repeats and get the matrices overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - for key in ['solvent', 'vacuum']: - overlap_stats[key] = [ - pus[0].outputs['unit_mbar_overlap'] - for pus in self.data[key].values() - ] + for key in ["solvent", "vacuum"]: + overlap_stats[key] = [pus[0].outputs["unit_mbar_overlap"] for pus in self.data[key].values()] return overlap_stats @@ -261,14 +262,10 @@ def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDAr """ repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} try: - for key in ['solvent', 'vacuum']: - repex_stats[key] = [ - pus[0].outputs['replica_exchange_statistics'] - for pus in self.data[key].values() - ] + for key in ["solvent", "vacuum"]: + repex_stats[key] = [pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values()] except KeyError: - errmsg = ("Replica exchange statistics were not found, " - "did you run a repex calculation?") + errmsg = "Replica exchange statistics were not found, " "did you run a repex calculation?" raise ValueError(errmsg) return repex_stats @@ -284,9 +281,7 @@ def get_replica_states(self) -> dict[str, list[npt.NDArray]]: the thermodynamic cycle, with lists of replica states timeseries for each repeat of that simulation type. """ - replica_states: dict[str, list[npt.NDArray]] = { - 'solvent': [], 'vacuum': [] - } + replica_states: dict[str, list[npt.NDArray]] = {"solvent": [], "vacuum": []} def is_file(filename: str): p = pathlib.Path(filename) @@ -302,20 +297,18 @@ def get_replica_state(nc, chk): dir_path = nc.parents[0] chk = is_file(dir_path / chk).name - reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode='r' - ) + reporter = multistate.MultiStateReporter(storage=nc, checkpoint_storage=chk, open_mode="r") retval = np.asarray(reporter.read_replica_thermodynamic_states()) reporter.close() return retval - for key in ['solvent', 'vacuum']: + for key in ["solvent", "vacuum"]: for pus in self.data[key].values(): states = get_replica_state( - pus[0].outputs['nc'], - pus[0].outputs['last_checkpoint'], + pus[0].outputs["nc"], + pus[0].outputs["last_checkpoint"], ) replica_states[key].append(states) @@ -335,11 +328,8 @@ def equilibration_iterations(self) -> dict[str, list[float]]: """ equilibration_lengths: dict[str, list[float]] = {} - for key in ['solvent', 'vacuum']: - equilibration_lengths[key] = [ - pus[0].outputs['equilibration_iterations'] - for pus in self.data[key].values() - ] + for key in ["solvent", "vacuum"]: + equilibration_lengths[key] = [pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values()] return equilibration_lengths @@ -359,11 +349,8 @@ def production_iterations(self) -> dict[str, list[float]]: """ production_lengths: dict[str, list[float]] = {} - for key in ['solvent', 'vacuum']: - production_lengths[key] = [ - pus[0].outputs['production_iterations'] - for pus in self.data[key].values() - ] + for key in ["solvent", "vacuum"]: + production_lengths[key] = [pus[0].outputs["production_iterations"] for pus in self.data[key].values()] return production_lengths @@ -380,6 +367,7 @@ class AbsoluteSolvationProtocol(gufe.Protocol): openfe.protocols.openmm_afe.AbsoluteSolvationVacuumUnit openfe.protocols.openmm_afe.AbsoluteSolvationSolventUnit """ + result_cls = AbsoluteSolvationProtocolResult _settings: AbsoluteSolvationSettings @@ -400,7 +388,7 @@ def _default_settings(cls): protocol_repeats=3, solvent_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), vacuum_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings( - nonbonded_method='nocutoff', + nonbonded_method="nocutoff", ), thermo_settings=settings.ThermoSettings( temperature=298.15 * unit.kelvin, @@ -408,13 +396,8 @@ def _default_settings(cls): ), alchemical_settings=AlchemicalSettings(), lambda_settings=LambdaSettings( - lambda_elec=[ - 0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 - ], - lambda_vdw=[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24, - 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0], + lambda_elec=[0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + lambda_vdw=[0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24, 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0], ), partial_charge_settings=OpenFFPartialChargeSettings(), solvation_settings=OpenMMSolvationSettings(), @@ -427,8 +410,8 @@ def _default_settings(cls): production_length=10.0 * unit.nanosecond, ), solvent_output_settings=OutputSettings( - output_filename='solvent.nc', - checkpoint_storage_filename='solvent_checkpoint.nc', + output_filename="solvent.nc", + checkpoint_storage_filename="solvent_checkpoint.nc", ), vacuum_simulation_settings=MultiStateSimulationSettings( n_replicas=14, @@ -436,14 +419,15 @@ def _default_settings(cls): production_length=2.0 * unit.nanosecond, ), vacuum_output_settings=OutputSettings( - output_filename='vacuum.nc', - checkpoint_storage_filename='vacuum_checkpoint.nc' + output_filename="vacuum.nc", + checkpoint_storage_filename="vacuum_checkpoint.nc", ), ) @staticmethod def _validate_solvent_endstates( - stateA: ChemicalSystem, stateB: ChemicalSystem, + stateA: ChemicalSystem, + stateB: ChemicalSystem, ) -> None: """ A solvent transformation is defined (in terms of gufe components) @@ -468,27 +452,20 @@ def _validate_solvent_endstates( # Check that there are no protein components for comp in itertools.chain(stateA.values(), stateB.values()): if isinstance(comp, ProteinComponent): - errmsg = ("Protein components are not allowed for " - "absolute solvation free energies") + errmsg = "Protein components are not allowed for " "absolute solvation free energies" raise ValueError(errmsg) # check that there is a solvent component - if not any( - isinstance(comp, SolventComponent) for comp in stateA.values() - ): + if not any(isinstance(comp, SolventComponent) for comp in stateA.values()): errmsg = "No SolventComponent found in stateA" raise ValueError(errmsg) - if not any( - isinstance(comp, SolventComponent) for comp in stateB.values() - ): + if not any(isinstance(comp, SolventComponent) for comp in stateB.values()): errmsg = "No SolventComponent found in stateB" raise ValueError(errmsg) @staticmethod - def _validate_alchemical_components( - alchemical_components: dict[str, list[Component]] - ) -> None: + def _validate_alchemical_components(alchemical_components: dict[str, list[Component]]) -> None: """ Checks that the ChemicalSystem alchemical components are correct. @@ -515,28 +492,25 @@ def _validate_alchemical_components( """ # Crash out if there are any alchemical components in state B for now - if len(alchemical_components['stateB']) > 0: - errmsg = ("Components appearing in state B are not " - "currently supported") + if len(alchemical_components["stateB"]) > 0: + errmsg = "Components appearing in state B are not " "currently supported" raise ValueError(errmsg) - if len(alchemical_components['stateA']) > 1: - errmsg = ("More than one alchemical components is not supported " - "for absolute solvation free energies") + if len(alchemical_components["stateA"]) > 1: + errmsg = "More than one alchemical components is not supported " "for absolute solvation free energies" raise ValueError(errmsg) # Crash out if any of the alchemical components are not # SmallMoleculeComponent - for comp in alchemical_components['stateA']: + for comp in alchemical_components["stateA"]: if not isinstance(comp, SmallMoleculeComponent): - errmsg = ("Non SmallMoleculeComponent alchemical species " - "are not currently supported") + errmsg = "Non SmallMoleculeComponent alchemical species " "are not currently supported" raise ValueError(errmsg) @staticmethod def _validate_lambda_schedule( - lambda_settings: LambdaSettings, - simulation_settings: MultiStateSimulationSettings, + lambda_settings: LambdaSettings, + simulation_settings: MultiStateSimulationSettings, ) -> None: """ Checks that the lambda schedule is set up correctly. @@ -571,14 +545,16 @@ def _validate_lambda_schedule( errmsg = ( "Components elec and vdw must have equal amount" f" of lambda windows. Got {len(lambda_elec)} elec lambda" - f" windows and {len(lambda_vdw)} vdw lambda windows.") + f" windows and {len(lambda_vdw)} vdw lambda windows." + ) raise ValueError(errmsg) # Ensure that number of overall lambda windows matches number of lambda # windows for individual components if n_replicas != len(lambda_vdw): - errmsg = (f"Number of replicas {n_replicas} does not equal the" - f" number of lambda windows {len(lambda_vdw)}") + errmsg = ( + f"Number of replicas {n_replicas} does not equal the" f" number of lambda windows {len(lambda_vdw)}" + ) raise ValueError(errmsg) # Check if there are lambda windows with naked charges @@ -588,15 +564,18 @@ def _validate_lambda_schedule( "There are states along this lambda schedule " "where there are atoms with charges but no LJ " f"interactions: lambda {inx}: " - f"elec {lam} vdW {lambda_vdw[inx]}") + f"elec {lam} vdW {lambda_vdw[inx]}" + ) raise ValueError(errmsg) # Check if there are lambda windows with non-zero restraints if len([r for r in lambda_restraints if r != 0]) > 0: - wmsg = ("Non-zero restraint lambdas applied. The absolute " - "solvation protocol doesn't apply restraints, " - "therefore restraints won't be applied. " - f"Given lambda_restraints: {lambda_restraints}") + wmsg = ( + "Non-zero restraint lambdas applied. The absolute " + "solvation protocol doesn't apply restraints, " + "therefore restraints won't be applied. " + f"Given lambda_restraints: {lambda_restraints}" + ) logger.warning(wmsg) warnings.warn(wmsg) @@ -604,8 +583,8 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None, + extends: gufe.ProtocolDAGResult | None = None, ) -> list[gufe.ProtocolUnit]: # TODO: extensions if extends: # pragma: no-cover @@ -614,15 +593,14 @@ def _create( # Validate components and get alchemical components self._validate_solvent_endstates(stateA, stateB) alchem_comps = system_validation.get_alchemical_components( - stateA, stateB, + stateA, + stateB, ) self._validate_alchemical_components(alchem_comps) # Validate the lambda schedule - self._validate_lambda_schedule(self.settings.lambda_settings, - self.settings.solvent_simulation_settings) - self._validate_lambda_schedule(self.settings.lambda_settings, - self.settings.vacuum_simulation_settings) + self._validate_lambda_schedule(self.settings.lambda_settings, self.settings.solvent_simulation_settings) + self._validate_lambda_schedule(self.settings.lambda_settings, self.settings.vacuum_simulation_settings) # Check nonbond & solvent compatibility solv_nonbonded_method = self.settings.solvent_forcefield_settings.nonbonded_method @@ -630,14 +608,16 @@ def _create( # Use the more complete system validation solvent checks system_validation.validate_solvent(stateA, solv_nonbonded_method) # Gas phase is always gas phase - if vac_nonbonded_method.lower() != 'nocutoff': - errmsg = ("Only the nocutoff nonbonded_method is supported for " - f"vacuum calculations, {vac_nonbonded_method} was " - "passed") + if vac_nonbonded_method.lower() != "nocutoff": + errmsg = ( + "Only the nocutoff nonbonded_method is supported for " + f"vacuum calculations, {vac_nonbonded_method} was " + "passed" + ) raise ValueError(errmsg) # Get the name of the alchemical species - alchname = alchem_comps['stateA'][0].name + alchname = alchem_comps["stateA"][0].name # Create list units for vacuum and solvent transforms @@ -647,9 +627,9 @@ def _create( stateA=stateA, stateB=stateB, alchemical_components=alchem_comps, - generation=0, repeat_id=int(uuid.uuid4()), - name=(f"Absolute Solvation, {alchname} solvent leg: " - f"repeat {i} generation 0"), + generation=0, + repeat_id=int(uuid.uuid4()), + name=(f"Absolute Solvation, {alchname} solvent leg: " f"repeat {i} generation 0"), ) for i in range(self.settings.protocol_repeats) ] @@ -662,18 +642,16 @@ def _create( stateA=stateA, stateB=stateB, alchemical_components=alchem_comps, - generation=0, repeat_id=int(uuid.uuid4()), - name=(f"Absolute Solvation, {alchname} vacuum leg: " - f"repeat {i} generation 0"), + generation=0, + repeat_id=int(uuid.uuid4()), + name=(f"Absolute Solvation, {alchname} vacuum leg: " f"repeat {i} generation 0"), ) for i in range(self.settings.protocol_repeats) ] return solvent_units + vacuum_units - def _gather( - self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] - ) -> dict[str, dict[str, Any]]: + def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, dict[str, Any]]: # result units will have a repeat_id and generation # first group according to repeat_id unsorted_solvent_repeats = defaultdict(list) @@ -683,19 +661,20 @@ def _gather( for pu in d.protocol_unit_results: if not pu.ok(): continue - if pu.outputs['simtype'] == 'solvent': - unsorted_solvent_repeats[pu.outputs['repeat_id']].append(pu) + if pu.outputs["simtype"] == "solvent": + unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu) else: - unsorted_vacuum_repeats[pu.outputs['repeat_id']].append(pu) + unsorted_vacuum_repeats[pu.outputs["repeat_id"]].append(pu) repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { - 'solvent': {}, 'vacuum': {}, + "solvent": {}, + "vacuum": {}, } for k, v in unsorted_solvent_repeats.items(): - repeats['solvent'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) for k, v in unsorted_vacuum_repeats.items(): - repeats['vacuum'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + repeats["vacuum"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) return repeats @@ -718,11 +697,10 @@ def _get_components(self): is equivalent to the alchemical components in stateA (since we only allow for disappearing ligands). """ - stateA = self._inputs['stateA'] - alchem_comps = self._inputs['alchemical_components'] + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] - off_comps = {m: m.to_openff() - for m in alchem_comps['stateA']} + off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} _, prot_comp, _ = system_validation.get_components(stateA) @@ -752,40 +730,41 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: * simulation_settings : SimulationSettings * output_settings: OutputSettings """ - prot_settings = self._inputs['protocol'].settings + prot_settings = self._inputs["protocol"].settings settings = {} - settings['forcefield_settings'] = prot_settings.vacuum_forcefield_settings - settings['thermo_settings'] = prot_settings.thermo_settings - settings['charge_settings'] = prot_settings.partial_charge_settings - settings['solvation_settings'] = prot_settings.solvation_settings - settings['alchemical_settings'] = prot_settings.alchemical_settings - settings['lambda_settings'] = prot_settings.lambda_settings - settings['engine_settings'] = prot_settings.vacuum_engine_settings - settings['integrator_settings'] = prot_settings.integrator_settings - settings['simulation_settings'] = prot_settings.vacuum_simulation_settings - settings['output_settings'] = prot_settings.vacuum_output_settings + settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.lambda_settings + settings["engine_settings"] = prot_settings.vacuum_engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["simulation_settings"] = prot_settings.vacuum_simulation_settings + settings["output_settings"] = prot_settings.vacuum_output_settings settings_validation.validate_timestep( - settings['forcefield_settings'].hydrogen_mass, - settings['integrator_settings'].timestep + settings["forcefield_settings"].hydrogen_mass, + settings["integrator_settings"].timestep, ) return settings def _execute( - self, ctx: gufe.Context, **kwargs, + self, + ctx: gufe.Context, + **kwargs, ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) - outputs = self.run(scratch_basepath=ctx.scratch, - shared_basepath=ctx.shared) + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) return { - 'repeat_id': self._inputs['repeat_id'], - 'generation': self._inputs['generation'], - 'simtype': 'vacuum', - **outputs + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + "simtype": "vacuum", + **outputs, } @@ -805,8 +784,8 @@ def _get_components(self): small_mols : dict[SmallMoleculeComponent: OFFMolecule] SmallMoleculeComponents to add to the system. """ - stateA = self._inputs['stateA'] - alchem_comps = self._inputs['alchemical_components'] + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) off_comps = {m: m.to_openff() for m in small_mols} @@ -837,38 +816,39 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: * simulation_settings : MultiStateSimulationSettings * output_settings: OutputSettings """ - prot_settings = self._inputs['protocol'].settings + prot_settings = self._inputs["protocol"].settings settings = {} - settings['forcefield_settings'] = prot_settings.solvent_forcefield_settings - settings['thermo_settings'] = prot_settings.thermo_settings - settings['charge_settings'] = prot_settings.partial_charge_settings - settings['solvation_settings'] = prot_settings.solvation_settings - settings['alchemical_settings'] = prot_settings.alchemical_settings - settings['lambda_settings'] = prot_settings.lambda_settings - settings['engine_settings'] = prot_settings.solvent_engine_settings - settings['integrator_settings'] = prot_settings.integrator_settings - settings['simulation_settings'] = prot_settings.solvent_simulation_settings - settings['output_settings'] = prot_settings.solvent_output_settings + settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.lambda_settings + settings["engine_settings"] = prot_settings.solvent_engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["simulation_settings"] = prot_settings.solvent_simulation_settings + settings["output_settings"] = prot_settings.solvent_output_settings settings_validation.validate_timestep( - settings['forcefield_settings'].hydrogen_mass, - settings['integrator_settings'].timestep + settings["forcefield_settings"].hydrogen_mass, + settings["integrator_settings"].timestep, ) return settings def _execute( - self, ctx: gufe.Context, **kwargs, + self, + ctx: gufe.Context, + **kwargs, ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) - outputs = self.run(scratch_basepath=ctx.scratch, - shared_basepath=ctx.shared) + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) return { - 'repeat_id': self._inputs['repeat_id'], - 'generation': self._inputs['generation'], - 'simtype': 'solvent', - **outputs + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + "simtype": "solvent", + **outputs, } diff --git a/openfe/protocols/openmm_md/__init__.py b/openfe/protocols/openmm_md/__init__.py index c157ecbec..a9b52d58d 100644 --- a/openfe/protocols/openmm_md/__init__.py +++ b/openfe/protocols/openmm_md/__init__.py @@ -5,12 +5,7 @@ """ -from .plain_md_methods import ( - PlainMDProtocol, - PlainMDProtocolSettings, - PlainMDProtocolResult, - PlainMDProtocolUnit, -) +from .plain_md_methods import PlainMDProtocol, PlainMDProtocolResult, PlainMDProtocolSettings, PlainMDProtocolUnit __all__ = [ "PlainMDProtocol", diff --git a/openfe/protocols/openmm_md/plain_md_methods.py b/openfe/protocols/openmm_md/plain_md_methods.py index 92df9dcb9..0fab218c9 100644 --- a/openfe/protocols/openmm_md/plain_md_methods.py +++ b/openfe/protocols/openmm_md/plain_md_methods.py @@ -11,43 +11,38 @@ from __future__ import annotations import logging - +import pathlib +import time +import uuid from collections import defaultdict +from collections.abc import Iterable +from typing import Any, Optional + import gufe +import mdtraj +import numpy as np import openmm +import openmm.unit as omm_unit +from gufe import ChemicalSystem, ProteinComponent, SmallMoleculeComponent, SolventComponent, settings +from mdtraj.reporters import XTCReporter +from openff.toolkit.topology import Molecule as OFFMolecule from openff.units import unit from openff.units.openmm import from_openmm, to_openmm -import openmm.unit as omm_unit -from typing import Optional from openmm import app -import pathlib -from typing import Any, Iterable -import uuid -import time -import numpy as np -import mdtraj -from mdtraj.reporters import XTCReporter -from openfe.utils import without_oechem_backend, log_system_probe -from gufe import ( - settings, ChemicalSystem, SmallMoleculeComponent, - ProteinComponent, SolventComponent -) -from openfe.protocols.openmm_utils.omm_settings import ( - BasePartialChargeSettings, -) + from openfe.protocols.openmm_md.plain_md_settings import ( - PlainMDProtocolSettings, + IntegratorSettings, + MDOutputSettings, + MDSimulationSettings, OpenFFPartialChargeSettings, - OpenMMSolvationSettings, OpenMMEngineSettings, - IntegratorSettings, MDSimulationSettings, MDOutputSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, + PlainMDProtocolSettings, ) -from openff.toolkit.topology import Molecule as OFFMolecule - from openfe.protocols.openmm_rfe._rfe_utils import compute -from openfe.protocols.openmm_utils import ( - system_validation, settings_validation, system_creation, - charge_generation, -) +from openfe.protocols.openmm_utils import charge_generation, settings_validation, system_creation, system_validation +from openfe.protocols.openmm_utils.omm_settings import BasePartialChargeSettings +from openfe.utils import log_system_probe, without_oechem_backend logger = logging.getLogger(__name__) @@ -55,6 +50,7 @@ class PlainMDProtocolResult(gufe.ProtocolResult): """Dict-like container for the output of a PlainMDProtocol outputs filenames for the pdb file and trajectory""" + def __init__(self, **data): super().__init__(**data) # data is mapping of str(repeat_id): list[protocolunitresults] @@ -85,7 +81,7 @@ def get_traj_filename(self) -> list[pathlib.Path]: traj : list[pathlib.Path] list of paths (pathlib.Path) to the simulation trajectory """ - traj = [pus[0].outputs['nc'] for pus in self.data.values()] + traj = [pus[0].outputs["nc"] for pus in self.data.values()] return traj @@ -98,7 +94,7 @@ def get_pdb_filename(self) -> list[pathlib.Path]: pdbs : list[pathlib.Path] list of paths (pathlib.Path) to the pdb files """ - pdbs = [pus[0].outputs['system_pdb'] for pus in self.data.values()] + pdbs = [pus[0].outputs["system_pdb"] for pus in self.data.values()] return pdbs @@ -140,11 +136,11 @@ def _default_settings(cls): ) def _create( - self, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: Optional[dict[str, gufe.ComponentMapping]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: dict[str, gufe.ComponentMapping] | None = None, + extends: gufe.ProtocolDAGResult | None = None, ) -> list[gufe.ProtocolUnit]: # TODO: Extensions? if extends: @@ -167,25 +163,27 @@ def _create( if comp is not None: comp_type = comp.__class__.__name__ if len(comp.name) == 0: - comp_name = 'NoName' + comp_name = "NoName" else: comp_name = comp.name system_name += f" {comp_type}:{comp_name}" # our DAG has no dependencies, so just list units n_repeats = self.settings.protocol_repeats - units = [PlainMDProtocolUnit( - protocol=self, - stateA=stateA, - generation=0, repeat_id=int(uuid.uuid4()), - name=f'{system_name} repeat {i} generation 0') - for i in range(n_repeats)] + units = [ + PlainMDProtocolUnit( + protocol=self, + stateA=stateA, + generation=0, + repeat_id=int(uuid.uuid4()), + name=f"{system_name} repeat {i} generation 0", + ) + for i in range(n_repeats) + ] return units - def _gather( - self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] - ) -> dict[str, Any]: + def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]: # result units will have a repeat_id and generations within this # repeat_id # first group according to repeat_id @@ -196,12 +194,12 @@ def _gather( if not pu.ok(): continue - unsorted_repeats[pu.outputs['repeat_id']].append(pu) + unsorted_repeats[pu.outputs["repeat_id"]].append(pu) # then sort by generation within each repeat_id list repeats: dict[str, list[gufe.ProtocolUnitResult]] = {} for k, v in unsorted_repeats.items(): - repeats[str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) # returns a dict of repeat_id: sorted list of ProtocolUnitResult return repeats @@ -219,7 +217,7 @@ def __init__( stateA: ChemicalSystem, generation: int, repeat_id: int, - name: Optional[str] = None, + name: str | None = None, ): """ Parameters @@ -241,28 +239,23 @@ def __init__( The mapping used must not involve any elemental changes. A check for this is done on class creation. """ - super().__init__( - name=name, - protocol=protocol, - stateA=stateA, - repeat_id=repeat_id, - generation=generation - ) + super().__init__(name=name, protocol=protocol, stateA=stateA, repeat_id=repeat_id, generation=generation) @staticmethod - def _run_MD(simulation: openmm.app.Simulation, - positions: omm_unit.Quantity, - simulation_settings: MDSimulationSettings, - output_settings: MDOutputSettings, - temperature: unit.Quantity, - barostat_frequency: unit.Quantity, - timestep: unit.Quantity, - equil_steps_nvt: int, - equil_steps_npt: int, - prod_steps: int, - verbose=True, - shared_basepath=None) -> None: - + def _run_MD( + simulation: openmm.app.Simulation, + positions: omm_unit.Quantity, + simulation_settings: MDSimulationSettings, + output_settings: MDOutputSettings, + temperature: unit.Quantity, + barostat_frequency: unit.Quantity, + timestep: unit.Quantity, + equil_steps_nvt: int, + equil_steps_npt: int, + prod_steps: int, + verbose=True, + shared_basepath=None, + ) -> None: """ Energy minimization, Equilibration and Production MD to be reused in multiple protocols @@ -297,24 +290,20 @@ def _run_MD(simulation: openmm.app.Simulation, """ if shared_basepath is None: - shared_basepath = pathlib.Path('.') + shared_basepath = pathlib.Path(".") simulation.context.setPositions(positions) # minimize if verbose: logger.info("minimizing systems") - simulation.minimizeEnergy( - maxIterations=simulation_settings.minimization_steps - ) + simulation.minimizeEnergy(maxIterations=simulation_settings.minimization_steps) # Get the sub selection of the system to save coords for - selection_indices = mdtraj.Topology.from_openmm( - simulation.topology).select(output_settings.output_indices) + selection_indices = mdtraj.Topology.from_openmm(simulation.topology).select(output_settings.output_indices) - positions = to_openmm(from_openmm( - simulation.context.getState(getPositions=True, - enforcePeriodicBox=False - ).getPositions())) + positions = to_openmm( + from_openmm(simulation.context.getState(getPositions=True, enforcePeriodicBox=False).getPositions()), + ) # Store subset of atoms, specified in input, as PDB file mdtraj_top = mdtraj.Topology.from_openmm(simulation.topology) traj = mdtraj.Trajectory( @@ -322,9 +311,7 @@ def _run_MD(simulation: openmm.app.Simulation, mdtraj_top.subset(selection_indices), ) - traj.save_pdb( - shared_basepath / output_settings.minimized_structure - ) + traj.save_pdb(shared_basepath / output_settings.minimized_structure) # equilibrate # NVT equilibration @@ -333,64 +320,54 @@ def _run_MD(simulation: openmm.app.Simulation, # Set barostat frequency to zero for NVT for x in simulation.context.getSystem().getForces(): - if x.getName() == 'MonteCarloBarostat': + if x.getName() == "MonteCarloBarostat": x.setFrequency(0) - simulation.context.setVelocitiesToTemperature( - to_openmm(temperature)) + simulation.context.setVelocitiesToTemperature(to_openmm(temperature)) t0 = time.time() simulation.step(equil_steps_nvt) t1 = time.time() if verbose: - logger.info( - f"Completed NVT equilibration in {t1 - t0} seconds") + logger.info(f"Completed NVT equilibration in {t1 - t0} seconds") # Save last frame NVT equilibration positions = to_openmm( - from_openmm(simulation.context.getState( - getPositions=True, enforcePeriodicBox=False - ).getPositions())) + from_openmm(simulation.context.getState(getPositions=True, enforcePeriodicBox=False).getPositions()), + ) traj = mdtraj.Trajectory( positions[selection_indices, :], mdtraj_top.subset(selection_indices), ) - traj.save_pdb( - shared_basepath / output_settings.equil_NVT_structure - ) + traj.save_pdb(shared_basepath / output_settings.equil_NVT_structure) # NPT equilibration if verbose: logger.info("Running NPT equilibration") - simulation.context.setVelocitiesToTemperature( - to_openmm(temperature)) + simulation.context.setVelocitiesToTemperature(to_openmm(temperature)) # Enable the barostat for NPT for x in simulation.context.getSystem().getForces(): - if x.getName() == 'MonteCarloBarostat': + if x.getName() == "MonteCarloBarostat": x.setFrequency(barostat_frequency.m) t0 = time.time() simulation.step(equil_steps_npt) t1 = time.time() if verbose: - logger.info( - f"Completed NPT equilibration in {t1 - t0} seconds") + logger.info(f"Completed NPT equilibration in {t1 - t0} seconds") # Save last frame NPT equilibration positions = to_openmm( - from_openmm(simulation.context.getState( - getPositions=True, enforcePeriodicBox=False - ).getPositions())) + from_openmm(simulation.context.getState(getPositions=True, enforcePeriodicBox=False).getPositions()), + ) traj = mdtraj.Trajectory( positions[selection_indices, :], mdtraj_top.subset(selection_indices), ) - traj.save_pdb( - shared_basepath / output_settings.equil_NPT_structure - ) + traj.save_pdb(shared_basepath / output_settings.equil_NPT_structure) # production if verbose: @@ -410,26 +387,34 @@ def _run_MD(simulation: openmm.app.Simulation, mc_steps=1, ) - simulation.reporters.append(XTCReporter( - file=str(shared_basepath / output_settings.production_trajectory_filename), - reportInterval=write_interval, - atomSubset=selection_indices)) - simulation.reporters.append(openmm.app.CheckpointReporter( - file=str(shared_basepath / output_settings.checkpoint_storage_filename), - reportInterval=checkpoint_interval)) - simulation.reporters.append(openmm.app.StateDataReporter( - str(shared_basepath / output_settings.log_output), - checkpoint_interval, - step=True, - time=True, - potentialEnergy=True, - kineticEnergy=True, - totalEnergy=True, - temperature=True, - volume=True, - density=True, - speed=True, - )) + simulation.reporters.append( + XTCReporter( + file=str(shared_basepath / output_settings.production_trajectory_filename), + reportInterval=write_interval, + atomSubset=selection_indices, + ), + ) + simulation.reporters.append( + openmm.app.CheckpointReporter( + file=str(shared_basepath / output_settings.checkpoint_storage_filename), + reportInterval=checkpoint_interval, + ), + ) + simulation.reporters.append( + openmm.app.StateDataReporter( + str(shared_basepath / output_settings.log_output), + checkpoint_interval, + step=True, + time=True, + potentialEnergy=True, + kineticEnergy=True, + totalEnergy=True, + temperature=True, + volume=True, + density=True, + speed=True, + ), + ) t0 = time.time() simulation.step(prod_steps) t1 = time.time() @@ -464,9 +449,7 @@ def _assign_partial_charges( nagl_model=charge_settings.nagl_model, ) - def run(self, *, dry=False, verbose=True, - scratch_basepath=None, - shared_basepath=None) -> dict[str, Any]: + def run(self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None) -> dict[str, Any]: """Run the MD simulation. Parameters @@ -498,13 +481,13 @@ def run(self, *, dry=False, verbose=True, self.logger.info("Creating system") if shared_basepath is None: # use cwd - shared_basepath = pathlib.Path('.') + shared_basepath = pathlib.Path(".") # 0. General setup and settings dependency resolution step # Extract relevant settings - protocol_settings: PlainMDProtocolSettings = self._inputs['protocol'].settings - stateA = self._inputs['stateA'] + protocol_settings: PlainMDProtocolSettings = self._inputs["protocol"].settings + stateA = self._inputs["stateA"] forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings @@ -516,21 +499,22 @@ def run(self, *, dry=False, verbose=True, integrator_settings = protocol_settings.integrator_settings # is the timestep good for the mass? - settings_validation.validate_timestep( - forcefield_settings.hydrogen_mass, timestep - ) + settings_validation.validate_timestep(forcefield_settings.hydrogen_mass, timestep) equil_steps_nvt = settings_validation.get_simsteps( sim_length=sim_settings.equilibration_length_nvt, - timestep=timestep, mc_steps=1, + timestep=timestep, + mc_steps=1, ) equil_steps_npt = settings_validation.get_simsteps( sim_length=sim_settings.equilibration_length, - timestep=timestep, mc_steps=1, + timestep=timestep, + mc_steps=1, ) prod_steps = settings_validation.get_simsteps( sim_length=sim_settings.production_length, - timestep=timestep, mc_steps=1, + timestep=timestep, + mc_steps=1, ) solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) @@ -564,9 +548,7 @@ def run(self, *, dry=False, verbose=True, # Force creation of smc templates so we can solvate later for mol in smc_components.values(): - system_generator.create_system( - mol.to_topology().to_openmm(), molecules=[mol] - ) + system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) # c. get OpenMM Modeller + a resids dictionary for each component stateA_modeller, comp_resids = system_creation.get_omm_modeller( @@ -580,9 +562,7 @@ def run(self, *, dry=False, verbose=True, # d. get topology & positions # Note: roundtrip positions to remove vec3 issues stateA_topology = stateA_modeller.getTopology() - stateA_positions = to_openmm( - from_openmm(stateA_modeller.getPositions()) - ) + stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) # e. create the stateA System stateA_system = system_generator.create_system( @@ -592,14 +572,10 @@ def run(self, *, dry=False, verbose=True, # f. Save pdb of entire system with open(shared_basepath / output_settings.preminimized_structure, "w") as f: - openmm.app.PDBFile.writeFile( - stateA_topology, stateA_positions, file=f, keepIds=True - ) + openmm.app.PDBFile.writeFile(stateA_topology, stateA_positions, file=f, keepIds=True) # 10. Get platform - platform = compute.get_openmm_platform( - protocol_settings.engine_settings.compute_platform - ) + platform = compute.get_openmm_platform(protocol_settings.engine_settings.compute_platform) # 11. Set the integrator integrator = openmm.LangevinMiddleIntegrator( @@ -608,12 +584,7 @@ def run(self, *, dry=False, verbose=True, to_openmm(timestep), ) - simulation = openmm.app.Simulation( - stateA_modeller.topology, - stateA_system, - integrator, - platform=platform - ) + simulation = openmm.app.Simulation(stateA_modeller.topology, stateA_system, integrator, platform=platform) try: @@ -639,26 +610,23 @@ def run(self, *, dry=False, verbose=True, if not dry: # pragma: no-cover return { - 'system_pdb': shared_basepath / output_settings.preminimized_structure, - 'minimized_pdb': shared_basepath / output_settings.minimized_structure, - 'nvt_equil_pdb': shared_basepath / output_settings.equil_NVT_structure, - 'npt_equil_pdb': shared_basepath / output_settings.equil_NPT_structure, - 'nc': shared_basepath / output_settings.production_trajectory_filename, - 'last_checkpoint': shared_basepath / output_settings.checkpoint_storage_filename, + "system_pdb": shared_basepath / output_settings.preminimized_structure, + "minimized_pdb": shared_basepath / output_settings.minimized_structure, + "nvt_equil_pdb": shared_basepath / output_settings.equil_NVT_structure, + "npt_equil_pdb": shared_basepath / output_settings.equil_NPT_structure, + "nc": shared_basepath / output_settings.production_trajectory_filename, + "last_checkpoint": shared_basepath / output_settings.checkpoint_storage_filename, } else: - return {'debug': {'system': stateA_system}} + return {"debug": {"system": stateA_system}} def _execute( - self, ctx: gufe.Context, **kwargs, + self, + ctx: gufe.Context, + **kwargs, ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) - outputs = self.run(scratch_basepath=ctx.scratch, - shared_basepath=ctx.shared) + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) - return { - 'repeat_id': self._inputs['repeat_id'], - 'generation': self._inputs['generation'], - **outputs - } + return {"repeat_id": self._inputs["repeat_id"], "generation": self._inputs["generation"], **outputs} diff --git a/openfe/protocols/openmm_md/plain_md_settings.py b/openfe/protocols/openmm_md/plain_md_settings.py index 594a49d36..c1400cf89 100644 --- a/openfe/protocols/openmm_md/plain_md_settings.py +++ b/openfe/protocols/openmm_md/plain_md_settings.py @@ -7,15 +7,18 @@ :class:`openfe.protocols.openmm_md.plain_md_methods.py` """ +from gufe.settings import SettingsBaseModel + from openfe.protocols.openmm_utils.omm_settings import ( - Settings, - OpenMMSolvationSettings, - OpenMMEngineSettings, + IntegratorSettings, + MDOutputSettings, MDSimulationSettings, - IntegratorSettings, MDOutputSettings, OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, + Settings, ) -from gufe.settings import SettingsBaseModel + try: from pydantic.v1 import validator except ImportError: @@ -31,7 +34,7 @@ class Config: Number of independent MD runs to perform. """ - @validator('protocol_repeats') + @validator("protocol_repeats") def must_be_positive(cls, v): if v <= 0: errmsg = f"protocol_repeats must be a positive value, got {v}." diff --git a/openfe/protocols/openmm_rfe/__init__.py b/openfe/protocols/openmm_rfe/__init__.py index eb50c0712..226751d77 100644 --- a/openfe/protocols/openmm_rfe/__init__.py +++ b/openfe/protocols/openmm_rfe/__init__.py @@ -2,14 +2,9 @@ # For details, see https://github.com/OpenFreeEnergy/openfe from . import _rfe_utils - -from .equil_rfe_settings import ( - RelativeHybridTopologyProtocolSettings, -) - from .equil_rfe_methods import ( RelativeHybridTopologyProtocol, RelativeHybridTopologyProtocolResult, RelativeHybridTopologyProtocolUnit, ) - +from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/__init__.py b/openfe/protocols/openmm_rfe/_rfe_utils/__init__.py index 81fb48a67..b2166b10f 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/__init__.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/__init__.py @@ -1,7 +1 @@ -from . import ( - compute, - lambdaprotocol, - multistate, - relative, - topologyhelpers, -) +from . import compute, lambdaprotocol, multistate, relative, topologyhelpers diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/compute.py b/openfe/protocols/openmm_rfe/_rfe_utils/compute.py index b3bee28f6..447148409 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/compute.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/compute.py @@ -1,9 +1,8 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe # Adapted Perses' perses.app.setup_relative_calculation.get_openmm_platform -import warnings import logging - +import warnings logger = logging.getLogger(__name__) @@ -28,31 +27,30 @@ def get_openmm_platform(platform_name=None): # No platform is specified, so retrieve fastest platform that supports # 'mixed' precision from openmmtools.utils import get_fastest_platform - platform = get_fastest_platform(minimum_precision='mixed') + + platform = get_fastest_platform(minimum_precision="mixed") else: try: platform_name = { - 'cpu': 'CPU', - 'opencl': 'OpenCL', - 'cuda': 'CUDA', + "cpu": "CPU", + "opencl": "OpenCL", + "cuda": "CUDA", }[str(platform_name).lower()] except KeyError: pass from openmm import Platform + platform = Platform.getPlatformByName(platform_name) # Set precision and properties name = platform.getName() - if name in ['CUDA', 'OpenCL']: - platform.setPropertyDefaultValue( - 'Precision', 'mixed') - if name == 'CUDA': - platform.setPropertyDefaultValue( - 'DeterministicForces', 'true') - - if name != 'CUDA': - wmsg = (f"Non-GPU platform selected: {name}, this may significantly " - "impact simulation performance") + if name in ["CUDA", "OpenCL"]: + platform.setPropertyDefaultValue("Precision", "mixed") + if name == "CUDA": + platform.setPropertyDefaultValue("DeterministicForces", "true") + + if name != "CUDA": + wmsg = f"Non-GPU platform selected: {name}, this may significantly " "impact simulation performance" warnings.warn(wmsg) logging.warning(wmsg) diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py b/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py index a8e764dd2..66869775d 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py @@ -2,13 +2,14 @@ # License: MIT # OpenFE note: eventually we aim to move this to openmmtools where possible -import numpy as np -import warnings import copy +import warnings + +import numpy as np from openmmtools.alchemy import AlchemicalState -class LambdaProtocol(object): +class LambdaProtocol: """Protocols for perturbing each of the component energy terms in alchemical free energy simulations. @@ -17,29 +18,21 @@ class LambdaProtocol(object): * Class needs cleaning up and made more consistent """ - default_functions = {'lambda_sterics_core': - lambda x: x, - 'lambda_electrostatics_core': - lambda x: x, - 'lambda_sterics_insert': - lambda x: 2.0 * x if x < 0.5 else 1.0, - 'lambda_sterics_delete': - lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), - 'lambda_electrostatics_insert': - lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), - 'lambda_electrostatics_delete': - lambda x: 2.0 * x if x < 0.5 else 1.0, - 'lambda_bonds': - lambda x: x, - 'lambda_angles': - lambda x: x, - 'lambda_torsions': - lambda x: x - } + default_functions = { + "lambda_sterics_core": lambda x: x, + "lambda_electrostatics_core": lambda x: x, + "lambda_sterics_insert": lambda x: 2.0 * x if x < 0.5 else 1.0, + "lambda_sterics_delete": lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), + "lambda_electrostatics_insert": lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), + "lambda_electrostatics_delete": lambda x: 2.0 * x if x < 0.5 else 1.0, + "lambda_bonds": lambda x: x, + "lambda_angles": lambda x: x, + "lambda_torsions": lambda x: x, + } # lambda components for each component, # all run from 0 -> 1 following master lambda - def __init__(self, functions='default', windows=10, lambda_schedule=None): + def __init__(self, functions="default", windows=10, lambda_schedule=None): """Instantiates lambda protocol to be used in a free energy calculation. Can either be user defined, by passing in a dict, or using one of the pregenerated sets by passing in a string 'default', 'namd' @@ -92,53 +85,51 @@ def __init__(self, functions='default', windows=10, lambda_schedule=None): self.functions = copy.deepcopy(functions) # set the lambda schedule - self.lambda_schedule = self._validate_schedule(lambda_schedule, - windows) + self.lambda_schedule = self._validate_schedule(lambda_schedule, windows) if lambda_schedule: self.lambda_schedule = lambda_schedule else: - self.lambda_schedule = np.linspace(0., 1., windows) + self.lambda_schedule = np.linspace(0.0, 1.0, windows) if type(self.functions) == dict: - self.type = 'user-defined' + self.type = "user-defined" elif type(self.functions) == str: self.functions = None # will be set later self.type = functions if self.functions is None: - if self.type == 'default': - self.functions = copy.deepcopy( - LambdaProtocol.default_functions) - elif self.type == 'namd': + if self.type == "default": + self.functions = copy.deepcopy(LambdaProtocol.default_functions) + elif self.type == "namd": self.functions = { - 'lambda_sterics_core': lambda x: x, - 'lambda_electrostatics_core': lambda x: x, - 'lambda_sterics_insert': lambda x: (3. / 2.) * x if x < (2. / 3.) else 1.0, - 'lambda_sterics_delete': lambda x: 0.0 if x < (1. / 3.) else (x - (1. / 3.)) * (3. / 2.), - 'lambda_electrostatics_insert': lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), - 'lambda_electrostatics_delete': lambda x: 2.0 * x if x < 0.5 else 1.0, - 'lambda_bonds': lambda x: x, - 'lambda_angles': lambda x: x, - 'lambda_torsions': lambda x: x + "lambda_sterics_core": lambda x: x, + "lambda_electrostatics_core": lambda x: x, + "lambda_sterics_insert": lambda x: (3.0 / 2.0) * x if x < (2.0 / 3.0) else 1.0, + "lambda_sterics_delete": lambda x: 0.0 if x < (1.0 / 3.0) else (x - (1.0 / 3.0)) * (3.0 / 2.0), + "lambda_electrostatics_insert": lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), + "lambda_electrostatics_delete": lambda x: 2.0 * x if x < 0.5 else 1.0, + "lambda_bonds": lambda x: x, + "lambda_angles": lambda x: x, + "lambda_torsions": lambda x: x, } - elif self.type == 'quarters': + elif self.type == "quarters": self.functions = { - 'lambda_sterics_core': lambda x: x, - 'lambda_electrostatics_core': lambda x: x, - 'lambda_sterics_insert': lambda x: 0. if x < 0.5 else 1 if x > 0.75 else 4 * (x - 0.5), - 'lambda_sterics_delete': lambda x: 0. if x < 0.25 else 1 if x > 0.5 else 4 * (x - 0.25), - 'lambda_electrostatics_insert': lambda x: 0. if x < 0.75 else 4 * (x - 0.75), - 'lambda_electrostatics_delete': lambda x: 4.0 * x if x < 0.25 else 1.0, - 'lambda_bonds': lambda x: x, - 'lambda_angles': lambda x: x, - 'lambda_torsions': lambda x: x + "lambda_sterics_core": lambda x: x, + "lambda_electrostatics_core": lambda x: x, + "lambda_sterics_insert": lambda x: 0.0 if x < 0.5 else 1 if x > 0.75 else 4 * (x - 0.5), + "lambda_sterics_delete": lambda x: 0.0 if x < 0.25 else 1 if x > 0.5 else 4 * (x - 0.25), + "lambda_electrostatics_insert": lambda x: 0.0 if x < 0.75 else 4 * (x - 0.75), + "lambda_electrostatics_delete": lambda x: 4.0 * x if x < 0.25 else 1.0, + "lambda_bonds": lambda x: x, + "lambda_angles": lambda x: x, + "lambda_torsions": lambda x: x, } - elif self.type == 'ele-scaled': + elif self.type == "ele-scaled": self.functions = { - 'lambda_electrostatics_insert': lambda x: 0.0 if x < 0.5 else ((2*(x-0.5))**0.5), - 'lambda_electrostatics_delete': lambda x: (2*x)**2 if x < 0.5 else 1.0 + "lambda_electrostatics_insert": lambda x: 0.0 if x < 0.5 else ((2 * (x - 0.5)) ** 0.5), + "lambda_electrostatics_delete": lambda x: (2 * x) ** 2 if x < 0.5 else 1.0, } - elif self.type == 'user-defined': + elif self.type == "user-defined": self.functions = functions else: errmsg = f"LambdaProtocol type : {self.type} not recognised " @@ -170,18 +161,17 @@ def _validate_schedule(schedule, windows): A valid lambda schedule. """ if schedule is None: - return np.linspace(0., 1., windows) + return np.linspace(0.0, 1.0, windows) # Check end states if schedule[0] != 0 or schedule[-1] != 1: - errmsg = ("end and start lambda windows must be lambda 0 and 1 " - "respectively") + errmsg = "end and start lambda windows must be lambda 0 and 1 " "respectively" raise ValueError(errmsg) # Check monotonically increasing difference = np.diff(schedule) - if not all(i >= 0. for i in difference): + if not all(i >= 0.0 for i in difference): errmsg = "The lambda schedule is not monotonic" raise ValueError(errmsg) @@ -204,8 +194,7 @@ def _validate_functions(self, n=10): for function in required_functions: if function not in self.functions: # IA switched from warn to error here - errmsg = (f"function {function} is missing from " - "self.lambda_functions.") + errmsg = f"function {function} is missing from " "self.lambda_functions." raise ValueError(errmsg) # Check that the function starts and ends at 0 and 1 respectively @@ -215,14 +204,12 @@ def _validate_functions(self, n=10): raise ValueError("lambda fucntions must end at 1") # now validatate that it's monotonic - global_lambda = np.linspace(0., 1., n) - sub_lambda = [self.functions[function](lam) for - lam in global_lambda] + global_lambda = np.linspace(0.0, 1.0, n) + sub_lambda = [self.functions[function](lam) for lam in global_lambda] difference = np.diff(sub_lambda) - if not all(i >= 0. for i in difference): - wmsg = (f"The function {function} is not monotonic as " - "typically expected.") + if not all(i >= 0.0 for i in difference): + wmsg = f"The function {function} is not monotonic as " "typically expected." warnings.warn(wmsg) def _check_for_naked_charges(self): @@ -240,19 +227,21 @@ def check_overlap(ele, sterics, global_lambda, functions, endstate): ster_val = functions[sterics](lam) # if charge > 0 and sterics == 0 raise error if ele_val != endstate and ster_val == endstate: - errmsg = ("There are states along this lambda schedule " - "where there are atoms with charges but no LJ " - f"interactions: {lam} {ele_val} {ster_val}") + errmsg = ( + "There are states along this lambda schedule " + "where there are atoms with charges but no LJ " + f"interactions: {lam} {ele_val} {ster_val}" + ) raise ValueError(errmsg) # checking unique new terms first - ele = 'lambda_electrostatics_insert' - sterics = 'lambda_sterics_insert' + ele = "lambda_electrostatics_insert" + sterics = "lambda_sterics_insert" check_overlap(ele, sterics, global_lambda, self.functions, endstate=0) # checking unique old terms now - ele = 'lambda_electrostatics_delete' - sterics = 'lambda_sterics_delete' + ele = "lambda_electrostatics_delete" + sterics = "lambda_sterics_delete" check_overlap(ele, sterics, global_lambda, self.functions, endstate=1) def get_functions(self): @@ -275,12 +264,10 @@ def plot_functions(self, lambda_schedule=None): global_lambda = lambda_schedule if lambda_schedule else self.lambda_schedule for f in self.functions: - plt.plot(global_lambda, - [self.functions[f](lam) for lam in global_lambda], - alpha=0.5, label=f) + plt.plot(global_lambda, [self.functions[f](lam) for lam in global_lambda], alpha=0.5, label=f) - plt.xlabel('global lambda') - plt.ylabel('sub-lambda') + plt.xlabel("global lambda") + plt.ylabel("sub-lambda") plt.legend() plt.show() @@ -311,17 +298,14 @@ class RelativeAlchemicalState(AlchemicalState): class _LambdaParameter(AlchemicalState._LambdaParameter): pass - lambda_sterics_core = _LambdaParameter('lambda_sterics_core') - lambda_electrostatics_core = _LambdaParameter('lambda_electrostatics_core') - lambda_sterics_insert = _LambdaParameter('lambda_sterics_insert') - lambda_sterics_delete = _LambdaParameter('lambda_sterics_delete') - lambda_electrostatics_insert = _LambdaParameter( - 'lambda_electrostatics_insert') - lambda_electrostatics_delete = _LambdaParameter( - 'lambda_electrostatics_delete') - - def set_alchemical_parameters(self, global_lambda, - lambda_protocol=LambdaProtocol()): + lambda_sterics_core = _LambdaParameter("lambda_sterics_core") + lambda_electrostatics_core = _LambdaParameter("lambda_electrostatics_core") + lambda_sterics_insert = _LambdaParameter("lambda_sterics_insert") + lambda_sterics_delete = _LambdaParameter("lambda_sterics_delete") + lambda_electrostatics_insert = _LambdaParameter("lambda_electrostatics_insert") + lambda_electrostatics_delete = _LambdaParameter("lambda_electrostatics_delete") + + def set_alchemical_parameters(self, global_lambda, lambda_protocol=LambdaProtocol()): """Set each lambda value according to the lambda_functions protocol. The undefined parameters (i.e. those being set to None) remain undefined. diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index a269b1b5b..6545b0707 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -7,24 +7,24 @@ """ import copy -import warnings import logging +import warnings + import numpy as np import openmm +import openmmtools.states as states from openmm import unit -from openmmtools.multistate import replicaexchange, sams, multistatesampler from openmmtools import cache -import openmmtools.states as states -from openmmtools.states import (CompoundThermodynamicState, - SamplerState, ThermodynamicState) from openmmtools.integrators import FIREMinimizationIntegrator -from .lambdaprotocol import RelativeAlchemicalState +from openmmtools.multistate import multistatesampler, replicaexchange, sams +from openmmtools.states import CompoundThermodynamicState, SamplerState, ThermodynamicState +from .lambdaprotocol import RelativeAlchemicalState logger = logging.getLogger(__name__) -class HybridCompatibilityMixin(object): +class HybridCompatibilityMixin: """ Mixin that allows the MultistateSampler to accommodate the situation where unsampled endpoints have a different number of degrees of freedom. @@ -32,12 +32,18 @@ class HybridCompatibilityMixin(object): def __init__(self, *args, hybrid_factory=None, **kwargs): self._hybrid_factory = hybrid_factory - super(HybridCompatibilityMixin, self).__init__(*args, **kwargs) - - def setup(self, reporter, lambda_protocol, - temperature=298.15 * unit.kelvin, n_replicas=None, - endstates=True, minimization_steps=100, - minimization_platform="CPU"): + super().__init__(*args, **kwargs) + + def setup( + self, + reporter, + lambda_protocol, + temperature=298.15 * unit.kelvin, + n_replicas=None, + endstates=True, + minimization_steps=100, + minimization_platform="CPU", + ): """ Setup MultistateSampler based on the input lambda protocol and number of replicas. @@ -75,52 +81,51 @@ class creation of LambdaProtocol. lambda_zero_state = RelativeAlchemicalState.from_system(hybrid_system) - thermostate = ThermodynamicState(hybrid_system, - temperature=temperature) - compound_thermostate = CompoundThermodynamicState( - thermostate, - composable_states=[lambda_zero_state]) + thermostate = ThermodynamicState(hybrid_system, temperature=temperature) + compound_thermostate = CompoundThermodynamicState(thermostate, composable_states=[lambda_zero_state]) # create lists for storing thermostates and sampler states thermodynamic_state_list = [] sampler_state_list = [] if n_replicas is None: - msg = (f"setting number of replicas to number of states: {n_states}") + msg = f"setting number of replicas to number of states: {n_states}" warnings.warn(msg) n_replicas = n_states elif n_replicas > n_states: - wmsg = (f"More sampler states: {n_replicas} requested than the " - f"number of available states: {n_states}. Setting " - "the number of replicas to the number of states") + wmsg = ( + f"More sampler states: {n_replicas} requested than the " + f"number of available states: {n_states}. Setting " + "the number of replicas to the number of states" + ) warnings.warn(wmsg) n_replicas = n_states lambda_schedule = lambda_protocol.lambda_schedule if len(lambda_schedule) != n_states: - errmsg = ("length of lambda_schedule must match the number of " - "states, n_states") + errmsg = "length of lambda_schedule must match the number of " "states, n_states" raise ValueError(errmsg) # starting with the hybrid factory positions box = hybrid_system.getDefaultPeriodicBoxVectors() - sampler_state = SamplerState(self._factory.hybrid_positions, - box_vectors=box) + sampler_state = SamplerState(self._factory.hybrid_positions, box_vectors=box) # Loop over the lambdas and create & store a compound thermostate at # that lambda value for lambda_val in lambda_schedule: compound_thermostate_copy = copy.deepcopy(compound_thermostate) - compound_thermostate_copy.set_alchemical_parameters( - lambda_val, lambda_protocol) + compound_thermostate_copy.set_alchemical_parameters(lambda_val, lambda_protocol) thermodynamic_state_list.append(compound_thermostate_copy) # now generating a sampler_state for each thermodyanmic state, # with relaxed positions # Note: remove once choderalab/openmmtools#672 is completed - minimize(compound_thermostate_copy, sampler_state, - max_iterations=minimization_steps, - platform_name=minimization_platform) + minimize( + compound_thermostate_copy, + sampler_state, + max_iterations=minimization_steps, + platform_name=minimization_platform, + ) sampler_state_list.append(copy.deepcopy(sampler_state)) del compound_thermostate, sampler_state @@ -129,11 +134,9 @@ class creation of LambdaProtocol. if len(sampler_state_list) != n_replicas: # picking roughly evenly spaced sampler states # if n_replicas == 1, then it will pick the first in the list - samples = np.linspace(0, len(sampler_state_list) - 1, - n_replicas) + samples = np.linspace(0, len(sampler_state_list) - 1, n_replicas) idx = np.round(samples).astype(int) - sampler_state_list = [state for i, state in - enumerate(sampler_state_list) if i in idx] + sampler_state_list = [state for i, state in enumerate(sampler_state_list) if i in idx] assert len(sampler_state_list) == n_replicas @@ -141,25 +144,30 @@ class creation of LambdaProtocol. # generating unsampled endstates unsampled_dispersion_endstates = create_endstates( copy.deepcopy(thermodynamic_state_list[0]), - copy.deepcopy(thermodynamic_state_list[-1])) - self.create(thermodynamic_states=thermodynamic_state_list, - sampler_states=sampler_state_list, storage=reporter, - unsampled_thermodynamic_states=unsampled_dispersion_endstates) + copy.deepcopy(thermodynamic_state_list[-1]), + ) + self.create( + thermodynamic_states=thermodynamic_state_list, + sampler_states=sampler_state_list, + storage=reporter, + unsampled_thermodynamic_states=unsampled_dispersion_endstates, + ) else: - self.create(thermodynamic_states=thermodynamic_state_list, - sampler_states=sampler_state_list, storage=reporter) + self.create( + thermodynamic_states=thermodynamic_state_list, + sampler_states=sampler_state_list, + storage=reporter, + ) -class HybridRepexSampler(HybridCompatibilityMixin, - replicaexchange.ReplicaExchangeSampler): +class HybridRepexSampler(HybridCompatibilityMixin, replicaexchange.ReplicaExchangeSampler): """ ReplicaExchangeSampler that supports unsampled end states with a different number of positions """ def __init__(self, *args, hybrid_factory=None, **kwargs): - super(HybridRepexSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs) + super().__init__(*args, hybrid_factory=hybrid_factory, **kwargs) self._factory = hybrid_factory @@ -170,22 +178,18 @@ class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler): """ def __init__(self, *args, hybrid_factory=None, **kwargs): - super(HybridSAMSSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs - ) + super().__init__(*args, hybrid_factory=hybrid_factory, **kwargs) self._factory = hybrid_factory -class HybridMultiStateSampler(HybridCompatibilityMixin, - multistatesampler.MultiStateSampler): +class HybridMultiStateSampler(HybridCompatibilityMixin, multistatesampler.MultiStateSampler): """ MultiStateSampler that supports unsample end states with a different number of positions """ + def __init__(self, *args, hybrid_factory=None, **kwargs): - super(HybridMultiStateSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs - ) + super().__init__(*args, hybrid_factory=hybrid_factory, **kwargs) self._factory = hybrid_factory @@ -214,58 +218,56 @@ def create_endstates(first_thermostate, last_thermostate): The corrected unsampled endstates. """ unsampled_endstates = [] - for master_lambda, endstate in zip([0., 1.], - [first_thermostate, last_thermostate]): + for master_lambda, endstate in zip([0.0, 1.0], [first_thermostate, last_thermostate]): dispersion_system = endstate.get_system() energy_unit = unit.kilocalories_per_mole # Find the NonbondedForce (there must be only one) - forces = {force.__class__.__name__: force for - force in dispersion_system.getForces()} + forces = {force.__class__.__name__: force for force in dispersion_system.getForces()} # Set NonbondedForce to use LJPME ljpme = openmm.NonbondedForce.LJPME - forces['NonbondedForce'].setNonbondedMethod(ljpme) + forces["NonbondedForce"].setNonbondedMethod(ljpme) # Set tight PME tolerance TIGHT_PME_TOLERANCE = 1.0e-5 - forces['NonbondedForce'].setEwaldErrorTolerance(TIGHT_PME_TOLERANCE) + forces["NonbondedForce"].setEwaldErrorTolerance(TIGHT_PME_TOLERANCE) # Move alchemical LJ sites from CustomNonbondedForce back to # NonbondedForce - for particle_index in range(forces['NonbondedForce'].getNumParticles()): - charge, sigma, epsilon = forces['NonbondedForce'].getParticleParameters(particle_index) - sigmaA, epsilonA, sigmaB, epsilonB, unique_old, unique_new = forces['CustomNonbondedForce'].getParticleParameters(particle_index) - if (epsilon/energy_unit == 0.0) and ((epsilonA > 0.0) or (epsilonB > 0.0)): - sigma = (1-master_lambda)*sigmaA + master_lambda*sigmaB - epsilon = (1-master_lambda)*epsilonA + master_lambda*epsilonB - forces['NonbondedForce'].setParticleParameters( - particle_index, charge, - sigma, epsilon) + for particle_index in range(forces["NonbondedForce"].getNumParticles()): + charge, sigma, epsilon = forces["NonbondedForce"].getParticleParameters(particle_index) + sigmaA, epsilonA, sigmaB, epsilonB, unique_old, unique_new = forces[ + "CustomNonbondedForce" + ].getParticleParameters(particle_index) + if (epsilon / energy_unit == 0.0) and ((epsilonA > 0.0) or (epsilonB > 0.0)): + sigma = (1 - master_lambda) * sigmaA + master_lambda * sigmaB + epsilon = (1 - master_lambda) * epsilonA + master_lambda * epsilonB + forces["NonbondedForce"].setParticleParameters(particle_index, charge, sigma, epsilon) # Delete the CustomNonbondedForce since we have moved all alchemical # particles out of it for force_index, force in enumerate(list(dispersion_system.getForces())): - if force.__class__.__name__ == 'CustomNonbondedForce': + if force.__class__.__name__ == "CustomNonbondedForce": custom_nonbonded_force_index = force_index break dispersion_system.removeForce(custom_nonbonded_force_index) # Set all parameters to master lambda for force_index, force in enumerate(list(dispersion_system.getForces())): - if hasattr(force, 'getNumGlobalParameters'): + if hasattr(force, "getNumGlobalParameters"): for parameter_index in range(force.getNumGlobalParameters()): - if force.getGlobalParameterName(parameter_index)[0:7] == 'lambda_': - force.setGlobalParameterDefaultValue(parameter_index, - master_lambda) + if force.getGlobalParameterName(parameter_index)[0:7] == "lambda_": + force.setGlobalParameterDefaultValue(parameter_index, master_lambda) # Store the unsampled endstate - unsampled_endstates.append(ThermodynamicState( - dispersion_system, temperature=endstate.temperature)) + unsampled_endstates.append(ThermodynamicState(dispersion_system, temperature=endstate.temperature)) return unsampled_endstates -def minimize(thermodynamic_state: states.ThermodynamicState, - sampler_state: states.SamplerState, - max_iterations: int=100, - platform_name: str="CPU") -> states.SamplerState: +def minimize( + thermodynamic_state: states.ThermodynamicState, + sampler_state: states.SamplerState, + max_iterations: int = 100, + platform_name: str = "CPU", +) -> states.SamplerState: """ Adapted from perses.dispersed.feptasks.minimize @@ -292,16 +294,10 @@ def minimize(thermodynamic_state: states.ThermodynamicState, integrator = openmm.VerletIntegrator(1.0) platform = openmm.Platform.getPlatformByName(platform_name) dummy_cache = cache.DummyContextCache(platform=platform) - context, integrator = dummy_cache.get_context( - thermodynamic_state, integrator - ) + context, integrator = dummy_cache.get_context(thermodynamic_state, integrator) try: - sampler_state.apply_to_context( - context, ignore_velocities=True - ) - openmm.LocalEnergyMinimizer.minimize( - context, maxIterations=max_iterations - ) + sampler_state.apply_to_context(context, ignore_velocities=True) + openmm.LocalEnergyMinimizer.minimize(context, maxIterations=max_iterations) sampler_state.update_from_context(context) finally: del context, integrator, dummy_cache diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/relative.py b/openfe/protocols/openmm_rfe/_rfe_utils/relative.py index db4002ace..c6d6cb46a 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/relative.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/relative.py @@ -3,16 +3,17 @@ # The eventual goal is to move a version of this towards openmmtools # LICENSE: MIT -import logging -import openmm -from openmm import unit, app -import numpy as np import copy import itertools -# OpenMM constant for Coulomb interactions (implicitly in md_unit_system units) -from openmmtools.constants import ONE_4PI_EPS0 +import logging + import mdtraj as mdt +import numpy as np +import openmm +from openmm import app, unit +# OpenMM constant for Coulomb interactions (implicitly in md_unit_system units) +from openmmtools.constants import ONE_4PI_EPS0 logger = logging.getLogger(__name__) @@ -78,15 +79,22 @@ class HybridTopologyFactory: """ - def __init__(self, - old_system, old_positions, old_topology, - new_system, new_positions, new_topology, - old_to_new_atom_map, old_to_new_core_atom_map, - use_dispersion_correction=False, - softcore_alpha=0.5, - softcore_LJ_v2=True, - softcore_LJ_v2_alpha=0.85, - interpolate_old_and_new_14s=False): + def __init__( + self, + old_system, + old_positions, + old_topology, + new_system, + new_positions, + new_topology, + old_to_new_atom_map, + old_to_new_core_atom_map, + use_dispersion_correction=False, + softcore_alpha=0.5, + softcore_LJ_v2=True, + softcore_LJ_v2_alpha=0.85, + interpolate_old_and_new_14s=False, + ): """ Initialize the Hybrid topology factory. @@ -172,10 +180,8 @@ def __init__(self, self._set_atom_classes() # Construct dictionary of exceptions in old and new systems - self._old_system_exceptions = self._generate_dict_from_exceptions( - self._old_system_forces['NonbondedForce']) - self._new_system_exceptions = self._generate_dict_from_exceptions( - self._new_system_forces['NonbondedForce']) + self._old_system_exceptions = self._generate_dict_from_exceptions(self._old_system_forces["NonbondedForce"]) + self._new_system_exceptions = self._generate_dict_from_exceptions(self._new_system_forces["NonbondedForce"]) # check for exceptions clashes between unique and env atoms self._validate_disjoint_sets() @@ -196,8 +202,7 @@ def __init__(self, self._add_torsion_force_terms() - has_nonbonded_force = ('NonbondedForce' in self._old_system_forces or - 'NonbondedForce' in self._new_system_forces) + has_nonbonded_force = "NonbondedForce" in self._old_system_forces or "NonbondedForce" in self._new_system_forces if has_nonbonded_force: self._add_nonbonded_force_terms() @@ -213,8 +218,7 @@ def __init__(self, if has_nonbonded_force: self._handle_nonbonded() - if not (len(self._old_system_exceptions.keys()) == 0 and - len(self._new_system_exceptions.keys()) == 0): + if not (len(self._old_system_exceptions.keys()) == 0 and len(self._new_system_exceptions.keys()) == 0): self._handle_old_new_exceptions() # Get positions for the hybrid @@ -314,30 +318,32 @@ def _check_and_store_system_forces(self): def _check_unknown_forces(forces, system_name): # TODO: double check that CMMotionRemover is ok being here - known_forces = {'HarmonicBondForce', 'HarmonicAngleForce', - 'PeriodicTorsionForce', 'NonbondedForce', - 'MonteCarloBarostat', 'CMMotionRemover'} + known_forces = { + "HarmonicBondForce", + "HarmonicAngleForce", + "PeriodicTorsionForce", + "NonbondedForce", + "MonteCarloBarostat", + "CMMotionRemover", + } force_names = forces.keys() unknown_forces = set(force_names) - set(known_forces) if unknown_forces: - errmsg = (f"Unknown forces {unknown_forces} encountered in " - f"{system_name} system") + errmsg = f"Unknown forces {unknown_forces} encountered in " f"{system_name} system" raise ValueError(errmsg) # Prepare dicts of forces, which will be useful later # TODO: Store this as self._system_forces[name], name in ('old', # 'new', 'hybrid') for compactness - self._old_system_forces = {type(force).__name__: force for force in - self._old_system.getForces()} - _check_unknown_forces(self._old_system_forces, 'old') - self._new_system_forces = {type(force).__name__: force for force in - self._new_system.getForces()} - _check_unknown_forces(self._new_system_forces, 'new') + self._old_system_forces = {type(force).__name__: force for force in self._old_system.getForces()} + _check_unknown_forces(self._old_system_forces, "old") + self._new_system_forces = {type(force).__name__: force for force in self._new_system.getForces()} + _check_unknown_forces(self._new_system_forces, "new") # TODO: check if this is actually used much, otherwise ditch it # Get and store the nonbonded method from the system: - self._nonbonded_method = self._old_system_forces['NonbondedForce'].getNonbondedMethod() + self._nonbonded_method = self._old_system_forces["NonbondedForce"].getNonbondedMethod() def _add_particles(self): """ @@ -360,8 +366,7 @@ def _add_particles(self): if particle_idx in self._old_to_new_map.keys(): particle_idx_new_system = self._old_to_new_map[particle_idx] - mass_new = self._new_system.getParticleMass( - particle_idx_new_system) + mass_new = self._new_system.getParticleMass(particle_idx_new_system) # Take the average of the masses if the atom is mapped particle_mass = (mass_old + mass_new) / 2 else: @@ -393,8 +398,7 @@ def _handle_box(self): # Check that if there is a barostat in the old system, # it is added to the hybrid system if "MonteCarloBarostat" in self._old_system_forces.keys(): - barostat = copy.deepcopy( - self._old_system_forces["MonteCarloBarostat"]) + barostat = copy.deepcopy(self._old_system_forces["MonteCarloBarostat"]) self._hybrid_system.addForce(barostat) # Copy over the box vectors from the old system @@ -407,41 +411,47 @@ def _set_atom_classes(self): unique new, core, or environment, as defined in the class docstring. All indices are indices in the hybrid system. """ - self._atom_classes = {'unique_old_atoms': set(), - 'unique_new_atoms': set(), - 'core_atoms': set(), - 'environment_atoms': set()} + self._atom_classes = { + "unique_old_atoms": set(), + "unique_new_atoms": set(), + "core_atoms": set(), + "environment_atoms": set(), + } # First, find the unique old atoms for atom_idx in self._unique_old_atoms: hybrid_idx = self._old_to_hybrid_map[atom_idx] - self._atom_classes['unique_old_atoms'].add(hybrid_idx) + self._atom_classes["unique_old_atoms"].add(hybrid_idx) # Then the unique new atoms for atom_idx in self._unique_new_atoms: hybrid_idx = self._new_to_hybrid_map[atom_idx] - self._atom_classes['unique_new_atoms'].add(hybrid_idx) + self._atom_classes["unique_new_atoms"].add(hybrid_idx) # The core atoms: for new_idx, old_idx in self._core_new_to_old_map.items(): new_to_hybrid_idx = self._new_to_hybrid_map[new_idx] old_to_hybrid_idx = self._old_to_hybrid_map[old_idx] if new_to_hybrid_idx != old_to_hybrid_idx: - errmsg = (f"there is an index collision in hybrid indices of " - f"the core atom map: {self._core_new_to_old_map}") + errmsg = ( + f"there is an index collision in hybrid indices of " + f"the core atom map: {self._core_new_to_old_map}" + ) raise AssertionError(errmsg) - self._atom_classes['core_atoms'].add(new_to_hybrid_idx) + self._atom_classes["core_atoms"].add(new_to_hybrid_idx) # The environment atoms: for new_idx, old_idx in self._env_new_to_old_map.items(): new_to_hybrid_idx = self._new_to_hybrid_map[new_idx] old_to_hybrid_idx = self._old_to_hybrid_map[old_idx] if new_to_hybrid_idx != old_to_hybrid_idx: - errmsg = (f"there is an index collion in hybrid indices of " - f"the environment atom map: " - f"{self._env_new_to_old_map}") + errmsg = ( + f"there is an index collion in hybrid indices of " + f"the environment atom map: " + f"{self._env_new_to_old_map}" + ) raise AssertionError(errmsg) - self._atom_classes['environment_atoms'].add(new_to_hybrid_idx) + self._atom_classes["environment_atoms"].add(new_to_hybrid_idx) @staticmethod def _generate_dict_from_exceptions(force): @@ -477,31 +487,27 @@ def _validate_disjoint_sets(self): TODO: repeated code - condense """ for old_indices in self._old_system_exceptions.keys(): - hybrid_indices = (self._old_to_hybrid_map[old_indices[0]], - self._old_to_hybrid_map[old_indices[1]]) - old_env_intersection = set(old_indices).intersection( - self._atom_classes['environment_atoms']) + hybrid_indices = (self._old_to_hybrid_map[old_indices[0]], self._old_to_hybrid_map[old_indices[1]]) + old_env_intersection = set(old_indices).intersection(self._atom_classes["environment_atoms"]) if old_env_intersection: - if set(old_indices).intersection( - self._atom_classes['unique_old_atoms'] - ): - errmsg = (f"old index exceptions {old_indices} include " - "unique old and environment atoms, which is " - "disallowed") + if set(old_indices).intersection(self._atom_classes["unique_old_atoms"]): + errmsg = ( + f"old index exceptions {old_indices} include " + "unique old and environment atoms, which is " + "disallowed" + ) raise AssertionError(errmsg) for new_indices in self._new_system_exceptions.keys(): - hybrid_indices = (self._new_to_hybrid_map[new_indices[0]], - self._new_to_hybrid_map[new_indices[1]]) - new_env_intersection = set(hybrid_indices).intersection( - self._atom_classes['environment_atoms']) + hybrid_indices = (self._new_to_hybrid_map[new_indices[0]], self._new_to_hybrid_map[new_indices[1]]) + new_env_intersection = set(hybrid_indices).intersection(self._atom_classes["environment_atoms"]) if new_env_intersection: - if set(hybrid_indices).intersection( - self._atom_classes['unique_new_atoms'] - ): - errmsg = (f"new index exceptions {new_indices} include " - "unique new and environment atoms, which is " - "dissallowed") + if set(hybrid_indices).intersection(self._atom_classes["unique_new_atoms"]): + errmsg = ( + f"new index exceptions {new_indices} include " + "unique new and environment atoms, which is " + "dissallowed" + ) raise AssertionError def _handle_constraints(self): @@ -519,38 +525,34 @@ def _handle_constraints(self): # old system hybrid_map = self._old_to_hybrid_map for const_idx in range(self._old_system.getNumConstraints()): - at1, at2, length = self._old_system.getConstraintParameters( - const_idx) + at1, at2, length = self._old_system.getConstraintParameters(const_idx) hybrid_atoms = tuple(sorted([hybrid_map[at1], hybrid_map[at2]])) if hybrid_atoms not in constraint_lengths.keys(): - self._hybrid_system.addConstraint(hybrid_atoms[0], - hybrid_atoms[1], length) + self._hybrid_system.addConstraint(hybrid_atoms[0], hybrid_atoms[1], length) constraint_lengths[hybrid_atoms] = length else: if constraint_lengths[hybrid_atoms] != length: - raise AssertionError('constraint length is changing') + raise AssertionError("constraint length is changing") # new system hybrid_map = self._new_to_hybrid_map for const_idx in range(self._new_system.getNumConstraints()): - at1, at2, length = self._new_system.getConstraintParameters( - const_idx) + at1, at2, length = self._new_system.getConstraintParameters(const_idx) hybrid_atoms = tuple(sorted([hybrid_map[at1], hybrid_map[at2]])) if hybrid_atoms not in constraint_lengths.keys(): - self._hybrid_system.addConstraint(hybrid_atoms[0], - hybrid_atoms[1], length) + self._hybrid_system.addConstraint(hybrid_atoms[0], hybrid_atoms[1], length) constraint_lengths[hybrid_atoms] = length else: if constraint_lengths[hybrid_atoms] != length: - raise AssertionError('constraint length is changing') + raise AssertionError("constraint length is changing") @staticmethod def _copy_threeparticleavg(atm_map, env_atoms, vs): """ Helper method to copy a ThreeParticleAverageSite virtual site from two mapped Systems. - + Parameters ---------- atm_map : dict[int, int] @@ -559,7 +561,7 @@ def _copy_threeparticleavg(atm_map, env_atoms, vs): A list of environment atoms for the target System. This checks that no alchemical atoms are being tied to. vs : openmm.ThreeParticleAverageSite - + Returns ------- openmm.ThreeParticleAverageSite @@ -570,14 +572,17 @@ def _copy_threeparticleavg(atm_map, env_atoms, vs): particles[i] = atm_map[vs.getParticle(i)] weights[i] = vs.getWeight(i) if not all(i in env_atoms for i in particles.values()): - errmsg = ("Virtual sites bound to non-environment atoms " - "are not supported") + errmsg = "Virtual sites bound to non-environment atoms " "are not supported" raise ValueError(errmsg) return openmm.ThreeParticleAverageSite( - particles[0], particles[1], particles[2], - weights[0], weights[1], weights[2], + particles[0], + particles[1], + particles[2], + weights[0], + weights[1], + weights[2], ) - + def _handle_virtual_sites(self): """ Ensure that all virtual sites in old and new system are copied over to @@ -595,27 +600,22 @@ def _handle_virtual_sites(self): # If it's a virtual site, make sure it is not in the unique or # core atoms, since this is currently unsupported hybrid_idx = self._old_to_hybrid_map[particle_idx] - if hybrid_idx not in self._atom_classes['environment_atoms']: - errmsg = ("Virtual sites in changing residue are " - "unsupported.") + if hybrid_idx not in self._atom_classes["environment_atoms"]: + errmsg = "Virtual sites in changing residue are " "unsupported." raise ValueError(errmsg) else: - virtual_site = self._old_system.getVirtualSite( - particle_idx) - if isinstance( - virtual_site, openmm.ThreeParticleAverageSite): + virtual_site = self._old_system.getVirtualSite(particle_idx) + if isinstance(virtual_site, openmm.ThreeParticleAverageSite): vs_copy = self._copy_threeparticleavg( self._old_to_hybrid_map, - self._atom_classes['environment_atoms'], + self._atom_classes["environment_atoms"], virtual_site, ) else: - errmsg = ("Unsupported VirtualSite " - f"class: {virtual_site}") + errmsg = "Unsupported VirtualSite " f"class: {virtual_site}" raise ValueError(errmsg) - self._hybrid_system.setVirtualSite(hybrid_idx, - vs_copy) + self._hybrid_system.setVirtualSite(hybrid_idx, vs_copy) # new system - there should be nothing left to add # Loop through virtual sites @@ -624,14 +624,12 @@ def _handle_virtual_sites(self): # If it's a virtual site, make sure it is not in the unique or # core atoms, since this is currently unsupported hybrid_idx = self._new_to_hybrid_map[particle_idx] - if hybrid_idx not in self._atom_classes['environment_atoms']: - errmsg = ("Virtual sites in changing residue are " - "unsupported.") + if hybrid_idx not in self._atom_classes["environment_atoms"]: + errmsg = "Virtual sites in changing residue are " "unsupported." raise ValueError(errmsg) else: if not self._hybrid_system.isVirtualSite(hybrid_idx): - errmsg = ("Environment virtual site in new system " - "found not copied from old system") + errmsg = "Environment virtual site in new system " "found not copied from old system" raise ValueError(errmsg) def _add_bond_force_terms(self): @@ -645,29 +643,29 @@ def _add_bond_force_terms(self): ----- * User defined functions have been removed for now. """ - core_energy_expression = '(K/2)*(r-length)^2;' + core_energy_expression = "(K/2)*(r-length)^2;" # linearly interpolate spring constant - core_energy_expression += 'K = (1-lambda_bonds)*K1 + lambda_bonds*K2;' + core_energy_expression += "K = (1-lambda_bonds)*K1 + lambda_bonds*K2;" # linearly interpolate bond length - core_energy_expression += 'length = (1-lambda_bonds)*length1 + lambda_bonds*length2;' + core_energy_expression += "length = (1-lambda_bonds)*length1 + lambda_bonds*length2;" # Create the force and add the relevant parameters custom_core_force = openmm.CustomBondForce(core_energy_expression) - custom_core_force.addPerBondParameter('length1') # old bond length - custom_core_force.addPerBondParameter('K1') # old spring constant - custom_core_force.addPerBondParameter('length2') # new bond length - custom_core_force.addPerBondParameter('K2') # new spring constant + custom_core_force.addPerBondParameter("length1") # old bond length + custom_core_force.addPerBondParameter("K1") # old spring constant + custom_core_force.addPerBondParameter("length2") # new bond length + custom_core_force.addPerBondParameter("K2") # new spring constant - custom_core_force.addGlobalParameter('lambda_bonds', 0.0) + custom_core_force.addGlobalParameter("lambda_bonds", 0.0) self._hybrid_system.addForce(custom_core_force) - self._hybrid_system_forces['core_bond_force'] = custom_core_force + self._hybrid_system_forces["core_bond_force"] = custom_core_force # Add a bond force for environment and unique atoms (bonds are never # scaled for these): standard_bond_force = openmm.HarmonicBondForce() self._hybrid_system.addForce(standard_bond_force) - self._hybrid_system_forces['standard_bond_force'] = standard_bond_force + self._hybrid_system_forces["standard_bond_force"] = standard_bond_force def _add_angle_force_terms(self): """ @@ -680,34 +678,34 @@ def _add_angle_force_terms(self): * User defined functions have been removed for now. * Neglected angle terms have been removed for now. """ - energy_expression = '(K/2)*(theta-theta0)^2;' + energy_expression = "(K/2)*(theta-theta0)^2;" # linearly interpolate spring constant - energy_expression += 'K = (1.0-lambda_angles)*K_1 + lambda_angles*K_2;' + energy_expression += "K = (1.0-lambda_angles)*K_1 + lambda_angles*K_2;" # linearly interpolate equilibrium angle - energy_expression += 'theta0 = (1.0-lambda_angles)*theta0_1 + lambda_angles*theta0_2;' + energy_expression += "theta0 = (1.0-lambda_angles)*theta0_1 + lambda_angles*theta0_2;" # Create the force and add relevant parameters custom_core_force = openmm.CustomAngleForce(energy_expression) # molecule1 equilibrium angle - custom_core_force.addPerAngleParameter('theta0_1') + custom_core_force.addPerAngleParameter("theta0_1") # molecule1 spring constant - custom_core_force.addPerAngleParameter('K_1') + custom_core_force.addPerAngleParameter("K_1") # molecule2 equilibrium angle - custom_core_force.addPerAngleParameter('theta0_2') + custom_core_force.addPerAngleParameter("theta0_2") # molecule2 spring constant - custom_core_force.addPerAngleParameter('K_2') + custom_core_force.addPerAngleParameter("K_2") - custom_core_force.addGlobalParameter('lambda_angles', 0.0) + custom_core_force.addGlobalParameter("lambda_angles", 0.0) # Add the force to the system and the force dict. self._hybrid_system.addForce(custom_core_force) - self._hybrid_system_forces['core_angle_force'] = custom_core_force + self._hybrid_system_forces["core_angle_force"] = custom_core_force # Add an angle term for environment/unique interactions -- these are # never scaled standard_angle_force = openmm.HarmonicAngleForce() self._hybrid_system.addForce(standard_angle_force) - self._hybrid_system_forces['standard_angle_force'] = standard_angle_force + self._hybrid_system_forces["standard_angle_force"] = standard_angle_force def _add_torsion_force_terms(self): """ @@ -722,35 +720,35 @@ def _add_torsion_force_terms(self): add_unique_atom_torsion_force (default True) have been removed for now. """ - energy_expression = '(1-lambda_torsions)*U1 + lambda_torsions*U2;' - energy_expression += 'U1 = K1*(1+cos(periodicity1*theta-phase1));' - energy_expression += 'U2 = K2*(1+cos(periodicity2*theta-phase2));' + energy_expression = "(1-lambda_torsions)*U1 + lambda_torsions*U2;" + energy_expression += "U1 = K1*(1+cos(periodicity1*theta-phase1));" + energy_expression += "U2 = K2*(1+cos(periodicity2*theta-phase2));" # Create the force and add the relevant parameters custom_core_force = openmm.CustomTorsionForce(energy_expression) # molecule1 periodicity - custom_core_force.addPerTorsionParameter('periodicity1') + custom_core_force.addPerTorsionParameter("periodicity1") # molecule1 phase - custom_core_force.addPerTorsionParameter('phase1') + custom_core_force.addPerTorsionParameter("phase1") # molecule1 spring constant - custom_core_force.addPerTorsionParameter('K1') + custom_core_force.addPerTorsionParameter("K1") # molecule2 periodicity - custom_core_force.addPerTorsionParameter('periodicity2') + custom_core_force.addPerTorsionParameter("periodicity2") # molecule2 phase - custom_core_force.addPerTorsionParameter('phase2') + custom_core_force.addPerTorsionParameter("phase2") # molecule2 spring constant - custom_core_force.addPerTorsionParameter('K2') + custom_core_force.addPerTorsionParameter("K2") - custom_core_force.addGlobalParameter('lambda_torsions', 0.0) + custom_core_force.addGlobalParameter("lambda_torsions", 0.0) # Add the force to the system self._hybrid_system.addForce(custom_core_force) - self._hybrid_system_forces['custom_torsion_force'] = custom_core_force + self._hybrid_system_forces["custom_torsion_force"] = custom_core_force # Create and add the torsion term for unique/environment atoms unique_atom_torsion_force = openmm.PeriodicTorsionForce() self._hybrid_system.addForce(unique_atom_torsion_force) - self._hybrid_system_forces['unique_atom_torsion_force'] = unique_atom_torsion_force + self._hybrid_system_forces["unique_atom_torsion_force"] = unique_atom_torsion_force @staticmethod def _nonbonded_custom(v2): @@ -808,7 +806,9 @@ def _nonbonded_custom_sterics_common(): sterics_addition += "reff_sterics = sigma*((softcore_alpha*lambda_alpha + (r/sigma)^6))^(1/6);" sterics_addition += "sigma = (1-lambda_sterics)*sigmaA + lambda_sterics*sigmaB;" - sterics_addition += "lambda_alpha = new_interaction*(1-lambda_sterics_insert) + old_interaction*lambda_sterics_delete;" + sterics_addition += ( + "lambda_alpha = new_interaction*(1-lambda_sterics_insert) + old_interaction*lambda_sterics_delete;" + ) sterics_addition += "lambda_sterics = core_interaction*lambda_sterics_core + new_interaction*lambda_sterics_insert + old_interaction*lambda_sterics_delete;" sterics_addition += "core_interaction = delta(unique_old1+unique_old2+unique_new1+unique_new2);new_interaction = max(unique_new1, unique_new2);old_interaction = max(unique_old1, unique_old2);" @@ -860,9 +860,11 @@ def _translate_nonbonded_method_to_custom(standard_nonbonded_method): custom_nonbonded_method : openmm.CustomNonbondedForce.NonbondedMethod the nonbonded method for the equivalent customnonbonded force """ - if standard_nonbonded_method in [openmm.NonbondedForce.CutoffPeriodic, - openmm.NonbondedForce.PME, - openmm.NonbondedForce.Ewald]: + if standard_nonbonded_method in [ + openmm.NonbondedForce.CutoffPeriodic, + openmm.NonbondedForce.PME, + openmm.NonbondedForce.Ewald, + ]: return openmm.CustomNonbondedForce.CutoffPeriodic elif standard_nonbonded_method == openmm.NonbondedForce.NoCutoff: return openmm.CustomNonbondedForce.NoCutoff @@ -893,28 +895,24 @@ def _add_nonbonded_force_terms(self): # changing. standard_nonbonded_force = openmm.NonbondedForce() self._hybrid_system.addForce(standard_nonbonded_force) - self._hybrid_system_forces['standard_nonbonded_force'] = standard_nonbonded_force + self._hybrid_system_forces["standard_nonbonded_force"] = standard_nonbonded_force # Create a CustomNonbondedForce to handle alchemically interpolated # nonbonded parameters. # Select functional form based on nonbonded method. # TODO: check _nonbonded_custom_ewald and _nonbonded_custom_cutoff # since they take arguments that are never used... - r_cutoff = self._old_system_forces['NonbondedForce'].getCutoffDistance() + r_cutoff = self._old_system_forces["NonbondedForce"].getCutoffDistance() sterics_energy_expression = self._nonbonded_custom(self._softcore_LJ_v2) if self._nonbonded_method in [openmm.NonbondedForce.NoCutoff]: - sterics_energy_expression = self._nonbonded_custom( - self._softcore_LJ_v2) - elif self._nonbonded_method in [openmm.NonbondedForce.CutoffPeriodic, - openmm.NonbondedForce.CutoffNonPeriodic]: - epsilon_solvent = self._old_system_forces['NonbondedForce'].getReactionFieldDielectric() - standard_nonbonded_force.setReactionFieldDielectric( - epsilon_solvent) + sterics_energy_expression = self._nonbonded_custom(self._softcore_LJ_v2) + elif self._nonbonded_method in [openmm.NonbondedForce.CutoffPeriodic, openmm.NonbondedForce.CutoffNonPeriodic]: + epsilon_solvent = self._old_system_forces["NonbondedForce"].getReactionFieldDielectric() + standard_nonbonded_force.setReactionFieldDielectric(epsilon_solvent) standard_nonbonded_force.setCutoffDistance(r_cutoff) - elif self._nonbonded_method in [openmm.NonbondedForce.PME, - openmm.NonbondedForce.Ewald]: - [alpha_ewald, nx, ny, nz] = self._old_system_forces['NonbondedForce'].getPMEParameters() - delta = self._old_system_forces['NonbondedForce'].getEwaldErrorTolerance() + elif self._nonbonded_method in [openmm.NonbondedForce.PME, openmm.NonbondedForce.Ewald]: + [alpha_ewald, nx, ny, nz] = self._old_system_forces["NonbondedForce"].getPMEParameters() + delta = self._old_system_forces["NonbondedForce"].getEwaldErrorTolerance() standard_nonbonded_force.setPMEParameters(alpha_ewald, nx, ny, nz) standard_nonbonded_force.setEwaldErrorTolerance(delta) standard_nonbonded_force.setCutoffDistance(r_cutoff) @@ -928,22 +926,18 @@ def _add_nonbonded_force_terms(self): sterics_mixing_rules = self._nonbonded_custom_mixing_rules() - custom_nonbonded_method = self._translate_nonbonded_method_to_custom( - self._nonbonded_method) + custom_nonbonded_method = self._translate_nonbonded_method_to_custom(self._nonbonded_method) total_sterics_energy = "U_sterics;" + sterics_energy_expression + sterics_mixing_rules - sterics_custom_nonbonded_force = openmm.CustomNonbondedForce( - total_sterics_energy) + sterics_custom_nonbonded_force = openmm.CustomNonbondedForce(total_sterics_energy) # Match cutoff from non-custom NB forces sterics_custom_nonbonded_force.setCutoffDistance(r_cutoff) if self._softcore_LJ_v2: - sterics_custom_nonbonded_force.addGlobalParameter( - "softcore_alpha", self._softcore_LJ_v2_alpha) + sterics_custom_nonbonded_force.addGlobalParameter("softcore_alpha", self._softcore_LJ_v2_alpha) else: - sterics_custom_nonbonded_force.addGlobalParameter( - "softcore_alpha", self._softcore_alpha) + sterics_custom_nonbonded_force.addGlobalParameter("softcore_alpha", self._softcore_alpha) # Lennard-Jones sigma initial sterics_custom_nonbonded_force.addPerParticleParameter("sigmaA") @@ -958,32 +952,27 @@ def _add_nonbonded_force_terms(self): # 1 = hybrid new atom, 0 otherwise sterics_custom_nonbonded_force.addPerParticleParameter("unique_new") - sterics_custom_nonbonded_force.addGlobalParameter( - "lambda_sterics_core", 0.0) - sterics_custom_nonbonded_force.addGlobalParameter( - "lambda_electrostatics_core", 0.0) - sterics_custom_nonbonded_force.addGlobalParameter( - "lambda_sterics_insert", 0.0) - sterics_custom_nonbonded_force.addGlobalParameter( - "lambda_sterics_delete", 0.0) + sterics_custom_nonbonded_force.addGlobalParameter("lambda_sterics_core", 0.0) + sterics_custom_nonbonded_force.addGlobalParameter("lambda_electrostatics_core", 0.0) + sterics_custom_nonbonded_force.addGlobalParameter("lambda_sterics_insert", 0.0) + sterics_custom_nonbonded_force.addGlobalParameter("lambda_sterics_delete", 0.0) - sterics_custom_nonbonded_force.setNonbondedMethod( - custom_nonbonded_method) + sterics_custom_nonbonded_force.setNonbondedMethod(custom_nonbonded_method) self._hybrid_system.addForce(sterics_custom_nonbonded_force) - self._hybrid_system_forces['core_sterics_force'] = sterics_custom_nonbonded_force + self._hybrid_system_forces["core_sterics_force"] = sterics_custom_nonbonded_force # Set the use of dispersion correction to be the same between the new # nonbonded force and the old one: - if self._old_system_forces['NonbondedForce'].getUseDispersionCorrection(): - self._hybrid_system_forces['standard_nonbonded_force'].setUseDispersionCorrection(True) + if self._old_system_forces["NonbondedForce"].getUseDispersionCorrection(): + self._hybrid_system_forces["standard_nonbonded_force"].setUseDispersionCorrection(True) if self._use_dispersion_correction: sterics_custom_nonbonded_force.setUseLongRangeCorrection(True) else: - self._hybrid_system_forces['standard_nonbonded_force'].setUseDispersionCorrection(False) + self._hybrid_system_forces["standard_nonbonded_force"].setUseDispersionCorrection(False) - if self._old_system_forces['NonbondedForce'].getUseSwitchingFunction(): - switching_distance = self._old_system_forces['NonbondedForce'].getSwitchingDistance() + if self._old_system_forces["NonbondedForce"].getUseSwitchingFunction(): + switching_distance = self._old_system_forces["NonbondedForce"].getSwitchingDistance() standard_nonbonded_force.setUseSwitchingFunction(True) standard_nonbonded_force.setSwitchingDistance(switching_distance) sterics_custom_nonbonded_force.setUseSwitchingFunction(True) @@ -1038,8 +1027,8 @@ def _handle_harmonic_bonds(self): ----- * Bond softening logic has been removed for now. """ - old_system_bond_force = self._old_system_forces['HarmonicBondForce'] - new_system_bond_force = self._new_system_forces['HarmonicBondForce'] + old_system_bond_force = self._old_system_forces["HarmonicBondForce"] + new_system_bond_force = self._new_system_forces["HarmonicBondForce"] # First, loop through the old system bond forces and add relevant terms for bond_index in range(old_system_bond_force.getNumBonds()): @@ -1056,21 +1045,25 @@ def _handle_harmonic_bonds(self): # atoms are in the core) # If it is, we need to find the parameters in the old system so # that we can interpolate - if index_set.issubset(self._atom_classes['core_atoms']): + if index_set.issubset(self._atom_classes["core_atoms"]): index1_new = self._old_to_new_map[index1_old] index2_new = self._old_to_new_map[index2_old] - new_bond_parameters = self._find_bond_parameters( - new_system_bond_force, index1_new, index2_new) + new_bond_parameters = self._find_bond_parameters(new_system_bond_force, index1_new, index2_new) if not new_bond_parameters: r0_new = r0_old - k_new = 0.0*unit.kilojoule_per_mole/unit.angstrom**2 + k_new = 0.0 * unit.kilojoule_per_mole / unit.angstrom**2 else: # TODO - why is this being recalculated? [index1, index2, r0_new, k_new] = self._find_bond_parameters( - new_system_bond_force, index1_new, index2_new) - self._hybrid_system_forces['core_bond_force'].addBond( - index1_hybrid, index2_hybrid, - [r0_old, k_old, r0_new, k_new]) + new_system_bond_force, + index1_new, + index2_new, + ) + self._hybrid_system_forces["core_bond_force"].addBond( + index1_hybrid, + index2_hybrid, + [r0_old, k_old, r0_new, k_new], + ) # Check if the index set is a subset of anything besides # environment (in the case of environment, we just add the bond to @@ -1080,28 +1073,27 @@ def _handle_harmonic_bonds(self): # NOTE - These are currently all the same because we don't soften # TODO - work these out somewhere else, this is terribly difficult # to understand logic. - elif (index_set.issubset(self._atom_classes['unique_old_atoms']) or - (len(index_set.intersection(self._atom_classes['unique_old_atoms'])) == 1 - and len(index_set.intersection(self._atom_classes['core_atoms'])) == 1)): + elif index_set.issubset(self._atom_classes["unique_old_atoms"]) or ( + len(index_set.intersection(self._atom_classes["unique_old_atoms"])) == 1 + and len(index_set.intersection(self._atom_classes["core_atoms"])) == 1 + ): # We can just add it to the regular bond force. - self._hybrid_system_forces['standard_bond_force'].addBond( - index1_hybrid, index2_hybrid, r0_old, k_old) + self._hybrid_system_forces["standard_bond_force"].addBond(index1_hybrid, index2_hybrid, r0_old, k_old) - elif (len(index_set.intersection(self._atom_classes['environment_atoms'])) == 1 and - len(index_set.intersection(self._atom_classes['core_atoms'])) == 1): - self._hybrid_system_forces['standard_bond_force'].addBond( - index1_hybrid, index2_hybrid, r0_old, k_old) + elif ( + len(index_set.intersection(self._atom_classes["environment_atoms"])) == 1 + and len(index_set.intersection(self._atom_classes["core_atoms"])) == 1 + ): + self._hybrid_system_forces["standard_bond_force"].addBond(index1_hybrid, index2_hybrid, r0_old, k_old) # Otherwise, we just add the same parameters as those in the old # system (these are environment atoms, and the parameters are the # same) - elif index_set.issubset(self._atom_classes['environment_atoms']): - self._hybrid_system_forces['standard_bond_force'].addBond( - index1_hybrid, index2_hybrid, r0_old, k_old) + elif index_set.issubset(self._atom_classes["environment_atoms"]): + self._hybrid_system_forces["standard_bond_force"].addBond(index1_hybrid, index2_hybrid, r0_old, k_old) else: - errmsg = (f"hybrid index set {index_set} does not fit into a " - "canonical atom type") + errmsg = f"hybrid index set {index_set} does not fit into a " "canonical atom type" raise ValueError(errmsg) # Now loop through the new system to get the interactions that are @@ -1119,13 +1111,13 @@ def _handle_harmonic_bonds(self): # anything, the bond is unique to the new system and must be added # all other bonds in the new system have been accounted for already # NOTE - These are mostly all the same because we don't soften - if (len(index_set.intersection(self._atom_classes['unique_new_atoms'])) == 2 or - (len(index_set.intersection(self._atom_classes['unique_new_atoms'])) == 1 and - len(index_set.intersection(self._atom_classes['core_atoms'])) == 1)): + if len(index_set.intersection(self._atom_classes["unique_new_atoms"])) == 2 or ( + len(index_set.intersection(self._atom_classes["unique_new_atoms"])) == 1 + and len(index_set.intersection(self._atom_classes["core_atoms"])) == 1 + ): # If we aren't softening bonds, then just add it to the standard bond force - self._hybrid_system_forces['standard_bond_force'].addBond( - index1_hybrid, index2_hybrid, r0_new, k_new) + self._hybrid_system_forces["standard_bond_force"].addBond(index1_hybrid, index2_hybrid, r0_new, k_new) # If the bond is in the core, it has probably already been added # in the above loop. However, there are some circumstances @@ -1133,26 +1125,31 @@ def _handle_harmonic_bonds(self): # not been added and should be added here. # This has some peculiarities to be discussed... # TODO - Work out what the above peculiarities are... - elif index_set.issubset(self._atom_classes['core_atoms']): + elif index_set.issubset(self._atom_classes["core_atoms"]): if not self._find_bond_parameters( - self._hybrid_system_forces['core_bond_force'], - index1_hybrid, index2_hybrid): + self._hybrid_system_forces["core_bond_force"], + index1_hybrid, + index2_hybrid, + ): r0_old = r0_new - k_old = 0.0*unit.kilojoule_per_mole/unit.angstrom**2 - self._hybrid_system_forces['core_bond_force'].addBond( - index1_hybrid, index2_hybrid, - [r0_old, k_old, r0_new, k_new]) - elif index_set.issubset(self._atom_classes['environment_atoms']): + k_old = 0.0 * unit.kilojoule_per_mole / unit.angstrom**2 + self._hybrid_system_forces["core_bond_force"].addBond( + index1_hybrid, + index2_hybrid, + [r0_old, k_old, r0_new, k_new], + ) + elif index_set.issubset(self._atom_classes["environment_atoms"]): # Already been added pass - elif (len(index_set.intersection(self._atom_classes['environment_atoms'])) == 1 and - len(index_set.intersection(self._atom_classes['core_atoms'])) == 1): + elif ( + len(index_set.intersection(self._atom_classes["environment_atoms"])) == 1 + and len(index_set.intersection(self._atom_classes["core_atoms"])) == 1 + ): pass else: - errmsg = (f"hybrid index set {index_set} does not fit into a " - "canonical atom type") + errmsg = f"hybrid index set {index_set} does not fit into a " "canonical atom type" raise ValueError(errmsg) @staticmethod @@ -1182,8 +1179,7 @@ def _find_angle_parameters(angle_force, indices): # Get a set representing the angle indices angle_param_indices = angle_params[:3] - if (indices == angle_param_indices or - indices_reversed == angle_param_indices): + if indices == angle_param_indices or indices_reversed == angle_param_indices: return angle_params return [] # Return empty if no matching angle found @@ -1210,8 +1206,8 @@ def _handle_harmonic_angles(self): ----- * Removed softening and neglected angle functionality """ - old_system_angle_force = self._old_system_forces['HarmonicAngleForce'] - new_system_angle_force = self._new_system_forces['HarmonicAngleForce'] + old_system_angle_force = self._old_system_forces["HarmonicAngleForce"] + new_system_angle_force = self._new_system_forces["HarmonicAngleForce"] # First, loop through all the angles in the old system to determine # what to do with them. We will only use the @@ -1219,59 +1215,58 @@ def _handle_harmonic_angles(self): # are either unique to one system or never change. for angle_index in range(old_system_angle_force.getNumAngles()): - old_angle_parameters = old_system_angle_force.getAngleParameters( - angle_index) + old_angle_parameters = old_system_angle_force.getAngleParameters(angle_index) # Get the indices in the hybrid system - hybrid_index_list = [ - self._old_to_hybrid_map[old_atomid] for old_atomid in old_angle_parameters[:3] - ] + hybrid_index_list = [self._old_to_hybrid_map[old_atomid] for old_atomid in old_angle_parameters[:3]] hybrid_index_set = set(hybrid_index_list) # If all atoms are in the core, we'll need to find the # corresponding parameters in the old system and interpolate - if hybrid_index_set.issubset(self._atom_classes['core_atoms']): + if hybrid_index_set.issubset(self._atom_classes["core_atoms"]): # Get the new indices so we can get the new angle parameters - new_indices = [ - self._old_to_new_map[old_atomid] for old_atomid in old_angle_parameters[:3] - ] - new_angle_parameters = self._find_angle_parameters( - new_system_angle_force, new_indices - ) + new_indices = [self._old_to_new_map[old_atomid] for old_atomid in old_angle_parameters[:3]] + new_angle_parameters = self._find_angle_parameters(new_system_angle_force, new_indices) if not new_angle_parameters: new_angle_parameters = [ - 0, 0, 0, old_angle_parameters[3], - 0.0*unit.kilojoule_per_mole/unit.radian**2 + 0, + 0, + 0, + old_angle_parameters[3], + 0.0 * unit.kilojoule_per_mole / unit.radian**2, ] # Add to the hybrid force: # the parameters at indices 3 and 4 represent theta0 and k, # respectively. hybrid_force_parameters = [ - old_angle_parameters[3], old_angle_parameters[4], - new_angle_parameters[3], new_angle_parameters[4] + old_angle_parameters[3], + old_angle_parameters[4], + new_angle_parameters[3], + new_angle_parameters[4], ] - self._hybrid_system_forces['core_angle_force'].addAngle( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_force_parameters + self._hybrid_system_forces["core_angle_force"].addAngle( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + hybrid_force_parameters, ) # Check if the atoms are neither all core nor all environment, # which would mean they involve unique old interactions - elif not hybrid_index_set.issubset( - self._atom_classes['environment_atoms']): + elif not hybrid_index_set.issubset(self._atom_classes["environment_atoms"]): # if there is an environment atom - if hybrid_index_set.intersection( - self._atom_classes['environment_atoms']): - if hybrid_index_set.intersection( - self._atom_classes['unique_old_atoms']): + if hybrid_index_set.intersection(self._atom_classes["environment_atoms"]): + if hybrid_index_set.intersection(self._atom_classes["unique_old_atoms"]): errmsg = "we disallow unique-environment terms" raise ValueError(errmsg) - self._hybrid_system_forces['standard_angle_force'].addAngle( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], old_angle_parameters[3], - old_angle_parameters[4] + self._hybrid_system_forces["standard_angle_force"].addAngle( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + old_angle_parameters[3], + old_angle_parameters[4], ) else: # There are no env atoms, so we can treat this term @@ -1279,82 +1274,80 @@ def _handle_harmonic_angles(self): # We don't soften so just add this to the standard angle # force - self._hybrid_system_forces['standard_angle_force'].addAngle( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], old_angle_parameters[3], - old_angle_parameters[4] + self._hybrid_system_forces["standard_angle_force"].addAngle( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + old_angle_parameters[3], + old_angle_parameters[4], ) # Otherwise, only environment atoms are in this interaction, so # add it to the standard angle force - elif hybrid_index_set.issubset( - self._atom_classes['environment_atoms']): - self._hybrid_system_forces['standard_angle_force'].addAngle( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], old_angle_parameters[3], - old_angle_parameters[4] + elif hybrid_index_set.issubset(self._atom_classes["environment_atoms"]): + self._hybrid_system_forces["standard_angle_force"].addAngle( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + old_angle_parameters[3], + old_angle_parameters[4], ) else: - errmsg = (f"handle_harmonic_angles: angle_index {angle_index} " - "does not fit a canonical form.") + errmsg = f"handle_harmonic_angles: angle_index {angle_index} " "does not fit a canonical form." raise ValueError(errmsg) # Finally, loop through the new system force to add any unique new # angles for angle_index in range(new_system_angle_force.getNumAngles()): - new_angle_parameters = new_system_angle_force.getAngleParameters( - angle_index) + new_angle_parameters = new_system_angle_force.getAngleParameters(angle_index) # Get the indices in the hybrid system - hybrid_index_list = [ - self._new_to_hybrid_map[new_atomid] for new_atomid in new_angle_parameters[:3] - ] + hybrid_index_list = [self._new_to_hybrid_map[new_atomid] for new_atomid in new_angle_parameters[:3]] hybrid_index_set = set(hybrid_index_list) # If the intersection of this hybrid set with the unique new atoms # is nonempty, it must be added: # TODO - there's a ton of len > 0 on sets, empty sets == False, # so we can simplify this logic. - if len(hybrid_index_set.intersection( - self._atom_classes['unique_new_atoms'])) > 0: - if hybrid_index_set.intersection( - self._atom_classes['environment_atoms']): - errmsg = ("we disallow angle terms with unique new and " - "environment atoms") + if len(hybrid_index_set.intersection(self._atom_classes["unique_new_atoms"])) > 0: + if hybrid_index_set.intersection(self._atom_classes["environment_atoms"]): + errmsg = "we disallow angle terms with unique new and " "environment atoms" raise ValueError(errmsg) # Not softening just add to the nonalchemical force - self._hybrid_system_forces['standard_angle_force'].addAngle( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], new_angle_parameters[3], - new_angle_parameters[4] + self._hybrid_system_forces["standard_angle_force"].addAngle( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + new_angle_parameters[3], + new_angle_parameters[4], ) - elif hybrid_index_set.issubset(self._atom_classes['core_atoms']): - if not self._find_angle_parameters(self._hybrid_system_forces['core_angle_force'], - hybrid_index_list): + elif hybrid_index_set.issubset(self._atom_classes["core_atoms"]): + if not self._find_angle_parameters(self._hybrid_system_forces["core_angle_force"], hybrid_index_list): hybrid_force_parameters = [ new_angle_parameters[3], - 0.0*unit.kilojoule_per_mole/unit.radian**2, - new_angle_parameters[3], new_angle_parameters[4] + 0.0 * unit.kilojoule_per_mole / unit.radian**2, + new_angle_parameters[3], + new_angle_parameters[4], ] - self._hybrid_system_forces['core_angle_force'].addAngle( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_force_parameters + self._hybrid_system_forces["core_angle_force"].addAngle( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + hybrid_force_parameters, ) - elif hybrid_index_set.issubset(self._atom_classes['environment_atoms']): + elif hybrid_index_set.issubset(self._atom_classes["environment_atoms"]): # We have already added the appropriate environmental atom # terms pass - elif hybrid_index_set.intersection(self._atom_classes['environment_atoms']): - if hybrid_index_set.intersection(self._atom_classes['unique_new_atoms']): - errmsg = ("we disallow angle terms with unique new and " - "environment atoms") + elif hybrid_index_set.intersection(self._atom_classes["environment_atoms"]): + if hybrid_index_set.intersection(self._atom_classes["unique_new_atoms"]): + errmsg = "we disallow angle terms with unique new and " "environment atoms" raise ValueError(errmsg) else: - errmsg = (f"hybrid index list {hybrid_index_list} does not " - "fit into a canonical atom set") + errmsg = f"hybrid index list {hybrid_index_list} does not " "fit into a canonical atom set" raise ValueError(errmsg) @staticmethod @@ -1386,8 +1379,7 @@ def _find_torsion_parameters(torsion_force, indices): # Get a set representing the torsion indices: torsion_param_indices = torsion_params[:4] - if (indices == torsion_param_indices or - indices_reversed == torsion_param_indices): + if indices == torsion_param_indices or indices_reversed == torsion_param_indices: torsion_params_list.append(torsion_params) return torsion_params_list @@ -1410,8 +1402,8 @@ def _handle_periodic_torsion_force(self): ----- * Torsion flattening logic has been removed for now. """ - old_system_torsion_force = self._old_system_forces['PeriodicTorsionForce'] - new_system_torsion_force = self._new_system_forces['PeriodicTorsionForce'] + old_system_torsion_force = self._old_system_forces["PeriodicTorsionForce"] + new_system_torsion_force = self._new_system_forces["PeriodicTorsionForce"] auxiliary_custom_torsion_force = [] old_custom_torsions_to_standard = [] @@ -1423,92 +1415,119 @@ def _handle_periodic_torsion_force(self): # Is it necessary? Should we add this logic back in? for torsion_index in range(old_system_torsion_force.getNumTorsions()): - torsion_parameters = old_system_torsion_force.getTorsionParameters( - torsion_index) + torsion_parameters = old_system_torsion_force.getTorsionParameters(torsion_index) # Get the indices in the hybrid system - hybrid_index_list = [ - self._old_to_hybrid_map[old_index] for old_index in torsion_parameters[:4] - ] + hybrid_index_list = [self._old_to_hybrid_map[old_index] for old_index in torsion_parameters[:4]] hybrid_index_set = set(hybrid_index_list) # If all atoms are in the core, we'll need to find the # corresponding parameters in the old system and interpolate - if hybrid_index_set.intersection(self._atom_classes['unique_old_atoms']): + if hybrid_index_set.intersection(self._atom_classes["unique_old_atoms"]): # Then it goes to a standard force... - self._hybrid_system_forces['unique_atom_torsion_force'].addTorsion( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_index_list[3], - torsion_parameters[4], torsion_parameters[5], - torsion_parameters[6] + self._hybrid_system_forces["unique_atom_torsion_force"].addTorsion( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + hybrid_index_list[3], + torsion_parameters[4], + torsion_parameters[5], + torsion_parameters[6], ) else: # It is a core-only term, an environment-only term, or a # core/env term; in any case, it goes to the core torsion_force # TODO - why are we even adding the 0.0, 0.0, 0.0 section? hybrid_force_parameters = [ - torsion_parameters[4], torsion_parameters[5], - torsion_parameters[6], 0.0, 0.0, 0.0 + torsion_parameters[4], + torsion_parameters[5], + torsion_parameters[6], + 0.0, + 0.0, + 0.0, ] auxiliary_custom_torsion_force.append( - [hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_index_list[3], - hybrid_force_parameters[:3]] + [ + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + hybrid_index_list[3], + hybrid_force_parameters[:3], + ], ) for torsion_index in range(new_system_torsion_force.getNumTorsions()): torsion_parameters = new_system_torsion_force.getTorsionParameters(torsion_index) # Get the indices in the hybrid system: - hybrid_index_list = [ - self._new_to_hybrid_map[new_index] for new_index in torsion_parameters[:4]] + hybrid_index_list = [self._new_to_hybrid_map[new_index] for new_index in torsion_parameters[:4]] hybrid_index_set = set(hybrid_index_list) - if hybrid_index_set.intersection(self._atom_classes['unique_new_atoms']): + if hybrid_index_set.intersection(self._atom_classes["unique_new_atoms"]): # Then it goes to the custom torsion force (scaled to zero) - self._hybrid_system_forces['unique_atom_torsion_force'].addTorsion( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_index_list[3], - torsion_parameters[4], torsion_parameters[5], - torsion_parameters[6] + self._hybrid_system_forces["unique_atom_torsion_force"].addTorsion( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + hybrid_index_list[3], + torsion_parameters[4], + torsion_parameters[5], + torsion_parameters[6], ) else: hybrid_force_parameters = [ - 0.0, 0.0, 0.0, torsion_parameters[4], - torsion_parameters[5], torsion_parameters[6]] + 0.0, + 0.0, + 0.0, + torsion_parameters[4], + torsion_parameters[5], + torsion_parameters[6], + ] # Check to see if this term is in the olds... - term = [hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_index_list[3], - hybrid_force_parameters[3:]] + term = [ + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + hybrid_index_list[3], + hybrid_force_parameters[3:], + ] if term in auxiliary_custom_torsion_force: # Then this terms has to go to standard and be deleted... old_index = auxiliary_custom_torsion_force.index(term) old_custom_torsions_to_standard.append(old_index) - self._hybrid_system_forces['unique_atom_torsion_force'].addTorsion( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_index_list[3], - torsion_parameters[4], torsion_parameters[5], - torsion_parameters[6] + self._hybrid_system_forces["unique_atom_torsion_force"].addTorsion( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + hybrid_index_list[3], + torsion_parameters[4], + torsion_parameters[5], + torsion_parameters[6], ) else: # Then this term has to go to the core force... - self._hybrid_system_forces['custom_torsion_force'].addTorsion( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_index_list[3], - hybrid_force_parameters + self._hybrid_system_forces["custom_torsion_force"].addTorsion( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + hybrid_index_list[3], + hybrid_force_parameters, ) # Now we have to loop through the aux custom torsion force - for index in [q for q in range(len(auxiliary_custom_torsion_force)) - if q not in old_custom_torsions_to_standard]: + for index in [ + q for q in range(len(auxiliary_custom_torsion_force)) if q not in old_custom_torsions_to_standard + ]: terms = auxiliary_custom_torsion_force[index] hybrid_index_list = terms[:4] - hybrid_force_parameters = terms[4] + [0., 0., 0.] - self._hybrid_system_forces['custom_torsion_force'].addTorsion( - hybrid_index_list[0], hybrid_index_list[1], - hybrid_index_list[2], hybrid_index_list[3], - hybrid_force_parameters + hybrid_force_parameters = terms[4] + [0.0, 0.0, 0.0] + self._hybrid_system_forces["custom_torsion_force"].addTorsion( + hybrid_index_list[0], + hybrid_index_list[1], + hybrid_index_list[2], + hybrid_index_list[3], + hybrid_force_parameters, ) def _handle_nonbonded(self): @@ -1521,28 +1540,28 @@ def _handle_nonbonded(self): * A lot of this logic is duplicated, probably turn it into a couple of functions. """ + def _check_indices(idx1, idx2): if idx1 != idx2: - errmsg = ("Attempting to add incorrect particle to hybrid " - "system") + errmsg = "Attempting to add incorrect particle to hybrid " "system" raise ValueError(errmsg) - old_system_nonbonded_force = self._old_system_forces['NonbondedForce'] - new_system_nonbonded_force = self._new_system_forces['NonbondedForce'] + old_system_nonbonded_force = self._old_system_forces["NonbondedForce"] + new_system_nonbonded_force = self._new_system_forces["NonbondedForce"] hybrid_to_old_map = self._hybrid_to_old_map hybrid_to_new_map = self._hybrid_to_new_map # Define new global parameters for NonbondedForce - self._hybrid_system_forces['standard_nonbonded_force'].addGlobalParameter('lambda_electrostatics_core', 0.0) - self._hybrid_system_forces['standard_nonbonded_force'].addGlobalParameter('lambda_sterics_core', 0.0) - self._hybrid_system_forces['standard_nonbonded_force'].addGlobalParameter("lambda_electrostatics_delete", 0.0) - self._hybrid_system_forces['standard_nonbonded_force'].addGlobalParameter("lambda_electrostatics_insert", 0.0) + self._hybrid_system_forces["standard_nonbonded_force"].addGlobalParameter("lambda_electrostatics_core", 0.0) + self._hybrid_system_forces["standard_nonbonded_force"].addGlobalParameter("lambda_sterics_core", 0.0) + self._hybrid_system_forces["standard_nonbonded_force"].addGlobalParameter("lambda_electrostatics_delete", 0.0) + self._hybrid_system_forces["standard_nonbonded_force"].addGlobalParameter("lambda_electrostatics_insert", 0.0) # We have to loop through the particles in the system, because # nonbonded force does not accept index for particle_index in range(self._hybrid_system.getNumParticles()): - if particle_index in self._atom_classes['unique_old_atoms']: + if particle_index in self._atom_classes["unique_old_atoms"]: # Get the parameters in the old system old_index = hybrid_to_old_map[particle_index] [charge, sigma, epsilon] = old_system_nonbonded_force.getParticleParameters(old_index) @@ -1550,15 +1569,17 @@ def _check_indices(idx1, idx2): # Add the particle to the hybrid custom sterics and # electrostatics. # turning off sterics in forward direction - check_index = self._hybrid_system_forces['core_sterics_force'].addParticle( - [sigma, epsilon, sigma, 0.0*epsilon, 1, 0] + check_index = self._hybrid_system_forces["core_sterics_force"].addParticle( + [sigma, epsilon, sigma, 0.0 * epsilon, 1, 0], ) _check_indices(particle_index, check_index) # Add particle to the regular nonbonded force, but # Lennard-Jones will be handled by CustomNonbondedForce - check_index = self._hybrid_system_forces['standard_nonbonded_force'].addParticle( - charge, sigma, 0.0*epsilon + check_index = self._hybrid_system_forces["standard_nonbonded_force"].addParticle( + charge, + sigma, + 0.0 * epsilon, ) _check_indices(particle_index, check_index) @@ -1566,39 +1587,47 @@ def _check_indices(idx1, idx2): # lambda_electrostatics_delete = 0, on at # lambda_electrostatics_delete = 1; kill charge with # lambda_electrostatics_delete = 0 --> 1 - self._hybrid_system_forces['standard_nonbonded_force'].addParticleParameterOffset( - 'lambda_electrostatics_delete', particle_index, - -charge, 0*sigma, 0*epsilon + self._hybrid_system_forces["standard_nonbonded_force"].addParticleParameterOffset( + "lambda_electrostatics_delete", + particle_index, + -charge, + 0 * sigma, + 0 * epsilon, ) - elif particle_index in self._atom_classes['unique_new_atoms']: + elif particle_index in self._atom_classes["unique_new_atoms"]: # Get the parameters in the new system new_index = hybrid_to_new_map[particle_index] [charge, sigma, epsilon] = new_system_nonbonded_force.getParticleParameters(new_index) # Add the particle to the hybrid custom sterics and electrostatics # turning on sterics in forward direction - check_index = self._hybrid_system_forces['core_sterics_force'].addParticle( - [sigma, 0.0*epsilon, sigma, epsilon, 0, 1] + check_index = self._hybrid_system_forces["core_sterics_force"].addParticle( + [sigma, 0.0 * epsilon, sigma, epsilon, 0, 1], ) _check_indices(particle_index, check_index) # Add particle to the regular nonbonded force, but # Lennard-Jones will be handled by CustomNonbondedForce - check_index = self._hybrid_system_forces['standard_nonbonded_force'].addParticle( - 0.0, sigma, 0.0 + check_index = self._hybrid_system_forces["standard_nonbonded_force"].addParticle( + 0.0, + sigma, + 0.0, ) # charge starts at zero _check_indices(particle_index, check_index) # Charge will be turned off at lambda_electrostatics_insert = 0 # on at lambda_electrostatics_insert = 1; # add charge with lambda_electrostatics_insert = 0 --> 1 - self._hybrid_system_forces['standard_nonbonded_force'].addParticleParameterOffset( - 'lambda_electrostatics_insert', particle_index, - +charge, 0, 0 + self._hybrid_system_forces["standard_nonbonded_force"].addParticleParameterOffset( + "lambda_electrostatics_insert", + particle_index, + +charge, + 0, + 0, ) - elif particle_index in self._atom_classes['core_atoms']: + elif particle_index in self._atom_classes["core_atoms"]: # Get the parameters in the new and old systems: old_index = hybrid_to_old_map[particle_index] [charge_old, sigma_old, epsilon_old] = old_system_nonbonded_force.getParticleParameters(old_index) @@ -1608,15 +1637,19 @@ def _check_indices(idx1, idx2): # Add the particle to the custom forces, interpolating between # the two parameters; add steric params and zero electrostatics # to core_sterics per usual - check_index = self._hybrid_system_forces['core_sterics_force'].addParticle( - [sigma_old, epsilon_old, sigma_new, epsilon_new, 0, 0]) + check_index = self._hybrid_system_forces["core_sterics_force"].addParticle( + [sigma_old, epsilon_old, sigma_new, epsilon_new, 0, 0], + ) _check_indices(particle_index, check_index) # Still add the particle to the regular nonbonded force, but # with zeroed out parameters; add old charge to # standard_nonbonded and zero sterics - check_index = self._hybrid_system_forces['standard_nonbonded_force'].addParticle( - charge_old, 0.5*(sigma_old+sigma_new), 0.0) + check_index = self._hybrid_system_forces["standard_nonbonded_force"].addParticle( + charge_old, + 0.5 * (sigma_old + sigma_new), + 0.0, + ) _check_indices(particle_index, check_index) # Charge is charge_old at lambda_electrostatics = 0, @@ -1627,9 +1660,12 @@ def _check_indices(idx1, idx2): # Interpolate between old and new charge with # lambda_electrostatics core make sure to keep sterics off - self._hybrid_system_forces['standard_nonbonded_force'].addParticleParameterOffset( - 'lambda_electrostatics_core', particle_index, - (charge_new - charge_old), 0, 0 + self._hybrid_system_forces["standard_nonbonded_force"].addParticleParameterOffset( + "lambda_electrostatics_core", + particle_index, + (charge_new - charge_old), + 0, + 0, ) # Otherwise, the particle is in the environment @@ -1641,30 +1677,30 @@ def _check_indices(idx1, idx2): # Add the particle to the hybrid custom sterics, but they dont # change; electrostatics are ignored - self._hybrid_system_forces['core_sterics_force'].addParticle( - [sigma, epsilon, sigma, epsilon, 0, 0] - ) + self._hybrid_system_forces["core_sterics_force"].addParticle([sigma, epsilon, sigma, epsilon, 0, 0]) # Add the environment atoms to the regular nonbonded force as # well: should we be adding steric terms here, too? - self._hybrid_system_forces['standard_nonbonded_force'].addParticle( - charge, sigma, epsilon - ) + self._hybrid_system_forces["standard_nonbonded_force"].addParticle(charge, sigma, epsilon) # Now loop pairwise through (unique_old, unique_new) and add exceptions # so that they never interact electrostatically # (place into Nonbonded Force) - unique_old_atoms = self._atom_classes['unique_old_atoms'] - unique_new_atoms = self._atom_classes['unique_new_atoms'] + unique_old_atoms = self._atom_classes["unique_old_atoms"] + unique_new_atoms = self._atom_classes["unique_new_atoms"] for old in unique_old_atoms: for new in unique_new_atoms: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - old, new, 0.0*unit.elementary_charge**2, - 1.0*unit.nanometers, 0.0*unit.kilojoules_per_mole) + self._hybrid_system_forces["standard_nonbonded_force"].addException( + old, + new, + 0.0 * unit.elementary_charge**2, + 1.0 * unit.nanometers, + 0.0 * unit.kilojoules_per_mole, + ) # This is only necessary to avoid the 'All forces must have # identical exclusions' rule - self._hybrid_system_forces['core_sterics_force'].addExclusion(old, new) + self._hybrid_system_forces["core_sterics_force"].addExclusion(old, new) self._handle_interaction_groups() @@ -1693,34 +1729,29 @@ def _handle_interaction_groups(self): 8) Unique-old - Unique-old """ # Get the force objects for convenience: - sterics_custom_force = self._hybrid_system_forces['core_sterics_force'] + sterics_custom_force = self._hybrid_system_forces["core_sterics_force"] # Also prepare the atom classes - core_atoms = self._atom_classes['core_atoms'] - unique_old_atoms = self._atom_classes['unique_old_atoms'] - unique_new_atoms = self._atom_classes['unique_new_atoms'] - environment_atoms = self._atom_classes['environment_atoms'] + core_atoms = self._atom_classes["core_atoms"] + unique_old_atoms = self._atom_classes["unique_old_atoms"] + unique_new_atoms = self._atom_classes["unique_new_atoms"] + environment_atoms = self._atom_classes["environment_atoms"] sterics_custom_force.addInteractionGroup(unique_old_atoms, core_atoms) - sterics_custom_force.addInteractionGroup(unique_old_atoms, - environment_atoms) + sterics_custom_force.addInteractionGroup(unique_old_atoms, environment_atoms) - sterics_custom_force.addInteractionGroup(unique_new_atoms, - core_atoms) + sterics_custom_force.addInteractionGroup(unique_new_atoms, core_atoms) - sterics_custom_force.addInteractionGroup(unique_new_atoms, - environment_atoms) + sterics_custom_force.addInteractionGroup(unique_new_atoms, environment_atoms) sterics_custom_force.addInteractionGroup(core_atoms, environment_atoms) sterics_custom_force.addInteractionGroup(core_atoms, core_atoms) - sterics_custom_force.addInteractionGroup(unique_new_atoms, - unique_new_atoms) + sterics_custom_force.addInteractionGroup(unique_new_atoms, unique_new_atoms) - sterics_custom_force.addInteractionGroup(unique_old_atoms, - unique_old_atoms) + sterics_custom_force.addInteractionGroup(unique_old_atoms, unique_old_atoms) def _handle_hybrid_exceptions(self): """ @@ -1728,12 +1759,12 @@ def _handle_hybrid_exceptions(self): exceptions for interactions that were zeroed out but should occur. """ # TODO - are these actually used anywhere? Flake8 says no - old_system_nonbonded_force = self._old_system_forces['NonbondedForce'] - new_system_nonbonded_force = self._new_system_forces['NonbondedForce'] + old_system_nonbonded_force = self._old_system_forces["NonbondedForce"] + new_system_nonbonded_force = self._new_system_forces["NonbondedForce"] # Prepare the atom classes - unique_old_atoms = self._atom_classes['unique_old_atoms'] - unique_new_atoms = self._atom_classes['unique_new_atoms'] + unique_old_atoms = self._atom_classes["unique_old_atoms"] + unique_new_atoms = self._atom_classes["unique_new_atoms"] # Get the list of interaction pairs for which we need to set exceptions unique_old_pairs = list(itertools.combinations(unique_old_atoms, 2)) @@ -1744,44 +1775,55 @@ def _handle_hybrid_exceptions(self): for atom_pair in unique_old_pairs: # Since the pairs are indexed in the dictionary by the old system # indices, we need to convert - old_index_atom_pair = (self._hybrid_to_old_map[atom_pair[0]], - self._hybrid_to_old_map[atom_pair[1]]) + old_index_atom_pair = (self._hybrid_to_old_map[atom_pair[0]], self._hybrid_to_old_map[atom_pair[1]]) # Now we check if the pair is in the exception dictionary if old_index_atom_pair in self._old_system_exceptions: [chargeProd, sigma, epsilon] = self._old_system_exceptions[old_index_atom_pair] # if we are interpolating 1,4 exceptions then we have to if self._interpolate_14s: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd*0.0, - sigma, epsilon*0.0 + self._hybrid_system_forces["standard_nonbonded_force"].addException( + atom_pair[0], + atom_pair[1], + chargeProd * 0.0, + sigma, + epsilon * 0.0, ) else: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon + self._hybrid_system_forces["standard_nonbonded_force"].addException( + atom_pair[0], + atom_pair[1], + chargeProd, + sigma, + epsilon, ) # Add exclusion to ensure exceptions are consistent - self._hybrid_system_forces['core_sterics_force'].addExclusion( - atom_pair[0], atom_pair[1] - ) + self._hybrid_system_forces["core_sterics_force"].addExclusion(atom_pair[0], atom_pair[1]) # Check if the pair is in the reverse order and use that if so elif old_index_atom_pair[::-1] in self._old_system_exceptions: [chargeProd, sigma, epsilon] = self._old_system_exceptions[old_index_atom_pair[::-1]] # If we are interpolating 1,4 exceptions then we have to if self._interpolate_14s: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd*0.0, - sigma, epsilon*0.0 + self._hybrid_system_forces["standard_nonbonded_force"].addException( + atom_pair[0], + atom_pair[1], + chargeProd * 0.0, + sigma, + epsilon * 0.0, ) else: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon) + self._hybrid_system_forces["standard_nonbonded_force"].addException( + atom_pair[0], + atom_pair[1], + chargeProd, + sigma, + epsilon, + ) # Add exclusion to ensure exceptions are consistent - self._hybrid_system_forces['core_sterics_force'].addExclusion( - atom_pair[0], atom_pair[1]) + self._hybrid_system_forces["core_sterics_force"].addExclusion(atom_pair[0], atom_pair[1]) # TODO: work out why there's a bunch of commented out code here # Exerpt: @@ -1796,43 +1838,51 @@ def _handle_hybrid_exceptions(self): for atom_pair in unique_new_pairs: # Since the pairs are indexed in the dictionary by the new system # indices, we need to convert - new_index_atom_pair = (self._hybrid_to_new_map[atom_pair[0]], - self._hybrid_to_new_map[atom_pair[1]]) + new_index_atom_pair = (self._hybrid_to_new_map[atom_pair[0]], self._hybrid_to_new_map[atom_pair[1]]) # Now we check if the pair is in the exception dictionary if new_index_atom_pair in self._new_system_exceptions: [chargeProd, sigma, epsilon] = self._new_system_exceptions[new_index_atom_pair] if self._interpolate_14s: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd*0.0, - sigma, epsilon*0.0 + self._hybrid_system_forces["standard_nonbonded_force"].addException( + atom_pair[0], + atom_pair[1], + chargeProd * 0.0, + sigma, + epsilon * 0.0, ) else: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon + self._hybrid_system_forces["standard_nonbonded_force"].addException( + atom_pair[0], + atom_pair[1], + chargeProd, + sigma, + epsilon, ) - self._hybrid_system_forces['core_sterics_force'].addExclusion( - atom_pair[0], atom_pair[1] - ) + self._hybrid_system_forces["core_sterics_force"].addExclusion(atom_pair[0], atom_pair[1]) # Check if the pair is present in the reverse order and use that if so elif new_index_atom_pair[::-1] in self._new_system_exceptions: [chargeProd, sigma, epsilon] = self._new_system_exceptions[new_index_atom_pair[::-1]] if self._interpolate_14s: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd*0.0, - sigma, epsilon*0.0 + self._hybrid_system_forces["standard_nonbonded_force"].addException( + atom_pair[0], + atom_pair[1], + chargeProd * 0.0, + sigma, + epsilon * 0.0, ) else: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon + self._hybrid_system_forces["standard_nonbonded_force"].addException( + atom_pair[0], + atom_pair[1], + chargeProd, + sigma, + epsilon, ) - self._hybrid_system_forces['core_sterics_force'].addExclusion( - atom_pair[0], atom_pair[1] - ) - + self._hybrid_system_forces["core_sterics_force"].addExclusion(atom_pair[0], atom_pair[1]) # TODO: work out why there's a bunch of commented out code here # If it's not handled by an exception in the original system, we @@ -1863,7 +1913,7 @@ def _find_exception(force, index1, index2): # Loop through the exceptions and try to find one matching the criteria for exception_idx in range(force.getNumExceptions()): exception_parameters = force.getExceptionParameters(exception_idx) - if index_set==set(exception_parameters[:2]): + if index_set == set(exception_parameters[:2]): return exception_parameters return [] @@ -1873,8 +1923,8 @@ def _handle_original_exceptions(self): present in the hybrid appropriately. """ # Get what we need to find the exceptions from the new and old systems: - old_system_nonbonded_force = self._old_system_forces['NonbondedForce'] - new_system_nonbonded_force = self._new_system_forces['NonbondedForce'] + old_system_nonbonded_force = self._old_system_forces["NonbondedForce"] + new_system_nonbonded_force = self._new_system_forces["NonbondedForce"] hybrid_to_old_map = self._hybrid_to_old_map hybrid_to_new_map = self._hybrid_to_new_map @@ -1890,41 +1940,45 @@ def _handle_original_exceptions(self): index2_hybrid = self._old_to_hybrid_map[index2_old] index_set = {index1_hybrid, index2_hybrid} - # In this case, the interaction is only covered by the regular # nonbonded force, and as such will be copied to that force # In the unique-old case, it is handled elsewhere due to internal # peculiarities regarding exceptions - if index_set.issubset(self._atom_classes['environment_atoms']): - self._hybrid_system_forces['standard_nonbonded_force'].addException( - index1_hybrid, index2_hybrid, chargeProd_old, - sigma_old, epsilon_old - ) - self._hybrid_system_forces['core_sterics_force'].addExclusion( - index1_hybrid, index2_hybrid + if index_set.issubset(self._atom_classes["environment_atoms"]): + self._hybrid_system_forces["standard_nonbonded_force"].addException( + index1_hybrid, + index2_hybrid, + chargeProd_old, + sigma_old, + epsilon_old, ) + self._hybrid_system_forces["core_sterics_force"].addExclusion(index1_hybrid, index2_hybrid) # We have already handled unique old - unique old exceptions - elif len(index_set.intersection(self._atom_classes['unique_old_atoms'])) == 2: + elif len(index_set.intersection(self._atom_classes["unique_old_atoms"])) == 2: continue # Otherwise, check if one of the atoms in the set is in the # unique_old_group and the other is not: - elif len(index_set.intersection(self._atom_classes['unique_old_atoms'])) == 1: + elif len(index_set.intersection(self._atom_classes["unique_old_atoms"])) == 1: if self._interpolate_14s: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - index1_hybrid, index2_hybrid, chargeProd_old*0.0, - sigma_old, epsilon_old*0.0 + self._hybrid_system_forces["standard_nonbonded_force"].addException( + index1_hybrid, + index2_hybrid, + chargeProd_old * 0.0, + sigma_old, + epsilon_old * 0.0, ) else: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - index1_hybrid, index2_hybrid, chargeProd_old, - sigma_old, epsilon_old + self._hybrid_system_forces["standard_nonbonded_force"].addException( + index1_hybrid, + index2_hybrid, + chargeProd_old, + sigma_old, + epsilon_old, ) - self._hybrid_system_forces['core_sterics_force'].addExclusion( - index1_hybrid, index2_hybrid - ) + self._hybrid_system_forces["core_sterics_force"].addExclusion(index1_hybrid, index2_hybrid) # If the exception particles are neither solely old unique, solely # environment, nor contain any unique old atoms, they are either @@ -1936,38 +1990,47 @@ def _handle_original_exceptions(self): index1_new = hybrid_to_new_map[index1_hybrid] index2_new = hybrid_to_new_map[index2_hybrid] # Get the exception parameters: - new_exception_parms= self._find_exception( - new_system_nonbonded_force, - index1_new, index2_new) + new_exception_parms = self._find_exception(new_system_nonbonded_force, index1_new, index2_new) # If there's no new exception, then we should just set the # exception parameters to be the nonbonded parameters if not new_exception_parms: - [charge1_new, sigma1_new, epsilon1_new] = new_system_nonbonded_force.getParticleParameters(index1_new) - [charge2_new, sigma2_new, epsilon2_new] = new_system_nonbonded_force.getParticleParameters(index2_new) + [charge1_new, sigma1_new, epsilon1_new] = new_system_nonbonded_force.getParticleParameters( + index1_new, + ) + [charge2_new, sigma2_new, epsilon2_new] = new_system_nonbonded_force.getParticleParameters( + index2_new, + ) chargeProd_new = charge1_new * charge2_new sigma_new = 0.5 * (sigma1_new + sigma2_new) - epsilon_new = unit.sqrt(epsilon1_new*epsilon2_new) + epsilon_new = unit.sqrt(epsilon1_new * epsilon2_new) else: [index1_new, index2_new, chargeProd_new, sigma_new, epsilon_new] = new_exception_parms # Interpolate between old and new - exception_index = self._hybrid_system_forces['standard_nonbonded_force'].addException( - index1_hybrid, index2_hybrid, chargeProd_old, - sigma_old, epsilon_old - ) - self._hybrid_system_forces['standard_nonbonded_force'].addExceptionParameterOffset( - 'lambda_electrostatics_core', exception_index, - (chargeProd_new - chargeProd_old), 0, 0 + exception_index = self._hybrid_system_forces["standard_nonbonded_force"].addException( + index1_hybrid, + index2_hybrid, + chargeProd_old, + sigma_old, + epsilon_old, ) - self._hybrid_system_forces['standard_nonbonded_force'].addExceptionParameterOffset( - 'lambda_sterics_core', exception_index, 0, - (sigma_new - sigma_old), (epsilon_new - epsilon_old) + self._hybrid_system_forces["standard_nonbonded_force"].addExceptionParameterOffset( + "lambda_electrostatics_core", + exception_index, + (chargeProd_new - chargeProd_old), + 0, + 0, ) - self._hybrid_system_forces['core_sterics_force'].addExclusion( - index1_hybrid, index2_hybrid + self._hybrid_system_forces["standard_nonbonded_force"].addExceptionParameterOffset( + "lambda_sterics_core", + exception_index, + 0, + (sigma_new - sigma_old), + (epsilon_new - epsilon_old), ) + self._hybrid_system_forces["core_sterics_force"].addExclusion(index1_hybrid, index2_hybrid) # Now, loop through the new system to collect remaining interactions. # The only that remain here are uniquenew-uniquenew, uniquenew-core, @@ -1988,31 +2051,35 @@ def _handle_original_exceptions(self): # specified in the regular nonbonded force. However, this is # handled elsewhere as above due to pecularities with exception # handling - if index_set.issubset(self._atom_classes['unique_new_atoms']): + if index_set.issubset(self._atom_classes["unique_new_atoms"]): continue # Look for the final class- interactions between uniquenew-core and # uniquenew-environment. They are treated similarly: they are # simply on and constant the entire time (as a valence term) - elif len(index_set.intersection(self._atom_classes['unique_new_atoms'])) > 0: + elif len(index_set.intersection(self._atom_classes["unique_new_atoms"])) > 0: if self._interpolate_14s: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - index1_hybrid, index2_hybrid, chargeProd_new*0.0, - sigma_new, epsilon_new*0.0 + self._hybrid_system_forces["standard_nonbonded_force"].addException( + index1_hybrid, + index2_hybrid, + chargeProd_new * 0.0, + sigma_new, + epsilon_new * 0.0, ) else: - self._hybrid_system_forces['standard_nonbonded_force'].addException( - index1_hybrid, index2_hybrid, chargeProd_new, - sigma_new, epsilon_new + self._hybrid_system_forces["standard_nonbonded_force"].addException( + index1_hybrid, + index2_hybrid, + chargeProd_new, + sigma_new, + epsilon_new, ) - self._hybrid_system_forces['core_sterics_force'].addExclusion( - index1_hybrid, index2_hybrid - ) + self._hybrid_system_forces["core_sterics_force"].addExclusion(index1_hybrid, index2_hybrid) # However, there may be a core exception that exists in one system # but not the other (ring closure) - elif index_set.issubset(self._atom_classes['core_atoms']): + elif index_set.issubset(self._atom_classes["core_atoms"]): # Get the old indices try: @@ -2025,32 +2092,43 @@ def _handle_original_exceptions(self): # But if it's not, we need to interpolate if not self._find_exception(old_system_nonbonded_force, index1_old, index2_old): - [charge1_old, sigma1_old, epsilon1_old] = old_system_nonbonded_force.getParticleParameters(index1_old) - [charge2_old, sigma2_old, epsilon2_old] = old_system_nonbonded_force.getParticleParameters(index2_old) + [charge1_old, sigma1_old, epsilon1_old] = old_system_nonbonded_force.getParticleParameters( + index1_old, + ) + [charge2_old, sigma2_old, epsilon2_old] = old_system_nonbonded_force.getParticleParameters( + index2_old, + ) - chargeProd_old = charge1_old*charge2_old + chargeProd_old = charge1_old * charge2_old sigma_old = 0.5 * (sigma1_old + sigma2_old) - epsilon_old = unit.sqrt(epsilon1_old*epsilon2_old) - - exception_index = self._hybrid_system_forces['standard_nonbonded_force'].addException( - index1_hybrid, index2_hybrid, - chargeProd_old, sigma_old, - epsilon_old) - - self._hybrid_system_forces['standard_nonbonded_force'].addExceptionParameterOffset( - 'lambda_electrostatics_core', exception_index, - (chargeProd_new - chargeProd_old), 0, 0 + epsilon_old = unit.sqrt(epsilon1_old * epsilon2_old) + + exception_index = self._hybrid_system_forces["standard_nonbonded_force"].addException( + index1_hybrid, + index2_hybrid, + chargeProd_old, + sigma_old, + epsilon_old, ) - self._hybrid_system_forces['standard_nonbonded_force'].addExceptionParameterOffset( - 'lambda_sterics_core', exception_index, 0, - (sigma_new - sigma_old), (epsilon_new - epsilon_old) + self._hybrid_system_forces["standard_nonbonded_force"].addExceptionParameterOffset( + "lambda_electrostatics_core", + exception_index, + (chargeProd_new - chargeProd_old), + 0, + 0, ) - self._hybrid_system_forces['core_sterics_force'].addExclusion( - index1_hybrid, index2_hybrid + self._hybrid_system_forces["standard_nonbonded_force"].addExceptionParameterOffset( + "lambda_sterics_core", + exception_index, + 0, + (sigma_new - sigma_old), + (epsilon_new - epsilon_old), ) + self._hybrid_system_forces["core_sterics_force"].addExclusion(index1_hybrid, index2_hybrid) + def _handle_old_new_exceptions(self): """ Find the exceptions associated with old-old and old-core interactions, @@ -2076,62 +2154,53 @@ def _handle_old_new_exceptions(self): else: old_new_nonbonded_exceptions += "U_sterics = 4*epsilon*x*(x-1.0); x = (sigma/reff_sterics)^6;" old_new_nonbonded_exceptions += "reff_sterics = sigma*((softcore_alpha*lambda_alpha + (r/sigma)^6))^(1/6);" - old_new_nonbonded_exceptions += "reff_sterics = sigma*((softcore_alpha*lambda_alpha + (r/sigma)^6))^(1/6);" # effective softcore distance for sterics - old_new_nonbonded_exceptions += "lambda_alpha = new_interaction*(1-lambda_sterics_insert) + old_interaction*lambda_sterics_delete;" + old_new_nonbonded_exceptions += "reff_sterics = sigma*((softcore_alpha*lambda_alpha + (r/sigma)^6))^(1/6);" # effective softcore distance for sterics + old_new_nonbonded_exceptions += ( + "lambda_alpha = new_interaction*(1-lambda_sterics_insert) + old_interaction*lambda_sterics_delete;" + ) old_new_nonbonded_exceptions += "U_electrostatics = (lambda_electrostatics_insert * unique_new + unique_old * (1 - lambda_electrostatics_delete)) * ONE_4PI_EPS0*chargeProd/r;" old_new_nonbonded_exceptions += "ONE_4PI_EPS0 = %f;" % ONE_4PI_EPS0 - old_new_nonbonded_exceptions += "epsilon = (1-lambda_sterics)*epsilonA + lambda_sterics*epsilonB;" # interpolation + old_new_nonbonded_exceptions += ( + "epsilon = (1-lambda_sterics)*epsilonA + lambda_sterics*epsilonB;" # interpolation + ) old_new_nonbonded_exceptions += "sigma = (1-lambda_sterics)*sigmaA + lambda_sterics*sigmaB;" - old_new_nonbonded_exceptions += "lambda_sterics = new_interaction*lambda_sterics_insert + old_interaction*lambda_sterics_delete;" + old_new_nonbonded_exceptions += ( + "lambda_sterics = new_interaction*lambda_sterics_insert + old_interaction*lambda_sterics_delete;" + ) old_new_nonbonded_exceptions += "new_interaction = delta(1-unique_new); old_interaction = delta(1-unique_old);" - - nonbonded_exceptions_force = openmm.CustomBondForce( - old_new_nonbonded_exceptions) + nonbonded_exceptions_force = openmm.CustomBondForce(old_new_nonbonded_exceptions) name = f"{nonbonded_exceptions_force.__class__.__name__}_exceptions" nonbonded_exceptions_force.setName(name) self._hybrid_system.addForce(nonbonded_exceptions_force) # For reference, set name in force dict - self._hybrid_system_forces['old_new_exceptions_force'] = nonbonded_exceptions_force + self._hybrid_system_forces["old_new_exceptions_force"] = nonbonded_exceptions_force if self._softcore_LJ_v2: - nonbonded_exceptions_force.addGlobalParameter( - "softcore_alpha", self._softcore_LJ_v2_alpha - ) + nonbonded_exceptions_force.addGlobalParameter("softcore_alpha", self._softcore_LJ_v2_alpha) else: - nonbonded_exceptions_force.addGlobalParameter( - "softcore_alpha", self._softcore_alpha - ) + nonbonded_exceptions_force.addGlobalParameter("softcore_alpha", self._softcore_alpha) # electrostatics insert - nonbonded_exceptions_force.addGlobalParameter( - "lambda_electrostatics_insert", 0.0 - ) + nonbonded_exceptions_force.addGlobalParameter("lambda_electrostatics_insert", 0.0) # electrostatics delete - nonbonded_exceptions_force.addGlobalParameter( - "lambda_electrostatics_delete", 0.0 - ) + nonbonded_exceptions_force.addGlobalParameter("lambda_electrostatics_delete", 0.0) # sterics insert - nonbonded_exceptions_force.addGlobalParameter( - "lambda_sterics_insert", 0.0 - ) + nonbonded_exceptions_force.addGlobalParameter("lambda_sterics_insert", 0.0) # steric delete - nonbonded_exceptions_force.addGlobalParameter( - "lambda_sterics_delete", 0.0 - ) + nonbonded_exceptions_force.addGlobalParameter("lambda_sterics_delete", 0.0) - for parameter in ['chargeProd','sigmaA', 'epsilonA', 'sigmaB', - 'epsilonB', 'unique_old', 'unique_new']: + for parameter in ["chargeProd", "sigmaA", "epsilonA", "sigmaB", "epsilonB", "unique_old", "unique_new"]: nonbonded_exceptions_force.addPerBondParameter(parameter) # Prepare for exceptions loop by grabbing nonbonded forces, # hybrid_to_old/new maps - old_system_nonbonded_force = self._old_system_forces['NonbondedForce'] - new_system_nonbonded_force = self._new_system_forces['NonbondedForce'] + old_system_nonbonded_force = self._old_system_forces["NonbondedForce"] + new_system_nonbonded_force = self._new_system_forces["NonbondedForce"] hybrid_to_old_map = self._hybrid_to_old_map hybrid_to_new_map = self._hybrid_to_new_map @@ -2149,21 +2218,20 @@ def _handle_old_new_exceptions(self): # Otherwise, check if one of the atoms in the set is in the # unique_old_group and the other is not: - if (len(index_set.intersection(self._atom_classes['unique_old_atoms'])) > 0 and - (chargeProd_old.value_in_unit_system(unit.md_unit_system) != 0.0 or - epsilon_old.value_in_unit_system(unit.md_unit_system) != 0.0)): + if len(index_set.intersection(self._atom_classes["unique_old_atoms"])) > 0 and ( + chargeProd_old.value_in_unit_system(unit.md_unit_system) != 0.0 + or epsilon_old.value_in_unit_system(unit.md_unit_system) != 0.0 + ): if self._interpolate_14s: # If we are interpolating 1,4s, then we anneal this term # off; otherwise, the exception force is constant and # already handled in the standard nonbonded force nonbonded_exceptions_force.addBond( - index1_hybrid, index2_hybrid, - [chargeProd_old, sigma_old, epsilon_old, sigma_old, - epsilon_old*0.0, 1, 0] + index1_hybrid, + index2_hybrid, + [chargeProd_old, sigma_old, epsilon_old, sigma_old, epsilon_old * 0.0, 1, 0], ) - - # Next, loop through the new system's exceptions and add them to the # hybrid appropriately for exception_pair, exception_parameters in self._new_system_exceptions.items(): @@ -2180,17 +2248,18 @@ def _handle_old_new_exceptions(self): # uniquenew-environment. They are treated # similarly: they are simply on and constant the entire time # (as a valence term) - if (len(index_set.intersection(self._atom_classes['unique_new_atoms'])) > 0 and - (chargeProd_new.value_in_unit_system(unit.md_unit_system) != 0.0 or - epsilon_new.value_in_unit_system(unit.md_unit_system) != 0.0)): + if len(index_set.intersection(self._atom_classes["unique_new_atoms"])) > 0 and ( + chargeProd_new.value_in_unit_system(unit.md_unit_system) != 0.0 + or epsilon_new.value_in_unit_system(unit.md_unit_system) != 0.0 + ): if self._interpolate_14s: # If we are interpolating 1,4s, then we anneal this term # on; otherwise, the exception force is constant and # already handled in the standard nonbonded force nonbonded_exceptions_force.addBond( - index1_hybrid, index2_hybrid, - [chargeProd_new, sigma_new, epsilon_new*0.0, - sigma_new, epsilon_new, 0, 1] + index1_hybrid, + index2_hybrid, + [chargeProd_new, sigma_new, epsilon_new * 0.0, sigma_new, epsilon_new, 0, 1], ) def _compute_hybrid_positions(self): @@ -2209,10 +2278,8 @@ def _compute_hybrid_positions(self): Positions of the hybrid system, in nm """ # Get unitless positions - old_pos_without_units = np.array( - self._old_positions.value_in_unit(unit.nanometer)) - new_pos_without_units = np.array( - self._new_positions.value_in_unit(unit.nanometer)) + old_pos_without_units = np.array(self._old_positions.value_in_unit(unit.nanometer)) + new_pos_without_units = np.array(self._new_positions.value_in_unit(unit.nanometer)) # Determine the number of particles in the system n_atoms_hybrid = self._hybrid_system.getNumParticles() @@ -2270,45 +2337,43 @@ def _create_mdtraj_topology(self): # find mapped atoms new_system_atom_set = {atom.index for atom in new_system_res.atoms} - # Now, we find the subset of atoms that are mapped. These must be + # Now, we find the subset of atoms that are mapped. These must be # in the "core" category, since they are mapped and part of a # changing residue - mapped_new_atom_indices = core_atoms_new_indices.intersection( - new_system_atom_set) + mapped_new_atom_indices = core_atoms_new_indices.intersection(new_system_atom_set) # Now get the old indices of the above atoms so that we can find # the appropriate residue in the old system for this we can use the # new to old atom map - mapped_old_atom_indices = [self._new_to_old_map[atom_idx] for - atom_idx in mapped_new_atom_indices] + mapped_old_atom_indices = [self._new_to_old_map[atom_idx] for atom_idx in mapped_new_atom_indices] # We can just take the first one--they all have the same residue first_mapped_old_atom_index = mapped_old_atom_indices[0] # Get the atom object corresponding to this index from the hybrid # (which is a deepcopy of the old) - mapped_hybrid_system_atom = hybrid_topology.atom( - first_mapped_old_atom_index) + mapped_hybrid_system_atom = hybrid_topology.atom(first_mapped_old_atom_index) # Get the residue that is relevant to this atom mapped_residue = mapped_hybrid_system_atom.residue # Add the atom using the mapped residue added_atoms[new_particle_hybrid_idx] = hybrid_topology.add_atom( - new_system_atom.name, - new_system_atom.element, - mapped_residue) + new_system_atom.name, + new_system_atom.element, + mapped_residue, + ) # Now loop through the bonds in the new system, and if the bond # contains a unique new atom, then add it to the hybrid topology - for (atom1, atom2) in new_top.bonds: + for atom1, atom2 in new_top.bonds: at1_hybrid_idx = self._new_to_hybrid_map[atom1.index] at2_hybrid_idx = self._new_to_hybrid_map[atom2.index] # If at least one atom is in the unique new class, we need to add # it to the hybrid system - at1_uniq = at1_hybrid_idx in self._atom_classes['unique_new_atoms'] - at2_uniq = at2_hybrid_idx in self._atom_classes['unique_new_atoms'] + at1_uniq = at1_hybrid_idx in self._atom_classes["unique_new_atoms"] + at2_uniq = at2_hybrid_idx in self._atom_classes["unique_new_atoms"] if at1_uniq or at2_uniq: if at1_uniq: atom1_to_bond = added_atoms[at1_hybrid_idx] @@ -2326,7 +2391,6 @@ def _create_mdtraj_topology(self): return hybrid_topology - def _create_hybrid_topology(self): """ Create a hybrid openmm.app.Topology from the input old and new @@ -2355,7 +2419,7 @@ def _create_hybrid_topology(self): atom_list.append(list(self._new_topology.atoms())[idx]) # Now we loop over the atoms and add them in alongside chains & resids - + # Non ideal variables to track the previous set of residues & chains # without having to constantly search backwards prev_res = None @@ -2367,14 +2431,10 @@ def _create_hybrid_topology(self): prev_chain = at.residue.chain if at.residue != prev_res: - hybrid_residue = hybrid_top.addResidue( - at.residue.name, hybrid_chain, at.residue.id - ) + hybrid_residue = hybrid_top.addResidue(at.residue.name, hybrid_chain, at.residue.id) prev_res = at.residue - hybrid_atom = hybrid_top.addAtom( - at.name, at.element, hybrid_residue, at.id - ) + hybrid_atom = hybrid_top.addAtom(at.name, at.element, hybrid_residue, at.id) # Next we deal with bonds # First we add in all the old topology bonds @@ -2385,7 +2445,8 @@ def _create_hybrid_topology(self): hybrid_top.addBond( list(hybrid_top.atoms())[at1], list(hybrid_top.atoms())[at2], - bond.type, bond.order, + bond.type, + bond.order, ) # Finally we add in all the bonds from the unique atoms in the @@ -2393,12 +2454,12 @@ def _create_hybrid_topology(self): for bond in self._new_topology.bonds(): at1 = self.new_to_hybrid_atom_map[bond.atom1.index] at2 = self.new_to_hybrid_atom_map[bond.atom2.index] - if ((at1 in self._atom_classes['unique_new_atoms']) or - (at2 in self._atom_classes['unique_new_atoms'])): + if (at1 in self._atom_classes["unique_new_atoms"]) or (at2 in self._atom_classes["unique_new_atoms"]): hybrid_top.addBond( list(hybrid_top.atoms())[at1], list(hybrid_top.atoms())[at2], - bond.type, bond.order, + bond.type, + bond.order, ) return hybrid_top @@ -2421,10 +2482,8 @@ def old_positions(self, hybrid_positions): n_atoms_old = self._old_system.getNumParticles() # making sure hybrid positions are simtk.unit.Quantity objects if not isinstance(hybrid_positions, unit.Quantity): - hybrid_positions = unit.Quantity(hybrid_positions, - unit=unit.nanometer) - old_positions = unit.Quantity(np.zeros([n_atoms_old, 3]), - unit=unit.nanometer) + hybrid_positions = unit.Quantity(hybrid_positions, unit=unit.nanometer) + old_positions = unit.Quantity(np.zeros([n_atoms_old, 3]), unit=unit.nanometer) for idx in range(n_atoms_old): hyb_idx = self._new_to_hybrid_map[idx] old_positions[idx, :] = hybrid_positions[hyb_idx, :] @@ -2448,10 +2507,8 @@ def new_positions(self, hybrid_positions): n_atoms_new = self._new_system.getNumParticles # making sure hybrid positions are simtk.unit.Quantity objects if not isinstance(hybrid_positions, unit.Quantity): - hybrid_positions = unit.Quantity(hybrid_positions, - unit=unit.nanometer) - new_positions = unit.Quantity(np.zeros([n_atoms_new, 3]), - unit=unit.nanometer) + hybrid_positions = unit.Quantity(hybrid_positions, unit=unit.nanometer) + new_positions = unit.Quantity(np.zeros([n_atoms_new, 3]), unit=unit.nanometer) for idx in range(n_atoms_new): hyb_idx = self._new_to_hybrid_map[idx] new_positions[idx, :] = hybrid_positions[hyb_idx, :] @@ -2512,7 +2569,7 @@ def hybrid_positions(self): def hybrid_topology(self): """ An MDTraj hybrid topology for the purpose of writing out trajectories. - + Note that we do not expect this to be able to be parameterized by the openmm forcefield class. diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py b/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py index 41e83238c..181e7345b 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py @@ -4,21 +4,21 @@ # building toolsets. # LICENSE: MIT -from copy import deepcopy import itertools import logging -from typing import Union, Optional import warnings +from copy import deepcopy +from typing import Optional, Union import mdtraj as mdt -from mdtraj.core.residue_names import _SOLVENT_TYPES import numpy as np import numpy.typing as npt -from openmm import app, System, NonbondedForce -from openmm import unit as omm_unit +from mdtraj.core.residue_names import _SOLVENT_TYPES from openff.units import unit -from openfe import SolventComponent +from openmm import NonbondedForce, System, app +from openmm import unit as omm_unit +from openfe import SolventComponent logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ def _get_ion_and_water_parameters( topology: app.Topology, system: System, ion_resname: str, - water_resname: str = 'HOH', + water_resname: str = "HOH", ): """ Get ion, and water (oxygen and hydrogen) atoms parameters. @@ -66,22 +66,24 @@ def _get_ion_and_water_parameters( ----------- Based on `perses.utils.charge_changing.get_ion_and_water_parameters`. """ + def _find_atom(topology, resname, elementname): for atom in topology.atoms(): if atom.residue.name == resname: - if (elementname is None or atom.element.symbol == elementname): + if elementname is None or atom.element.symbol == elementname: return atom.index - errmsg = ("Error encountered when attempting to explicitly handle " - "charge changes using an alchemical water. No residue " - f"named: {resname} found, with element {elementname}") + errmsg = ( + "Error encountered when attempting to explicitly handle " + "charge changes using an alchemical water. No residue " + f"named: {resname} found, with element {elementname}" + ) raise ValueError(errmsg) ion_index = _find_atom(topology, ion_resname, None) - oxygen_index = _find_atom(topology, water_resname, 'O') - hydrogen_index = _find_atom(topology, water_resname, 'H') + oxygen_index = _find_atom(topology, water_resname, "O") + hydrogen_index = _find_atom(topology, water_resname, "H") - nbf = [i for i in system.getForces() - if isinstance(i, NonbondedForce)][0] + nbf = [i for i in system.getForces() if isinstance(i, NonbondedForce)][0] ion_charge, ion_sigma, ion_epsilon = nbf.getParticleParameters(ion_index) o_charge, _, _ = nbf.getParticleParameters(oxygen_index) @@ -104,24 +106,26 @@ def _fix_alchemical_water_atom_mapping( b_idx : int The index of the state B particle. """ - a_idx = system_mapping['new_to_old_atom_map'][b_idx] + a_idx = system_mapping["new_to_old_atom_map"][b_idx] # Note, because these are already shared positions, we don't # append alchemical molecule indices in the new & old molecule # i.e. the `old_mol_indices` and `new_mol_indices` lists # remove atom from the environment atom map - system_mapping['old_to_new_env_atom_map'].pop(a_idx) - system_mapping['new_to_old_env_atom_map'].pop(b_idx) + system_mapping["old_to_new_env_atom_map"].pop(a_idx) + system_mapping["new_to_old_env_atom_map"].pop(b_idx) # add atom to the new_to_old_core atom maps - system_mapping['old_to_new_core_atom_map'][a_idx] = b_idx - system_mapping['new_to_old_core_atom_map'][b_idx] = a_idx + system_mapping["old_to_new_core_atom_map"][a_idx] = b_idx + system_mapping["new_to_old_core_atom_map"][b_idx] = a_idx def handle_alchemical_waters( - water_resids: list[int], topology: app.Topology, - system: System, system_mapping: dict, + water_resids: list[int], + topology: app.Topology, + system: System, + system_mapping: dict, charge_difference: int, solvent_component: SolventComponent, ): @@ -163,27 +167,30 @@ def handle_alchemical_waters( """ if abs(charge_difference) != len(water_resids): - errmsg = ("There should be as many alchemical water residues: " - f"{len(water_resids)} as the absolute charge " - f"difference: {abs(charge_difference)}") + errmsg = ( + "There should be as many alchemical water residues: " + f"{len(water_resids)} as the absolute charge " + f"difference: {abs(charge_difference)}" + ) raise ValueError(errmsg) if charge_difference > 0: - ion_resname = solvent_component.positive_ion.strip('-+').upper() + ion_resname = solvent_component.positive_ion.strip("-+").upper() elif charge_difference < 0: - ion_resname = solvent_component.negative_ion.strip('-+').upper() + ion_resname = solvent_component.negative_ion.strip("-+").upper() # if there's no charge difference then just skip altogether else: return None ion_charge, ion_sigma, ion_epsilon, o_charge, h_charge = _get_ion_and_water_parameters( - topology, system, ion_resname, - 'HOH', # Modeller always adds HOH waters + topology, + system, + ion_resname, + "HOH", # Modeller always adds HOH waters ) # get the nonbonded forces - nbfrcs = [i for i in system.getForces() - if isinstance(i, NonbondedForce)] + nbfrcs = [i for i in system.getForces() if isinstance(i, NonbondedForce)] if len(nbfrcs) > 1: raise ValueError("Too many NonbondedForce forces found") @@ -197,8 +204,10 @@ def handle_alchemical_waters( # if the number of atoms > 3, then we have virtual sites which are # not supported currently if len([at for at in res.atoms()]) > 3: - errmsg = ("Non 3-site waters (i.e. waters with virtual sites) " - "are not currently supported as alchemical waters") + errmsg = ( + "Non 3-site waters (i.e. waters with virtual sites) " + "are not currently supported as alchemical waters" + ) raise ValueError(errmsg) for at in res.atoms(): @@ -207,13 +216,10 @@ def handle_alchemical_waters( _fix_alchemical_water_atom_mapping(system_mapping, idx) if charge == o_charge: - nbf.setParticleParameters( - idx, ion_charge, ion_sigma, ion_epsilon - ) + nbf.setParticleParameters(idx, ion_charge, ion_sigma, ion_epsilon) else: if charge != h_charge: - errmsg = ("modifying an atom that doesn't match known " - "water parameters") + errmsg = "modifying an atom that doesn't match known " "water parameters" raise ValueError(errmsg) nbf.setParticleParameters(idx, 0.0, sigma, epsilon) @@ -257,43 +263,46 @@ def get_alchemical_waters( return [] # construct a new mdt trajectory - traj = mdt.Trajectory( - positions[np.newaxis, ...], - mdt.Topology.from_openmm(topology) - ) + traj = mdt.Trajectory(positions[np.newaxis, ...], mdt.Topology.from_openmm(topology)) water_atoms = traj.topology.select("water") solvent_residue_names = list(_SOLVENT_TYPES) - solute_atoms = [atom.index for atom in traj.topology.atoms - if atom.residue.name not in solvent_residue_names] + solute_atoms = [atom.index for atom in traj.topology.atoms if atom.residue.name not in solvent_residue_names] excluded_waters = mdt.compute_neighbors( - traj, distance_cutoff.to(unit.nanometer).m, - solute_atoms, haystack_indices=water_atoms, + traj, + distance_cutoff.to(unit.nanometer).m, + solute_atoms, + haystack_indices=water_atoms, periodic=True, )[0] - solvent_indices = set([ - atom.residue.index for atom in traj.topology.atoms + solvent_indices = { + atom.residue.index + for atom in traj.topology.atoms if (atom.index in water_atoms) and (atom.index not in excluded_waters) - ]) + } if len(solvent_indices) < 1: - errmsg = ("There are no waters outside of a " - f"{distance_cutoff.to(unit.nanometer)} nanometer distance " - "of the system solutes to be used as alchemical waters") + errmsg = ( + "There are no waters outside of a " + f"{distance_cutoff.to(unit.nanometer)} nanometer distance " + "of the system solutes to be used as alchemical waters" + ) raise ValueError(errmsg) # unlike the original perses approach, we stick to the first water index # in order to make sure we somewhat reproducibily pick the same water - chosen_residues = list(solvent_indices)[:abs(charge_difference)] + chosen_residues = list(solvent_indices)[: abs(charge_difference)] return chosen_residues -def combined_topology(topology1: app.Topology, - topology2: app.Topology, - exclude_resids: Optional[npt.NDArray] = None,): +def combined_topology( + topology1: app.Topology, + topology2: app.Topology, + exclude_resids: Optional[npt.NDArray] = None, +): """ Create a new topology combining these two topologies. @@ -321,21 +330,16 @@ def combined_topology(topology1: app.Topology, top = app.Topology() # create list of excluded residues from topology - excluded_res = [ - r for r in topology1.residues() if r.index in exclude_resids - ] + excluded_res = [r for r in topology1.residues() if r.index in exclude_resids] # get a list of all excluded atoms - excluded_atoms = set(itertools.chain.from_iterable( - r.atoms() for r in excluded_res) - ) + excluded_atoms = set(itertools.chain.from_iterable(r.atoms() for r in excluded_res)) # add new copies of selected chains, residues, and atoms; keep mapping # of old atoms to new for adding bonds later old_to_new_atom_map = {} appended_resids = [] - for chain_id, chain in enumerate( - itertools.chain(topology1.chains(), topology2.chains())): + for chain_id, chain in enumerate(itertools.chain(topology1.chains(), topology2.chains())): # TODO: is chain ID int or str? I recall it being int in MDTraj.... # are there any issues if we just add a blank chain? new_chain = top.addChain(chain_id) @@ -343,35 +347,29 @@ def combined_topology(topology1: app.Topology, if residue in excluded_res: continue - new_res = top.addResidue(residue.name, - new_chain, - residue.id) + new_res = top.addResidue(residue.name, new_chain, residue.id) # append the new resindex if it's part of topology2 if residue in list(topology2.residues()): appended_resids.append(new_res.index) for atom in residue.atoms(): - new_atom = top.addAtom(atom.name, - atom.element, - new_res, - atom.id) + new_atom = top.addAtom(atom.name, atom.element, new_res, atom.id) old_to_new_atom_map[atom] = new_atom # figure out which bonds to keep: drop any that involve removed atoms def atoms_for_bond(bond): return {bond.atom1, bond.atom2} - keep_bonds = (bond for bond in itertools.chain(topology1.bonds(), - topology2.bonds()) - if not (atoms_for_bond(bond) & excluded_atoms)) + keep_bonds = ( + bond + for bond in itertools.chain(topology1.bonds(), topology2.bonds()) + if not (atoms_for_bond(bond) & excluded_atoms) + ) # add bonds to topology for bond in keep_bonds: - top.addBond(old_to_new_atom_map[bond.atom1], - old_to_new_atom_map[bond.atom2], - bond.type, - bond.order) + top.addBond(old_to_new_atom_map[bond.atom1], old_to_new_atom_map[bond.atom2], bond.type, bond.order) # Copy over the box vectors top.setPeriodicBoxVectors(topology1.getPeriodicBoxVectors()) @@ -403,8 +401,7 @@ def _get_indices(topology, resids): return [at.index for at in top_atoms] -def _remove_constraints(old_to_new_atom_map, old_system, old_topology, - new_system, new_topology): +def _remove_constraints(old_to_new_atom_map, old_system, old_topology, new_system, new_topology): """ Adapted from Perses' Topology Proposal. Adjusts atom mapping to account for any bonds that are constrained but change in length. @@ -436,10 +433,12 @@ def _remove_constraints(old_to_new_atom_map, old_system, old_topology, no_const_old_to_new_atom_map = deepcopy(old_to_new_atom_map) h_elem = app.Element.getByAtomicNumber(1) - old_H_atoms = {i for i, atom in enumerate(old_topology.atoms()) - if atom.element == h_elem and i in old_to_new_atom_map} - new_H_atoms = {i for i, atom in enumerate(new_topology.atoms()) - if atom.element == h_elem and i in old_to_new_atom_map.values()} + old_H_atoms = { + i for i, atom in enumerate(old_topology.atoms()) if atom.element == h_elem and i in old_to_new_atom_map + } + new_H_atoms = { + i for i, atom in enumerate(new_topology.atoms()) if atom.element == h_elem and i in old_to_new_atom_map.values() + } def pick_H(i, j, x, y) -> int: """Identify which atom to remove to resolve constraint violation @@ -453,8 +452,7 @@ def pick_H(i, j, x, y) -> int: elif j in old_H_atoms or y in new_H_atoms: return j else: - raise ValueError(f"Couldn't resolve constraint demapping for atoms" - f" A: {i}-{j} B: {x}-{y}") + raise ValueError(f"Couldn't resolve constraint demapping for atoms" f" A: {i}-{j} B: {x}-{y}") old_constraints: dict[[int, int], float] = dict() for idx in range(old_system.getNumConstraints()): @@ -467,8 +465,7 @@ def pick_H(i, j, x, y) -> int: for idx in range(new_system.getNumConstraints()): atom1, atom2, length = new_system.getConstraintParameters(idx) - if (atom1 in old_to_new_atom_map.values() and - atom2 in old_to_new_atom_map.values()): + if atom1 in old_to_new_atom_map.values() and atom2 in old_to_new_atom_map.values(): new_constraints[atom1, atom2] = length # there are two reasons constraints would invalidate a mapping entry @@ -508,10 +505,16 @@ def pick_H(i, j, x, y) -> int: return no_const_old_to_new_atom_map -def get_system_mappings(old_to_new_atom_map, - old_system, old_topology, old_resids, - new_system, new_topology, new_resids, - fix_constraints=True): +def get_system_mappings( + old_to_new_atom_map, + old_system, + old_topology, + old_resids, + new_system, + new_topology, + new_resids, + fix_constraints=True, +): """ From a starting alchemical map between two molecules, get the mappings between two alchemical end state systems. @@ -575,7 +578,7 @@ def get_system_mappings(old_to_new_atom_map, # We assume that the atom indices are linear in the residue so we shift # by the index of the first atom in each residue adjusted_old_to_new_map = {} - for (key, value) in old_to_new_atom_map.items(): + for key, value in old_to_new_atom_map.items(): shift_old = old_at_indices[0] + key shift_new = new_at_indices[0] + value adjusted_old_to_new_map[shift_old] = shift_new @@ -584,14 +587,20 @@ def get_system_mappings(old_to_new_atom_map, # the atoms in the two systems. For now we are only doing the alchemical # residues. We might want to change this as necessary in the future. if not fix_constraints: - wmsg = ("Not attempting to fix atom mapping to account for " - "constraints. Please note that core atoms which have " - "constrained bonds and changing bond lengths are not allowed.") + wmsg = ( + "Not attempting to fix atom mapping to account for " + "constraints. Please note that core atoms which have " + "constrained bonds and changing bond lengths are not allowed." + ) warnings.warn(wmsg) else: adjusted_old_to_new_map = _remove_constraints( - adjusted_old_to_new_map, old_system, old_topology, - new_system, new_topology) + adjusted_old_to_new_map, + old_system, + old_topology, + new_system, + new_topology, + ) # We return a dictionary with all the necessary mappings (as they are # needed downstream). These include: @@ -639,21 +648,19 @@ def get_system_mappings(old_to_new_atom_map, # Now let's create our output dictionary mappings = {} - mappings['new_to_old_atom_map'] = new_to_old_all_map - mappings['old_to_new_atom_map'] = {v: k for k, v in new_to_old_all_map.items()} - mappings['new_to_old_core_atom_map'] = {v: k for k, v in adjusted_old_to_new_map.items()} - mappings['old_to_new_core_atom_map'] = adjusted_old_to_new_map - mappings['new_to_old_env_atom_map'] = new_to_old_env_map - mappings['old_to_new_env_atom_map'] = {v: k for k, v in new_to_old_env_map.items()} - mappings['old_mol_indices'] = old_at_indices - mappings['new_mol_indices'] = new_at_indices + mappings["new_to_old_atom_map"] = new_to_old_all_map + mappings["old_to_new_atom_map"] = {v: k for k, v in new_to_old_all_map.items()} + mappings["new_to_old_core_atom_map"] = {v: k for k, v in adjusted_old_to_new_map.items()} + mappings["old_to_new_core_atom_map"] = adjusted_old_to_new_map + mappings["new_to_old_env_atom_map"] = new_to_old_env_map + mappings["old_to_new_env_atom_map"] = {v: k for k, v in new_to_old_env_map.items()} + mappings["old_mol_indices"] = old_at_indices + mappings["new_mol_indices"] = new_at_indices return mappings -def set_and_check_new_positions(mapping, old_topology, new_topology, - old_positions, insert_positions, - tolerance=1.0): +def set_and_check_new_positions(mapping, old_topology, new_topology, old_positions, insert_positions, tolerance=1.0): """ Utility to create new positions given a mapping, the old positions and the positions of the molecule being inserted, defined by `insert_positions. @@ -687,9 +694,9 @@ def set_and_check_new_positions(mapping, old_topology, new_topology, new_pos_array = np.zeros((new_topology.getNumAtoms(), 3)) # get your mappings - new_idxs = list(mapping['old_to_new_atom_map'].values()) - old_idxs = list(mapping['old_to_new_atom_map'].keys()) - new_mol_idxs = mapping['new_mol_indices'] + new_idxs = list(mapping["old_to_new_atom_map"].values()) + old_idxs = list(mapping["old_to_new_atom_map"].keys()) + new_mol_idxs = mapping["new_mol_indices"] # copy over the old positions for mapped atoms new_pos_array[new_idxs, :] = old_pos_array[old_idxs, :] @@ -698,9 +705,8 @@ def set_and_check_new_positions(mapping, old_topology, new_topology, # loop through all mapped atoms and make sure we don't deviate by more than # tolerance - not super necessary, but it's a nice sanity check - for key, val in mapping['old_to_new_atom_map'].items(): - if np.any( - np.abs(new_pos_array[val] - old_pos_array[key]) > tolerance): + for key, val in mapping["old_to_new_atom_map"].items(): + if np.any(np.abs(new_pos_array[val] - old_pos_array[key]) > tolerance): wmsg = f"mapping {key} : {val} deviates by more than {tolerance}" warnings.warn(wmsg) logging.warning(wmsg) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 649da9ffe..6f5cf4a4f 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -21,73 +21,90 @@ """ from __future__ import annotations -import os +import json import logging -from collections import defaultdict +import os +import pathlib +import subprocess import uuid import warnings -import json +from collections import defaultdict +from collections.abc import Iterable from itertools import chain +from typing import Any, Optional, Union + +import gufe import matplotlib.pyplot as plt +import mdtraj import numpy as np import numpy.typing as npt -from openff.units import unit -from openff.units.openmm import to_openmm, from_openmm, ensure_quantity +import openmmtools +from gufe import ( + ChemicalSystem, + Component, + ComponentMapping, + LigandAtomMapping, + ProteinComponent, + SmallMoleculeComponent, + SolventComponent, + settings, +) from openff.toolkit.topology import Molecule as OFFMolecule -from openmmtools import multistate -from typing import Optional +from openff.units import unit +from openff.units.openmm import ensure_quantity, from_openmm, to_openmm from openmm import unit as omm_unit from openmm.app import PDBFile -import pathlib -from typing import Any, Iterable, Union -import openmmtools -import mdtraj -import subprocess +from openmmtools import multistate from rdkit import Chem -import gufe -from gufe import ( - settings, ChemicalSystem, LigandAtomMapping, Component, ComponentMapping, - SmallMoleculeComponent, ProteinComponent, SolventComponent, -) +from openfe.due import Doi, due +from openfe.protocols.openmm_utils.omm_settings import BasePartialChargeSettings -from .equil_rfe_settings import ( - RelativeHybridTopologyProtocolSettings, - OpenMMSolvationSettings, AlchemicalSettings, LambdaSettings, - MultiStateSimulationSettings, OpenMMEngineSettings, - IntegratorSettings, OutputSettings, - OpenFFPartialChargeSettings, -) -from openfe.protocols.openmm_utils.omm_settings import ( - BasePartialChargeSettings, -) +from ...analysis import plotting +from ...utils import log_system_probe, without_oechem_backend from ..openmm_utils import ( - system_validation, settings_validation, system_creation, - multistate_analysis, charge_generation + charge_generation, + multistate_analysis, + settings_validation, + system_creation, + system_validation, ) from . import _rfe_utils -from ...utils import without_oechem_backend, log_system_probe -from ...analysis import plotting -from openfe.due import due, Doi - +from .equil_rfe_settings import ( + AlchemicalSettings, + IntegratorSettings, + LambdaSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, + OutputSettings, + RelativeHybridTopologyProtocolSettings, +) logger = logging.getLogger(__name__) -due.cite(Doi("10.5281/zenodo.1297683"), - description="Perses", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.1297683"), + description="Perses", + path="openfe.protocols.openmm_rfe.equil_rfe_methods", + cite_module=True, +) -due.cite(Doi("10.5281/zenodo.596622"), - description="OpenMMTools", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_rfe.equil_rfe_methods", + cite_module=True, +) -due.cite(Doi("10.1371/journal.pcbi.1005659"), - description="OpenMM", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True) +due.cite( + Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_rfe.equil_rfe_methods", + cite_module=True, +) def _get_resname(off_mol) -> str: @@ -103,7 +120,7 @@ def _get_alchemical_charge_difference( mapping: LigandAtomMapping, nonbonded_method: str, explicit_charge_correction: bool, - solvent_component: SolventComponent + solvent_component: SolventComponent, ) -> int: """ Checks and returns the difference in formal charge between state A and B. @@ -135,40 +152,40 @@ def _get_alchemical_charge_difference( The formal charge difference between states A and B. This is defined as sum(charge state A) - sum(charge state B) """ - chg_A = Chem.rdmolops.GetFormalCharge( - mapping.componentA.to_rdkit() - ) - chg_B = Chem.rdmolops.GetFormalCharge( - mapping.componentB.to_rdkit() - ) + chg_A = Chem.rdmolops.GetFormalCharge(mapping.componentA.to_rdkit()) + chg_B = Chem.rdmolops.GetFormalCharge(mapping.componentB.to_rdkit()) difference = chg_A - chg_B if abs(difference) > 0: if explicit_charge_correction: if nonbonded_method.lower() != "pme": - errmsg = ("Explicit charge correction when not using PME is " - "not currently supported.") + errmsg = "Explicit charge correction when not using PME is " "not currently supported." raise ValueError(errmsg) if abs(difference) > 1: - errmsg = (f"A charge difference of {difference} is observed " - "between the end states and an explicit charge " - "correction has been requested. Unfortunately " - "only absolute differences of 1 are supported.") + errmsg = ( + f"A charge difference of {difference} is observed " + "between the end states and an explicit charge " + "correction has been requested. Unfortunately " + "only absolute differences of 1 are supported." + ) raise ValueError(errmsg) - ion = {-1: solvent_component.positive_ion, - 1: solvent_component.negative_ion}[difference] - wmsg = (f"A charge difference of {difference} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion") + ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[difference] + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) logger.warning(wmsg) warnings.warn(wmsg) else: - wmsg = (f"A charge difference of {difference} is observed " - "between the end states. No charge correction has " - "been requested, please account for this in your " - "final results.") + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. No charge correction has " + "been requested, please account for this in your " + "final results." + ) logger.warning(wmsg) warnings.warn(wmsg) @@ -177,7 +194,7 @@ def _get_alchemical_charge_difference( def _validate_alchemical_components( alchemical_components: dict[str, list[Component]], - mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], + mapping: ComponentMapping | list[ComponentMapping] | None, ): """ Checks that the alchemical components are suitable for the RFE protocol. @@ -214,10 +231,9 @@ def _validate_alchemical_components( raise ValueError(errmsg) # Check that all alchemical components are mapped & small molecules - mapped = {'stateA': [m.componentA for m in mapping], - 'stateB': [m.componentB for m in mapping]} + mapped = {"stateA": [m.componentA for m in mapping], "stateB": [m.componentB for m in mapping]} - for idx in ['stateA', 'stateB']: + for idx in ["stateA", "stateB"]: if len(alchemical_components[idx]) != len(mapped[idx]): errmsg = f"missing alchemical components in {idx}" raise ValueError(errmsg) @@ -225,9 +241,11 @@ def _validate_alchemical_components( if comp not in mapped[idx]: raise ValueError(f"Unmapped alchemical component {comp}") if not isinstance(comp, SmallMoleculeComponent): # pragma: no-cover - errmsg = ("Transformations involving non " - "SmallMoleculeComponent species {comp} " - "are not currently supported") + errmsg = ( + "Transformations involving non " + "SmallMoleculeComponent species {comp} " + "are not currently supported" + ) raise ValueError(errmsg) # Validate element changes in mappings @@ -244,13 +262,15 @@ def _validate_alchemical_components( f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" "No mass scaling is attempted in the hybrid topology, " "the average mass of the two atoms will be used in the " - "simulation") + "simulation" + ) logger.warning(wmsg) warnings.warn(wmsg) # TODO: remove this once logging is fixed class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): """Dict-like container for the output of a RelativeHybridTopologyProtocol""" + def __init__(self, **data): super().__init__(**data) # data is mapping of str(repeat_id): list[protocolunitresults] @@ -268,7 +288,7 @@ def get_estimate(self) -> unit.Quantity: a Quantity defined with units. """ # TODO: Check this holds up completely for SAMS. - dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] u = dGs[0].u # convert all values to units of the first value, then take average of magnitude # this would avoid a screwy case where each value was in different units @@ -280,7 +300,7 @@ def get_uncertainty(self) -> unit.Quantity: """The uncertainty/error in the dG value: The std of the estimates of each independent repeat """ - dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] u = dGs[0].u # convert all values to units of the first value, then take average of magnitude # this would avoid a screwy case where each value was in different units @@ -299,12 +319,10 @@ def get_individual_estimates(self) -> list[tuple[unit.Quantity, unit.Quantity]]: estimates (first entry) and associated MBAR estimate errors (second entry). """ - dGs = [(pus[0].outputs['unit_estimate'], - pus[0].outputs['unit_estimate_error']) - for pus in self.data.values()] + dGs = [(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) for pus in self.data.values()] return dGs - def get_forward_and_reverse_energy_analysis(self) -> list[dict[str, Union[npt.NDArray, unit.Quantity]]]: + def get_forward_and_reverse_energy_analysis(self) -> list[dict[str, npt.NDArray | unit.Quantity]]: """ Get a list of forward and reverse analysis of the free energies for each repeat using uncorrelated production samples. @@ -321,8 +339,7 @@ def get_forward_and_reverse_energy_analysis(self) -> list[dict[str, Union[npt.ND ------- forward_reverse : dict[str, Union[npt.NDArray, unit.Quantity]] """ - forward_reverse = [pus[0].outputs['forward_and_reverse_energies'] - for pus in self.data.values()] + forward_reverse = [pus[0].outputs["forward_and_reverse_energies"] for pus in self.data.values()] return forward_reverse @@ -342,8 +359,7 @@ def get_overlap_matrices(self) -> list[dict[str, npt.NDArray]]: state i in state j """ # Loop through and get the repeats and get the matrices - overlap_stats = [pus[0].outputs['unit_mbar_overlap'] - for pus in self.data.values()] + overlap_stats = [pus[0].outputs["unit_mbar_overlap"] for pus in self.data.values()] return overlap_stats @@ -365,11 +381,9 @@ def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]: from state i to state j. """ try: - repex_stats = [pus[0].outputs['replica_exchange_statistics'] - for pus in self.data.values()] + repex_stats = [pus[0].outputs["replica_exchange_statistics"] for pus in self.data.values()] except KeyError: - errmsg = ("Replica exchange statistics were not found, " - "did you run a repex calculation?") + errmsg = "Replica exchange statistics were not found, " "did you run a repex calculation?" raise ValueError(errmsg) return repex_stats @@ -383,6 +397,7 @@ def get_replica_states(self) -> list[npt.NDArray]: replica_states : List[npt.NDArray] List of replica states for each repeat """ + def is_file(filename: str): p = pathlib.Path(filename) if not p.exists(): @@ -393,15 +408,11 @@ def is_file(filename: str): replica_states = [] for pus in self.data.values(): - nc = is_file(pus[0].outputs['nc']) + nc = is_file(pus[0].outputs["nc"]) dir_path = nc.parents[0] - chk = is_file(dir_path / pus[0].outputs['last_checkpoint']).name - reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode='r' - ) - replica_states.append( - np.asarray(reporter.read_replica_thermodynamic_states()) - ) + chk = is_file(dir_path / pus[0].outputs["last_checkpoint"]).name + reporter = multistate.MultiStateReporter(storage=nc, checkpoint_storage=chk, open_mode="r") + replica_states.append(np.asarray(reporter.read_replica_thermodynamic_states())) reporter.close() return replica_states @@ -415,8 +426,7 @@ def equilibration_iterations(self) -> list[float]: ------- equilibration_lengths : list[float] """ - equilibration_lengths = [pus[0].outputs['equilibration_iterations'] - for pus in self.data.values()] + equilibration_lengths = [pus[0].outputs["equilibration_iterations"] for pus in self.data.values()] return equilibration_lengths @@ -429,8 +439,7 @@ def production_iterations(self) -> list[float]: ------- production_lengths : list[float] """ - production_lengths = [pus[0].outputs['production_iterations'] - for pus in self.data.values()] + production_lengths = [pus[0].outputs["production_iterations"] for pus in self.data.values()] return production_lengths @@ -461,7 +470,7 @@ def _default_settings(cls): ), partial_charge_settings=OpenFFPartialChargeSettings(), solvation_settings=OpenMMSolvationSettings(), - alchemical_settings=AlchemicalSettings(softcore_LJ='gapsys'), + alchemical_settings=AlchemicalSettings(softcore_LJ="gapsys"), lambda_settings=LambdaSettings(), simulation_settings=MultiStateSimulationSettings( equilibration_length=1.0 * unit.nanosecond, @@ -476,17 +485,15 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], - extends: Optional[gufe.ProtocolDAGResult] = None, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None, + extends: gufe.ProtocolDAGResult | None = None, ) -> list[gufe.ProtocolUnit]: # TODO: Extensions? if extends: raise NotImplementedError("Can't extend simulations yet") # Get alchemical components & validate them + mapping - alchem_comps = system_validation.get_alchemical_components( - stateA, stateB - ) + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) _validate_alchemical_components(alchem_comps, mapping) ligandmapping = mapping[0] if isinstance(mapping, list) else mapping # type: ignore @@ -498,23 +505,26 @@ def _create( system_validation.validate_protein(stateA) # actually create and return Units - Anames = ','.join(c.name for c in alchem_comps['stateA']) - Bnames = ','.join(c.name for c in alchem_comps['stateB']) + Anames = ",".join(c.name for c in alchem_comps["stateA"]) + Bnames = ",".join(c.name for c in alchem_comps["stateB"]) # our DAG has no dependencies, so just list units n_repeats = self.settings.protocol_repeats - units = [RelativeHybridTopologyProtocolUnit( - protocol=self, - stateA=stateA, stateB=stateB, - ligandmapping=ligandmapping, # type: ignore - generation=0, repeat_id=int(uuid.uuid4()), - name=f'{Anames} to {Bnames} repeat {i} generation 0') - for i in range(n_repeats)] + units = [ + RelativeHybridTopologyProtocolUnit( + protocol=self, + stateA=stateA, + stateB=stateB, + ligandmapping=ligandmapping, # type: ignore + generation=0, + repeat_id=int(uuid.uuid4()), + name=f"{Anames} to {Bnames} repeat {i} generation 0", + ) + for i in range(n_repeats) + ] return units - def _gather( - self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] - ) -> dict[str, Any]: + def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]: # result units will have a repeat_id and generations within this repeat_id # first group according to repeat_id unsorted_repeats = defaultdict(list) @@ -524,12 +534,12 @@ def _gather( if not pu.ok(): continue - unsorted_repeats[pu.outputs['repeat_id']].append(pu) + unsorted_repeats[pu.outputs["repeat_id"]].append(pu) # then sort by generation within each repeat_id list repeats: dict[str, list[gufe.ProtocolUnitResult]] = {} for k, v in unsorted_repeats.items(): - repeats[str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) # returns a dict of repeat_id: sorted list of ProtocolUnitResult return repeats @@ -549,7 +559,7 @@ def __init__( ligandmapping: LigandAtomMapping, generation: int, repeat_id: int, - name: Optional[str] = None, + name: str | None = None, ): """ Parameters @@ -581,7 +591,7 @@ def __init__( stateB=stateB, ligandmapping=ligandmapping, repeat_id=repeat_id, - generation=generation + generation=generation, ) @staticmethod @@ -600,9 +610,7 @@ def _assign_partial_charges( Dictionary of dictionary of OpenFF Molecules to add, keyed by state and SmallMoleculeComponent. """ - for smc, mol in chain(off_small_mols['stateA'], - off_small_mols['stateB'], - off_small_mols['both']): + for smc, mol in chain(off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"]): charge_generation.assign_offmol_partial_charges( offmol=mol, overwrite=False, @@ -612,9 +620,7 @@ def _assign_partial_charges( nagl_model=charge_settings.nagl_model, ) - def run(self, *, dry=False, verbose=True, - scratch_basepath=None, - shared_basepath=None) -> dict[str, Any]: + def run(self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None) -> dict[str, Any]: """Run the relative free energy calculation. Parameters @@ -645,18 +651,18 @@ def run(self, *, dry=False, verbose=True, if verbose: self.logger.info("Preparing the hybrid topology simulation") if scratch_basepath is None: - scratch_basepath = pathlib.Path('.') + scratch_basepath = pathlib.Path(".") if shared_basepath is None: # use cwd - shared_basepath = pathlib.Path('.') + shared_basepath = pathlib.Path(".") # 0. General setup and settings dependency resolution step # Extract relevant settings - protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['protocol'].settings - stateA = self._inputs['stateA'] - stateB = self._inputs['stateB'] - mapping = self._inputs['ligandmapping'] + protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs["protocol"].settings + stateA = self._inputs["stateA"] + stateB = self._inputs["stateB"] + mapping = self._inputs["ligandmapping"] forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings @@ -669,10 +675,7 @@ def run(self, *, dry=False, verbose=True, integrator_settings: IntegratorSettings = protocol_settings.integrator_settings # is the timestep good for the mass? - settings_validation.validate_timestep( - forcefield_settings.hydrogen_mass, - integrator_settings.timestep - ) + settings_validation.validate_timestep(forcefield_settings.hydrogen_mass, integrator_settings.timestep) # TODO: Also validate various conversions? # Convert various time based inputs to steps/iterations steps_per_iteration = settings_validation.convert_steps_per_iteration( @@ -713,10 +716,9 @@ def run(self, *, dry=False, verbose=True, # and keep the molecule around to maintain the partial charges off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] off_small_mols = { - 'stateA': [(mapping.componentA, mapping.componentA.to_openff())], - 'stateB': [(mapping.componentB, mapping.componentB.to_openff())], - 'both': [(m, m.to_openff()) for m in small_mols - if (m != mapping.componentA and m != mapping.componentB)] + "stateA": [(mapping.componentA, mapping.componentA.to_openff())], + "stateB": [(mapping.componentB, mapping.componentB.to_openff())], + "both": [(m, m.to_openff()) for m in small_mols if (m != mapping.componentA and m != mapping.componentB)], } self._assign_partial_charges(charge_settings, off_small_mols) @@ -741,18 +743,14 @@ def run(self, *, dry=False, verbose=True, # c. force the creation of parameters # This is necessary because we need to have the FF templates # registered ahead of solvating the system. - for smc, mol in chain(off_small_mols['stateA'], - off_small_mols['stateB'], - off_small_mols['both']): - system_generator.create_system(mol.to_topology().to_openmm(), - molecules=[mol]) + for smc, mol in chain(off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"]): + system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) # c. get OpenMM Modeller + a dictionary of resids for each component stateA_modeller, comp_resids = system_creation.get_omm_modeller( protein_comp=protein_comp, solvent_comp=solvent_comp, - small_mols=dict(chain(off_small_mols['stateA'], - off_small_mols['both'])), + small_mols=dict(chain(off_small_mols["stateA"], off_small_mols["both"])), omm_forcefield=system_generator.forcefield, solvent_settings=solvation_settings, ) @@ -760,9 +758,7 @@ def run(self, *, dry=False, verbose=True, # d. get topology & positions # Note: roundtrip positions to remove vec3 issues stateA_topology = stateA_modeller.getTopology() - stateA_positions = to_openmm( - from_openmm(stateA_modeller.getPositions()) - ) + stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) # e. create the stateA System # Block out oechem backend in system_generator calls to avoid @@ -770,8 +766,7 @@ def run(self, *, dry=False, verbose=True, with without_oechem_backend(): stateA_system = system_generator.create_system( stateA_modeller.topology, - molecules=[m for _, m in chain(off_small_mols['stateA'], - off_small_mols['both'])], + molecules=[m for _, m in chain(off_small_mols["stateA"], off_small_mols["both"])], ) # 2. Get stateB system @@ -779,7 +774,7 @@ def run(self, *, dry=False, verbose=True, stateB_topology, stateB_alchem_resids = _rfe_utils.topologyhelpers.combined_topology( stateA_topology, # zeroth item (there's only one) then get the OFF representation - off_small_mols['stateB'][0][1].to_topology().to_openmm(), + off_small_mols["stateB"][0][1].to_topology().to_openmm(), exclude_resids=comp_resids[mapping.componentA], ) @@ -789,15 +784,18 @@ def run(self, *, dry=False, verbose=True, with without_oechem_backend(): stateB_system = system_generator.create_system( stateB_topology, - molecules=[m for _, m in chain(off_small_mols['stateB'], - off_small_mols['both'])], + molecules=[m for _, m in chain(off_small_mols["stateB"], off_small_mols["both"])], ) # c. Define correspondence mappings between the two systems ligand_mappings = _rfe_utils.topologyhelpers.get_system_mappings( mapping.componentA_to_componentB, - stateA_system, stateA_topology, comp_resids[mapping.componentA], - stateB_system, stateB_topology, stateB_alchem_resids, + stateA_system, + stateA_topology, + comp_resids[mapping.componentA], + stateB_system, + stateB_topology, + stateB_alchem_resids, # These are non-optional settings for this method fix_constraints=True, ) @@ -806,35 +804,45 @@ def run(self, *, dry=False, verbose=True, # and transform them if alchem_settings.explicit_charge_correction: alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( - stateA_topology, stateA_positions, + stateA_topology, + stateA_positions, charge_difference, alchem_settings.explicit_charge_correction_cutoff, ) _rfe_utils.topologyhelpers.handle_alchemical_waters( - alchem_water_resids, stateB_topology, stateB_system, - ligand_mappings, charge_difference, + alchem_water_resids, + stateB_topology, + stateB_system, + ligand_mappings, + charge_difference, solvent_comp, ) # e. Finally get the positions stateB_positions = _rfe_utils.topologyhelpers.set_and_check_new_positions( - ligand_mappings, stateA_topology, stateB_topology, - old_positions=ensure_quantity(stateA_positions, 'openmm'), - insert_positions=ensure_quantity(off_small_mols['stateB'][0][1].conformers[0], 'openmm'), + ligand_mappings, + stateA_topology, + stateB_topology, + old_positions=ensure_quantity(stateA_positions, "openmm"), + insert_positions=ensure_quantity(off_small_mols["stateB"][0][1].conformers[0], "openmm"), ) # 3. Create the hybrid topology # a. Get softcore potential settings - if alchem_settings.softcore_LJ.lower() == 'gapsys': + if alchem_settings.softcore_LJ.lower() == "gapsys": softcore_LJ_v2 = True - elif alchem_settings.softcore_LJ.lower() == 'beutler': + elif alchem_settings.softcore_LJ.lower() == "beutler": softcore_LJ_v2 = False # b. Get hybrid topology factory hybrid_factory = _rfe_utils.relative.HybridTopologyFactory( - stateA_system, stateA_positions, stateA_topology, - stateB_system, stateB_positions, stateB_topology, - old_to_new_atom_map=ligand_mappings['old_to_new_atom_map'], - old_to_new_core_atom_map=ligand_mappings['old_to_new_core_atom_map'], + stateA_system, + stateA_positions, + stateA_topology, + stateB_system, + stateB_positions, + stateB_topology, + old_to_new_atom_map=ligand_mappings["old_to_new_atom_map"], + old_to_new_core_atom_map=ligand_mappings["old_to_new_core_atom_map"], use_dispersion_correction=alchem_settings.use_dispersion_correction, softcore_alpha=alchem_settings.softcore_alpha, softcore_LJ_v2=softcore_LJ_v2, @@ -847,22 +855,22 @@ def run(self, *, dry=False, verbose=True, # ability to print the schedule directly in settings? lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( functions=lambda_settings.lambda_functions, - windows=lambda_settings.lambda_windows + windows=lambda_settings.lambda_windows, ) # PR #125 temporarily pin lambda schedule spacing to n_replicas n_replicas = sampler_settings.n_replicas if n_replicas != len(lambdas.lambda_schedule): - errmsg = (f"Number of replicas {n_replicas} " - f"does not equal the number of lambda windows " - f"{len(lambdas.lambda_schedule)}") + errmsg = ( + f"Number of replicas {n_replicas} " + f"does not equal the number of lambda windows " + f"{len(lambdas.lambda_schedule)}" + ) raise ValueError(errmsg) # 9. Create the multistate reporter # Get the sub selection of the system to print coords for - selection_indices = hybrid_factory.hybrid_topology.select( - output_settings.output_indices - ) + selection_indices = hybrid_factory.hybrid_topology.select(output_settings.output_indices) # a. Create the multistate reporter # convert checkpoint_interval from time to iterations @@ -882,24 +890,22 @@ def run(self, *, dry=False, verbose=True, # b. Write out a PDB containing the subsampled hybrid state bfactors = np.zeros_like(selection_indices, dtype=float) # solvent - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_old_atoms']))] = 0.25 # lig A - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['core_atoms']))] = 0.50 # core - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_new_atoms']))] = 0.75 # lig B + bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes["unique_old_atoms"]))] = 0.25 # lig A + bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes["core_atoms"]))] = 0.50 # core + bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes["unique_new_atoms"]))] = 0.75 # lig B # bfactors[np.in1d(selection_indices, protein)] = 1.0 # prot+cofactor if len(selection_indices) > 0: traj = mdtraj.Trajectory( - hybrid_factory.hybrid_positions[selection_indices, :], - hybrid_factory.hybrid_topology.subset(selection_indices), + hybrid_factory.hybrid_positions[selection_indices, :], + hybrid_factory.hybrid_topology.subset(selection_indices), ).save_pdb( shared_basepath / output_settings.output_structure, bfactors=bfactors, ) # 10. Get platform - platform = _rfe_utils.compute.get_openmm_platform( - protocol_settings.engine_settings.compute_platform - ) + platform = _rfe_utils.compute.get_openmm_platform(protocol_settings.engine_settings.compute_platform) # 11. Set the integrator # a. Validate integrator settings for current system @@ -907,8 +913,7 @@ def run(self, *, dry=False, verbose=True, # there are virtual sites in the system if hybrid_factory.has_virtual_sites: if not integrator_settings.reassign_velocities: - errmsg = ("Simulations with virtual sites without velocity " - "reassignments are unstable in openmmtools") + errmsg = "Simulations with virtual sites without velocity " "reassignments are unstable in openmmtools" raise ValueError(errmsg) # b. create langevin integrator @@ -949,7 +954,7 @@ def run(self, *, dry=False, verbose=True, flatness_criteria=sampler_settings.sams_flatness_criteria, gamma0=sampler_settings.sams_gamma0, ) - elif sampler_settings.sampler_method.lower() == 'independent': + elif sampler_settings.sampler_method.lower() == "independent": sampler = _rfe_utils.multistate.HybridMultiStateSampler( mcmc_moves=integrator, hybrid_factory=hybrid_factory, @@ -973,11 +978,15 @@ def run(self, *, dry=False, verbose=True, try: # Create context caches (energy + sampler) energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, + capacity=None, + time_to_live=None, + platform=platform, ) sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, + capacity=None, + time_to_live=None, + platform=platform, ) sampler.energy_context_cache = energy_context_cache @@ -994,17 +1003,13 @@ def run(self, *, dry=False, verbose=True, if verbose: self.logger.info("Running equilibration phase") - sampler.equilibrate( - int(equil_steps / steps_per_iteration) # type: ignore - ) + sampler.equilibrate(int(equil_steps / steps_per_iteration)) # type: ignore # production if verbose: self.logger.info("Running production phase") - sampler.extend( - int(prod_steps / steps_per_iteration) # type: ignore - ) + sampler.extend(int(prod_steps / steps_per_iteration)) # type: ignore self.logger.info("Production phase complete") @@ -1021,8 +1026,10 @@ def run(self, *, dry=False, verbose=True, else: # clean up the reporter file - fns = [shared_basepath / output_settings.output_filename, - shared_basepath / output_settings.checkpoint_storage_filename] + fns = [ + shared_basepath / output_settings.output_filename, + shared_basepath / output_settings.checkpoint_storage_filename, + ] for fn in fns: os.remove(fn) finally: @@ -1038,8 +1045,7 @@ def run(self, *, dry=False, verbose=True, for context in list(sampler_context_cache._lru._data.keys()): del sampler_context_cache._lru._data[context] # cautiously clear out the global context cache too - for context in list( - openmmtools.cache.global_context_cache._lru._data.keys()): + for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] del sampler_context_cache, energy_context_cache @@ -1048,58 +1054,55 @@ def run(self, *, dry=False, verbose=True, del integrator, sampler if not dry: # pragma: no-cover - return { - 'nc': nc, - 'last_checkpoint': chk, - **analyzer.unit_results_dict - } + return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict} else: - return {'debug': {'sampler': sampler}} + return {"debug": {"sampler": sampler}} @staticmethod def analyse(where) -> dict: # don't put energy analysis in here, it uses the open file reporter # whereas structural stuff requires that the file handle is closed - analysis_out = where / 'structural_analysis.json' + analysis_out = where / "structural_analysis.json" - ret = subprocess.run(['openfe_analysis', 'RFE_analysis', - str(where), str(analysis_out)], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + ret = subprocess.run( + ["openfe_analysis", "RFE_analysis", str(where), str(analysis_out)], + capture_output=True, + ) if ret.returncode: - return {'structural_analysis_error': ret.stderr} + return {"structural_analysis_error": ret.stderr} - with open(analysis_out, 'rb') as f: + with open(analysis_out, "rb") as f: data = json.load(f) savedir = pathlib.Path(where) - if d := data['protein_2D_RMSD']: + if d := data["protein_2D_RMSD"]: fig = plotting.plot_2D_rmsd(d) fig.savefig(savedir / "protein_2D_RMSD.png") plt.close(fig) - f2 = plotting.plot_ligand_COM_drift(data['time(ps)'], data['ligand_wander']) + f2 = plotting.plot_ligand_COM_drift(data["time(ps)"], data["ligand_wander"]) f2.savefig(savedir / "ligand_COM_drift.png") plt.close(f2) - f3 = plotting.plot_ligand_RMSD(data['time(ps)'], data['ligand_RMSD']) + f3 = plotting.plot_ligand_RMSD(data["time(ps)"], data["ligand_RMSD"]) f3.savefig(savedir / "ligand_RMSD.png") plt.close(f3) - return {'structural_analysis': data} + return {"structural_analysis": data} def _execute( - self, ctx: gufe.Context, **kwargs, + self, + ctx: gufe.Context, + **kwargs, ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) - - outputs = self.run(scratch_basepath=ctx.scratch, - shared_basepath=ctx.shared) + + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) analysis_outputs = self.analyse(ctx.shared) return { - 'repeat_id': self._inputs['repeat_id'], - 'generation': self._inputs['generation'], + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], **outputs, **analysis_outputs, } diff --git a/openfe/protocols/openmm_rfe/equil_rfe_settings.py b/openfe/protocols/openmm_rfe/equil_rfe_settings.py index 4d3de222f..6c4fc1cfc 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_settings.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_settings.py @@ -9,22 +9,18 @@ from __future__ import annotations from typing import Literal -from openff.units import unit + +from gufe.settings import OpenMMSystemGeneratorFFSettings, Settings, SettingsBaseModel, ThermoSettings from openff.models.types import FloatQuantity +from openff.units import unit -from gufe.settings import ( - Settings, - SettingsBaseModel, - OpenMMSystemGeneratorFFSettings, - ThermoSettings, -) from openfe.protocols.openmm_utils.omm_settings import ( IntegratorSettings, MultiStateSimulationSettings, + OpenFFPartialChargeSettings, OpenMMEngineSettings, OpenMMSolvationSettings, OutputSettings, - OpenFFPartialChargeSettings, ) try: @@ -35,15 +31,15 @@ class LambdaSettings(SettingsBaseModel): class Config: - extra = 'ignore' + extra = "ignore" arbitrary_types_allowed = True """Lambda schedule settings. - - Settings controlling the lambda schedule, these include the switching + + Settings controlling the lambda schedule, these include the switching function type, and the number of windows. """ - lambda_functions = 'default' + lambda_functions = "default" """ Key of which switching functions to use for alchemical mutation. Default 'default'. @@ -54,7 +50,7 @@ class Config: class AlchemicalSettings(SettingsBaseModel): class Config: - extra = 'ignore' + extra = "ignore" arbitrary_types_allowed = True """Settings for the alchemical protocol @@ -74,9 +70,9 @@ class Config: Whether to use dispersion correction in the hybrid topology state. Default False. """ - softcore_LJ: Literal['gapsys', 'beutler'] + softcore_LJ: Literal["gapsys", "beutler"] """ - Whether to use the LJ softcore function as defined by Gapsys et al. + Whether to use the LJ softcore function as defined by Gapsys et al. JCTC 2012, or the one by Beutler et al. Chem. Phys. Lett. 1994. Default 'gapsys'. """ @@ -85,7 +81,7 @@ class Config: turn_off_core_unique_exceptions = False """ Whether to turn off interactions for new exceptions (not just 1,4s) - at lambda 0 and old exceptions at lambda 1 between unique atoms and core + at lambda 0 and old exceptions at lambda 1 between unique atoms and core atoms. If False they are present in the nonbonded force. Default False. """ explicit_charge_correction = False @@ -101,7 +97,7 @@ class Config: Default False. """ - explicit_charge_correction_cutoff: FloatQuantity['nanometer'] = 0.8 * unit.nanometer + explicit_charge_correction_cutoff: FloatQuantity[nanometer] = 0.8 * unit.nanometer """ The minimum distance from the system solutes from which an alchemical water can be chosen. Default 0.8 * unit.nanometer. @@ -111,12 +107,12 @@ class Config: class RelativeHybridTopologyProtocolSettings(Settings): protocol_repeats: int """ - The number of completely independent repeats of the entire sampling - process. The mean of the repeats defines the final estimate of FE - difference, while the variance between repeats is used as the uncertainty. + The number of completely independent repeats of the entire sampling + process. The mean of the repeats defines the final estimate of FE + difference, while the variance between repeats is used as the uncertainty. """ - @validator('protocol_repeats') + @validator("protocol_repeats") def must_be_positive(cls, v): if v <= 0: errmsg = f"protocol_repeats must be a positive value, got {v}." diff --git a/openfe/protocols/openmm_utils/__init__.py b/openfe/protocols/openmm_utils/__init__.py index 8b1378917..e69de29bb 100644 --- a/openfe/protocols/openmm_utils/__init__.py +++ b/openfe/protocols/openmm_utils/__init__.py @@ -1 +0,0 @@ - diff --git a/openfe/protocols/openmm_utils/charge_generation.py b/openfe/protocols/openmm_utils/charge_generation.py index e588cd0e7..1a5287e5b 100644 --- a/openfe/protocols/openmm_utils/charge_generation.py +++ b/openfe/protocols/openmm_utils/charge_generation.py @@ -4,19 +4,16 @@ Reusable utilities for assigning partial charges to ChemicalComponents. """ import copy -from typing import Union, Optional, Literal, Callable import sys import warnings +from typing import Callable, Literal, Optional, Union + import numpy as np -from openff.units import unit from openff.toolkit import Molecule as OFFMol from openff.toolkit.utils.base_wrapper import ToolkitWrapper -from openff.toolkit.utils.toolkits import ( - AmberToolsToolkitWrapper, - OpenEyeToolkitWrapper, - RDKitToolkitWrapper -) from openff.toolkit.utils.toolkit_registry import ToolkitRegistry +from openff.toolkit.utils.toolkits import AmberToolsToolkitWrapper, OpenEyeToolkitWrapper, RDKitToolkitWrapper +from openff.units import unit try: import openeye @@ -26,21 +23,15 @@ HAS_OPENEYE = True try: - from openff.toolkit.utils.toolkit_registry import ( - toolkit_registry_manager, - ) + from openff.toolkit.utils.toolkit_registry import toolkit_registry_manager except ImportError: # toolkit_registry_manager was made non private in 0.14.4 - from openff.toolkit.utils.toolkit_registry import ( - _toolkit_registry_manager as toolkit_registry_manager - ) + from openff.toolkit.utils.toolkit_registry import _toolkit_registry_manager as toolkit_registry_manager try: + from openff.nagl_models import get_models_by_type, validate_nagl_model_path from openff.toolkit.utils.nagl_wrapper import NAGLToolkitWrapper - from openff.nagl_models import ( - get_models_by_type, validate_nagl_model_path - ) except ImportError: HAS_NAGL = False else: @@ -66,10 +57,7 @@ } -def assign_offmol_espaloma_charges( - offmol: OFFMol, - toolkit_registry: ToolkitRegistry -) -> None: +def assign_offmol_espaloma_charges(offmol: OFFMol, toolkit_registry: ToolkitRegistry) -> None: """ Assign Espaloma charges using the OpenFF toolkit. @@ -84,12 +72,10 @@ def assign_offmol_espaloma_charges( assignment stage. """ if not HAS_ESPALOMA: - errmsg = ("The Espaloma ToolkiWrapper is not available, " - "please install espaloma_charge") + errmsg = "The Espaloma ToolkiWrapper is not available, " "please install espaloma_charge" raise ImportError(errmsg) - warnings.warn("Using espaloma to assign charges is not well tested", - category=RuntimeWarning) + warnings.warn("Using espaloma to assign charges is not well tested", category=RuntimeWarning) # make a copy to remove conformers as espaloma enforces # a 0 conformer check @@ -101,7 +87,7 @@ def assign_offmol_espaloma_charges( # https://github.com/openforcefield/openff-nagl/issues/69 with toolkit_registry_manager(toolkit_registry): offmol_copy.assign_partial_charges( - partial_charge_method='espaloma-am1bcc', + partial_charge_method="espaloma-am1bcc", toolkit_registry=EspalomaChargeToolkitWrapper(), ) @@ -131,21 +117,23 @@ def assign_offmol_nagl_charges( If ``None``, will fetch the latest production "am1bcc" model. """ if not HAS_NAGL: - errmsg = ("The NAGL toolkit is not available, you may " - "be using an older version of the OpenFF " - "toolkit - you need v0.14.4 or above") + errmsg = ( + "The NAGL toolkit is not available, you may " + "be using an older version of the OpenFF " + "toolkit - you need v0.14.4 or above" + ) raise ImportError(errmsg) if nagl_model is None: - prod_models = get_models_by_type( - model_type='am1bcc', production_only=True - ) + prod_models = get_models_by_type(model_type="am1bcc", production_only=True) # Currently there are no production models so expect an IndexError try: nagl_model = prod_models[-1] except IndexError: - errmsg = ("No production am1bcc NAGL models are current available " - "please manually select a candidate release model") + errmsg = ( + "No production am1bcc NAGL models are current available " + "please manually select a candidate release model" + ) raise ValueError(errmsg) model_path = validate_nagl_model_path(nagl_model) @@ -162,7 +150,7 @@ def assign_offmol_nagl_charges( def assign_offmol_am1bcc_charges( offmol: OFFMol, - partial_charge_method: Literal['am1bcc', 'am1bccelf10'], + partial_charge_method: Literal["am1bcc", "am1bccelf10"], toolkit_registry: ToolkitRegistry, ) -> None: """ @@ -199,7 +187,7 @@ def assign_offmol_am1bcc_charges( offmol.assign_partial_charges( partial_charge_method=partial_charge_method, use_conformers=offmol.conformers, - toolkit_registry=toolkit_registry + toolkit_registry=toolkit_registry, ) @@ -237,33 +225,35 @@ def _generate_offmol_conformers( # Check number of conformers if generate_n_conformers is None and return if generate_n_conformers is None: if offmol.n_conformers == 0: - errmsg = ("No conformers are associated with input OpenFF " - "Molecule. Need at least one for partial charge " - "assignment") + errmsg = ( + "No conformers are associated with input OpenFF " + "Molecule. Need at least one for partial charge " + "assignment" + ) raise ValueError(errmsg) if offmol.n_conformers > max_conf: - errmsg = ("OpenFF Molecule has too many conformers: " - f"{offmol.n_conformers}, selected partial charge " - f"method can only support a maximum of {max_conf} " - "conformers.") + errmsg = ( + "OpenFF Molecule has too many conformers: " + f"{offmol.n_conformers}, selected partial charge " + f"method can only support a maximum of {max_conf} " + "conformers." + ) raise ValueError(errmsg) return - # Check that generate_n_conformers < max_conf if generate_n_conformers > max_conf: - errmsg = (f"{generate_n_conformers} conformers were requested " - "for partial charge generation, but the selected " - "method only supports up to {max_conf} conformers.") + errmsg = ( + f"{generate_n_conformers} conformers were requested " + "for partial charge generation, but the selected " + "method only supports up to {max_conf} conformers." + ) raise ValueError(errmsg) # Generate conformers # OpenEye tk needs cis carboxylic acids - make_carbox_cis = any( - [isinstance(i, OpenEyeToolkitWrapper) - for i in toolkit_registry.registered_toolkits] - ) + make_carbox_cis = any([isinstance(i, OpenEyeToolkitWrapper) for i in toolkit_registry.registered_toolkits]) # We are being overly cautious by both passing the # registry and applying the manager here - this is @@ -281,8 +271,8 @@ def _generate_offmol_conformers( def assign_offmol_partial_charges( offmol: OFFMol, overwrite: bool, - method: Literal['am1bcc', 'am1bccelf10', 'nagl', 'espaloma'], - toolkit_backend: Literal['ambertools', 'openeye', 'rdkit'], + method: Literal["am1bcc", "am1bccelf10", "nagl", "espaloma"], + toolkit_backend: Literal["ambertools", "openeye", "rdkit"], generate_n_conformers: Optional[int], nagl_model: Optional[str], ) -> None: @@ -326,7 +316,7 @@ def assign_offmol_partial_charges( """ # If you have non-zero charges and not overwriting, just return - if (offmol.partial_charges is not None and np.any(offmol.partial_charges)): + if offmol.partial_charges is not None and np.any(offmol.partial_charges): if not overwrite: return @@ -346,28 +336,28 @@ def assign_offmol_partial_charges( "am1bcc": { "confgen_func": _generate_offmol_conformers, "charge_func": assign_offmol_am1bcc_charges, - "backends": ['ambertools', 'openeye'], + "backends": ["ambertools", "openeye"], "max_conf": 1, - "charge_extra_kwargs": {'partial_charge_method': 'am1bcc'}, + "charge_extra_kwargs": {"partial_charge_method": "am1bcc"}, }, "am1bccelf10": { "confgen_func": _generate_offmol_conformers, "charge_func": assign_offmol_am1bcc_charges, - "backends": ['openeye'], + "backends": ["openeye"], "max_conf": sys.maxsize, - "charge_extra_kwargs": {'partial_charge_method': 'am1bccelf10'}, + "charge_extra_kwargs": {"partial_charge_method": "am1bccelf10"}, }, "nagl": { "confgen_func": _generate_offmol_conformers, "charge_func": assign_offmol_nagl_charges, - "backends": ['openeye', 'rdkit', 'ambertools'], + "backends": ["openeye", "rdkit", "ambertools"], "max_conf": 1, "charge_extra_kwargs": {"nagl_model": nagl_model}, }, "espaloma": { "confgen_func": _generate_offmol_conformers, "charge_func": assign_offmol_espaloma_charges, - "backends": ['rdkit', 'ambertools'], + "backends": ["rdkit", "ambertools"], "max_conf": 1, "charge_extra_kwargs": {}, }, @@ -375,44 +365,44 @@ def assign_offmol_partial_charges( # Grab the backends and also check our method try: - backends = CHARGE_METHODS[method.lower()]['backends'] + backends = CHARGE_METHODS[method.lower()]["backends"] except KeyError: errmsg = f"Unknown partial charge method {method}" raise ValueError(errmsg) # Check our method actually supports the toolkit backend selected if toolkit_backend.lower() not in backends: # type: ignore - errmsg = (f"Selected toolkit_backend ({toolkit_backend}) cannot " - f"be used with the selected method ({method}). " - f"Available backends are: {backends}") + errmsg = ( + f"Selected toolkit_backend ({toolkit_backend}) cannot " + f"be used with the selected method ({method}). " + f"Available backends are: {backends}" + ) raise ValueError(errmsg) # OpenEye is the only optional dependency in the toolkit backends - if toolkit_backend.lower() == 'openeye' and not HAS_OPENEYE: + if toolkit_backend.lower() == "openeye" and not HAS_OPENEYE: errmsg = "OpenEye is not available and cannot be selected as a backend" raise ImportError(errmsg) - toolkits = ToolkitRegistry( - [i() for i in BACKEND_OPTIONS[toolkit_backend.lower()]] - ) + toolkits = ToolkitRegistry([i() for i in BACKEND_OPTIONS[toolkit_backend.lower()]]) # We make a copy of the molecule since we're going to modify conformers offmol_copy = copy.deepcopy(offmol) # Generate conformers - note this method may differ based on the partial # charge method employed - CHARGE_METHODS[method.lower()]['confgen_func']( + CHARGE_METHODS[method.lower()]["confgen_func"]( offmol=offmol_copy, - max_conf=CHARGE_METHODS[method.lower()]['max_conf'], + max_conf=CHARGE_METHODS[method.lower()]["max_conf"], toolkit_registry=toolkits, generate_n_conformers=generate_n_conformers, ) # type: ignore # Call selected method to assign partial charges - CHARGE_METHODS[method.lower()]['charge_func']( + CHARGE_METHODS[method.lower()]["charge_func"]( offmol=offmol_copy, toolkit_registry=toolkits, - **CHARGE_METHODS[method.lower()]['charge_extra_kwargs'], + **CHARGE_METHODS[method.lower()]["charge_extra_kwargs"], ) # type: ignore # Copy partial charges back diff --git a/openfe/protocols/openmm_utils/multistate_analysis.py b/openfe/protocols/openmm_utils/multistate_analysis.py index e5d7d8be7..7ffd195b8 100644 --- a/openfe/protocols/openmm_utils/multistate_analysis.py +++ b/openfe/protocols/openmm_utils/multistate_analysis.py @@ -3,16 +3,18 @@ """ Reusable utility methods to analyze results from multistate calculations. """ -from pathlib import Path import warnings +from pathlib import Path +from typing import Optional, Union + import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt +from openff.units import ensure_quantity, unit from openmmtools import multistate -from openff.units import unit, ensure_quantity from pymbar.utils import ParameterError + from openfe.analysis import plotting -from typing import Optional, Union class MultistateEquilFEAnalysis: @@ -40,13 +42,18 @@ class MultistateEquilFEAnalysis: The number of samples to use in the foward and reverse analysis of the free energies. Default 10. """ - def __init__(self, reporter: multistate.MultiStateReporter, - sampling_method: str, result_units: unit.Quantity, - forward_reverse_samples: int = 10): + + def __init__( + self, + reporter: multistate.MultiStateReporter, + sampling_method: str, + result_units: unit.Quantity, + forward_reverse_samples: int = 10, + ): self.analyzer = multistate.MultiStateSamplerAnalyzer(reporter) self.units = result_units - if sampling_method.lower() not in ['repex', 'sams', 'independent']: + if sampling_method.lower() not in ["repex", "sams", "independent"]: wmsg = f"Unknown sampling method {sampling_method}" warnings.warn(wmsg) self.sampling_method = sampling_method.lower() @@ -72,43 +79,29 @@ def plot(self, filepath: Path, filename_prefix: str): A prefix for the written filenames. """ # MBAR overlap matrix - ax = plotting.plot_lambda_transition_matrix(self.free_energy_overlaps['matrix']) - ax.set_title('MBAR overlap matrix') - ax.figure.savefig( # type: ignore - filepath / (filename_prefix + 'mbar_overlap_matrix.png') - ) + ax = plotting.plot_lambda_transition_matrix(self.free_energy_overlaps["matrix"]) + ax.set_title("MBAR overlap matrix") + ax.figure.savefig(filepath / (filename_prefix + "mbar_overlap_matrix.png")) # type: ignore plt.close(ax.figure) # type: ignore # Reverse and forward analysis if self.forward_and_reverse_free_energies is not None: - ax = plotting.plot_convergence( - self.forward_and_reverse_free_energies, self.units - ) - ax.set_title('Forward and Reverse free energy convergence') - ax.figure.savefig( # type: ignore - filepath / (filename_prefix + 'forward_reverse_convergence.png') - ) + ax = plotting.plot_convergence(self.forward_and_reverse_free_energies, self.units) + ax.set_title("Forward and Reverse free energy convergence") + ax.figure.savefig(filepath / (filename_prefix + "forward_reverse_convergence.png")) # type: ignore plt.close(ax.figure) # type: ignore # Replica state timeseries plot - ax = plotting.plot_replica_timeseries( - self.replica_states, self.equilibration_iterations - ) - ax.set_title('Change in replica state over time') - ax.figure.savefig( # type: ignore - filepath / (filename_prefix + 'replica_state_timeseries.png') - ) + ax = plotting.plot_replica_timeseries(self.replica_states, self.equilibration_iterations) + ax.set_title("Change in replica state over time") + ax.figure.savefig(filepath / (filename_prefix + "replica_state_timeseries.png")) # type: ignore plt.close(ax.figure) # type: ignore # Replica exchange transition matrix - if self.sampling_method == 'repex': - ax = plotting.plot_lambda_transition_matrix( - self.replica_exchange_statistics['matrix'] - ) - ax.set_title('Replica exchange transition matrix') - ax.figure.savefig( # type: ignore - filepath / (filename_prefix + 'replica_exchange_matrix.png') - ) + if self.sampling_method == "repex": + ax = plotting.plot_lambda_transition_matrix(self.replica_exchange_statistics["matrix"]) + ax.set_title("Replica exchange transition matrix") + ax.figure.savefig(filepath / (filename_prefix + "replica_exchange_matrix.png")) # type: ignore plt.close(ax.figure) # type: ignore def _analyze(self, forward_reverse_samples: int): @@ -141,9 +134,7 @@ def _analyze(self, forward_reverse_samples: int): self._free_energy, self._free_energy_err = self.get_equil_free_energy() # forward and reverse analysis - self._forward_reverse = self.get_forward_and_reverse_analysis( - forward_reverse_samples - ) + self._forward_reverse = self.get_forward_and_reverse_analysis(forward_reverse_samples) # Gather overlap matrix self._overlap_matrix = self.get_overlap_matrix() @@ -151,13 +142,14 @@ def _analyze(self, forward_reverse_samples: int): # Gather exchange transition matrix # Note we only generate these for replica exchange calculations # TODO: consider if this would also work for SAMS - if self.sampling_method == 'repex': + if self.sampling_method == "repex": self._exchange_matrix = self.get_exchanges() @staticmethod def _get_free_energy( analyzer: multistate.MultiStateSamplerAnalyzer, - u_ln: npt.NDArray, N_l: npt.NDArray, + u_ln: npt.NDArray, + N_l: npt.NDArray, return_units: unit.Quantity, ) -> tuple[unit.Quantity, unit.Quantity]: """ @@ -196,14 +188,13 @@ def _get_free_energy( DF_ij, dDF_ij = mbar.getFreeEnergyDifferences() except AttributeError: r = mbar.compute_free_energy_differences() - DF_ij = r['Delta_f'] - dDF_ij = r['dDelta_f'] + DF_ij = r["Delta_f"] + dDF_ij = r["dDelta_f"] DG = DF_ij[0, -1] * analyzer.kT dDG = dDF_ij[0, -1] * analyzer.kT - return (ensure_quantity(DG, 'openff').to(return_units), - ensure_quantity(dDG, 'openff').to(return_units)) + return (ensure_quantity(DG, "openff").to(return_units), ensure_quantity(dDG, "openff").to(return_units)) def get_equil_free_energy(self) -> tuple[unit.Quantity, unit.Quantity]: """ @@ -220,14 +211,13 @@ def get_equil_free_energy(self) -> tuple[unit.Quantity, unit.Quantity]: u_ln_decorr = self.analyzer._unbiased_decorrelated_u_ln N_l_decorr = self.analyzer._unbiased_decorrelated_N_l - DG, dDG = self._get_free_energy( - self.analyzer, u_ln_decorr, N_l_decorr, self.units - ) + DG, dDG = self._get_free_energy(self.analyzer, u_ln_decorr, N_l_decorr, self.units) return DG, dDG def get_forward_and_reverse_analysis( - self, num_samples: int = 10 + self, + num_samples: int = 10, ) -> Optional[dict[str, Union[npt.NDArray, unit.Quantity]]]: """ Calculate free energies with a progressively larger @@ -257,14 +247,12 @@ def get_forward_and_reverse_analysis( # Check that the N_l is the same across all states if not np.all(N_l == N_l[0]): - errmsg = ("The number of samples is not equivalent across all " - f"states {N_l}") + errmsg = "The number of samples is not equivalent across all " f"states {N_l}" raise ValueError(errmsg) # Get the chunks of N_l going from 10% to ~ 100% # Note: you always lose out a few data points but it's fine - chunks = [max(int(N_l[0] / num_samples * i), 1) - for i in range(1, num_samples + 1)] + chunks = [max(int(N_l[0] / num_samples * i), 1) for i in range(1, num_samples + 1)] forward_DGs = [] forward_dDGs = [] @@ -279,7 +267,8 @@ def get_forward_and_reverse_analysis( # Forward DG, dDG = self._get_free_energy( self.analyzer, - u_ln[:, :samples], new_N_l, + u_ln[:, :samples], + new_N_l, self.units, ) forward_DGs.append(DG) @@ -288,7 +277,8 @@ def get_forward_and_reverse_analysis( # Reverse DG, dDG = self._get_free_energy( self.analyzer, - u_ln[:, -samples:], new_N_l, + u_ln[:, -samples:], + new_N_l, self.units, ) reverse_DGs.append(DG) @@ -299,11 +289,11 @@ def get_forward_and_reverse_analysis( return None forward_reverse = { - 'fractions': np.array(fractions), - 'forward_DGs': unit.Quantity.from_list(forward_DGs), - 'forward_dDGs': unit.Quantity.from_list(forward_dDGs), - 'reverse_DGs': unit.Quantity.from_list(reverse_DGs), - 'reverse_dDGs': unit.Quantity.from_list(reverse_dDGs) + "fractions": np.array(fractions), + "forward_DGs": unit.Quantity.from_list(forward_DGs), + "forward_dDGs": unit.Quantity.from_list(forward_dDGs), + "reverse_DGs": unit.Quantity.from_list(reverse_DGs), + "reverse_dDGs": unit.Quantity.from_list(reverse_dDGs), } return forward_reverse @@ -325,7 +315,7 @@ def get_overlap_matrix(self) -> dict[str, npt.NDArray]: # pymbar 3 overlap_matrix = self.analyzer.mbar.computeOverlap() # convert matrix to np array - overlap_matrix['matrix'] = np.array(overlap_matrix['matrix']) + overlap_matrix["matrix"] = np.array(overlap_matrix["matrix"]) except AttributeError: overlap_matrix = self.analyzer.mbar.compute_overlap() @@ -347,8 +337,7 @@ def get_exchanges(self) -> dict[str, npt.NDArray]: """ # Get replica mixing statistics mixing_stats = self.analyzer.generate_mixing_statistics() - transition_matrix = {'eigenvalues': mixing_stats.eigenvalues, - 'matrix': mixing_stats.transition_matrix} + transition_matrix = {"eigenvalues": mixing_stats.eigenvalues, "matrix": mixing_stats.transition_matrix} return transition_matrix @property @@ -408,26 +397,25 @@ def replica_exchange_statistics(self): A dictionary containing the estimated replica exchange matrix and corresponding eigenvalues. """ - if hasattr(self, '_exchange_matrix'): + if hasattr(self, "_exchange_matrix"): return self._exchange_matrix else: - errmsg = ("Exchange matrix was not generated, this is likely " - f"{self.sampling_method} is not repex.") + errmsg = "Exchange matrix was not generated, this is likely " f"{self.sampling_method} is not repex." raise ValueError(errmsg) @property def unit_results_dict(self): results_dict = { - 'unit_estimate': self.free_energy, - 'unit_estimate_error': self.free_energy_error, - 'unit_mbar_overlap': self.free_energy_overlaps, - 'forward_and_reverse_energies': self.forward_and_reverse_free_energies, - 'production_iterations': self.production_iterations, - 'equilibration_iterations': self.equilibration_iterations, + "unit_estimate": self.free_energy, + "unit_estimate_error": self.free_energy_error, + "unit_mbar_overlap": self.free_energy_overlaps, + "forward_and_reverse_energies": self.forward_and_reverse_free_energies, + "production_iterations": self.production_iterations, + "equilibration_iterations": self.equilibration_iterations, } - if hasattr(self, '_exchange_matrix'): - results_dict['replica_exchange_statistics'] = self.replica_exchange_statistics + if hasattr(self, "_exchange_matrix"): + results_dict["replica_exchange_statistics"] = self.replica_exchange_statistics return results_dict diff --git a/openfe/protocols/openmm_utils/omm_settings.py b/openfe/protocols/openmm_utils/omm_settings.py index b07d57d6b..86ac2fa10 100644 --- a/openfe/protocols/openmm_utils/omm_settings.py +++ b/openfe/protocols/openmm_utils/omm_settings.py @@ -9,17 +9,11 @@ """ from __future__ import annotations -from typing import Optional, Literal -from openff.units import unit -from openff.models.types import FloatQuantity - -from gufe.settings import ( - Settings, - SettingsBaseModel, - OpenMMSystemGeneratorFFSettings, - ThermoSettings, -) +from typing import Literal, Optional +from gufe.settings import OpenMMSystemGeneratorFFSettings, Settings, SettingsBaseModel, ThermoSettings +from openff.models.types import FloatQuantity +from openff.units import unit try: from pydantic.v1 import validator @@ -31,6 +25,7 @@ class BaseSolvationSettings(SettingsBaseModel): """ Base class for SolvationSettings objects """ + class Config: arbitrary_types_allowed = True @@ -43,21 +38,21 @@ class OpenMMSolvationSettings(BaseSolvationSettings): No solvation will happen if a SolventComponent is not passed. """ - solvent_model: Literal['tip3p', 'spce', 'tip4pew', 'tip5p'] = 'tip3p' + + solvent_model: Literal["tip3p", "spce", "tip4pew", "tip5p"] = "tip3p" """ Force field water model to use. Allowed values are; `tip3p`, `spce`, `tip4pew`, and `tip5p`. """ - solvent_padding: FloatQuantity['nanometer'] = 1.2 * unit.nanometer + solvent_padding: FloatQuantity[nanometer] = 1.2 * unit.nanometer """Minimum distance from any solute atoms to the solvent box edge.""" - @validator('solvent_padding') + @validator("solvent_padding") def is_positive_distance(cls, v): # these are time units, not simulation steps if not v.is_compatible_with(unit.nanometer): - raise ValueError("solvent_padding must be in distance units " - "(i.e. nanometers)") + raise ValueError("solvent_padding must be in distance units " "(i.e. nanometers)") if v < 0: errmsg = "solvent_padding must be a positive value" raise ValueError(errmsg) @@ -68,6 +63,7 @@ class BasePartialChargeSettings(SettingsBaseModel): """ Base class for partial charge assignment. """ + class Config: arbitrary_types_allowed = True @@ -76,7 +72,8 @@ class OpenFFPartialChargeSettings(BasePartialChargeSettings): """ Settings for controlling partial charge assignment using the OpenFF tooling """ - partial_charge_method: Literal['am1bcc', 'am1bccelf10', 'nagl', 'espaloma'] = 'am1bcc' + + partial_charge_method: Literal["am1bcc", "am1bccelf10", "nagl", "espaloma"] = "am1bcc" """ Selection of method for partial charge generation. @@ -110,7 +107,7 @@ class OpenFFPartialChargeSettings(BasePartialChargeSettings): are supported. A maximum of one conformer is allowed. """ - off_toolkit_backend: Literal['ambertools', 'openeye', 'rdkit'] = 'ambertools' + off_toolkit_backend: Literal["ambertools", "openeye", "rdkit"] = "ambertools" """ The OpenFF toolkit registry backend to use for partial charge generation. @@ -129,7 +126,7 @@ class OpenFFPartialChargeSettings(BasePartialChargeSettings): am1bcc partial charge generation, but is usually used in combination with the ``nagl`` or ``espaloma`` ``partial_charge_method`` selections. """ - number_of_conformers: Optional[int] = None + number_of_conformers: int | None = None """ Number of conformers to generate as part of the partial charge assignement. @@ -140,7 +137,7 @@ class OpenFFPartialChargeSettings(BasePartialChargeSettings): partial charges through ``am1bccelf10``. See ``partial_charge_method``'s ``Description of options`` documentation. """ - nagl_model: Optional[str] = None + nagl_model: str | None = None """ The `NAGL `_ model to use for partial charge assignment. @@ -159,7 +156,7 @@ class OpenMMEngineSettings(SettingsBaseModel): * In the future make precision and deterministic forces user defined too. """ - compute_platform: Optional[str] = None + compute_platform: str | None = None """ OpenMM compute platform to perform MD integration with. If None, will choose fastest available platform. Default None. @@ -172,9 +169,9 @@ class IntegratorSettings(SettingsBaseModel): class Config: arbitrary_types_allowed = True - timestep: FloatQuantity['femtosecond'] = 4 * unit.femtosecond + timestep: FloatQuantity[femtosecond] = 4 * unit.femtosecond """Size of the simulation timestep. Default 4 * unit.femtosecond.""" - langevin_collision_rate: FloatQuantity['1/picosecond'] = 1.0 / unit.picosecond + langevin_collision_rate: FloatQuantity[1 / picosecond] = 1.0 / unit.picosecond """Collision frequency. Default 1.0 / unit.pisecond.""" reassign_velocities = False """ @@ -198,35 +195,31 @@ class Config: Whether or not to remove the center of mass motion. Default False. """ - @validator('langevin_collision_rate', 'n_restart_attempts') + @validator("langevin_collision_rate", "n_restart_attempts") def must_be_positive_or_zero(cls, v): if v < 0: - errmsg = ("langevin_collision_rate, and n_restart_attempts must be" - f" zero or positive values, got {v}.") + errmsg = "langevin_collision_rate, and n_restart_attempts must be" f" zero or positive values, got {v}." raise ValueError(errmsg) return v - @validator('timestep', 'constraint_tolerance') + @validator("timestep", "constraint_tolerance") def must_be_positive(cls, v): if v <= 0: - errmsg = ("timestep, and constraint_tolerance " - f"must be positive values, got {v}.") + errmsg = "timestep, and constraint_tolerance " f"must be positive values, got {v}." raise ValueError(errmsg) return v - @validator('timestep') + @validator("timestep") def is_time(cls, v): # these are time units, not simulation steps if not v.is_compatible_with(unit.picosecond): - raise ValueError("timestep must be in time units " - "(i.e. picoseconds)") + raise ValueError("timestep must be in time units " "(i.e. picoseconds)") return v - @validator('langevin_collision_rate') + @validator("langevin_collision_rate") def must_be_inverse_time(cls, v): if not v.is_compatible_with(1 / unit.picosecond): - raise ValueError("langevin collision_rate must be in inverse time " - "(i.e. 1/picoseconds)") + raise ValueError("langevin collision_rate must be in inverse time " "(i.e. 1/picoseconds)") return v @@ -235,39 +228,40 @@ class OutputSettings(SettingsBaseModel): Settings for simulation output settings, writing to disk, etc... """ + class Config: arbitrary_types_allowed = True # reporter settings - output_filename = 'simulation.nc' + output_filename = "simulation.nc" """Path to the trajectory storage file. Default 'simulation.nc'.""" - output_structure = 'hybrid_system.pdb' + output_structure = "hybrid_system.pdb" """ Path of the output hybrid topology structure file. This is used to visualise and further manipulate the system. Default 'hybrid_system.pdb'. """ - output_indices = 'not water' + output_indices = "not water" """ Selection string for which part of the system to write coordinates for. Default 'not water'. """ - checkpoint_interval: FloatQuantity['picosecond'] = 1 * unit.picosecond + checkpoint_interval: FloatQuantity[picosecond] = 1 * unit.picosecond """ Frequency to write the checkpoint file. Default 1 * unit.picosecond. """ - checkpoint_storage_filename = 'checkpoint.chk' + checkpoint_storage_filename = "checkpoint.chk" """ Separate filename for the checkpoint file. Note, this should not be a full path, just a filename. Default 'checkpoint.chk'. """ - forcefield_cache: Optional[str] = 'db.json' + forcefield_cache: str | None = "db.json" """ Filename for caching small molecule residue templates so they can be later reused. """ - @validator('checkpoint_interval') + @validator("checkpoint_interval") def must_be_positive(cls, v): if v <= 0: errmsg = f"Checkpoint intervals must be positive, got {v}." @@ -279,12 +273,13 @@ class SimulationSettings(SettingsBaseModel): """ Settings for simulation control, including lengths, etc... """ + class Config: arbitrary_types_allowed = True minimization_steps = 5000 """Number of minimization steps to perform. Default 5000.""" - equilibration_length: FloatQuantity['nanosecond'] + equilibration_length: FloatQuantity[nanosecond] """ Length of the equilibration phase in units of time. The total number of steps from this equilibration length @@ -292,7 +287,7 @@ class Config: must be a multiple of the value defined for :class:`AlchemicalSamplerSettings.steps_per_iteration`. """ - production_length: FloatQuantity['nanosecond'] + production_length: FloatQuantity[nanosecond] """ Length of the production phase in units of time. The total number of steps from this production length (i.e. @@ -300,19 +295,17 @@ class Config: a multiple of the value defined for :class:`IntegratorSettings.nsteps`. """ - @validator('equilibration_length', 'production_length') + @validator("equilibration_length", "production_length") def is_time(cls, v): # these are time units, not simulation steps if not v.is_compatible_with(unit.picosecond): raise ValueError("Durations must be in time units") return v - @validator('minimization_steps', 'equilibration_length', - 'production_length') + @validator("minimization_steps", "equilibration_length", "production_length") def must_be_positive(cls, v): if v <= 0: - errmsg = ("Minimization steps, and MD lengths must be positive, " - f"got {v}") + errmsg = "Minimization steps, and MD lengths must be positive, " f"got {v}" raise ValueError(errmsg) return v @@ -346,12 +339,12 @@ class Config: or `independent` (independently sampled lambda windows). Default `repex`. """ - time_per_iteration: FloatQuantity['picosecond'] = 1 * unit.picosecond + time_per_iteration: FloatQuantity[picosecond] = 1 * unit.picosecond # todo: Add validators in the protocol """ Simulation time between each MCMC move attempt. Default 1 * unit.picosecond. """ - real_time_analysis_interval: Optional[FloatQuantity['picosecond']] = 250 * unit.picosecond + real_time_analysis_interval: FloatQuantity[picosecond] | None = 250 * unit.picosecond # todo: Add validators in the protocol """ Time interval at which to perform an analysis of the free energies. @@ -369,29 +362,29 @@ class Config: Must be a multiple of ``OutputSettings.checkpoint_interval`` Default `250`. - + """ - early_termination_target_error: Optional[FloatQuantity['kcal/mol']] = 0.0 * unit.kilocalorie_per_mole + early_termination_target_error: FloatQuantity[kcal / mol] | None = 0.0 * unit.kilocalorie_per_mole # todo: have default ``None`` or ``0.0 * unit.kilocalorie_per_mole`` # (later would give an example of unit). """ - Target error for the real time analysis measured in kcal/mol. Once the MBAR - error of the free energy is at or below this value, the simulation will be - considered complete. + Target error for the real time analysis measured in kcal/mol. Once the MBAR + error of the free energy is at or below this value, the simulation will be + considered complete. A suggested value of 0.12 * `unit.kilocalorie_per_mole` has shown to be effective in both hydration and binding free energy benchmarks. Default ``None``, i.e. no early termination will occur. """ - real_time_analysis_minimum_time: FloatQuantity['picosecond'] = 500 * unit.picosecond + real_time_analysis_minimum_time: FloatQuantity[picosecond] = 500 * unit.picosecond # todo: Add validators in the protocol """ Simulation time which must pass before real time analysis is - carried out. - + carried out. + Default 500 * unit.picosecond. """ - sams_flatness_criteria = 'logZ-flatness' + sams_flatness_criteria = "logZ-flatness" """ SAMS only. Method for assessing when to switch to asymptomatically optimal scheme. @@ -403,41 +396,36 @@ class Config: n_replicas = 11 """Number of replicas to use. Default 11.""" - @validator('sams_flatness_criteria') + @validator("sams_flatness_criteria") def supported_flatness(cls, v): - supported = [ - 'logz-flatness', 'minimum-visits', 'histogram-flatness' - ] + supported = ["logz-flatness", "minimum-visits", "histogram-flatness"] if v.lower() not in supported: - errmsg = ("Only the following sams_flatness_criteria are " - f"supported: {supported}") + errmsg = "Only the following sams_flatness_criteria are " f"supported: {supported}" raise ValueError(errmsg) return v - @validator('sampler_method') + @validator("sampler_method") def supported_sampler(cls, v): - supported = ['repex', 'sams', 'independent'] + supported = ["repex", "sams", "independent"] if v.lower() not in supported: - errmsg = ("Only the following sampler_method values are " - f"supported: {supported}") + errmsg = "Only the following sampler_method values are " f"supported: {supported}" raise ValueError(errmsg) return v - @validator('n_replicas', 'time_per_iteration') + @validator("n_replicas", "time_per_iteration") def must_be_positive(cls, v): if v <= 0: - errmsg = "n_replicas and steps_per_iteration must be positive " \ - f"values, got {v}." + errmsg = "n_replicas and steps_per_iteration must be positive " f"values, got {v}." raise ValueError(errmsg) return v - @validator('early_termination_target_error', - 'real_time_analysis_minimum_time', 'sams_gamma0', - 'n_replicas') + @validator("early_termination_target_error", "real_time_analysis_minimum_time", "sams_gamma0", "n_replicas") def must_be_zero_or_positive(cls, v): if v < 0: - errmsg = ("Early termination target error, minimum iteration and" - f" SAMS gamma0 must be 0 or positive values, got {v}.") + errmsg = ( + "Early termination target error, minimum iteration and" + f" SAMS gamma0 must be 0 or positive values, got {v}." + ) raise ValueError(errmsg) return v @@ -446,42 +434,44 @@ class MDSimulationSettings(SimulationSettings): """ Settings for simulation control for plain MD simulations """ + class Config: arbitrary_types_allowed = True equilibration_length_nvt: unit.Quantity """ - Length of the equilibration phase in the NVT ensemble in units of time. + Length of the equilibration phase in the NVT ensemble in units of time. The total number of steps from this equilibration length (i.e. ``equilibration_length_nvt`` / :class:`IntegratorSettings.timestep`). """ class MDOutputSettings(OutputSettings): - """ Settings for simulation output settings for plain MD simulations.""" + """Settings for simulation output settings for plain MD simulations.""" + class Config: arbitrary_types_allowed = True # reporter settings - production_trajectory_filename = 'simulation.xtc' + production_trajectory_filename = "simulation.xtc" """Path to the storage file for analysis. Default 'simulation.xtc'.""" - trajectory_write_interval: FloatQuantity['picosecond'] = 20 * unit.picosecond + trajectory_write_interval: FloatQuantity[picosecond] = 20 * unit.picosecond """ Frequency to write the xtc file. Default 5000 * unit.timestep. """ - preminimized_structure = 'system.pdb' - """Path to the pdb file of the full pre-minimized system. + preminimized_structure = "system.pdb" + """Path to the pdb file of the full pre-minimized system. Default 'system.pdb'.""" - minimized_structure = 'minimized.pdb' - """Path to the pdb file of the system after minimization. + minimized_structure = "minimized.pdb" + """Path to the pdb file of the system after minimization. Only the specified atom subset is saved. Default 'minimized.pdb'.""" - equil_NVT_structure = 'equil_NVT.pdb' - """Path to the pdb file of the system after NVT equilibration. + equil_NVT_structure = "equil_NVT.pdb" + """Path to the pdb file of the system after NVT equilibration. Only the specified atom subset is saved. Default 'equil_NVT.pdb'.""" - equil_NPT_structure = 'equil_NPT.pdb' - """Path to the pdb file of the system after NPT equilibration. + equil_NPT_structure = "equil_NPT.pdb" + """Path to the pdb file of the system after NPT equilibration. Only the specified atom subset is saved. Default 'equil_NPT.pdb'.""" - log_output = 'simulation.log' + log_output = "simulation.log" """ Filename for writing the log of the MD simulation, including timesteps, energies, density, etc. diff --git a/openfe/protocols/openmm_utils/settings_validation.py b/openfe/protocols/openmm_utils/settings_validation.py index 0526f95f2..55282d022 100644 --- a/openfe/protocols/openmm_utils/settings_validation.py +++ b/openfe/protocols/openmm_utils/settings_validation.py @@ -4,12 +4,11 @@ Reusable utility methods to validate input settings to OpenMM-based alchemical Protocols. """ -from openff.units import unit from typing import Optional -from .omm_settings import ( - IntegratorSettings, - MultiStateSimulationSettings, -) + +from openff.units import unit + +from .omm_settings import IntegratorSettings, MultiStateSimulationSettings def validate_timestep(hmass: float, timestep: unit.Quantity): @@ -37,8 +36,7 @@ def validate_timestep(hmass: float, timestep: unit.Quantity): raise ValueError(errmsg) -def get_simsteps(sim_length: unit.Quantity, - timestep: unit.Quantity, mc_steps: int) -> int: +def get_simsteps(sim_length: unit.Quantity, timestep: unit.Quantity, mc_steps: int) -> int: """ Gets and validates the number of simulation steps. @@ -57,17 +55,19 @@ def get_simsteps(sim_length: unit.Quantity, The number of simulation timesteps. """ - sim_time = round(sim_length.to('attosecond').m) - ts = round(timestep.to('attosecond').m) + sim_time = round(sim_length.to("attosecond").m) + ts = round(timestep.to("attosecond").m) sim_steps, mod = divmod(sim_time, ts) if mod != 0: raise ValueError("Simulation time not divisible by timestep") if (sim_steps % mc_steps) != 0: - errmsg = (f"Simulation time {sim_time/1000000} ps should contain a " - "number of steps divisible by the number of integrator " - f"timesteps between MC moves {mc_steps}") + errmsg = ( + f"Simulation time {sim_time/1000000} ps should contain a " + "number of steps divisible by the number of integrator " + f"timesteps between MC moves {mc_steps}" + ) raise ValueError(errmsg) return sim_steps @@ -102,8 +102,12 @@ def divmod_time( return iterations, remainder -def divmod_time_and_check(numerator: unit.Quantity, denominator: unit.Quantity, - numerator_name: str, denominator_name: str) -> int: +def divmod_time_and_check( + numerator: unit.Quantity, + denominator: unit.Quantity, + numerator_name: str, + denominator_name: str, +) -> int: """Perform a division of time, failing if there is a remainder For example numerator 20.0 ps and denominator 4.0 fs gives 5000 @@ -129,9 +133,11 @@ def divmod_time_and_check(numerator: unit.Quantity, denominator: unit.Quantity, its, rem = divmod_time(numerator, denominator) if rem: - errmsg = (f"The {numerator_name} ({numerator}) " - "does not evenly divide by the " - f"{denominator_name} ({denominator})") + errmsg = ( + f"The {numerator_name} ({numerator}) " + "does not evenly divide by the " + f"{denominator_name} ({denominator})" + ) raise ValueError(errmsg) return its @@ -161,9 +167,10 @@ def convert_checkpoint_interval_to_iterations( The number of iterations per checkpoint. """ return divmod_time_and_check( - checkpoint_interval, time_per_iteration, + checkpoint_interval, + time_per_iteration, "amount of time per checkpoint", - "amount of time per state MCM move attempt" + "amount of time per state MCM move attempt", ) diff --git a/openfe/protocols/openmm_utils/system_creation.py b/openfe/protocols/openmm_utils/system_creation.py index 77ae8274d..253d2597c 100644 --- a/openfe/protocols/openmm_utils/system_creation.py +++ b/openfe/protocols/openmm_utils/system_creation.py @@ -4,23 +4,20 @@ Reusable utility methods to create Systems for OpenMM-based alchemical Protocols. """ +from pathlib import Path +from typing import Optional + import numpy as np import numpy.typing as npt -from openmm import app, MonteCarloBarostat -from openmm import unit as omm_unit +from gufe import Component, ProteinComponent, SmallMoleculeComponent, SolventComponent +from gufe.settings import OpenMMSystemGeneratorFFSettings, ThermoSettings from openff.toolkit import Molecule as OFFMol -from openff.units.openmm import to_openmm, ensure_quantity +from openff.units.openmm import ensure_quantity, to_openmm +from openmm import MonteCarloBarostat, app +from openmm import unit as omm_unit from openmmforcefields.generators import SystemGenerator -from typing import Optional -from pathlib import Path -from gufe.settings import OpenMMSystemGeneratorFFSettings, ThermoSettings -from gufe import ( - Component, ProteinComponent, SolventComponent, SmallMoleculeComponent -) -from openfe.protocols.openmm_utils.omm_settings import ( - IntegratorSettings, - OpenMMSolvationSettings, -) + +from openfe.protocols.openmm_utils.omm_settings import IntegratorSettings, OpenMMSolvationSettings def get_system_generator( @@ -63,28 +60,28 @@ def get_system_generator( """ # get the right constraint constraints = { - 'hbonds': app.HBonds, - 'none': None, - 'allbonds': app.AllBonds, - 'hangles': app.HAngles + "hbonds": app.HBonds, + "none": None, + "allbonds": app.AllBonds, + "hangles": app.HAngles, # vvv can be None so string it }[str(forcefield_settings.constraints).lower()] # create forcefield_kwargs entry forcefield_kwargs = { - 'constraints': constraints, - 'rigidWater': forcefield_settings.rigid_water, - 'removeCMMotion': integrator_settings.remove_com, - 'hydrogenMass': forcefield_settings.hydrogen_mass * omm_unit.amu, + "constraints": constraints, + "rigidWater": forcefield_settings.rigid_water, + "removeCMMotion": integrator_settings.remove_com, + "hydrogenMass": forcefield_settings.hydrogen_mass * omm_unit.amu, } # get the right nonbonded method nonbonded_method = { - 'pme': app.PME, - 'nocutoff': app.NoCutoff, - 'cutoffnonperiodic': app.CutoffNonPeriodic, - 'cutoffperiodic': app.CutoffPeriodic, - 'ewald': app.Ewald + "pme": app.PME, + "nocutoff": app.NoCutoff, + "cutoffnonperiodic": app.CutoffNonPeriodic, + "cutoffperiodic": app.CutoffPeriodic, + "ewald": app.Ewald, }[forcefield_settings.nonbonded_method.lower()] nonbonded_cutoff = to_openmm( @@ -93,15 +90,15 @@ def get_system_generator( # create the periodic_kwarg entry periodic_kwargs = { - 'nonbondedMethod': nonbonded_method, - 'nonbondedCutoff': nonbonded_cutoff, + "nonbondedMethod": nonbonded_method, + "nonbondedCutoff": nonbonded_cutoff, } # Currently the else is a dead branch, we will want to investigate the # possibility of using CutoffNonPeriodic at some point though (for RF) if nonbonded_method is not app.CutoffNonPeriodic: nonperiodic_kwargs = { - 'nonbondedMethod': app.NoCutoff, + "nonbondedMethod": app.NoCutoff, } else: # pragma: no-cover nonperiodic_kwargs = periodic_kwargs @@ -110,8 +107,8 @@ def get_system_generator( # TODO: move this to its own place where we can handle membranes if has_solvent: barostat = MonteCarloBarostat( - ensure_quantity(thermo_settings.pressure, 'openmm'), - ensure_quantity(thermo_settings.temperature, 'openmm'), + ensure_quantity(thermo_settings.pressure, "openmm"), + ensure_quantity(thermo_settings.temperature, "openmm"), integrator_settings.barostat_frequency.m, ) else: @@ -137,8 +134,8 @@ def get_omm_modeller( protein_comp: Optional[ProteinComponent], solvent_comp: Optional[SolventComponent], small_mols: dict[SmallMoleculeComponent, OFFMol], - omm_forcefield : app.ForceField, - solvent_settings : OpenMMSolvationSettings + omm_forcefield: app.ForceField, + solvent_settings: OpenMMSolvationSettings, ) -> ModellerReturn: """ Generate an OpenMM Modeller class based on a potential input ProteinComponent, @@ -167,19 +164,13 @@ def get_omm_modeller( """ component_resids = {} - def _add_small_mol(comp, - mol, - system_modeller: app.Modeller, - comp_resids: dict[Component, npt.NDArray]): + def _add_small_mol(comp, mol, system_modeller: app.Modeller, comp_resids: dict[Component, npt.NDArray]): """ Helper method to add OFFMol to an existing Modeller object and update a dictionary tracking residue indices for each component. """ omm_top = mol.to_topology().to_openmm() - system_modeller.add( - omm_top, - ensure_quantity(mol.conformers[0], 'openmm') - ) + system_modeller.add(omm_top, ensure_quantity(mol.conformers[0], "openmm")) nres = omm_top.getNumResidues() resids = [res.index for res in system_modeller.topology.residues()] @@ -190,19 +181,16 @@ def _add_small_mol(comp, # If there's a protein in the system, we add it first to the Modeller if protein_comp is not None: - system_modeller.add(protein_comp.to_openmm_topology(), - protein_comp.to_openmm_positions()) + system_modeller.add(protein_comp.to_openmm_topology(), protein_comp.to_openmm_positions()) # add missing virtual particles (from crystal waters) system_modeller.addExtraParticles(omm_forcefield) - component_resids[protein_comp] = np.array( - [r.index for r in system_modeller.topology.residues()] - ) + component_resids[protein_comp] = np.array([r.index for r in system_modeller.topology.residues()]) # if we solvate temporarily rename water molecules to 'WAT' # see openmm issue #4103 if solvent_comp is not None: for r in system_modeller.topology.residues(): - if r.name == 'HOH': - r.name = 'WAT' + if r.name == "HOH": + r.name = "WAT" # Now loop through small mols for comp, mol in small_mols.items(): @@ -218,26 +206,20 @@ def _add_small_mol(comp, omm_forcefield, model=solvent_settings.solvent_model, padding=to_openmm(solvent_settings.solvent_padding), - positiveIon=pos, negativeIon=neg, + positiveIon=pos, + negativeIon=neg, ionicStrength=to_openmm(conc), neutralize=solvent_comp.neutralize, ) - all_resids = np.array( - [r.index for r in system_modeller.topology.residues()] - ) + all_resids = np.array([r.index for r in system_modeller.topology.residues()]) - existing_resids = np.concatenate( - [resids for resids in component_resids.values()] - ) + existing_resids = np.concatenate([resids for resids in component_resids.values()]) - component_resids[solvent_comp] = np.setdiff1d( - all_resids, existing_resids - ) + component_resids[solvent_comp] = np.setdiff1d(all_resids, existing_resids) # undo rename of pre-existing waters for r in system_modeller.topology.residues(): - if r.name == 'WAT': - r.name = 'HOH' + if r.name == "WAT": + r.name = "HOH" return system_modeller, component_resids - diff --git a/openfe/protocols/openmm_utils/system_validation.py b/openfe/protocols/openmm_utils/system_validation.py index f1710d93e..65eb4e866 100644 --- a/openfe/protocols/openmm_utils/system_validation.py +++ b/openfe/protocols/openmm_utils/system_validation.py @@ -5,16 +5,14 @@ Protocols. """ from typing import Optional, Tuple + +from gufe import ChemicalSystem, Component, ProteinComponent, SmallMoleculeComponent, SolventComponent from openff.toolkit import Molecule as OFFMol -from gufe import ( - Component, ChemicalSystem, SolventComponent, ProteinComponent, - SmallMoleculeComponent -) def get_alchemical_components( - stateA: ChemicalSystem, - stateB: ChemicalSystem, + stateA: ChemicalSystem, + stateB: ChemicalSystem, ) -> dict[str, list[Component]]: """ Checks the equality between Components of two end state ChemicalSystems @@ -37,9 +35,10 @@ def get_alchemical_components( ValueError If there are any duplicate components in states A or B. """ - matched_components: dict[Component, Component] = {} + matched_components: dict[Component, Component] = {} alchemical_components: dict[str, list[Component]] = { - 'stateA': [], 'stateB': [], + "stateA": [], + "stateB": [], } for keyA, valA in stateA.components.items(): @@ -50,19 +49,18 @@ def get_alchemical_components( else: # Could be that either we have a duplicate component # in stateA or in stateB - errmsg = (f"state A components {keyA}: {valA} matches " - "multiple components in stateA or stateB") + errmsg = f"state A components {keyA}: {valA} matches " "multiple components in stateA or stateB" raise ValueError(errmsg) # populate stateA alchemical components for valA in stateA.components.values(): if valA not in matched_components.keys(): - alchemical_components['stateA'].append(valA) + alchemical_components["stateA"].append(valA) # populate stateB alchemical components for valB in stateB.components.values(): if valB not in matched_components.values(): - alchemical_components['stateB'].append(valB) + alchemical_components["stateB"].append(valB) return alchemical_components @@ -87,14 +85,13 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): `nocutoff`. * If the SolventComponent solvent is not water. """ - solv = [comp for comp in state.values() - if isinstance(comp, SolventComponent)] + solv = [comp for comp in state.values() if isinstance(comp, SolventComponent)] if len(solv) > 0 and nonbonded_method.lower() == "nocutoff": errmsg = "nocutoff cannot be used for solvent transformations" raise ValueError(errmsg) - if len(solv) == 0 and nonbonded_method.lower() == 'pme': + if len(solv) == 0 and nonbonded_method.lower() == "pme": errmsg = "PME cannot be used for vacuum transform" raise ValueError(errmsg) @@ -102,7 +99,7 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): errmsg = "Multiple SolventComponent found, only one is supported" raise ValueError(errmsg) - if len(solv) > 0 and solv[0].smiles != 'O': + if len(solv) > 0 and solv[0].smiles != "O": errmsg = "Non water solvent is not currently supported" raise ValueError(errmsg) @@ -122,16 +119,16 @@ def validate_protein(state: ChemicalSystem): ValueError If there are multiple ProteinComponent in the ChemicalSystem. """ - nprot = sum(1 for comp in state.values() - if isinstance(comp, ProteinComponent)) + nprot = sum(1 for comp in state.values() if isinstance(comp, ProteinComponent)) if nprot > 1: errmsg = "Multiple ProteinComponent found, only one is supported" raise ValueError(errmsg) -ParseCompRet = Tuple[ - Optional[SolventComponent], Optional[ProteinComponent], +ParseCompRet = tuple[ + Optional[SolventComponent], + Optional[ProteinComponent], list[SmallMoleculeComponent], ] @@ -153,21 +150,17 @@ def get_components(state: ChemicalSystem) -> ParseCompRet: If it exists, the ProteinComponent for the state, otherwise None. small_mols : list[SmallMoleculeComponent] """ + def _get_single_comps(comp_list, comptype): - ret_comps = [comp for comp in comp_list - if isinstance(comp, comptype)] + ret_comps = [comp for comp in comp_list if isinstance(comp, comptype)] if ret_comps: return ret_comps[0] else: return None - solvent_comp: Optional[SolventComponent] = _get_single_comps( - list(state.values()), SolventComponent - ) + solvent_comp: Optional[SolventComponent] = _get_single_comps(list(state.values()), SolventComponent) - protein_comp: Optional[ProteinComponent] = _get_single_comps( - list(state.values()), ProteinComponent - ) + protein_comp: Optional[ProteinComponent] = _get_single_comps(list(state.values()), ProteinComponent) small_mols = [] for comp in state.components.values(): diff --git a/openfe/setup/__init__.py b/openfe/setup/__init__.py index 2411bb300..8f5e9dd6a 100644 --- a/openfe/setup/__init__.py +++ b/openfe/setup/__init__.py @@ -2,13 +2,16 @@ # For details, see https://github.com/OpenFreeEnergy/openfe -from .atom_mapping import (LigandAtomMapping, - LigandAtomMapper, - LomapAtomMapper, lomap_scorers, - PersesAtomMapper, perses_scorers, - KartografAtomMapper,) - from gufe import LigandNetwork -from . import ligand_network_planning -from .alchemical_network_planner import RHFEAlchemicalNetworkPlanner, RBFEAlchemicalNetworkPlanner \ No newline at end of file +from . import ligand_network_planning +from .alchemical_network_planner import RBFEAlchemicalNetworkPlanner, RHFEAlchemicalNetworkPlanner +from .atom_mapping import ( + KartografAtomMapper, + LigandAtomMapper, + LigandAtomMapping, + LomapAtomMapper, + PersesAtomMapper, + lomap_scorers, + perses_scorers, +) diff --git a/openfe/setup/alchemical_network_planner/__init__.py b/openfe/setup/alchemical_network_planner/__init__.py index 7a67e50a9..1bd15fea5 100644 --- a/openfe/setup/alchemical_network_planner/__init__.py +++ b/openfe/setup/alchemical_network_planner/__init__.py @@ -1,7 +1,4 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from .relative_alchemical_network_planner import ( - RHFEAlchemicalNetworkPlanner, - RBFEAlchemicalNetworkPlanner, -) +from .relative_alchemical_network_planner import RBFEAlchemicalNetworkPlanner, RHFEAlchemicalNetworkPlanner diff --git a/openfe/setup/alchemical_network_planner/abstract_alchemical_network_planner.py b/openfe/setup/alchemical_network_planner/abstract_alchemical_network_planner.py index dfbde4990..2d1b7fce5 100644 --- a/openfe/setup/alchemical_network_planner/abstract_alchemical_network_planner.py +++ b/openfe/setup/alchemical_network_planner/abstract_alchemical_network_planner.py @@ -2,7 +2,8 @@ # For details, see https://github.com/OpenFreeEnergy/openfe import abc -from typing import Iterable +from collections.abc import Iterable + from gufe import AlchemicalNetwork diff --git a/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py b/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py index 3dc060e91..9b4b0eb08 100644 --- a/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py +++ b/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py @@ -2,38 +2,29 @@ # For details, see https://github.com/OpenFreeEnergy/openfe import abc import copy -from typing import Iterable, Callable, Type, Optional +from collections.abc import Iterable +from typing import Callable, Optional, Type from gufe import ( - Protocol, AlchemicalNetwork, - LigandAtomMapping, - Transformation, ChemicalSystem, -) -from gufe import ( - SmallMoleculeComponent, ProteinComponent, SolventComponent, + LigandAtomMapping, LigandNetwork, + ProteinComponent, + Protocol, + SmallMoleculeComponent, + SolventComponent, + Transformation, ) - -from .abstract_alchemical_network_planner import ( - AbstractAlchemicalNetworkPlanner, -) - +from ...protocols.openmm_rfe.equil_rfe_methods import RelativeHybridTopologyProtocol from .. import LomapAtomMapper from ..atom_mapping.ligandatommapper import LigandAtomMapper from ..atom_mapping.lomap_scorers import default_lomap_score +from ..chemicalsystem_generator import EasyChemicalSystemGenerator, RFEComponentLabels +from ..chemicalsystem_generator.abstract_chemicalsystem_generator import AbstractChemicalSystemGenerator from ..ligand_network_planning import generate_minimal_spanning_network -from ..chemicalsystem_generator.abstract_chemicalsystem_generator import ( - AbstractChemicalSystemGenerator, -) -from ..chemicalsystem_generator import ( - EasyChemicalSystemGenerator, - RFEComponentLabels, -) -from ...protocols.openmm_rfe.equil_rfe_methods import RelativeHybridTopologyProtocol - +from .abstract_alchemical_network_planner import AbstractAlchemicalNetworkPlanner # TODO: move/or find better structure for protocol_generator combintations! PROTOCOL_GENERATOR = { @@ -41,16 +32,14 @@ } -class RelativeAlchemicalNetworkPlanner( - AbstractAlchemicalNetworkPlanner, abc.ABC -): +class RelativeAlchemicalNetworkPlanner(AbstractAlchemicalNetworkPlanner, abc.ABC): _chemical_system_generator: AbstractChemicalSystemGenerator def __init__( self, name: str = "easy_rfe_calculation", mappers: Optional[Iterable[LigandAtomMapper]] = None, - mapping_scorer: Callable[[LigandAtomMapping], float] = default_lomap_score, + mapping_scorer: Callable[[LigandAtomMapping], float] = default_lomap_score, ligand_network_planner: Callable = generate_minimal_spanning_network, protocol: Optional[Protocol] = None, ): @@ -74,21 +63,17 @@ def __init__( if protocol is None: protocol = RelativeHybridTopologyProtocol(RelativeHybridTopologyProtocol.default_settings()) if mappers is None: - mappers = [LomapAtomMapper(time=20, threed=True, - element_change=False, max3d=1)] + mappers = [LomapAtomMapper(time=20, threed=True, element_change=False, max3d=1)] self.name = name self._mappers = mappers self._mapping_scorer = mapping_scorer self._ligand_network_planner = ligand_network_planner self._protocol = protocol - self._chemical_system_generator_type = PROTOCOL_GENERATOR[ - protocol.__class__ - ] + self._chemical_system_generator_type = PROTOCOL_GENERATOR[protocol.__class__] @abc.abstractmethod - def __call__(self, *args, **kwargs) -> AlchemicalNetwork: - ... # -no-cov- + def __call__(self, *args, **kwargs) -> AlchemicalNetwork: ... # -no-cov- @property def mappers(self) -> Iterable[LigandAtomMapper]: @@ -109,15 +94,11 @@ def transformation_protocol(self) -> Protocol: @property def chemical_system_generator_type( self, - ) -> Type[AbstractChemicalSystemGenerator]: + ) -> type[AbstractChemicalSystemGenerator]: return self._chemical_system_generator_type - def _construct_ligand_network( - self, ligands: Iterable[SmallMoleculeComponent] - ) -> LigandNetwork: - ligand_network = self._ligand_network_planner( - ligands=ligands, mappers=self.mappers, scorer=self.mapping_scorer - ) + def _construct_ligand_network(self, ligands: Iterable[SmallMoleculeComponent]) -> LigandNetwork: + ligand_network = self._ligand_network_planner(ligands=ligands, mappers=self.mappers, scorer=self.mapping_scorer) return ligand_network @@ -162,25 +143,19 @@ def _build_transformations( end_state_nodes.extend([stateA_env, stateB_env]) # Todo: make the code here more stable in future: Name doubling check - all_transformation_labels = list( - map(lambda x: x.name, transformation_edges) - ) + all_transformation_labels = list(map(lambda x: x.name, transformation_edges)) - if len(all_transformation_labels) != len( - set(all_transformation_labels) - ): + if len(all_transformation_labels) != len(set(all_transformation_labels)): raise ValueError( "There were multiple transformations with the same edge label! This might lead to overwritting your files. \n labels: " + str(len(all_transformation_labels)) + "\nunique: " + str(len(set(all_transformation_labels))) + "\ngot: \n\t" - + "\n\t".join(all_transformation_labels) + + "\n\t".join(all_transformation_labels), ) - alchemical_network = AlchemicalNetwork( - nodes=end_state_nodes, edges=transformation_edges, name=self.name - ) + alchemical_network = AlchemicalNetwork(nodes=end_state_nodes, edges=transformation_edges, name=self.name) return alchemical_network def _build_transformation( @@ -211,9 +186,7 @@ def _build_transformation( if "vacuum" in transformation_name: protocol_settings.forcefield_settings.nonbonded_method = "nocutoff" - transformation_protocol = transformation_protocol.__class__( - settings=protocol_settings - ) + transformation_protocol = transformation_protocol.__class__(settings=protocol_settings) return Transformation( stateA=stateA, @@ -237,7 +210,7 @@ def __init__( self, name: str = "easy_rhfe", mappers: Optional[Iterable[LigandAtomMapper]] = None, - mapping_scorer: Callable[[LigandAtomMapping], float] = default_lomap_score, + mapping_scorer: Callable[[LigandAtomMapping], float] = default_lomap_score, ligand_network_planner: Callable = generate_minimal_spanning_network, protocol: Optional[Protocol] = None, ): @@ -274,7 +247,8 @@ def __call__( # Prepare system generation self._chemical_system_generator = self._chemical_system_generator_type( - solvent=solvent, do_vacuum=True, + solvent=solvent, + do_vacuum=True, ) # Build transformations @@ -295,11 +269,12 @@ class RBFEAlchemicalNetworkPlanner(RelativeAlchemicalNetworkPlanner): network planning scheme, then call it on a collection of ligands, protein, solvent, and co-factors to create the network. """ + def __init__( self, name: str = "easy_rbfe", mappers: Optional[Iterable[LigandAtomMapper]] = None, - mapping_scorer: Callable[[LigandAtomMapping], float] = default_lomap_score, + mapping_scorer: Callable[[LigandAtomMapping], float] = default_lomap_score, ligand_network_planner: Callable = generate_minimal_spanning_network, protocol: Optional[Protocol] = None, ): @@ -342,7 +317,9 @@ def __call__( # Prepare system generation self._chemical_system_generator = self._chemical_system_generator_type( - solvent=solvent, protein=protein, cofactors=cofactors, + solvent=solvent, + protein=protein, + cofactors=cofactors, ) # Build transformations diff --git a/openfe/setup/atom_mapping/__init__.py b/openfe/setup/atom_mapping/__init__.py index 4f9ce5350..e5914757a 100644 --- a/openfe/setup/atom_mapping/__init__.py +++ b/openfe/setup/atom_mapping/__init__.py @@ -1,9 +1,7 @@ from gufe import LigandAtomMapping -from .ligandatommapper import LigandAtomMapper +from kartograf import KartografAtomMapper +from . import lomap_scorers, perses_scorers +from .ligandatommapper import LigandAtomMapper from .lomap_mapper import LomapAtomMapper from .perses_mapper import PersesAtomMapper -from kartograf import KartografAtomMapper - -from . import perses_scorers -from . import lomap_scorers diff --git a/openfe/setup/atom_mapping/ligandatommapper.py b/openfe/setup/atom_mapping/ligandatommapper.py index e35dfb445..17a493b90 100644 --- a/openfe/setup/atom_mapping/ligandatommapper.py +++ b/openfe/setup/atom_mapping/ligandatommapper.py @@ -1,11 +1,12 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import abc -from typing import Iterable +from collections.abc import Iterable +import gufe from gufe import SmallMoleculeComponent + from . import LigandAtomMapping -import gufe class LigandAtomMapper(gufe.AtomMapper): @@ -15,11 +16,13 @@ class LigandAtomMapper(gufe.AtomMapper): Subclasses will typically implement the ``_mappings_generator`` method, which returns an iterable of :class:`.LigandAtomMapping` suggestions. """ + @abc.abstractmethod - def _mappings_generator(self, - componentA: SmallMoleculeComponent, - componentB: SmallMoleculeComponent - ) -> Iterable[dict[int, int]]: + def _mappings_generator( + self, + componentA: SmallMoleculeComponent, + componentB: SmallMoleculeComponent, + ) -> Iterable[dict[int, int]]: """ Suggest mapping options for the input molecules. @@ -35,8 +38,10 @@ def _mappings_generator(self, """ ... - def suggest_mappings(self, componentA: SmallMoleculeComponent, - componentB: SmallMoleculeComponent + def suggest_mappings( + self, + componentA: SmallMoleculeComponent, + componentB: SmallMoleculeComponent, ) -> Iterable[LigandAtomMapping]: """ Suggest :class:`.LigandAtomMapping` options for the input molecules. diff --git a/openfe/setup/atom_mapping/lomap_scorers.py b/openfe/setup/atom_mapping/lomap_scorers.py index eb4cc46ef..4c34cc0ff 100644 --- a/openfe/setup/atom_mapping/lomap_scorers.py +++ b/openfe/setup/atom_mapping/lomap_scorers.py @@ -1,7 +1,7 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from lomap.gufe_bindings.scorers import ( +from lomap.gufe_bindings.scorers import ( # looks like we gotta make it detailed for mypy and RTD atomic_number_score, default_lomap_score, ecr_score, @@ -13,5 +13,4 @@ tmcsr_score, transmuting_methyl_into_ring_score, transmuting_ring_sizes_score, -) # looks like we gotta make it detailed for mypy and RTD - +) diff --git a/openfe/setup/atom_mapping/perses_mapper.py b/openfe/setup/atom_mapping/perses_mapper.py index 8e246dc4d..bdebcf5be 100644 --- a/openfe/setup/atom_mapping/perses_mapper.py +++ b/openfe/setup/atom_mapping/perses_mapper.py @@ -7,16 +7,16 @@ """ from openmm import unit + from openfe.utils import requires_package from ...utils.silence_root_logging import silence_root_logging + try: with silence_root_logging(): - from perses.rjmc.atom_mapping import ( - AtomMapper, InvalidMappingException - ) + from perses.rjmc.atom_mapping import AtomMapper, InvalidMappingException except ImportError: - pass # Don't throw error, will happen later + pass # Don't throw error, will happen later from .ligandatommapper import LigandAtomMapper @@ -27,10 +27,13 @@ class PersesAtomMapper(LigandAtomMapper): use_positions: bool @requires_package("perses") - def __init__(self, allow_ring_breaking: bool = True, - preserve_chirality: bool = True, - use_positions: bool = True, - coordinate_tolerance: float = 0.25 * unit.angstrom): + def __init__( + self, + allow_ring_breaking: bool = True, + preserve_chirality: bool = True, + use_positions: bool = True, + coordinate_tolerance: float = 0.25 * unit.angstrom, + ): """ Suggest atom mappings with the Perses atom mapper. @@ -58,12 +61,15 @@ def _mappings_generator(self, componentA, componentB): _atom_mapper = AtomMapper( use_positions=self.use_positions, coordinate_tolerance=self.coordinate_tolerance, - allow_ring_breaking=self.allow_ring_breaking) + allow_ring_breaking=self.allow_ring_breaking, + ) # Try generating a mapping try: _atom_mappings = _atom_mapper.get_all_mappings( - old_mol=componentA.to_openff(), new_mol=componentB.to_openff()) + old_mol=componentA.to_openff(), + new_mol=componentB.to_openff(), + ) except InvalidMappingException: return diff --git a/openfe/setup/atom_mapping/perses_scorers.py b/openfe/setup/atom_mapping/perses_scorers.py index 341194dad..0ba93082c 100644 --- a/openfe/setup/atom_mapping/perses_scorers.py +++ b/openfe/setup/atom_mapping/perses_scorers.py @@ -6,27 +6,22 @@ from openfe.utils import requires_package from ...utils.silence_root_logging import silence_root_logging + try: with silence_root_logging(): from perses.rjmc.atom_mapping import AtomMapper, AtomMapping except ImportError: - pass # Don't throw error, will happen later + pass # Don't throw error, will happen later from . import LigandAtomMapping # Helpfer Function / reducing code amount -def _get_all_mapped_atoms_with(oeyMolA, - oeyMolB, - numMaxPossibleMappingAtoms: int, - criterium: Callable) -> int: - molA_allAtomsWith = len( - list(filter(criterium, oeyMolA.GetAtoms()))) - molB_allAtomsWith = len( - list(filter(criterium, oeyMolB.GetAtoms()))) - - if (molA_allAtomsWith > molB_allAtomsWith and - molA_allAtomsWith <= numMaxPossibleMappingAtoms): +def _get_all_mapped_atoms_with(oeyMolA, oeyMolB, numMaxPossibleMappingAtoms: int, criterium: Callable) -> int: + molA_allAtomsWith = len(list(filter(criterium, oeyMolA.GetAtoms()))) + molB_allAtomsWith = len(list(filter(criterium, oeyMolB.GetAtoms()))) + + if molA_allAtomsWith > molB_allAtomsWith and molA_allAtomsWith <= numMaxPossibleMappingAtoms: numMaxPossibleMappings = molA_allAtomsWith else: numMaxPossibleMappings = molB_allAtomsWith @@ -35,9 +30,7 @@ def _get_all_mapped_atoms_with(oeyMolA, @requires_package("perses") -def default_perses_scorer(mapping: LigandAtomMapping, - use_positions: bool = False, - normalize: bool = True) -> float: +def default_perses_scorer(mapping: LigandAtomMapping, use_positions: bool = False, normalize: bool = True) -> float: """ Score an atom mapping with the default Perses score function. @@ -66,45 +59,54 @@ def default_perses_scorer(mapping: LigandAtomMapping, float """ score = AtomMapper(use_positions=use_positions).score_mapping( - AtomMapping(old_mol=mapping.componentA.to_openff(), - new_mol=mapping.componentB.to_openff(), - old_to_new_atom_map=mapping.componentA_to_componentB)) + AtomMapping( + old_mol=mapping.componentA.to_openff(), + new_mol=mapping.componentB.to_openff(), + old_to_new_atom_map=mapping.componentA_to_componentB, + ), + ) # normalize - if (normalize): + if normalize: oeyMolA = mapping.componentA.to_openff().to_openeye() oeyMolB = mapping.componentB.to_openff().to_openeye() - if (use_positions): - raise NotImplementedError("normalizing using positions is " - "not currently implemented") + if use_positions: + raise NotImplementedError("normalizing using positions is " "not currently implemented") else: - smallerMolecule = oeyMolA if ( - oeyMolA.NumAtoms() < oeyMolB.NumAtoms()) else oeyMolB + smallerMolecule = oeyMolA if (oeyMolA.NumAtoms() < oeyMolB.NumAtoms()) else oeyMolB numMaxPossibleMappingAtoms = smallerMolecule.NumAtoms() # Max possible Aromatic mappings numMaxPossibleAromaticMappings = _get_all_mapped_atoms_with( - oeyMolA=oeyMolA, oeyMolB=oeyMolB, + oeyMolA=oeyMolA, + oeyMolB=oeyMolB, numMaxPossibleMappingAtoms=numMaxPossibleMappingAtoms, - criterium=lambda x: x.IsAromatic()) + criterium=lambda x: x.IsAromatic(), + ) # Max possible heavy mappings numMaxPossibleHeavyAtomMappings = _get_all_mapped_atoms_with( - oeyMolA=oeyMolA, oeyMolB=oeyMolB, + oeyMolA=oeyMolA, + oeyMolB=oeyMolB, numMaxPossibleMappingAtoms=numMaxPossibleMappingAtoms, - criterium=lambda x: x.GetAtomicNum() > 1) + criterium=lambda x: x.GetAtomicNum() > 1, + ) # Max possible ring mappings numMaxPossibleRingMappings = _get_all_mapped_atoms_with( - oeyMolA=oeyMolA, oeyMolB=oeyMolB, + oeyMolA=oeyMolA, + oeyMolB=oeyMolB, numMaxPossibleMappingAtoms=numMaxPossibleMappingAtoms, - criterium=lambda x: x.IsInRing()) + criterium=lambda x: x.IsInRing(), + ) # These weights are totally arbitrary - normalize_score = (1.0 * numMaxPossibleMappingAtoms + - 0.8 * numMaxPossibleAromaticMappings + - 0.5 * numMaxPossibleHeavyAtomMappings + - 0.4 * numMaxPossibleRingMappings) + normalize_score = ( + 1.0 * numMaxPossibleMappingAtoms + + 0.8 * numMaxPossibleAromaticMappings + + 0.5 * numMaxPossibleHeavyAtomMappings + + 0.4 * numMaxPossibleRingMappings + ) score /= normalize_score # final normalize score diff --git a/openfe/setup/chemicalsystem_generator/abstract_chemicalsystem_generator.py b/openfe/setup/chemicalsystem_generator/abstract_chemicalsystem_generator.py index 822ea0fbc..cd96d0c23 100644 --- a/openfe/setup/chemicalsystem_generator/abstract_chemicalsystem_generator.py +++ b/openfe/setup/chemicalsystem_generator/abstract_chemicalsystem_generator.py @@ -1,13 +1,14 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import abc +from collections.abc import Iterable from enum import Enum -from typing import Iterable from gufe import ChemicalSystem # Todo: connect to protocols - use this for labels? + class RFEComponentLabels(str, Enum): PROTEIN = "protein" LIGAND = "ligand" diff --git a/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py b/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py index 2b5907dc8..be138b682 100644 --- a/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py +++ b/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py @@ -1,19 +1,12 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from .abstract_chemicalsystem_generator import ( - AbstractChemicalSystemGenerator, - RFEComponentLabels, -) -from typing import Iterable, Optional +from collections.abc import Iterable +from typing import Optional -from gufe import ( - Component, - SmallMoleculeComponent, - ProteinComponent, - SolventComponent, - ChemicalSystem, -) +from gufe import ChemicalSystem, Component, ProteinComponent, SmallMoleculeComponent, SolventComponent + +from .abstract_chemicalsystem_generator import AbstractChemicalSystemGenerator, RFEComponentLabels class EasyChemicalSystemGenerator(AbstractChemicalSystemGenerator): @@ -65,12 +58,10 @@ def __init__( if solvent is None and protein is None and not do_vacuum: raise ValueError( - "Chemical system generator is unable to generate any chemical systems with neither protein nor solvent nor do_vacuum" + "Chemical system generator is unable to generate any chemical systems with neither protein nor solvent nor do_vacuum", ) - def __call__( - self, component: SmallMoleculeComponent - ) -> Iterable[ChemicalSystem]: + def __call__(self, component: SmallMoleculeComponent) -> Iterable[ChemicalSystem]: """Generate systems around the given :class:`SmallMoleculeComponent`. Parameters @@ -112,12 +103,10 @@ def __call__( RFEComponentLabels.PROTEIN: self.protein, } for i, c in enumerate(self.cofactors): - components.update({f'{RFEComponentLabels.COFACTOR}{i+1}': c}) + components.update({f"{RFEComponentLabels.COFACTOR}{i+1}": c}) if self.solvent is not None: components.update({RFEComponentLabels.SOLVENT: self.solvent}) - chem_sys = ChemicalSystem( - components=components, name=component.name + "_complex" - ) + chem_sys = ChemicalSystem(components=components, name=component.name + "_complex") yield chem_sys return diff --git a/openfe/setup/ligand_network_planning.py b/openfe/setup/ligand_network_planning.py index 984d496b0..c5e9de184 100644 --- a/openfe/setup/ligand_network_planning.py +++ b/openfe/setup/ligand_network_planning.py @@ -1,39 +1,41 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import math -from pathlib import Path -from typing import Iterable, Callable, Optional, Union -import itertools -from collections import Counter import functools +import itertools +import math import warnings +from collections import Counter +from collections.abc import Iterable +from pathlib import Path +from typing import Callable, Optional, Union import networkx as nx +from gufe import AtomMapper, SmallMoleculeComponent +from lomap import LomapAtomMapper, generate_lomap_network +from lomap.dbmol import _find_common_core from tqdm.auto import tqdm -from gufe import SmallMoleculeComponent, AtomMapper from openfe.setup import LigandNetwork from openfe.setup.atom_mapping import LigandAtomMapping -from lomap import generate_lomap_network, LomapAtomMapper -from lomap.dbmol import _find_common_core - def _hasten_lomap(mapper, ligands): - """take a mapper and some ligands, put a common core arg into the mapper """ + """take a mapper and some ligands, put a common core arg into the mapper""" if mapper.seed: return mapper try: - core = _find_common_core([m.to_rdkit() for m in ligands], - element_change=mapper.element_change) + core = _find_common_core([m.to_rdkit() for m in ligands], element_change=mapper.element_change) except RuntimeError: # in case MCS throws a hissy fit core = "" return LomapAtomMapper( - time=mapper.time, threed=mapper.threed, max3d=mapper.max3d, - element_change=mapper.element_change, seed=core, - shift=mapper.shift + time=mapper.time, + threed=mapper.threed, + max3d=mapper.max3d, + element_change=mapper.element_change, + seed=core, + shift=mapper.shift, ) @@ -80,24 +82,24 @@ def generate_radial_network( """ if isinstance(mappers, AtomMapper): mappers = [mappers] - mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) - else m for m in mappers] + mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) else m for m in mappers] edges = [] for ligand in ligands: if ligand == central_ligand: - wmsg = (f"The central_ligand {ligand.name} was also found in " - "the list of ligands to arrange around the " - "central_ligand this will be ignored.") + wmsg = ( + f"The central_ligand {ligand.name} was also found in " + "the list of ligands to arrange around the " + "central_ligand this will be ignored." + ) warnings.warn(wmsg) continue best_score = 0.0 best_mapping = None for mapping in itertools.chain.from_iterable( - mapper.suggest_mappings(central_ligand, ligand) - for mapper in mappers + mapper.suggest_mappings(central_ligand, ligand) for mapper in mappers ): if not scorer: best_mapping = mapping @@ -151,8 +153,7 @@ def generate_maximal_network( """ if isinstance(mappers, AtomMapper): mappers = [mappers] - mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) - else m for m in mappers] + mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) else m for m in mappers] nodes = list(ligands) @@ -161,7 +162,10 @@ def generate_maximal_network( total = len(nodes) * (len(nodes) - 1) // 2 progress = functools.partial(tqdm, total=total, delay=1.5) elif progress is False: - def progress(x): return x + + def progress(x): + return x + # otherwise, it should be a user-defined callable mapping_generator = itertools.chain.from_iterable( @@ -170,8 +174,7 @@ def progress(x): return x for mapper in mappers ) if scorer: - mappings = [mapping.with_annotations({'score': scorer(mapping)}) - for mapping in mapping_generator] + mappings = [mapping.with_annotations({"score": scorer(mapping)}) for mapping in mapping_generator] else: mappings = list(mapping_generator) @@ -205,8 +208,7 @@ def generate_minimal_spanning_network( """ if isinstance(mappers, AtomMapper): mappers = [mappers] - mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) - else m for m in mappers] + mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) else m for m in mappers] # First create a network with all the proposed mappings (scored) network = generate_maximal_network(ligands, mappers, scorer, progress) @@ -214,18 +216,17 @@ def generate_minimal_spanning_network( # Flip network scores so we can use minimal algorithm g2 = nx.MultiGraph() for e1, e2, d in network.graph.edges(data=True): - g2.add_edge(e1, e2, weight=-d['score'], object=d['object']) + g2.add_edge(e1, e2, weight=-d["score"], object=d["object"]) # Next analyze that network to create minimal spanning network. Because # we carry the original (directed) LigandAtomMapping, we don't lose # direction information when converting to an undirected graph. min_edges = nx.minimum_spanning_edges(g2) - min_mappings = [edge_data['object'] for _, _, _, edge_data in min_edges] + min_mappings = [edge_data["object"] for _, _, _, edge_data in min_edges] min_network = LigandNetwork(min_mappings) missing_nodes = set(network.nodes) - set(min_network.nodes) if missing_nodes: - raise RuntimeError("Unable to create edges to some nodes: " - f"{list(missing_nodes)}") + raise RuntimeError("Unable to create edges to some nodes: " f"{list(missing_nodes)}") return min_network @@ -265,8 +266,7 @@ def generate_minimal_redundant_network( """ if isinstance(mappers, AtomMapper): mappers = [mappers] - mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) - else m for m in mappers] + mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper) else m for m in mappers] # First create a network with all the proposed mappings (scored) network = generate_maximal_network(ligands, mappers, scorer, progress) @@ -274,7 +274,7 @@ def generate_minimal_redundant_network( # Flip network scores so we can use minimal algorithm g2 = nx.MultiGraph() for e1, e2, d in network.graph.edges(data=True): - g2.add_edge(e1, e2, weight=-d['score'], object=d['object']) + g2.add_edge(e1, e2, weight=-d["score"], object=d["object"]) # As in .generate_minimal_spanning_network(), use nx to get the minimal # network. But now also remove those edges from the fully-connected @@ -287,21 +287,20 @@ def generate_minimal_redundant_network( g2.remove_edges_from(current_best_edges) for _, _, _, edge_data in current_best_edges: - mappings.append(edge_data['object']) + mappings.append(edge_data["object"]) redund_network = LigandNetwork(mappings) missing_nodes = set(network.nodes) - set(redund_network.nodes) if missing_nodes: - raise RuntimeError("Unable to create edges to some nodes: " - f"{list(missing_nodes)}") + raise RuntimeError("Unable to create edges to some nodes: " f"{list(missing_nodes)}") return redund_network def generate_network_from_names( - ligands: list[SmallMoleculeComponent], - mapper: AtomMapper, - names: list[tuple[str, str]], + ligands: list[SmallMoleculeComponent], + mapper: AtomMapper, + names: list[tuple[str, str]], ) -> LigandNetwork: """ Generate a :class:`.LigandNetwork` by specifying edges as tuples of names. @@ -332,26 +331,24 @@ def generate_network_from_names( nm2idx = {l.name: i for i, l in enumerate(ligands)} if len(nm2idx) < len(ligands): - dupes = Counter((l.name for l in ligands)) + dupes = Counter(l.name for l in ligands) dupe_names = [k for k, v in dupes.items() if v > 1] raise ValueError(f"Duplicate names: {dupe_names}") try: ids = [(nm2idx[nm1], nm2idx[nm2]) for nm1, nm2 in names] except KeyError: - badnames = [nm for nm in itertools.chain.from_iterable(names) - if nm not in nm2idx] + badnames = [nm for nm in itertools.chain.from_iterable(names) if nm not in nm2idx] available = [ligand.name for ligand in ligands] - raise KeyError(f"Invalid name(s) requested {badnames}. " - f"Available: {available}") + raise KeyError(f"Invalid name(s) requested {badnames}. " f"Available: {available}") return generate_network_from_indices(ligands, mapper, ids) def generate_network_from_indices( - ligands: list[SmallMoleculeComponent], - mapper: AtomMapper, - indices: list[tuple[int, int]], + ligands: list[SmallMoleculeComponent], + mapper: AtomMapper, + indices: list[tuple[int, int]], ) -> LigandNetwork: """ Generate a :class:`.LigandNetwork` by specifying edges as tuples of indices. @@ -382,8 +379,7 @@ def generate_network_from_indices( try: m1, m2 = ligands[i], ligands[j] except IndexError: - raise IndexError(f"Invalid ligand id, requested {i} {j} " - f"with {len(ligands)} available") + raise IndexError(f"Invalid ligand id, requested {i} {j} " f"with {len(ligands)} available") mapping = next(mapper.suggest_mappings(m1, m2)) @@ -398,9 +394,9 @@ def generate_network_from_indices( def load_orion_network( - ligands: list[SmallMoleculeComponent], - mapper: AtomMapper, - network_file: Union[str, Path], + ligands: list[SmallMoleculeComponent], + mapper: AtomMapper, + network_file: Union[str, Path], ) -> LigandNetwork: """Load a :class:`.LigandNetwork` from an Orion NES network file. @@ -423,15 +419,13 @@ def load_orion_network( If an unexpected line format is encountered. """ - with open(network_file, 'r') as f: - network_lines = [l.strip().split(' ') for l in f - if not l.startswith('#')] + with open(network_file) as f: + network_lines = [l.strip().split(" ") for l in f if not l.startswith("#")] names = [] for entry in network_lines: if len(entry) != 3 or entry[1] != ">>": - errmsg = ("line does not match expected name >> name format: " - f"{entry}") + errmsg = "line does not match expected name >> name format: " f"{entry}" raise KeyError(errmsg) names.append((entry[0], entry[2])) @@ -440,9 +434,9 @@ def load_orion_network( def load_fepplus_network( - ligands: list[SmallMoleculeComponent], - mapper: AtomMapper, - network_file: Union[str, Path], + ligands: list[SmallMoleculeComponent], + mapper: AtomMapper, + network_file: Union[str, Path], ) -> LigandNetwork: """Load a :class:`.LigandNetwork` from an FEP+ edges network file. @@ -465,15 +459,13 @@ def load_fepplus_network( If an unexpected line format is encountered. """ - with open(network_file, 'r') as f: + with open(network_file) as f: network_lines = [l.split() for l in f.readlines()] names = [] for entry in network_lines: - if len(entry) != 5 or entry[1] != '#' or entry[3] != '->': - errmsg = ("line does not match expected format " - f"hash:hash # name -> name\n" - "line format: {entry}") + if len(entry) != 5 or entry[1] != "#" or entry[3] != "->": + errmsg = "line does not match expected format " f"hash:hash # name -> name\n" "line format: {entry}" raise KeyError(errmsg) names.append((entry[2], entry[4])) diff --git a/openfe/storage/metadatastore.py b/openfe/storage/metadatastore.py index 60de6b4c2..5916777dd 100644 --- a/openfe/storage/metadatastore.py +++ b/openfe/storage/metadatastore.py @@ -1,17 +1,13 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe -import json import abc import collections +import json +from typing import Dict, Tuple -from typing import Tuple, Dict - +from gufe.storage.errors import ChangedExternalResourceError, MissingExternalResourceError from gufe.storage.externalresource.base import Metadata -from gufe.storage.errors import ( - MissingExternalResourceError, ChangedExternalResourceError -) - class MetadataStore(collections.abc.Mapping): def __init__(self, external_store): @@ -23,7 +19,7 @@ def store_metadata(self, location: str, metadata: Metadata): raise NotImplementedError() @abc.abstractmethod - def load_all_metadata(self) -> Dict[str, Metadata]: + def load_all_metadata(self) -> dict[str, Metadata]: raise NotImplementedError() @abc.abstractmethod @@ -45,24 +41,22 @@ class JSONMetadataStore(MetadataStore): # require any external dependencies. It is NOT the right way to go in # the long term. API will probably stay the same, though. def _dump_file(self): - metadata_dict = {key: val.to_dict() - for key, val in self._metadata_cache.items()} - metadata_bytes = json.dumps(metadata_dict).encode('utf-8') - self.external_store.store_bytes('metadata.json', metadata_bytes) + metadata_dict = {key: val.to_dict() for key, val in self._metadata_cache.items()} + metadata_bytes = json.dumps(metadata_dict).encode("utf-8") + self.external_store.store_bytes("metadata.json", metadata_bytes) def store_metadata(self, location: str, metadata: Metadata): self._metadata_cache[location] = metadata self._dump_file() def load_all_metadata(self): - if not self.external_store.exists('metadata.json'): + if not self.external_store.exists("metadata.json"): return {} - with self.external_store.load_stream('metadata.json') as json_f: - all_metadata_dict = json.loads(json_f.read().decode('utf-8')) + with self.external_store.load_stream("metadata.json") as json_f: + all_metadata_dict = json.loads(json_f.read().decode("utf-8")) - all_metadata = {key: Metadata(**val) - for key, val in all_metadata_dict.items()} + all_metadata = {key: Metadata(**val) for key, val in all_metadata_dict.items()} return all_metadata @@ -81,10 +75,10 @@ def store_metadata(self, location: str, metadata: Metadata): self._metadata_cache[location] = metadata path = self._metadata_path(location) dct = { - 'path': location, - 'metadata': metadata.to_dict(), + "path": location, + "metadata": metadata.to_dict(), } - metadata_bytes = json.dumps(dct).encode('utf-8') + metadata_bytes = json.dumps(dct).encode("utf-8") self.external_store.store_bytes(path, metadata_bytes) def load_all_metadata(self): @@ -93,12 +87,11 @@ def load_all_metadata(self): for location in self.external_store.iter_contents(prefix=prefix): if location.endswith(".json"): with self.external_store.load_stream(location) as f: - dct = json.loads(f.read().decode('utf-8')) + dct = json.loads(f.read().decode("utf-8")) if set(dct) != {"path", "metadata"}: - raise ChangedExternalResourceError("Bad metadata file: " - f"'{location}'") - metadata_cache[dct['path']] = Metadata(**dct['metadata']) + raise ChangedExternalResourceError("Bad metadata file: " f"'{location}'") + metadata_cache[dct["path"]] = Metadata(**dct["metadata"]) return metadata_cache diff --git a/openfe/storage/resultclient.py b/openfe/storage/resultclient.py index 012c12f66..51001b1ae 100644 --- a/openfe/storage/resultclient.py +++ b/openfe/storage/resultclient.py @@ -5,33 +5,26 @@ import re from typing import Any -from .resultserver import ResultServer -from .metadatastore import JSONMetadataStore - -from gufe.tokenization import ( - get_all_gufe_objs, key_decode_dependencies, from_dict, JSON_HANDLER, -) +from gufe.tokenization import JSON_HANDLER, from_dict, get_all_gufe_objs, key_decode_dependencies +from .metadatastore import JSONMetadataStore +from .resultserver import ResultServer -GUFEKEY_JSON_REGEX = re.compile( - '":gufe-key:": "(?P[A-Za-z0-9_]+-[0-9a-f]+)"' -) +GUFEKEY_JSON_REGEX = re.compile('":gufe-key:": "(?P[A-Za-z0-9_]+-[0-9a-f]+)"') class _ResultContainer(abc.ABC): """ Abstract class, represents all data under some level of the heirarchy. """ + def __init__(self, parent, path_component): self.parent = parent self._path_component = self._to_path_component(path_component) self._cache = {} def __eq__(self, other): - return ( - isinstance(other, self.__class__) - and self.path == other.path - ) + return isinstance(other, self.__class__) and self.path == other.path @staticmethod def _to_path_component(item: Any) -> str: @@ -122,8 +115,8 @@ def _gufe_key_to_storage_key(prefix: str, key: str): storage key (string identifier used by storage to locate this object) """ - pref = prefix.split('/') # remove this if we switch to tuples - cls, token = key.split('-') + pref = prefix.split("/") # remove this if we switch to tuples + cls, token = key.split("-") tup = tuple(list(pref) + [cls, f"{token}.json"]) # right now we're using strings, but we've talked about switching # that to tuples @@ -137,9 +130,7 @@ def _store_gufe_tokenizable(self, prefix, obj): # we trust that if we get the same key, it's the same object, so # we only store on keys that we don't already know if key not in self.result_server: - data = json.dumps(o.to_keyed_dict(), - cls=JSON_HANDLER.encoder, - sort_keys=True).encode('utf-8') + data = json.dumps(o.to_keyed_dict(), cls=JSON_HANDLER.encoder, sort_keys=True).encode("utf-8") self.result_server.store_bytes(key, data) def store_transformation(self, transformation): @@ -174,7 +165,7 @@ def recursive_build_object_cache(gufe_key): # (they are cached on creation). storage_key = self._gufe_key_to_storage_key(prefix, gufe_key) with self.load_stream(storage_key) as f: - keyencoded_json = f.read().decode('utf-8') + keyencoded_json = f.read().decode("utf-8") dct = json.loads(keyencoded_json, cls=JSON_HANDLER.decoder) # this implementation may seem strange, but it will be a @@ -243,7 +234,7 @@ def _load_next_level(self, transformation): # the recursive chain @property def path(self): - return 'transformations' + return "transformations" @property def result_server(self): diff --git a/openfe/storage/resultserver.py b/openfe/storage/resultserver.py index da217d95f..80f80426b 100644 --- a/openfe/storage/resultserver.py +++ b/openfe/storage/resultserver.py @@ -1,12 +1,9 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe import warnings - from typing import ClassVar -from gufe.storage.errors import ( - MissingExternalResourceError, ChangedExternalResourceError -) +from gufe.storage.errors import ChangedExternalResourceError, MissingExternalResourceError class ResultServer: @@ -15,6 +12,7 @@ class ResultServer: At this level, we provide an abstraction where client code no longer needs to be aware of the nature of the metadata, or even that it exists. """ + def __init__(self, external_store, metadata_store): self.external_store = external_store self.metadata_store = metadata_store @@ -39,17 +37,12 @@ def validate(self, location, allow_changed=False): try: metadata = self.metadata_store[location] except KeyError: - raise MissingExternalResourceError(f"Metadata for '{location}' " - "not found") + raise MissingExternalResourceError(f"Metadata for '{location}' " "not found") if not self.external_store.get_metadata(location) == metadata: - msg = (f"Metadata mismatch for {location}: this object " - "may have changed.") + msg = f"Metadata mismatch for {location}: this object " "may have changed." if not allow_changed: - raise ChangedExternalResourceError( - msg + " To allow this, set ExternalStorage." - "allow_changed = True" - ) + raise ChangedExternalResourceError(msg + " To allow this, set ExternalStorage." "allow_changed = True") else: warnings.warn(msg) diff --git a/openfe/tests/conftest.py b/openfe/tests/conftest.py index 51cfb598b..38268989d 100644 --- a/openfe/tests/conftest.py +++ b/openfe/tests/conftest.py @@ -1,16 +1,17 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import os import importlib -import pytest +import os from importlib import resources + +import gufe +import pytest +from gufe import LigandAtomMapping, SmallMoleculeComponent +from openff.units import unit from rdkit import Chem from rdkit.Chem import AllChem -from openff.units import unit -import gufe import openfe -from gufe import SmallMoleculeComponent, LigandAtomMapping class SlowTests: @@ -46,13 +47,13 @@ class SlowTests: To run the `slow` tests, either use the `--runslow` flag when invoking pytest, or set the environment variable `OFE_SLOW_TESTS` to `true` """ + def __init__(self, config): self.config = config @staticmethod def _modify_slow(items, config): - msg = ("need --runslow pytest cli option or the environment variable " - "`OFE_SLOW_TESTS` set to `True` to run") + msg = "need --runslow pytest cli option or the environment variable " "`OFE_SLOW_TESTS` set to `True` to run" skip_slow = pytest.mark.skip(reason=msg) for item in items: if "slow" in item.keywords: @@ -60,19 +61,19 @@ def _modify_slow(items, config): @staticmethod def _modify_integration(items, config): - msg = ("need --integration pytest cli option or the environment " - "variable `OFE_INTEGRATION_TESTS` set to `True` to run") + msg = ( + "need --integration pytest cli option or the environment " + "variable `OFE_INTEGRATION_TESTS` set to `True` to run" + ) skip_int = pytest.mark.skip(reason=msg) for item in items: if "integration" in item.keywords: item.add_marker(skip_int) def pytest_collection_modifyitems(self, items, config): - if (config.getoption('--integration') or - os.getenv("OFE_INTEGRATION_TESTS", default="false").lower() == 'true'): + if config.getoption("--integration") or os.getenv("OFE_INTEGRATION_TESTS", default="false").lower() == "true": return - elif (config.getoption('--runslow') or - os.getenv("OFE_SLOW_TESTS", default="false").lower() == 'true'): + elif config.getoption("--runslow") or os.getenv("OFE_SLOW_TESTS", default="false").lower() == "true": self._modify_integration(items, config) else: self._modify_integration(items, config) @@ -82,11 +83,11 @@ def pytest_collection_modifyitems(self, items, config): # allow for optional slow tests # See: https://docs.pytest.org/en/latest/example/simple.html def pytest_addoption(parser): + parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") parser.addoption( - "--runslow", action="store_true", default=False, help="run slow tests" - ) - parser.addoption( - "--integration", action="store_true", default=False, + "--integration", + action="store_true", + default=False, help="run long integration tests", ) @@ -94,8 +95,7 @@ def pytest_addoption(parser): def pytest_configure(config): config.pluginmanager.register(SlowTests(config), "slow") config.addinivalue_line("markers", "slow: mark test as slow") - config.addinivalue_line( - "markers", "integration: mark test as long integration test") + config.addinivalue_line("markers", "integration: mark test as long integration test") def mol_from_smiles(smiles: str) -> Chem.Mol: @@ -105,12 +105,12 @@ def mol_from_smiles(smiles: str) -> Chem.Mol: return m -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def ethane(): - return SmallMoleculeComponent(mol_from_smiles('CC')) + return SmallMoleculeComponent(mol_from_smiles("CC")) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def simple_mapping(): """Disappearing oxygen on end @@ -118,15 +118,15 @@ def simple_mapping(): C C """ - molA = SmallMoleculeComponent(mol_from_smiles('CCO')) - molB = SmallMoleculeComponent(mol_from_smiles('CC')) + molA = SmallMoleculeComponent(mol_from_smiles("CCO")) + molB = SmallMoleculeComponent(mol_from_smiles("CC")) m = LigandAtomMapping(molA, molB, componentA_to_componentB={0: 0, 1: 1}) return m -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def other_mapping(): """Disappearing middle carbon @@ -134,8 +134,8 @@ def other_mapping(): C C """ - molA = SmallMoleculeComponent(mol_from_smiles('CCO')) - molB = SmallMoleculeComponent(mol_from_smiles('CC')) + molA = SmallMoleculeComponent(mol_from_smiles("CCO")) + molB = SmallMoleculeComponent(mol_from_smiles("CC")) m = LigandAtomMapping(molA, molB, componentA_to_componentB={0: 0, 2: 1}) @@ -145,50 +145,51 @@ def other_mapping(): @pytest.fixture() def lomap_basic_test_files_dir(tmpdir_factory): # for lomap, which wants the files in a directory - lomap_files = tmpdir_factory.mktemp('lomap_files') - lomap_basic = 'openfe.tests.data.lomap_basic' + lomap_files = tmpdir_factory.mktemp("lomap_files") + lomap_basic = "openfe.tests.data.lomap_basic" for f in importlib.resources.contents(lomap_basic): - if not f.endswith('mol2'): + if not f.endswith("mol2"): continue stuff = importlib.resources.read_binary(lomap_basic, f) - with open(str(lomap_files.join(f)), 'wb') as fout: + with open(str(lomap_files.join(f)), "wb") as fout: fout.write(stuff) yield str(lomap_files) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def atom_mapping_basic_test_files(): # a dict of {filenames.strip(mol2): SmallMoleculeComponent} for a simple # set of ligands files = {} for f in [ - '1,3,7-trimethylnaphthalene', - '1-butyl-4-methylbenzene', - '2,6-dimethylnaphthalene', - '2-methyl-6-propylnaphthalene', - '2-methylnaphthalene', - '2-naftanol', - 'methylcyclohexane', - 'toluene']: - with importlib.resources.files('openfe.tests.data.lomap_basic') as d: - fn = str(d / (f + '.mol2')) + "1,3,7-trimethylnaphthalene", + "1-butyl-4-methylbenzene", + "2,6-dimethylnaphthalene", + "2-methyl-6-propylnaphthalene", + "2-methylnaphthalene", + "2-naftanol", + "methylcyclohexane", + "toluene", + ]: + with importlib.resources.files("openfe.tests.data.lomap_basic") as d: + fn = str(d / (f + ".mol2")) mol = Chem.MolFromMol2File(fn, removeHs=False) files[f] = SmallMoleculeComponent(mol, name=f) return files -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_modifications(): files = {} - with importlib.resources.files('openfe.tests.data') as d: - fn = str(d / 'benzene_modifications.sdf') + with importlib.resources.files("openfe.tests.data") as d: + fn = str(d / "benzene_modifications.sdf") supp = Chem.SDMolSupplier(str(fn), removeHs=False) for rdmol in supp: - files[rdmol.GetProp('_Name')] = SmallMoleculeComponent(rdmol) + files[rdmol.GetProp("_Name")] = SmallMoleculeComponent(rdmol) return files @@ -197,27 +198,27 @@ def serialization_template(): def inner(filename): loc = "openfe.tests.data.serialization" tmpl = importlib.resources.read_text(loc, filename) - return tmpl.replace('{OFE_VERSION}', openfe.__version__) + return tmpl.replace("{OFE_VERSION}", openfe.__version__) return inner -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_transforms(): # a dict of Molecules for benzene transformations mols = {} - with resources.files('openfe.tests.data') as d: - fn = str(d / 'benzene_modifications.sdf') + with resources.files("openfe.tests.data") as d: + fn = str(d / "benzene_modifications.sdf") supplier = Chem.SDMolSupplier(fn, removeHs=False) for mol in supplier: - mols[mol.GetProp('_Name')] = SmallMoleculeComponent(mol) + mols[mol.GetProp("_Name")] = SmallMoleculeComponent(mol) return mols -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def T4_protein_component(): - with resources.files('openfe.tests.data') as d: - fn = str(d / '181l_only.pdb') + with resources.files("openfe.tests.data") as d: + fn = str(d / "181l_only.pdb") comp = gufe.ProteinComponent.from_pdb_file(fn, name="T4_protein") return comp @@ -225,20 +226,20 @@ def T4_protein_component(): @pytest.fixture() def eg5_protein_pdb(): - with resources.files('openfe.tests.data.eg5') as d: - yield str(d / 'eg5_protein.pdb') + with resources.files("openfe.tests.data.eg5") as d: + yield str(d / "eg5_protein.pdb") @pytest.fixture() def eg5_ligands_sdf(): - with resources.files('openfe.tests.data.eg5') as d: - yield str(d / 'eg5_ligands.sdf') + with resources.files("openfe.tests.data.eg5") as d: + yield str(d / "eg5_ligands.sdf") @pytest.fixture() def eg5_cofactor_sdf(): - with resources.files('openfe.tests.data.eg5') as d: - yield str(d / 'eg5_cofactor.sdf') + with resources.files("openfe.tests.data.eg5") as d: + yield str(d / "eg5_cofactor.sdf") @pytest.fixture() @@ -248,8 +249,7 @@ def eg5_protein(eg5_protein_pdb) -> openfe.ProteinComponent: @pytest.fixture() def eg5_ligands(eg5_ligands_sdf) -> list[SmallMoleculeComponent]: - return [SmallMoleculeComponent(m) - for m in Chem.SDMolSupplier(eg5_ligands_sdf, removeHs=False)] + return [SmallMoleculeComponent(m) for m in Chem.SDMolSupplier(eg5_ligands_sdf, removeHs=False)] @pytest.fixture() @@ -259,14 +259,14 @@ def eg5_cofactor(eg5_cofactor_sdf) -> SmallMoleculeComponent: @pytest.fixture() def orion_network(): - with resources.files('openfe.tests.data.external_formats') as d: - yield str(d / 'somebenzenes_nes.dat') + with resources.files("openfe.tests.data.external_formats") as d: + yield str(d / "somebenzenes_nes.dat") @pytest.fixture() def fepplus_network(): - with resources.files('openfe.tests.data.external_formats') as d: - yield str(d / 'somebenzenes_edges.edge') + with resources.files("openfe.tests.data.external_formats") as d: + yield str(d / "somebenzenes_edges.edge") @pytest.fixture() @@ -274,8 +274,8 @@ def CN_molecule(): """ A basic CH3NH2 molecule for quick testing. """ - with resources.files('openfe.tests.data') as d: - fn = str(d / 'CN.sdf') + with resources.files("openfe.tests.data") as d: + fn = str(d / "CN.sdf") supp = Chem.SDMolSupplier(str(fn), removeHs=False) smc = [SmallMoleculeComponent(i) for i in supp][0] @@ -283,24 +283,12 @@ def CN_molecule(): return smc -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def am1bcc_ref_charges(): ref_chgs = { - 'ambertools': [ - 0.146957, -0.918943, 0.025557, 0.025557, - 0.025557, 0.347657, 0.347657 - ] * unit.elementary_charge, - 'openeye': [ - 0.14713, -0.92016, 0.02595, 0.02595, - 0.02595, 0.34759, 0.34759 - ] * unit.elementary_charge, - 'nagl': [ - 0.170413, -0.930417, 0.021593, 0.021593, - 0.021593, 0.347612, 0.347612 - ] * unit.elementary_charge, - 'espaloma': [ - 0.017702, -0.966793, 0.063076, 0.063076, - 0.063076, 0.379931, 0.379931 - ] * unit.elementary_charge, + "ambertools": [0.146957, -0.918943, 0.025557, 0.025557, 0.025557, 0.347657, 0.347657] * unit.elementary_charge, + "openeye": [0.14713, -0.92016, 0.02595, 0.02595, 0.02595, 0.34759, 0.34759] * unit.elementary_charge, + "nagl": [0.170413, -0.930417, 0.021593, 0.021593, 0.021593, 0.347612, 0.347612] * unit.elementary_charge, + "espaloma": [0.017702, -0.966793, 0.063076, 0.063076, 0.063076, 0.379931, 0.379931] * unit.elementary_charge, } return ref_chgs diff --git a/openfe/tests/data/external_formats/__init__.py b/openfe/tests/data/external_formats/__init__.py index 8b1378917..e69de29bb 100644 --- a/openfe/tests/data/external_formats/__init__.py +++ b/openfe/tests/data/external_formats/__init__.py @@ -1 +0,0 @@ - diff --git a/openfe/tests/data/openmm_afe/__init__.py b/openfe/tests/data/openmm_afe/__init__.py index 8b1378917..e69de29bb 100644 --- a/openfe/tests/data/openmm_afe/__init__.py +++ b/openfe/tests/data/openmm_afe/__init__.py @@ -1 +0,0 @@ - diff --git a/openfe/tests/data/openmm_md/__init__.py b/openfe/tests/data/openmm_md/__init__.py index 8b1378917..e69de29bb 100644 --- a/openfe/tests/data/openmm_md/__init__.py +++ b/openfe/tests/data/openmm_md/__init__.py @@ -1 +0,0 @@ - diff --git a/openfe/tests/data/openmm_rfe/reference.xml b/openfe/tests/data/openmm_rfe/reference.xml index 779614ca4..f3480d2f6 100644 --- a/openfe/tests/data/openmm_rfe/reference.xml +++ b/openfe/tests/data/openmm_rfe/reference.xml @@ -1551,4 +1551,4 @@ - \ No newline at end of file + diff --git a/openfe/tests/dev/serialization_test_templates.py b/openfe/tests/dev/serialization_test_templates.py index dd0c50b5b..c0cdd3341 100644 --- a/openfe/tests/dev/serialization_test_templates.py +++ b/openfe/tests/dev/serialization_test_templates.py @@ -12,7 +12,8 @@ from rdkit import Chem from rdkit.Chem import AllChem -from openfe import SmallMoleculeComponent, LigandNetwork, LigandAtomMapping + +from openfe import LigandAtomMapping, LigandNetwork, SmallMoleculeComponent # multi_molecule.sdf mol1 = Chem.MolFromSmiles("CCO") diff --git a/openfe/tests/protocols/conftest.py b/openfe/tests/protocols/conftest.py index 743352215..0f0a2d51e 100644 --- a/openfe/tests/protocols/conftest.py +++ b/openfe/tests/protocols/conftest.py @@ -1,29 +1,34 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import gzip -import pytest from importlib import resources + +import pooch +import pytest +from openff.units import unit from rdkit import Chem from rdkit.Geometry import Point3D + import openfe -from openff.units import unit -import pooch @pytest.fixture def benzene_vacuum_system(benzene_modifications): return openfe.ChemicalSystem( - {'ligand': benzene_modifications['benzene']}, + {"ligand": benzene_modifications["benzene"]}, ) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_system(benzene_modifications): return openfe.ChemicalSystem( - {'ligand': benzene_modifications['benzene'], - 'solvent': openfe.SolventComponent( - positive_ion='Na', negative_ion='Cl', - ion_concentration=0.15 * unit.molar) + { + "ligand": benzene_modifications["benzene"], + "solvent": openfe.SolventComponent( + positive_ion="Na", + negative_ion="Cl", + ion_concentration=0.15 * unit.molar, + ), }, ) @@ -31,28 +36,35 @@ def benzene_system(benzene_modifications): @pytest.fixture def benzene_complex_system(benzene_modifications, T4_protein_component): return openfe.ChemicalSystem( - {'ligand': benzene_modifications['benzene'], - 'solvent': openfe.SolventComponent( - positive_ion='Na', negative_ion='Cl', - ion_concentration=0.15 * unit.molar), - 'protein': T4_protein_component,} + { + "ligand": benzene_modifications["benzene"], + "solvent": openfe.SolventComponent( + positive_ion="Na", + negative_ion="Cl", + ion_concentration=0.15 * unit.molar, + ), + "protein": T4_protein_component, + }, ) @pytest.fixture def toluene_vacuum_system(benzene_modifications): return openfe.ChemicalSystem( - {'ligand': benzene_modifications['toluene']}, + {"ligand": benzene_modifications["toluene"]}, ) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def toluene_system(benzene_modifications): return openfe.ChemicalSystem( - {'ligand': benzene_modifications['toluene'], - 'solvent': openfe.SolventComponent( - positive_ion='Na', negative_ion='Cl', - ion_concentration=0.15 * unit.molar), + { + "ligand": benzene_modifications["toluene"], + "solvent": openfe.SolventComponent( + positive_ion="Na", + negative_ion="Cl", + ion_concentration=0.15 * unit.molar, + ), }, ) @@ -60,20 +72,24 @@ def toluene_system(benzene_modifications): @pytest.fixture def toluene_complex_system(benzene_modifications, T4_protein_component): return openfe.ChemicalSystem( - {'ligand': benzene_modifications['toluene'], - 'solvent': openfe.SolventComponent( - positive_ion='Na', negative_ion='Cl', - ion_concentration=0.15 * unit.molar), - 'protein': T4_protein_component,} + { + "ligand": benzene_modifications["toluene"], + "solvent": openfe.SolventComponent( + positive_ion="Na", + negative_ion="Cl", + ion_concentration=0.15 * unit.molar, + ), + "protein": T4_protein_component, + }, ) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_to_toluene_mapping(benzene_modifications): mapper = openfe.setup.LomapAtomMapper(element_change=False) - molA = benzene_modifications['benzene'] - molB = benzene_modifications['toluene'] + molA = benzene_modifications["benzene"] + molB = benzene_modifications["toluene"] return next(mapper.suggest_mappings(molA, molB)) @@ -81,116 +97,112 @@ def benzene_to_toluene_mapping(benzene_modifications): @pytest.fixture def benzene_charges(): files = {} - with resources.files('openfe.tests.data.openmm_rfe') as d: - fn = str(d / 'charged_benzenes.sdf') + with resources.files("openfe.tests.data.openmm_rfe") as d: + fn = str(d / "charged_benzenes.sdf") supp = Chem.SDMolSupplier(str(fn), removeHs=False) for rdmol in supp: - files[rdmol.GetProp('_Name')] = openfe.SmallMoleculeComponent(rdmol) + files[rdmol.GetProp("_Name")] = openfe.SmallMoleculeComponent(rdmol) return files @pytest.fixture def benzene_to_benzoic_mapping(benzene_charges): mapper = openfe.setup.LomapAtomMapper(element_change=False) - molA = benzene_charges['benzene'] - molB = benzene_charges['benzoic_acid'] + molA = benzene_charges["benzene"] + molB = benzene_charges["benzoic_acid"] return next(mapper.suggest_mappings(molA, molB)) @pytest.fixture def benzoic_to_benzene_mapping(benzene_charges): mapper = openfe.setup.LomapAtomMapper(element_change=False) - molA = benzene_charges['benzoic_acid'] - molB = benzene_charges['benzene'] + molA = benzene_charges["benzoic_acid"] + molB = benzene_charges["benzene"] return next(mapper.suggest_mappings(molA, molB)) @pytest.fixture def benzene_to_aniline_mapping(benzene_charges): mapper = openfe.setup.LomapAtomMapper(element_change=False) - molA = benzene_charges['benzene'] - molB = benzene_charges['aniline'] + molA = benzene_charges["benzene"] + molB = benzene_charges["aniline"] return next(mapper.suggest_mappings(molA, molB)) @pytest.fixture def aniline_to_benzene_mapping(benzene_charges): mapper = openfe.setup.LomapAtomMapper(element_change=False) - molA = benzene_charges['aniline'] - molB = benzene_charges['benzene'] + molA = benzene_charges["aniline"] + molB = benzene_charges["benzene"] return next(mapper.suggest_mappings(molA, molB)) @pytest.fixture def aniline_to_benzoic_mapping(benzene_charges): mapper = openfe.setup.LomapAtomMapper(element_change=False) - molA = benzene_charges['aniline'] - molB = benzene_charges['benzoic_acid'] + molA = benzene_charges["aniline"] + molB = benzene_charges["benzoic_acid"] return next(mapper.suggest_mappings(molA, molB)) @pytest.fixture def benzene_many_solv_system(benzene_modifications): - rdmol_phenol = benzene_modifications['phenol'].to_rdkit() - rdmol_benzo = benzene_modifications['benzonitrile'].to_rdkit() + rdmol_phenol = benzene_modifications["phenol"].to_rdkit() + rdmol_benzo = benzene_modifications["benzonitrile"].to_rdkit() conf_phenol = rdmol_phenol.GetConformer() conf_benzo = rdmol_benzo.GetConformer() for atm in range(rdmol_phenol.GetNumAtoms()): x, y, z = conf_phenol.GetAtomPosition(atm) - conf_phenol.SetAtomPosition(atm, Point3D(x+30, y, z)) + conf_phenol.SetAtomPosition(atm, Point3D(x + 30, y, z)) for atm in range(rdmol_benzo.GetNumAtoms()): x, y, z = conf_benzo.GetAtomPosition(atm) - conf_benzo.SetAtomPosition(atm, Point3D(x, y+30, z)) + conf_benzo.SetAtomPosition(atm, Point3D(x, y + 30, z)) - phenol = openfe.SmallMoleculeComponent.from_rdkit( - rdmol_phenol, name='phenol' - ) + phenol = openfe.SmallMoleculeComponent.from_rdkit(rdmol_phenol, name="phenol") - benzo = openfe.SmallMoleculeComponent.from_rdkit( - rdmol_benzo, name='benzonitrile' - ) + benzo = openfe.SmallMoleculeComponent.from_rdkit(rdmol_benzo, name="benzonitrile") return openfe.ChemicalSystem( - {'whatligand': benzene_modifications['benzene'], - "foo": phenol, - "bar": benzo, - "solvent": openfe.SolventComponent()}, + { + "whatligand": benzene_modifications["benzene"], + "foo": phenol, + "bar": benzo, + "solvent": openfe.SolventComponent(), + }, ) @pytest.fixture def toluene_many_solv_system(benzene_modifications): - rdmol_phenol = benzene_modifications['phenol'].to_rdkit() - rdmol_benzo = benzene_modifications['benzonitrile'].to_rdkit() + rdmol_phenol = benzene_modifications["phenol"].to_rdkit() + rdmol_benzo = benzene_modifications["benzonitrile"].to_rdkit() conf_phenol = rdmol_phenol.GetConformer() conf_benzo = rdmol_benzo.GetConformer() for atm in range(rdmol_phenol.GetNumAtoms()): x, y, z = conf_phenol.GetAtomPosition(atm) - conf_phenol.SetAtomPosition(atm, Point3D(x+30, y, z)) + conf_phenol.SetAtomPosition(atm, Point3D(x + 30, y, z)) for atm in range(rdmol_benzo.GetNumAtoms()): x, y, z = conf_benzo.GetAtomPosition(atm) - conf_benzo.SetAtomPosition(atm, Point3D(x, y+30, z)) + conf_benzo.SetAtomPosition(atm, Point3D(x, y + 30, z)) - phenol = openfe.SmallMoleculeComponent.from_rdkit( - rdmol_phenol, name='phenol' - ) + phenol = openfe.SmallMoleculeComponent.from_rdkit(rdmol_phenol, name="phenol") - benzo = openfe.SmallMoleculeComponent.from_rdkit( - rdmol_benzo, name='benzonitrile' - ) + benzo = openfe.SmallMoleculeComponent.from_rdkit(rdmol_benzo, name="benzonitrile") return openfe.ChemicalSystem( - {'whatligand': benzene_modifications['toluene'], - "foo": phenol, - "bar": benzo, - "solvent": openfe.SolventComponent()}, + { + "whatligand": benzene_modifications["toluene"], + "foo": phenol, + "bar": benzo, + "solvent": openfe.SolventComponent(), + }, ) @@ -200,9 +212,9 @@ def rfe_transformation_json() -> str: generated with gen-serialized-results.py """ - d = resources.files('openfe.tests.data.openmm_rfe') + d = resources.files("openfe.tests.data.openmm_rfe") - with gzip.open((d / 'RHFEProtocol_json_results.gz').as_posix(), 'r') as f: # type: ignore + with gzip.open((d / "RHFEProtocol_json_results.gz").as_posix(), "r") as f: # type: ignore return f.read().decode() # type: ignore @@ -213,10 +225,10 @@ def afe_solv_transformation_json() -> str: generated with gen-serialized-results.py """ - d = resources.files('openfe.tests.data.openmm_afe') + d = resources.files("openfe.tests.data.openmm_afe") fname = "AHFEProtocol_json_results.gz" - - with gzip.open((d / fname).as_posix(), 'r') as f: # type: ignore + + with gzip.open((d / fname).as_posix(), "r") as f: # type: ignore return f.read().decode() # type: ignore @@ -227,10 +239,10 @@ def md_json() -> str: generated with gen-serialized-results.py """ - d = resources.files('openfe.tests.data.openmm_md') + d = resources.files("openfe.tests.data.openmm_md") fname = "MDProtocol_json_results.gz" - with gzip.open((d / fname).as_posix(), 'r') as f: # type: ignore + with gzip.open((d / fname).as_posix(), "r") as f: # type: ignore return f.read().decode() # type: ignore @@ -243,7 +255,7 @@ def md_json() -> str: "hybrid_system.pdb": "07203679cb14b840b36e4320484df2360f45e323faadb02d6eacac244fddd517", "simulation.nc": "92361a0864d4359a75399470135f56642b72c605069a4c33dbc4be6f91f28b31", "simulation_real_time_analysis.yaml": "65706002f371fafba96037f29b054fd7e050e442915205df88567f48f5e5e1cf", - } + }, ) diff --git a/openfe/tests/protocols/test_openmm_afe_slow.py b/openfe/tests/protocols/test_openmm_afe_slow.py index 020be4efa..8ba1f8e37 100644 --- a/openfe/tests/protocols/test_openmm_afe_slow.py +++ b/openfe/tests/protocols/test_openmm_afe_slow.py @@ -1,12 +1,13 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from gufe.protocols import execute_DAG +import os +import pathlib + import pytest +from gufe.protocols import execute_DAG from openff.units import unit from openmm import Platform -import os -import pathlib import openfe from openfe.protocols import openmm_afe @@ -21,25 +22,22 @@ def available_platforms() -> set[str]: def set_openmm_threads_1(): # for vacuum sims, we want to limit threads to one # this fixture sets OPENMM_CPU_THREADS='1' for a single test, then reverts to previously held value - previous: str | None = os.environ.get('OPENMM_CPU_THREADS') + previous: str | None = os.environ.get("OPENMM_CPU_THREADS") try: - os.environ['OPENMM_CPU_THREADS'] = '1' + os.environ["OPENMM_CPU_THREADS"] = "1" yield finally: if previous is None: - del os.environ['OPENMM_CPU_THREADS'] + del os.environ["OPENMM_CPU_THREADS"] else: - os.environ['OPENMM_CPU_THREADS'] = previous + os.environ["OPENMM_CPU_THREADS"] = previous @pytest.mark.integration # takes too long to be a slow test ~ 4 mins locally @pytest.mark.flaky(reruns=3) # pytest-rerunfailures; we can get bad minimisation -@pytest.mark.parametrize('platform', ['CPU', 'CUDA']) -def test_openmm_run_engine(platform, - available_platforms, - benzene_modifications, - set_openmm_threads_1, tmpdir): +@pytest.mark.parametrize("platform", ["CPU", "CUDA"]) +def test_openmm_run_engine(platform, available_platforms, benzene_modifications, set_openmm_threads_1, tmpdir): if platform not in available_platforms: pytest.skip(f"OpenMM Platform: {platform} not available") @@ -54,31 +52,67 @@ def test_openmm_run_engine(platform, s.vacuum_engine_settings.compute_platform = platform s.solvent_engine_settings.compute_platform = platform s.vacuum_simulation_settings.time_per_iteration = 20 * unit.femtosecond - s.solvent_simulation_settings.time_per_iteration = 20 * unit.femtosecond + s.solvent_simulation_settings.time_per_iteration = 20 * unit.femtosecond s.vacuum_output_settings.checkpoint_interval = 20 * unit.femtosecond s.solvent_output_settings.checkpoint_interval = 20 * unit.femtosecond s.vacuum_simulation_settings.n_replicas = 20 s.solvent_simulation_settings.n_replicas = 20 - s.lambda_settings.lambda_elec = \ - [0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - s.lambda_settings.lambda_vdw = \ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, - 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0] - + s.lambda_settings.lambda_elec = [ + 0.0, + 0.25, + 0.5, + 0.75, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ] + s.lambda_settings.lambda_vdw = [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.05, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.65, + 0.7, + 0.75, + 0.8, + 0.85, + 0.9, + 0.95, + 1.0, + ] protocol = openmm_afe.AbsoluteSolvationProtocol( - settings=s, + settings=s, ) - stateA = openfe.ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': openfe.SolventComponent() - }) + stateA = openfe.ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": openfe.SolventComponent()}) - stateB = openfe.ChemicalSystem({ - 'solvent': openfe.SolventComponent(), - }) + stateB = openfe.ChemicalSystem( + { + "solvent": openfe.SolventComponent(), + }, + ) # Create DAG from protocol, get the vacuum and solvent units # and eventually dry run the first solvent unit @@ -88,20 +122,18 @@ def test_openmm_run_engine(platform, mapping=None, ) - cwd = pathlib.Path(str(tmpdir)) - r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, - keep_shared=True) + r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True) assert r.ok() for pur in r.protocol_unit_results: unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" assert unit_shared.exists() assert pathlib.Path(unit_shared).is_dir() - checkpoint = pur.outputs['last_checkpoint'] + checkpoint = pur.outputs["last_checkpoint"] assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc" assert (unit_shared / checkpoint).exists() - nc = pur.outputs['nc'] + nc = pur.outputs["nc"] assert nc == unit_shared / f"{pur.outputs['simtype']}.nc" assert nc.exists() @@ -109,5 +141,5 @@ def test_openmm_run_engine(platform, results = protocol.gather([r]) states = results.get_replica_states() assert len(states.items()) == 2 - assert len(states['solvent']) == 1 - assert states['solvent'][0].shape[1] == 20 + assert len(states["solvent"]) == 1 + assert states["solvent"][0].shape[1] == 20 diff --git a/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py b/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py index 7a3a09a15..08df24d04 100644 --- a/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py +++ b/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py @@ -3,29 +3,28 @@ import itertools import json import sys -import pytest from unittest import mock -from openmm import NonbondedForce, CustomNonbondedForce -from openmmtools.multistate.multistatesampler import MultiStateSampler -from openff.units import unit as offunit -from openff.units.openmm import ensure_quantity, from_openmm + +import gufe import mdtraj as mdt import numpy as np +import pytest from numpy.testing import assert_allclose -import gufe +from openff.units import unit as offunit +from openff.units.openmm import ensure_quantity, from_openmm +from openmm import CustomNonbondedForce, NonbondedForce +from openmmtools.multistate.multistatesampler import MultiStateSampler + import openfe from openfe import ChemicalSystem, SolventComponent from openfe.protocols import openmm_afe from openfe.protocols.openmm_afe import ( + AbsoluteSolvationProtocol, AbsoluteSolvationSolventUnit, AbsoluteSolvationVacuumUnit, - AbsoluteSolvationProtocol, ) - from openfe.protocols.openmm_utils import system_validation -from openfe.protocols.openmm_utils.charge_generation import ( - HAS_NAGL, HAS_OPENEYE, HAS_ESPALOMA -) +from openfe.protocols.openmm_utils.charge_generation import HAS_ESPALOMA, HAS_NAGL, HAS_OPENEYE @pytest.fixture() @@ -38,43 +37,54 @@ def test_create_default_settings(): assert settings -@pytest.mark.parametrize('val', [ - {'elec': [0.0, -1], 'vdw': [0.0, 1.0], 'restraints': [0.0, 1.0]}, - {'elec': [0.0, 1.5], 'vdw': [0.0, 1.5], 'restraints': [-0.1, 1.0]} -]) +@pytest.mark.parametrize( + "val", + [ + {"elec": [0.0, -1], "vdw": [0.0, 1.0], "restraints": [0.0, 1.0]}, + {"elec": [0.0, 1.5], "vdw": [0.0, 1.5], "restraints": [-0.1, 1.0]}, + ], +) def test_incorrect_window_settings(val, default_settings): errmsg = "Lambda windows must be between 0 and 1." lambda_settings = default_settings.lambda_settings with pytest.raises(ValueError, match=errmsg): - lambda_settings.lambda_elec = val['elec'] - lambda_settings.lambda_vdw = val['vdw'] - lambda_settings.lambda_restraints = val['restraints'] + lambda_settings.lambda_elec = val["elec"] + lambda_settings.lambda_vdw = val["vdw"] + lambda_settings.lambda_restraints = val["restraints"] -@pytest.mark.parametrize('val', [ - {'elec': [0.0, 0.1, 0.0], 'vdw': [0.0, 1.0, 1.0], 'restraints': [0.0, 1.0, 1.0]}, -]) +@pytest.mark.parametrize( + "val", + [ + {"elec": [0.0, 0.1, 0.0], "vdw": [0.0, 1.0, 1.0], "restraints": [0.0, 1.0, 1.0]}, + ], +) def test_monotonic_lambda_windows(val, default_settings): errmsg = "The lambda schedule is not monotonic." lambda_settings = default_settings.lambda_settings with pytest.raises(ValueError, match=errmsg): - lambda_settings.lambda_elec = val['elec'] - lambda_settings.lambda_vdw = val['vdw'] - lambda_settings.lambda_restraints = val['restraints'] + lambda_settings.lambda_elec = val["elec"] + lambda_settings.lambda_vdw = val["vdw"] + lambda_settings.lambda_restraints = val["restraints"] -@pytest.mark.parametrize('val', [ - {'elec': [0.0, 1.0], 'vdw': [1.0, 1.0], 'restraints': [0.0, 0.0]}, -]) +@pytest.mark.parametrize( + "val", + [ + {"elec": [0.0, 1.0], "vdw": [1.0, 1.0], "restraints": [0.0, 0.0]}, + ], +) def test_validate_lambda_schedule_naked_charge(val, default_settings): - errmsg = ("There are states along this lambda schedule " - "where there are atoms with charges but no LJ " - f"interactions: lambda 0: " - f"elec {val['elec'][0]} vdW {val['vdw'][0]}") - default_settings.lambda_settings.lambda_elec = val['elec'] - default_settings.lambda_settings.lambda_vdw = val['vdw'] - default_settings.lambda_settings.lambda_restraints = val['restraints'] + errmsg = ( + "There are states along this lambda schedule " + "where there are atoms with charges but no LJ " + f"interactions: lambda 0: " + f"elec {val['elec'][0]} vdW {val['vdw'][0]}" + ) + default_settings.lambda_settings.lambda_elec = val["elec"] + default_settings.lambda_settings.lambda_vdw = val["vdw"] + default_settings.lambda_settings.lambda_restraints = val["restraints"] default_settings.vacuum_simulation_settings.n_replicas = 2 default_settings.solvent_simulation_settings.n_replicas = 2 with pytest.raises(ValueError, match=errmsg): @@ -89,17 +99,19 @@ def test_validate_lambda_schedule_naked_charge(val, default_settings): ) -@pytest.mark.parametrize('val', [ - {'elec': [1.0, 1.0], 'vdw': [0.0, 1.0], 'restraints': [0.0, 0.0]}, -]) +@pytest.mark.parametrize( + "val", + [ + {"elec": [1.0, 1.0], "vdw": [0.0, 1.0], "restraints": [0.0, 0.0]}, + ], +) def test_validate_lambda_schedule_nreplicas(val, default_settings): - default_settings.lambda_settings.lambda_elec = val['elec'] - default_settings.lambda_settings.lambda_vdw = val['vdw'] - default_settings.lambda_settings.lambda_restraints = val['restraints'] + default_settings.lambda_settings.lambda_elec = val["elec"] + default_settings.lambda_settings.lambda_vdw = val["vdw"] + default_settings.lambda_settings.lambda_restraints = val["restraints"] n_replicas = 3 default_settings.vacuum_simulation_settings.n_replicas = n_replicas - errmsg = (f"Number of replicas {n_replicas} does not equal the" - f" number of lambda windows {len(val['vdw'])}") + errmsg = f"Number of replicas {n_replicas} does not equal the" f" number of lambda windows {len(val['vdw'])}" with pytest.raises(ValueError, match=errmsg): AbsoluteSolvationProtocol._validate_lambda_schedule( default_settings.lambda_settings, @@ -107,19 +119,23 @@ def test_validate_lambda_schedule_nreplicas(val, default_settings): ) -@pytest.mark.parametrize('val', [ - {'elec': [1.0, 1.0, 1.0], 'vdw': [0.0, 1.0], 'restraints': [0.0, 0.0]}, -]) +@pytest.mark.parametrize( + "val", + [ + {"elec": [1.0, 1.0, 1.0], "vdw": [0.0, 1.0], "restraints": [0.0, 0.0]}, + ], +) def test_validate_lambda_schedule_nwindows(val, default_settings): - default_settings.lambda_settings.lambda_elec = val['elec'] - default_settings.lambda_settings.lambda_vdw = val['vdw'] - default_settings.lambda_settings.lambda_restraints = val['restraints'] + default_settings.lambda_settings.lambda_elec = val["elec"] + default_settings.lambda_settings.lambda_vdw = val["vdw"] + default_settings.lambda_settings.lambda_restraints = val["restraints"] n_replicas = 3 default_settings.vacuum_simulation_settings.n_replicas = n_replicas errmsg = ( "Components elec and vdw must have equal amount" f" of lambda windows. Got {len(val['elec'])} elec lambda" - f" windows and {len(val['vdw'])} vdw lambda windows.") + f" windows and {len(val['vdw'])} vdw lambda windows." + ) with pytest.raises(ValueError, match=errmsg): AbsoluteSolvationProtocol._validate_lambda_schedule( default_settings.lambda_settings, @@ -127,16 +143,21 @@ def test_validate_lambda_schedule_nwindows(val, default_settings): ) -@pytest.mark.parametrize('val', [ - {'elec': [1.0, 1.0], 'vdw': [1.0, 1.0], 'restraints': [0.0, 1.0]}, -]) +@pytest.mark.parametrize( + "val", + [ + {"elec": [1.0, 1.0], "vdw": [1.0, 1.0], "restraints": [0.0, 1.0]}, + ], +) def test_validate_lambda_schedule_nonzero_restraints(val, default_settings): - wmsg = ("Non-zero restraint lambdas applied. The absolute " - "solvation protocol doesn't apply restraints, " - "therefore restraints won't be applied.") - default_settings.lambda_settings.lambda_elec = val['elec'] - default_settings.lambda_settings.lambda_vdw = val['vdw'] - default_settings.lambda_settings.lambda_restraints = val['restraints'] + wmsg = ( + "Non-zero restraint lambdas applied. The absolute " + "solvation protocol doesn't apply restraints, " + "therefore restraints won't be applied." + ) + default_settings.lambda_settings.lambda_elec = val["elec"] + default_settings.lambda_settings.lambda_vdw = val["vdw"] + default_settings.lambda_settings.lambda_restraints = val["restraints"] default_settings.vacuum_simulation_settings.n_replicas = 2 with pytest.warns(UserWarning, match=wmsg): AbsoluteSolvationProtocol._validate_lambda_schedule( @@ -163,153 +184,140 @@ def test_serialize_protocol(default_settings): assert protocol == ret -def test_validate_solvent_endstates_protcomp( - benzene_modifications, T4_protein_component -): - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'protein': T4_protein_component, - 'solvent': SolventComponent() - }) +def test_validate_solvent_endstates_protcomp(benzene_modifications, T4_protein_component): + stateA = ChemicalSystem( + {"benzene": benzene_modifications["benzene"], "protein": T4_protein_component, "solvent": SolventComponent()}, + ) - stateB = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'phenol': benzene_modifications['phenol'], - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "phenol": benzene_modifications["phenol"], + "solvent": SolventComponent(), + }, + ) with pytest.raises(ValueError, match="Protein components are not allowed"): AbsoluteSolvationProtocol._validate_solvent_endstates(stateA, stateB) -def test_validate_solvent_endstates_nosolvcomp_stateA( - benzene_modifications, T4_protein_component -): - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - }) +def test_validate_solvent_endstates_nosolvcomp_stateA(benzene_modifications, T4_protein_component): + stateA = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + }, + ) - stateB = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'phenol': benzene_modifications['phenol'], - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "phenol": benzene_modifications["phenol"], + "solvent": SolventComponent(), + }, + ) - with pytest.raises( - ValueError, match="No SolventComponent found in stateA" - ): + with pytest.raises(ValueError, match="No SolventComponent found in stateA"): AbsoluteSolvationProtocol._validate_solvent_endstates(stateA, stateB) -def test_validate_solvent_endstates_nosolvcomp_stateB( - benzene_modifications, T4_protein_component -): - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent(), - }) +def test_validate_solvent_endstates_nosolvcomp_stateB(benzene_modifications, T4_protein_component): + stateA = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "solvent": SolventComponent(), + }, + ) - stateB = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'phenol': benzene_modifications['phenol'], - }) + stateB = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "phenol": benzene_modifications["phenol"], + }, + ) - with pytest.raises( - ValueError, match="No SolventComponent found in stateB" - ): + with pytest.raises(ValueError, match="No SolventComponent found in stateB"): AbsoluteSolvationProtocol._validate_solvent_endstates(stateA, stateB) def test_validate_alchem_comps_appearingB(benzene_modifications): - stateA = ChemicalSystem({ - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent() - }) + stateB = ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}) alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - with pytest.raises(ValueError, match='Components appearing in state B'): + with pytest.raises(ValueError, match="Components appearing in state B"): AbsoluteSolvationProtocol._validate_alchemical_components(alchem_comps) def test_validate_alchem_comps_multi(benzene_modifications): - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'toluene': benzene_modifications['toluene'], - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "toluene": benzene_modifications["toluene"], + "solvent": SolventComponent(), + }, + ) - stateB = ChemicalSystem({ - 'solvent': SolventComponent() - }) + stateB = ChemicalSystem({"solvent": SolventComponent()}) alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - assert len(alchem_comps['stateA']) == 2 + assert len(alchem_comps["stateA"]) == 2 - with pytest.raises(ValueError, match='More than one alchemical'): + with pytest.raises(ValueError, match="More than one alchemical"): AbsoluteSolvationProtocol._validate_alchemical_components(alchem_comps) def test_validate_alchem_nonsmc(benzene_modifications): - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - }) + stateB = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + }, + ) alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - with pytest.raises(ValueError, match='Non SmallMoleculeComponent'): + with pytest.raises(ValueError, match="Non SmallMoleculeComponent"): AbsoluteSolvationProtocol._validate_alchemical_components(alchem_comps) def test_vac_bad_nonbonded(benzene_modifications): settings = openmm_afe.AbsoluteSolvationProtocol.default_settings() - settings.vacuum_forcefield_settings.nonbonded_method = 'pme' + settings.vacuum_forcefield_settings.nonbonded_method = "pme" protocol = openmm_afe.AbsoluteSolvationProtocol(settings=settings) - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "solvent": SolventComponent(), + }, + ) - with pytest.raises(ValueError, match='Only the nocutoff'): + with pytest.raises(ValueError, match="Only the nocutoff"): protocol.create(stateA=stateA, stateB=stateB, mapping=None) -@pytest.mark.parametrize('method', [ - 'repex', 'sams', 'independent', 'InDePeNdENT' -]) -def test_dry_run_vac_benzene(benzene_modifications, - method, tmpdir): +@pytest.mark.parametrize("method", ["repex", "sams", "independent", "InDePeNdENT"]) +def test_dry_run_vac_benzene(benzene_modifications, method, tmpdir): s = openmm_afe.AbsoluteSolvationProtocol.default_settings() s.protocol_repeats = 1 s.vacuum_simulation_settings.sampler_method = method protocol = openmm_afe.AbsoluteSolvationProtocol( - settings=s, + settings=s, ) - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "solvent": SolventComponent(), + }, + ) # Create DAG from protocol, get the vacuum and solvent units # and eventually dry run the first vacuum unit @@ -319,23 +327,21 @@ def test_dry_run_vac_benzene(benzene_modifications, mapping=None, ) prot_units = list(dag.protocol_units) - + assert len(prot_units) == 2 - vac_unit = [u for u in prot_units - if isinstance(u, AbsoluteSolvationVacuumUnit)] - sol_unit = [u for u in prot_units - if isinstance(u, AbsoluteSolvationSolventUnit)] + vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)] + sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)] assert len(vac_unit) == 1 assert len(sol_unit) == 1 with tmpdir.as_cwd(): - vac_sampler = vac_unit[0].run(dry=True)['debug']['sampler'] + vac_sampler = vac_unit[0].run(dry=True)["debug"]["sampler"] assert not vac_sampler.is_periodic -def test_confgen_fail_AFE(benzene_modifications, tmpdir): +def test_confgen_fail_AFE(benzene_modifications, tmpdir): # check system parametrisation works even if confgen fails s = openmm_afe.AbsoluteSolvationProtocol.default_settings() s.protocol_repeats = 1 @@ -344,14 +350,13 @@ def test_confgen_fail_AFE(benzene_modifications, tmpdir): settings=s, ) - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "solvent": SolventComponent(), + }, + ) # Create DAG from protocol, get the vacuum and solvent units # and eventually dry run the first vacuum unit @@ -361,12 +366,11 @@ def test_confgen_fail_AFE(benzene_modifications, tmpdir): mapping=None, ) prot_units = list(dag.protocol_units) - vac_unit = [u for u in prot_units - if isinstance(u, AbsoluteSolvationVacuumUnit)] + vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)] with tmpdir.as_cwd(): - with mock.patch('rdkit.Chem.AllChem.EmbedMultipleConfs', return_value=0): - vac_sampler = vac_unit[0].run(dry=True)['debug']['sampler'] + with mock.patch("rdkit.Chem.AllChem.EmbedMultipleConfs", return_value=0): + vac_sampler = vac_unit[0].run(dry=True)["debug"]["sampler"] assert vac_sampler @@ -377,17 +381,16 @@ def test_dry_run_solv_benzene(benzene_modifications, tmpdir): s.solvent_output_settings.output_indices = "resname UNK" protocol = openmm_afe.AbsoluteSolvationProtocol( - settings=s, + settings=s, ) - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "solvent": SolventComponent(), + }, + ) # Create DAG from protocol, get the vacuum and solvent units # and eventually dry run the first solvent unit @@ -400,19 +403,17 @@ def test_dry_run_solv_benzene(benzene_modifications, tmpdir): assert len(prot_units) == 2 - vac_unit = [u for u in prot_units - if isinstance(u, AbsoluteSolvationVacuumUnit)] - sol_unit = [u for u in prot_units - if isinstance(u, AbsoluteSolvationSolventUnit)] + vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)] + sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)] assert len(vac_unit) == 1 assert len(sol_unit) == 1 with tmpdir.as_cwd(): - sol_sampler = sol_unit[0].run(dry=True)['debug']['sampler'] + sol_sampler = sol_unit[0].run(dry=True)["debug"]["sampler"] assert sol_sampler.is_periodic - pdb = mdt.load_pdb('hybrid_system.pdb') + pdb = mdt.load_pdb("hybrid_system.pdb") assert pdb.n_atoms == 12 @@ -420,30 +421,29 @@ def test_dry_run_solv_benzene_tip4p(benzene_modifications, tmpdir): s = AbsoluteSolvationProtocol.default_settings() s.protocol_repeats = 1 s.vacuum_forcefield_settings.forcefields = [ - "amber/ff14SB.xml", # ff14SB protein force field + "amber/ff14SB.xml", # ff14SB protein force field "amber/tip4pew_standard.xml", # FF we are testsing with the fun VS "amber/phosaa10.xml", # Handles THE TPO ] s.solvent_forcefield_settings.forcefields = [ - "amber/ff14SB.xml", # ff14SB protein force field + "amber/ff14SB.xml", # ff14SB protein force field "amber/tip4pew_standard.xml", # FF we are testsing with the fun VS "amber/phosaa10.xml", # Handles THE TPO ] - s.solvation_settings.solvent_model = 'tip4pew' + s.solvation_settings.solvent_model = "tip4pew" s.integrator_settings.reassign_velocities = True protocol = AbsoluteSolvationProtocol( - settings=s, + settings=s, ) - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "solvent": SolventComponent(), + }, + ) # Create DAG from protocol, get the vacuum and solvent units # and eventually dry run the first solvent unit @@ -454,11 +454,10 @@ def test_dry_run_solv_benzene_tip4p(benzene_modifications, tmpdir): ) prot_units = list(dag.protocol_units) - sol_unit = [u for u in prot_units - if isinstance(u, AbsoluteSolvationSolventUnit)] + sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)] with tmpdir.as_cwd(): - sol_sampler = sol_unit[0].run(dry=True)['debug']['sampler'] + sol_sampler = sol_unit[0].run(dry=True)["debug"]["sampler"] assert sol_sampler.is_periodic @@ -472,7 +471,7 @@ def test_dry_run_solv_user_charges_benzene(benzene_modifications, tmpdir): s.protocol_repeats = 1 protocol = openmm_afe.AbsoluteSolvationProtocol( - settings=s, + settings=s, ) def assign_fictitious_charges(offmol): @@ -483,42 +482,42 @@ def assign_fictitious_charges(offmol): rand_arr[-1] = -sum(rand_arr[:-1]) return rand_arr * offunit.elementary_charge - benzene_offmol = benzene_modifications['benzene'].to_openff() + benzene_offmol = benzene_modifications["benzene"].to_openff() offmol_pchgs = assign_fictitious_charges(benzene_offmol) benzene_offmol.partial_charges = offmol_pchgs benzene_smc = openfe.SmallMoleculeComponent.from_openff(benzene_offmol) # check propchgs - prop_chgs = benzene_smc.to_dict()['molprops']['atom.dprop.PartialCharge'] + prop_chgs = benzene_smc.to_dict()["molprops"]["atom.dprop.PartialCharge"] prop_chgs = np.array(prop_chgs.split(), dtype=float) np.testing.assert_allclose(prop_chgs, offmol_pchgs) # Create ChemicalSystems - stateA = ChemicalSystem({ - 'benzene': benzene_smc, - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": benzene_smc, "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "solvent": SolventComponent(), + }, + ) # Create DAG from protocol, get the vacuum and solvent units # and eventually dry run the first solvent unit - dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None,) + dag = protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) prot_units = list(dag.protocol_units) - vac_unit = [u for u in prot_units - if isinstance(u, AbsoluteSolvationVacuumUnit)][0] - sol_unit = [u for u in prot_units - if isinstance(u, AbsoluteSolvationSolventUnit)][0] + vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)][0] + sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)][0] # check sol_unit charges with tmpdir.as_cwd(): - sampler = sol_unit.run(dry=True)['debug']['sampler'] + sampler = sol_unit.run(dry=True)["debug"]["sampler"] system = sampler._thermodynamic_states[0].system - nonbond = [f for f in system.getForces() - if isinstance(f, NonbondedForce)] + nonbond = [f for f in system.getForces() if isinstance(f, NonbondedForce)] assert len(nonbond) == 1 @@ -526,53 +525,59 @@ def assign_fictitious_charges(offmol): # partial charge is stored in the offset for i in range(12): offsets = nonbond[0].getParticleParameterOffset(i) - c = ensure_quantity(offsets[2], 'openff') + c = ensure_quantity(offsets[2], "openff") assert pytest.approx(c) == prop_chgs[i] # check vac_unit charges with tmpdir.as_cwd(): - sampler = vac_unit.run(dry=True)['debug']['sampler'] + sampler = vac_unit.run(dry=True)["debug"]["sampler"] system = sampler._thermodynamic_states[0].system - nonbond = [f for f in system.getForces() - if isinstance(f, CustomNonbondedForce)] + nonbond = [f for f in system.getForces() if isinstance(f, CustomNonbondedForce)] assert len(nonbond) == 4 - custom_elec = [ - n for n in nonbond if - n.getGlobalParameterName(0) == 'lambda_electrostatics'][0] + custom_elec = [n for n in nonbond if n.getGlobalParameterName(0) == "lambda_electrostatics"][0] # loop through the 12 benzene atoms for i in range(12): c, s = custom_elec.getParticleParameters(i) - c = ensure_quantity(c, 'openff') + c = ensure_quantity(c, "openff") assert pytest.approx(c) == prop_chgs[i] -@pytest.mark.parametrize('method, backend, ref_key', [ - ('am1bcc', 'ambertools', 'ambertools'), - pytest.param( - 'am1bcc', 'openeye', 'openeye', - marks=pytest.mark.skipif( - not HAS_OPENEYE, reason='needs oechem', +@pytest.mark.parametrize( + "method, backend, ref_key", + [ + ("am1bcc", "ambertools", "ambertools"), + pytest.param( + "am1bcc", + "openeye", + "openeye", + marks=pytest.mark.skipif( + not HAS_OPENEYE, + reason="needs oechem", + ), ), - ), - pytest.param( - 'nagl', 'rdkit', 'nagl', - marks=pytest.mark.skipif( - not HAS_NAGL or sys.platform.startswith('darwin'), - reason='needs NAGL and/or on macos', + pytest.param( + "nagl", + "rdkit", + "nagl", + marks=pytest.mark.skipif( + not HAS_NAGL or sys.platform.startswith("darwin"), + reason="needs NAGL and/or on macos", + ), ), - ), - pytest.param( - 'espaloma', 'rdkit', 'espaloma', - marks=pytest.mark.skipif( - not HAS_ESPALOMA, reason='needs espaloma', + pytest.param( + "espaloma", + "rdkit", + "espaloma", + marks=pytest.mark.skipif( + not HAS_ESPALOMA, + reason="needs espaloma", + ), ), - ), -]) -def test_dry_run_charge_backends( - CN_molecule, tmpdir, method, backend, ref_key, am1bcc_ref_charges -): + ], +) +def test_dry_run_charge_backends(CN_molecule, tmpdir, method, backend, ref_key, am1bcc_ref_charges): """ Check that partial charge generation with different backends works as expected. @@ -581,39 +586,34 @@ def test_dry_run_charge_backends( s.protocol_repeats = 1 s.partial_charge_settings.partial_charge_method = method s.partial_charge_settings.off_toolkit_backend = backend - s.partial_charge_settings.nagl_model = 'openff-gnn-am1bcc-0.1.0-rc.1.pt' + s.partial_charge_settings.nagl_model = "openff-gnn-am1bcc-0.1.0-rc.1.pt" protocol = openmm_afe.AbsoluteSolvationProtocol(settings=s) # Create ChemicalSystems - stateA = ChemicalSystem({ - 'benzene': CN_molecule, - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": CN_molecule, "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "solvent": SolventComponent(), + }, + ) # Create DAG from protocol, get the vacuum and solvent units # and eventually dry run the first solvent unit dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None) prot_units = list(dag.protocol_units) - vac_unit = [u for u in prot_units - if isinstance(u, AbsoluteSolvationVacuumUnit)][0] + vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)][0] # check vac_unit charges with tmpdir.as_cwd(): - sampler = vac_unit.run(dry=True)['debug']['sampler'] + sampler = vac_unit.run(dry=True)["debug"]["sampler"] system = sampler._thermodynamic_states[0].system - nonbond = [f for f in system.getForces() - if isinstance(f, CustomNonbondedForce)] + nonbond = [f for f in system.getForces() if isinstance(f, CustomNonbondedForce)] assert len(nonbond) == 4 - custom_elec = [ - n for n in nonbond if - n.getGlobalParameterName(0) == 'lambda_electrostatics'][0] + custom_elec = [n for n in nonbond if n.getGlobalParameterName(0) == "lambda_electrostatics"][0] charges = [] for i in range(system.getNumParticles()): @@ -634,17 +634,16 @@ def test_high_timestep(benzene_modifications, tmpdir): s.vacuum_forcefield_settings.hydrogen_mass = 1.0 protocol = AbsoluteSolvationProtocol( - settings=s, + settings=s, ) - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "solvent": SolventComponent(), + }, + ) dag = protocol.create( stateA=stateA, @@ -664,17 +663,16 @@ def benzene_solvation_dag(benzene_modifications): s = AbsoluteSolvationProtocol.default_settings() protocol = openmm_afe.AbsoluteSolvationProtocol( - settings=s, + settings=s, ) - stateA = ChemicalSystem({ - 'benzene': benzene_modifications['benzene'], - 'solvent': SolventComponent() - }) + stateA = ChemicalSystem({"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}) - stateB = ChemicalSystem({ - 'solvent': SolventComponent(), - }) + stateB = ChemicalSystem( + { + "solvent": SolventComponent(), + }, + ) return protocol.create(stateA=stateA, stateB=stateB, mapping=None) @@ -685,10 +683,14 @@ def test_unit_tagging(benzene_solvation_dag, tmpdir): dag_units = benzene_solvation_dag.protocol_units with ( - mock.patch('openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run', - return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'}), - mock.patch('openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run', - return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'}), + mock.patch( + "openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run", + return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, + ), + mock.patch( + "openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run", + return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, + ), ): results = [] for u in dag_units: @@ -699,11 +701,11 @@ def test_unit_tagging(benzene_solvation_dag, tmpdir): vac_repeats = set() for ret in results: assert isinstance(ret, gufe.ProtocolUnitResult) - assert ret.outputs['generation'] == 0 - if ret.outputs['simtype'] == 'vacuum': - vac_repeats.add(ret.outputs['repeat_id']) + assert ret.outputs["generation"] == 0 + if ret.outputs["simtype"] == "vacuum": + vac_repeats.add(ret.outputs["repeat_id"]) else: - solv_repeats.add(ret.outputs['repeat_id']) + solv_repeats.add(ret.outputs["repeat_id"]) # Repeat ids are random ints so just check their lengths assert len(vac_repeats) == len(solv_repeats) == 3 @@ -711,15 +713,21 @@ def test_unit_tagging(benzene_solvation_dag, tmpdir): def test_gather(benzene_solvation_dag, tmpdir): # check that .gather behaves as expected with ( - mock.patch('openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run', - return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'}), - mock.patch('openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run', - return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'}), + mock.patch( + "openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run", + return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, + ), + mock.patch( + "openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run", + return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, + ), ): - dagres = gufe.protocols.execute_DAG(benzene_solvation_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, - keep_shared=True) + dagres = gufe.protocols.execute_DAG( + benzene_solvation_dag, + shared_basedir=tmpdir, + scratch_basedir=tmpdir, + keep_shared=True, + ) protocol = AbsoluteSolvationProtocol( settings=AbsoluteSolvationProtocol.default_settings(), @@ -733,18 +741,16 @@ def test_gather(benzene_solvation_dag, tmpdir): class TestProtocolResult: @pytest.fixture() def protocolresult(self, afe_solv_transformation_json): - d = json.loads(afe_solv_transformation_json, - cls=gufe.tokenization.JSON_HANDLER.decoder) + d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) - pr = openfe.ProtocolResult.from_dict(d['protocol_result']) + pr = openfe.ProtocolResult.from_dict(d["protocol_result"]) return pr def test_reload_protocol_result(self, afe_solv_transformation_json): - d = json.loads(afe_solv_transformation_json, - cls=gufe.tokenization.JSON_HANDLER.decoder) + d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) - pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d['protocol_result']) + pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"]) assert pr @@ -768,14 +774,14 @@ def test_get_individual(self, protocolresult): inds = protocolresult.get_individual_estimates() assert isinstance(inds, dict) - assert isinstance(inds['solvent'], list) - assert isinstance(inds['vacuum'], list) - assert len(inds['solvent']) == len(inds['vacuum']) == 3 - for e, u in itertools.chain(inds['solvent'], inds['vacuum']): + assert isinstance(inds["solvent"], list) + assert isinstance(inds["vacuum"], list) + assert len(inds["solvent"]) == len(inds["vacuum"]) == 3 + for e, u in itertools.chain(inds["solvent"], inds["vacuum"]): assert e.is_compatible_with(offunit.kilojoule_per_mole) assert u.is_compatible_with(offunit.kilojoule_per_mole) - @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) def test_get_forwards_etc(self, key, protocolresult): far = protocolresult.get_forward_and_reverse_energy_analysis() @@ -784,14 +790,13 @@ def test_get_forwards_etc(self, key, protocolresult): far1 = far[key][0] assert isinstance(far1, dict) - for k in ['fractions', 'forward_DGs', 'forward_dDGs', - 'reverse_DGs', 'reverse_dDGs']: + for k in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]: assert k in far1 - if k == 'fractions': + if k == "fractions": assert isinstance(far1[k], np.ndarray) - @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) def test_get_overlap_matrices(self, key, protocolresult): ovp = protocolresult.get_overlap_matrices() @@ -800,10 +805,10 @@ def test_get_overlap_matrices(self, key, protocolresult): assert len(ovp[key]) == 3 ovp1 = ovp[key][0] - assert isinstance(ovp1['matrix'], np.ndarray) - assert ovp1['matrix'].shape == (14, 14) + assert isinstance(ovp1["matrix"], np.ndarray) + assert ovp1["matrix"].shape == (14, 14) - @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) def test_get_replica_transition_statistics(self, key, protocolresult): rpx = protocolresult.get_replica_transition_statistics() @@ -811,12 +816,12 @@ def test_get_replica_transition_statistics(self, key, protocolresult): assert isinstance(rpx[key], list) assert len(rpx[key]) == 3 rpx1 = rpx[key][0] - assert 'eigenvalues' in rpx1 - assert 'matrix' in rpx1 - assert rpx1['eigenvalues'].shape == (14,) - assert rpx1['matrix'].shape == (14, 14) + assert "eigenvalues" in rpx1 + assert "matrix" in rpx1 + assert rpx1["eigenvalues"].shape == (14,) + assert rpx1["matrix"].shape == (14, 14) - @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) def test_equilibration_iterations(self, key, protocolresult): eq = protocolresult.equilibration_iterations() @@ -825,7 +830,7 @@ def test_equilibration_iterations(self, key, protocolresult): assert len(eq[key]) == 3 assert all(isinstance(v, float) for v in eq[key]) - @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) def test_production_iterations(self, key, protocolresult): prod = protocolresult.production_iterations() diff --git a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py index 4f55b9738..0a916148e 100644 --- a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py +++ b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py @@ -2,59 +2,53 @@ # For details, see https://github.com/OpenFreeEnergy/openfe import copy import json +import sys import xml.etree.ElementTree as ET from importlib import resources from unittest import mock -import sys import gufe import mdtraj as mdt import numpy as np import pytest +from kartograf import KartografAtomMapper +from kartograf.atom_aligner import align_mol_shape from openff.toolkit import Molecule from openff.units import unit -from openff.units.openmm import ensure_quantity -from openff.units.openmm import to_openmm, from_openmm -from openmm import ( - app, XmlSerializer, MonteCarloBarostat, - NonbondedForce, CustomNonbondedForce -) +from openff.units.openmm import ensure_quantity, from_openmm, to_openmm +from openmm import CustomNonbondedForce, MonteCarloBarostat, NonbondedForce, XmlSerializer, app from openmm import unit as omm_unit from openmmforcefields.generators import SMIRNOFFTemplateGenerator from openmmtools.multistate.multistatesampler import MultiStateSampler from rdkit import Chem from rdkit.Geometry import Point3D -from kartograf.atom_aligner import align_mol_shape -from kartograf import KartografAtomMapper - import openfe from openfe import setup from openfe.protocols import openmm_rfe from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers from openfe.protocols.openmm_rfe.equil_rfe_methods import ( - _validate_alchemical_components, _get_alchemical_charge_difference + _get_alchemical_charge_difference, + _validate_alchemical_components, ) from openfe.protocols.openmm_utils import system_creation -from openfe.protocols.openmm_utils.charge_generation import ( - HAS_NAGL, HAS_OPENEYE, HAS_ESPALOMA -) +from openfe.protocols.openmm_utils.charge_generation import HAS_ESPALOMA, HAS_NAGL, HAS_OPENEYE def test_compute_platform_warn(): with pytest.warns(UserWarning, match="Non-GPU platform selected: CPU"): - openmm_rfe._rfe_utils.compute.get_openmm_platform('CPU') + openmm_rfe._rfe_utils.compute.get_openmm_platform("CPU") def test_append_topology(benzene_complex_system, toluene_complex_system): mod = app.Modeller( - benzene_complex_system['protein'].to_openmm_topology(), - benzene_complex_system['protein'].to_openmm_positions(), + benzene_complex_system["protein"].to_openmm_topology(), + benzene_complex_system["protein"].to_openmm_positions(), ) - lig1 = benzene_complex_system['ligand'].to_openff() + lig1 = benzene_complex_system["ligand"].to_openff() mod.add( lig1.to_topology().to_openmm(), - ensure_quantity(lig1.conformers[0], 'openmm'), + ensure_quantity(lig1.conformers[0], "openmm"), ) top1 = mod.topology @@ -62,10 +56,11 @@ def test_append_topology(benzene_complex_system, toluene_complex_system): assert len(list(top1.atoms())) == 2625 assert len(list(top1.bonds())) == 2645 - lig2 = toluene_complex_system['ligand'].to_openff() + lig2 = toluene_complex_system["ligand"].to_openff() top2, appended_resids = openmm_rfe._rfe_utils.topologyhelpers.combined_topology( - top1, lig2.to_topology().to_openmm(), + top1, + lig2.to_topology().to_openmm(), exclude_resids=np.asarray(list(top1.residues())[-1].index), ) @@ -74,16 +69,15 @@ def test_append_topology(benzene_complex_system, toluene_complex_system): assert appended_resids[0] == len(list(top1.residues())) - 1 -def test_append_topology_no_exclude(benzene_complex_system, - toluene_complex_system): +def test_append_topology_no_exclude(benzene_complex_system, toluene_complex_system): mod = app.Modeller( - benzene_complex_system['protein'].to_openmm_topology(), - benzene_complex_system['protein'].to_openmm_positions(), + benzene_complex_system["protein"].to_openmm_topology(), + benzene_complex_system["protein"].to_openmm_positions(), ) - lig1 = benzene_complex_system['ligand'].to_openff() + lig1 = benzene_complex_system["ligand"].to_openff() mod.add( lig1.to_topology().to_openmm(), - ensure_quantity(lig1.conformers[0], 'openmm'), + ensure_quantity(lig1.conformers[0], "openmm"), ) top1 = mod.topology @@ -91,10 +85,11 @@ def test_append_topology_no_exclude(benzene_complex_system, assert len(list(top1.atoms())) == 2625 assert len(list(top1.bonds())) == 2645 - lig2 = toluene_complex_system['ligand'].to_openff() + lig2 = toluene_complex_system["ligand"].to_openff() top2, appended_resids = openmm_rfe._rfe_utils.topologyhelpers.combined_topology( - top1, lig2.to_topology().to_openmm(), + top1, + lig2.to_topology().to_openmm(), exclude_resids=None, ) @@ -136,7 +131,7 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) dag1 = protocol.create( stateA=benzene_system, @@ -152,45 +147,56 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t repeat_ids = set() u: openmm_rfe.RelativeHybridTopologyProtocolUnit for u in dag1.protocol_units: - repeat_ids.add(u.inputs['repeat_id']) + repeat_ids.add(u.inputs["repeat_id"]) for u in dag2.protocol_units: - repeat_ids.add(u.inputs['repeat_id']) + repeat_ids.add(u.inputs["repeat_id"]) assert len(repeat_ids) == 6 -@pytest.mark.parametrize('mapping', [ - None, [], ['A', 'B'], -]) +@pytest.mark.parametrize( + "mapping", + [ + None, + [], + ["A", "B"], + ], +) def test_validate_alchemical_components_wrong_mappings(mapping): with pytest.raises(ValueError, match="A single LigandAtomMapping"): - _validate_alchemical_components( - {'stateA': [], 'stateB': []}, mapping - ) + _validate_alchemical_components({"stateA": [], "stateB": []}, mapping) -def test_validate_alchemical_components_missing_alchem_comp( - benzene_to_toluene_mapping): - alchem_comps = {'stateA': [openfe.SolventComponent(), ], 'stateB': []} +def test_validate_alchemical_components_missing_alchem_comp(benzene_to_toluene_mapping): + alchem_comps = { + "stateA": [ + openfe.SolventComponent(), + ], + "stateB": [], + } with pytest.raises(ValueError, match="Unmapped alchemical component"): _validate_alchemical_components( - alchem_comps, benzene_to_toluene_mapping, + alchem_comps, + benzene_to_toluene_mapping, ) -@pytest.mark.parametrize('method', [ - 'repex', 'sams', 'independent', 'InDePeNdENT' -]) -def test_dry_run_default_vacuum(benzene_vacuum_system, toluene_vacuum_system, - benzene_to_toluene_mapping, method, tmpdir): +@pytest.mark.parametrize("method", ["repex", "sams", "independent", "InDePeNdENT"]) +def test_dry_run_default_vacuum( + benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + method, + tmpdir, +): vac_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" vac_settings.simulation_settings.sampler_method = method vac_settings.protocol_repeats = 1 protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, + settings=vac_settings, ) # create DAG from protocol and take first (and only) work unit from within @@ -202,7 +208,7 @@ def test_dry_run_default_vacuum(benzene_vacuum_system, toluene_vacuum_system, dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)['debug']['sampler'] + sampler = dag_unit.run(dry=True)["debug"]["sampler"] assert isinstance(sampler, MultiStateSampler) assert not sampler.is_periodic assert sampler._thermodynamic_states[0].barostat is None @@ -223,18 +229,17 @@ def test_dry_run_default_vacuum(benzene_vacuum_system, toluene_vacuum_system, assert len(list(ret_top.bonds())) == 16 # check that our PDB has the right number of atoms - pdb = mdt.load_pdb('hybrid_system.pdb') + pdb = mdt.load_pdb("hybrid_system.pdb") assert pdb.n_atoms == 16 -def test_dry_run_gaff_vacuum(benzene_vacuum_system, toluene_vacuum_system, - benzene_to_toluene_mapping, tmpdir): +def test_dry_run_gaff_vacuum(benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, tmpdir): vac_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' - vac_settings.forcefield_settings.small_molecule_forcefield = 'gaff-2.11' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" + vac_settings.forcefield_settings.small_molecule_forcefield = "gaff-2.11" protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, + settings=vac_settings, ) # create DAG from protocol and take first (and only) work unit from within @@ -246,13 +251,15 @@ def test_dry_run_gaff_vacuum(benzene_vacuum_system, toluene_vacuum_system, unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = unit.run(dry=True)['debug']['sampler'] + sampler = unit.run(dry=True)["debug"]["sampler"] @pytest.mark.slow def test_dry_many_molecules_solvent( - benzene_many_solv_system, toluene_many_solv_system, - benzene_to_toluene_mapping, tmpdir + benzene_many_solv_system, + toluene_many_solv_system, + benzene_to_toluene_mapping, + tmpdir, ): """ A basic test flushing "will it work if you pass multiple molecules" @@ -260,7 +267,7 @@ def test_dry_many_molecules_solvent( settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) # create DAG from protocol and take first (and only) work unit from within @@ -272,7 +279,7 @@ def test_dry_many_molecules_solvent( unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = unit.run(dry=True)['debug']['sampler'] + sampler = unit.run(dry=True)["debug"]["sampler"] BENZ = """\ @@ -346,53 +353,56 @@ def test_dry_core_element_change(tmpdir): benz = openfe.SmallMoleculeComponent(Chem.MolFromMolBlock(BENZ, removeHs=False)) pyr = openfe.SmallMoleculeComponent(Chem.MolFromMolBlock(PYRIDINE, removeHs=False)) - mapping = openfe.LigandAtomMapping( - benz, pyr, - {0: 0, 1: 10, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 8: 9, 9: 8, 10: 7, 11: 6} - ) + mapping = openfe.LigandAtomMapping(benz, pyr, {0: 0, 1: 10, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 8: 9, 9: 8, 10: 7, 11: 6}) settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - settings.forcefield_settings.nonbonded_method = 'nocutoff' + settings.forcefield_settings.nonbonded_method = "nocutoff" protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) dag = protocol.create( - stateA=openfe.ChemicalSystem({'ligand': benz, }), - stateB=openfe.ChemicalSystem({'ligand': pyr, }), + stateA=openfe.ChemicalSystem( + { + "ligand": benz, + }, + ), + stateB=openfe.ChemicalSystem( + { + "ligand": pyr, + }, + ), mapping=mapping, ) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)['debug']['sampler'] + sampler = dag_unit.run(dry=True)["debug"]["sampler"] system = sampler._hybrid_factory.hybrid_system assert system.getNumParticles() == 12 # Average mass between nitrogen and carbon assert system.getParticleMass(1) == 12.0127235 * omm_unit.amu # Get out the CustomNonbondedForce - cnf = [f for f in system.getForces() - if f.__class__.__name__ == 'CustomNonbondedForce'][0] + cnf = [f for f in system.getForces() if f.__class__.__name__ == "CustomNonbondedForce"][0] # there should be no new unique atoms assert cnf.getInteractionGroupParameters(6) == [(), ()] # there should be one old unique atom (spare hydrogen from the benzene) assert cnf.getInteractionGroupParameters(7) == [(7,), (7,)] -@pytest.mark.parametrize('method', ['repex', 'sams', 'independent']) -def test_dry_run_ligand(benzene_system, toluene_system, - benzene_to_toluene_mapping, method, tmpdir): +@pytest.mark.parametrize("method", ["repex", "sams", "independent"]) +def test_dry_run_ligand(benzene_system, toluene_system, benzene_to_toluene_mapping, method, tmpdir): # this might be a bit time consuming settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() settings.simulation_settings.sampler_method = method settings.protocol_repeats = 1 - settings.output_settings.output_indices = 'resname UNK' + settings.output_settings.output_indices = "resname UNK" protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) dag = protocol.create( stateA=benzene_system, @@ -402,20 +412,18 @@ def test_dry_run_ligand(benzene_system, toluene_system, dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)['debug']['sampler'] + sampler = dag_unit.run(dry=True)["debug"]["sampler"] assert isinstance(sampler, MultiStateSampler) assert sampler.is_periodic - assert isinstance(sampler._thermodynamic_states[0].barostat, - MonteCarloBarostat) + assert isinstance(sampler._thermodynamic_states[0].barostat, MonteCarloBarostat) assert sampler._thermodynamic_states[1].pressure == 1 * omm_unit.bar # Check we have the right number of atoms in the PDB - pdb = mdt.load_pdb('hybrid_system.pdb') + pdb = mdt.load_pdb("hybrid_system.pdb") assert pdb.n_atoms == 16 -def test_confgen_mocked_fail(benzene_system, toluene_system, - benzene_to_toluene_mapping, tmpdir): +def test_confgen_mocked_fail(benzene_system, toluene_system, benzene_to_toluene_mapping, tmpdir): """ Check that even if conformer generation fails, we can still perform a sim """ @@ -424,38 +432,34 @@ def test_confgen_mocked_fail(benzene_system, toluene_system, protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings=settings) - dag = protocol.create(stateA=benzene_system, stateB=toluene_system, - mapping=benzene_to_toluene_mapping) + dag = protocol.create(stateA=benzene_system, stateB=toluene_system, mapping=benzene_to_toluene_mapping) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - with mock.patch('rdkit.Chem.AllChem.EmbedMultipleConfs', return_value=0): + with mock.patch("rdkit.Chem.AllChem.EmbedMultipleConfs", return_value=0): sampler = dag_unit.run(dry=True) assert sampler -@pytest.fixture(scope='session') -def tip4p_hybrid_factory( - benzene_system, toluene_system, - benzene_to_toluene_mapping, tmp_path_factory -): +@pytest.fixture(scope="session") +def tip4p_hybrid_factory(benzene_system, toluene_system, benzene_to_toluene_mapping, tmp_path_factory): """ Hybrid system with virtual sites in the environment (waters) """ settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() settings.forcefield_settings.forcefields = [ - "amber/ff14SB.xml", # ff14SB protein force field + "amber/ff14SB.xml", # ff14SB protein force field "amber/tip4pew_standard.xml", # FF we are testsing with the fun VS "amber/phosaa10.xml", # Handles THE TPO ] settings.solvation_settings.solvent_padding = 1.0 * unit.nanometer settings.forcefield_settings.nonbonded_cutoff = 0.9 * unit.nanometer - settings.solvation_settings.solvent_model = 'tip4pew' + settings.solvation_settings.solvent_model = "tip4pew" settings.integrator_settings.reassign_velocities = True protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) dag = protocol.create( stateA=benzene_system, @@ -468,12 +472,12 @@ def tip4p_hybrid_factory( scratch_temp = tmp_path_factory.mktemp("tip4p_scratch") dag_unit_result = dag_unit.run( - dry=True, - scratch_basepath=scratch_temp, - shared_basepath=shared_temp, + dry=True, + scratch_basepath=scratch_temp, + shared_basepath=shared_temp, ) - return dag_unit_result['debug']['sampler']._factory + return dag_unit_result["debug"]["sampler"]._factory def test_tip4p_particle_count(tip4p_hybrid_factory): @@ -499,13 +503,8 @@ def test_tip4p_num_waters(tip4p_hybrid_factory): htf = tip4p_hybrid_factory # Test 2 - num_waters = len( - [r for r in htf._old_topology.residues() if r.name =='HOH'] - ) - virtual_sites = [ - ix for ix in range(htf.hybrid_system.getNumParticles()) if - htf.hybrid_system.isVirtualSite(ix) - ] + num_waters = len([r for r in htf._old_topology.residues() if r.name == "HOH"]) + virtual_sites = [ix for ix in range(htf.hybrid_system.getNumParticles()) if htf.hybrid_system.isVirtualSite(ix)] assert num_waters == len(virtual_sites) @@ -517,32 +516,25 @@ def test_tip4p_check_vsite_parameters(tip4p_hybrid_factory): htf = tip4p_hybrid_factory - virtual_sites = [ - ix for ix in range(htf.hybrid_system.getNumParticles()) if - htf.hybrid_system.isVirtualSite(ix) - ] + virtual_sites = [ix for ix in range(htf.hybrid_system.getNumParticles()) if htf.hybrid_system.isVirtualSite(ix)] # get the standard and custom nonbonded forces - one of each - nonbond = [f for f in htf.hybrid_system.getForces() - if isinstance(f, NonbondedForce)][0] + nonbond = [f for f in htf.hybrid_system.getForces() if isinstance(f, NonbondedForce)][0] - cust_nonbond = [f for f in htf.hybrid_system.getForces() - if isinstance(f, CustomNonbondedForce)][0] + cust_nonbond = [f for f in htf.hybrid_system.getForces() if isinstance(f, CustomNonbondedForce)][0] # loop through every virtual site and check that they have the # expected tip4p parameters for entry in virtual_sites: vs = htf.hybrid_system.getVirtualSite(entry) vs_mass = htf.hybrid_system.getParticleMass(entry) - assert ensure_quantity(vs_mass, 'openff').m == pytest.approx(0) + assert ensure_quantity(vs_mass, "openff").m == pytest.approx(0) vs_weights = [vs.getWeight(ix) for ix in range(vs.getNumParticles())] - np.testing.assert_allclose( - vs_weights, [0.786646558, 0.106676721, 0.106676721] - ) + np.testing.assert_allclose(vs_weights, [0.786646558, 0.106676721, 0.106676721]) c, s, e = nonbond.getParticleParameters(entry) - assert ensure_quantity(c, 'openff').m == pytest.approx(-1.04844) - assert ensure_quantity(s, 'openff').m == 1 - assert ensure_quantity(e, 'openff').m == 0 + assert ensure_quantity(c, "openff").m == pytest.approx(-1.04844) + assert ensure_quantity(s, "openff").m == 1 + assert ensure_quantity(e, "openff").m == 0 s1, e1, s2, e2, i, j = cust_nonbond.getParticleParameters(entry) @@ -552,14 +544,8 @@ def test_tip4p_check_vsite_parameters(tip4p_hybrid_factory): @pytest.mark.slow -@pytest.mark.parametrize('cutoff', - [1.0 * unit.nanometer, - 12.0 * unit.angstrom, - 0.9 * unit.nanometer] -) -def test_dry_run_ligand_system_cutoff( - cutoff, benzene_system, toluene_system, benzene_to_toluene_mapping, tmpdir -): +@pytest.mark.parametrize("cutoff", [1.0 * unit.nanometer, 12.0 * unit.angstrom, 0.9 * unit.nanometer]) +def test_dry_run_ligand_system_cutoff(cutoff, benzene_system, toluene_system, benzene_to_toluene_mapping, tmpdir): """ Test that the right nonbonded cutoff is propagated to the hybrid system. """ @@ -568,7 +554,7 @@ def test_dry_run_ligand_system_cutoff( settings.forcefield_settings.nonbonded_cutoff = cutoff protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) dag = protocol.create( stateA=benzene_system, @@ -578,86 +564,94 @@ def test_dry_run_ligand_system_cutoff( dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)['debug']['sampler'] + sampler = dag_unit.run(dry=True)["debug"]["sampler"] hs = sampler._factory.hybrid_system - nbfs = [f for f in hs.getForces() if - isinstance(f, CustomNonbondedForce) or - isinstance(f, NonbondedForce)] + nbfs = [f for f in hs.getForces() if isinstance(f, CustomNonbondedForce) or isinstance(f, NonbondedForce)] for f in nbfs: f_cutoff = from_openmm(f.getCutoffDistance()) assert f_cutoff == cutoff -@pytest.mark.parametrize('method, backend, ref_key', [ - ('am1bcc', 'ambertools', 'ambertools'), - pytest.param( - 'am1bcc', 'openeye', 'openeye', - marks=pytest.mark.skipif( - not HAS_OPENEYE, reason='needs oechem', +@pytest.mark.parametrize( + "method, backend, ref_key", + [ + ("am1bcc", "ambertools", "ambertools"), + pytest.param( + "am1bcc", + "openeye", + "openeye", + marks=pytest.mark.skipif( + not HAS_OPENEYE, + reason="needs oechem", + ), ), - ), - pytest.param( - 'nagl', 'rdkit', 'nagl', - marks=pytest.mark.skipif( - not HAS_NAGL or sys.platform.startswith('darwin'), - reason='needs NAGL and/or on macos', + pytest.param( + "nagl", + "rdkit", + "nagl", + marks=pytest.mark.skipif( + not HAS_NAGL or sys.platform.startswith("darwin"), + reason="needs NAGL and/or on macos", + ), ), - ), - pytest.param( - 'espaloma', 'rdkit', 'espaloma', - marks=pytest.mark.skipif( - not HAS_ESPALOMA, reason='needs espaloma', + pytest.param( + "espaloma", + "rdkit", + "espaloma", + marks=pytest.mark.skipif( + not HAS_ESPALOMA, + reason="needs espaloma", + ), ), - ), -]) -def test_dry_run_charge_backends( - CN_molecule, tmpdir, method, backend, ref_key, am1bcc_ref_charges -): + ], +) +def test_dry_run_charge_backends(CN_molecule, tmpdir, method, backend, ref_key, am1bcc_ref_charges): vac_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" vac_settings.protocol_repeats = 1 vac_settings.partial_charge_settings.partial_charge_method = method vac_settings.partial_charge_settings.off_toolkit_backend = backend - vac_settings.partial_charge_settings.nagl_model = 'openff-gnn-am1bcc-0.1.0-rc.1.pt' + vac_settings.partial_charge_settings.nagl_model = "openff-gnn-am1bcc-0.1.0-rc.1.pt" protocol = openmm_rfe.RelativeHybridTopologyProtocol( settings=vac_settings, ) # make stateB molecule - offmolB = Molecule.from_smiles('CCN') + offmolB = Molecule.from_smiles("CCN") offmolB.generate_conformers() molB = openfe.SmallMoleculeComponent.from_openff(offmolB) a_molB = align_mol_shape(molB, ref_mol=CN_molecule) mapper = KartografAtomMapper(atom_map_hydrogens=True) mapping = next(mapper.suggest_mappings(CN_molecule, a_molB)) - systemA = openfe.ChemicalSystem({'l': CN_molecule}) - systemB = openfe.ChemicalSystem({'l': a_molB}) + systemA = openfe.ChemicalSystem({"l": CN_molecule}) + systemB = openfe.ChemicalSystem({"l": a_molB}) dag = protocol.create( - stateA=systemA, stateB=systemB, mapping=mapping, + stateA=systemA, + stateB=systemB, + mapping=mapping, ) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)['debug']['sampler'] + sampler = dag_unit.run(dry=True)["debug"]["sampler"] htf = sampler._factory hybrid_system = htf.hybrid_system # get the standard nonbonded force - nonbond = [f for f in hybrid_system.getForces() - if isinstance(f, NonbondedForce)] + nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)] assert len(nonbond) == 1 # get the particle parameter offsets c_offsets = {} for i in range(nonbond[0].getNumParticleParameterOffsets()): offset = nonbond[0].getParticleParameterOffset(i) - c_offsets[offset[1]] = ensure_quantity(offset[2], 'openff') + c_offsets[offset[1]] = ensure_quantity(offset[2], "openff") # See the user charges test below for an idea of what we're doing here # In this particular case we are solely checking that the old atoms @@ -665,17 +659,17 @@ def test_dry_run_charge_backends( for i in range(hybrid_system.getNumParticles()): c, s, e = nonbond[0].getParticleParameters(i) # get the particle charge (c) - c = ensure_quantity(c, 'openff') + c = ensure_quantity(c, "openff") # particle charge (c) is equal to molA particle charge # offset (c_offsets) is equal to -(molA particle charge) - if i in htf._atom_classes['unique_old_atoms']: + if i in htf._atom_classes["unique_old_atoms"]: idx = htf._hybrid_to_old_map[i] ref = am1bcc_ref_charges[ref_key][idx] np.testing.assert_allclose(c, ref, rtol=1e-4) np.testing.assert_allclose(c_offsets[i], -ref, rtol=1e-4) # particle charge (c) is equal to molA particle charge # offset (c_offsets) is equal to difference between molB and molA - elif i in htf._atom_classes['core_atoms']: + elif i in htf._atom_classes["core_atoms"]: old_i = htf._hybrid_to_old_map[i] ref = am1bcc_ref_charges[ref_key][i] np.testing.assert_allclose(c, ref, rtol=1e-4) @@ -689,7 +683,7 @@ def test_dry_run_user_charges(benzene_modifications, tmpdir): hybrid topology. """ vac_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" vac_settings.protocol_repeats = 1 protocol = openmm_rfe.RelativeHybridTopologyProtocol( @@ -712,13 +706,13 @@ def check_propchgs(smc, charge_array): Check that the partial charges we assigned to our offmol from which the smc was constructed are present and the right ones. """ - prop_chgs = smc.to_dict()['molprops']['atom.dprop.PartialCharge'] + prop_chgs = smc.to_dict()["molprops"]["atom.dprop.PartialCharge"] prop_chgs = np.array(prop_chgs.split(), dtype=float) np.testing.assert_allclose(prop_chgs, charge_array.m) # Create new smc with overriden charges - benzene_offmol = benzene_modifications['benzene'].to_openff() - toluene_offmol = benzene_modifications['toluene'].to_openff() + benzene_offmol = benzene_modifications["benzene"].to_openff() + toluene_offmol = benzene_modifications["toluene"].to_openff() benzene_rand_chg = assign_fictitious_charges(benzene_offmol) toluene_rand_chg = assign_fictitious_charges(toluene_offmol) benzene_offmol.partial_charges = benzene_rand_chg @@ -736,27 +730,34 @@ def check_propchgs(smc, charge_array): # create DAG from protocol and take first (and only) work unit from within dag = protocol.create( - stateA=openfe.ChemicalSystem({'l': benzene_smc, }), - stateB=openfe.ChemicalSystem({'l': toluene_smc, }), + stateA=openfe.ChemicalSystem( + { + "l": benzene_smc, + }, + ), + stateB=openfe.ChemicalSystem( + { + "l": toluene_smc, + }, + ), mapping=mapping, ) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)['debug']['sampler'] + sampler = dag_unit.run(dry=True)["debug"]["sampler"] htf = sampler._factory hybrid_system = htf.hybrid_system # get the standard nonbonded force - nonbond = [f for f in hybrid_system.getForces() - if isinstance(f, NonbondedForce)] + nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)] assert len(nonbond) == 1 # get the particle parameter offsets c_offsets = {} for i in range(nonbond[0].getNumParticleParameterOffsets()): offset = nonbond[0].getParticleParameterOffset(i) - c_offsets[offset[1]] = ensure_quantity(offset[2], 'openff') + c_offsets[offset[1]] = ensure_quantity(offset[2], "openff") # Here is a bit of exposition on what we're doing # HTF creates two sets of nonbonded forces, a standard one (for the @@ -785,22 +786,22 @@ def check_propchgs(smc, charge_array): for i in range(hybrid_system.getNumParticles()): c, s, e = nonbond[0].getParticleParameters(i) # get the particle charge (c) - c = ensure_quantity(c, 'openff') + c = ensure_quantity(c, "openff") # particle charge (c) is equal to molA particle charge # offset (c_offsets) is equal to -(molA particle charge) - if i in htf._atom_classes['unique_old_atoms']: + if i in htf._atom_classes["unique_old_atoms"]: idx = htf._hybrid_to_old_map[i] np.testing.assert_allclose(c, benzene_rand_chg[idx]) np.testing.assert_allclose(c_offsets[i], -benzene_rand_chg[idx]) # particle charge (c) is equal to 0 # offset (c_offsets) is equal to molB particle charge - elif i in htf._atom_classes['unique_new_atoms']: + elif i in htf._atom_classes["unique_new_atoms"]: idx = htf._hybrid_to_new_map[i] np.testing.assert_allclose(c, 0 * unit.elementary_charge) np.testing.assert_allclose(c_offsets[i], toluene_rand_chg[idx]) # particle charge (c) is equal to molA particle charge # offset (c_offsets) is equal to difference between molB and molA - elif i in htf._atom_classes['core_atoms']: + elif i in htf._atom_classes["core_atoms"]: old_i = htf._hybrid_to_old_map[i] new_i = htf._hybrid_to_new_map[i] c_exp = toluene_rand_chg[new_i] - benzene_rand_chg[old_i] @@ -808,8 +809,7 @@ def check_propchgs(smc, charge_array): np.testing.assert_allclose(c_offsets[i], c_exp) -def test_virtual_sites_no_reassign(benzene_system, toluene_system, - benzene_to_toluene_mapping, tmpdir): +def test_virtual_sites_no_reassign(benzene_system, toluene_system, benzene_to_toluene_mapping, tmpdir): """ Because of some as-of-yet not fully identified issue, not reassigning velocities will cause systems to NaN. @@ -817,17 +817,17 @@ def test_virtual_sites_no_reassign(benzene_system, toluene_system, """ settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() settings.forcefield_settings.forcefields = [ - "amber/ff14SB.xml", # ff14SB protein force field + "amber/ff14SB.xml", # ff14SB protein force field "amber/tip4pew_standard.xml", # FF we are testsing with the fun VS "amber/phosaa10.xml", # Handles THE TPO ] settings.solvation_settings.solvent_padding = 1.0 * unit.nanometer settings.forcefield_settings.nonbonded_cutoff = 0.9 * unit.nanometer - settings.solvation_settings.solvent_model = 'tip4pew' + settings.solvation_settings.solvent_model = "tip4pew" settings.integrator_settings.reassign_velocities = False protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) dag = protocol.create( stateA=benzene_system, @@ -843,17 +843,16 @@ def test_virtual_sites_no_reassign(benzene_system, toluene_system, @pytest.mark.slow -@pytest.mark.parametrize('method', ['repex', 'sams', 'independent']) -def test_dry_run_complex(benzene_complex_system, toluene_complex_system, - benzene_to_toluene_mapping, method, tmpdir): +@pytest.mark.parametrize("method", ["repex", "sams", "independent"]) +def test_dry_run_complex(benzene_complex_system, toluene_complex_system, benzene_to_toluene_mapping, method, tmpdir): # this will be very time consuming settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() settings.simulation_settings.sampler_method = method settings.protocol_repeats = 1 - settings.output_settings.output_indices = 'protein or resname UNK' + settings.output_settings.output_indices = "protein or resname UNK" protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) dag = protocol.create( stateA=benzene_complex_system, @@ -863,39 +862,35 @@ def test_dry_run_complex(benzene_complex_system, toluene_complex_system, dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)['debug']['sampler'] + sampler = dag_unit.run(dry=True)["debug"]["sampler"] assert isinstance(sampler, MultiStateSampler) assert sampler.is_periodic - assert isinstance(sampler._thermodynamic_states[0].barostat, - MonteCarloBarostat) + assert isinstance(sampler._thermodynamic_states[0].barostat, MonteCarloBarostat) assert sampler._thermodynamic_states[1].pressure == 1 * omm_unit.bar # Check we have the right number of atoms in the PDB - pdb = mdt.load_pdb('hybrid_system.pdb') + pdb = mdt.load_pdb("hybrid_system.pdb") assert pdb.n_atoms == 2629 def test_lambda_schedule_default(): - lambdas = openmm_rfe._rfe_utils.lambdaprotocol.LambdaProtocol(functions='default') + lambdas = openmm_rfe._rfe_utils.lambdaprotocol.LambdaProtocol(functions="default") assert len(lambdas.lambda_schedule) == 10 -@pytest.mark.parametrize('windows', [11, 6, 9000]) +@pytest.mark.parametrize("windows", [11, 6, 9000]) def test_lambda_schedule(windows): - lambdas = openmm_rfe._rfe_utils.lambdaprotocol.LambdaProtocol( - functions='default', windows=windows) + lambdas = openmm_rfe._rfe_utils.lambdaprotocol.LambdaProtocol(functions="default", windows=windows) assert len(lambdas.lambda_schedule) == windows -def test_hightimestep(benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, tmpdir): +def test_hightimestep(benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, tmpdir): settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() settings.forcefield_settings.hydrogen_mass = 1.0 - settings.forcefield_settings.nonbonded_method = 'nocutoff' + settings.forcefield_settings.nonbonded_method = "nocutoff" p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) dag = p.create( @@ -911,23 +906,20 @@ def test_hightimestep(benzene_vacuum_system, dag_unit.run(dry=True) -def test_n_replicas_not_n_windows(benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, tmpdir): +def test_n_replicas_not_n_windows(benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, tmpdir): # For PR #125 we pin such that the number of lambda windows # equals the numbers of replicas used - TODO: remove limitation settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() # default lambda windows is 11 settings.simulation_settings.n_replicas = 13 - settings.forcefield_settings.nonbonded_method = 'nocutoff' + settings.forcefield_settings.nonbonded_method = "nocutoff" - errmsg = ("Number of replicas 13 does not equal the number of " - "lambda windows 11") + errmsg = "Number of replicas 13 does not equal the number of " "lambda windows 11" with tmpdir.as_cwd(): with pytest.raises(ValueError, match=errmsg): p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=settings, + settings=settings, ) dag = p.create( stateA=benzene_vacuum_system, @@ -940,7 +932,7 @@ def test_n_replicas_not_n_windows(benzene_vacuum_system, def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): # state B doesn't have a ligand component - stateB = openfe.ChemicalSystem({'solvent': openfe.SolventComponent()}) + stateB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) p = openmm_rfe.RelativeHybridTopologyProtocol( settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), @@ -955,10 +947,9 @@ def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): ) -def test_vaccuum_PME_error(benzene_vacuum_system, benzene_modifications, - benzene_to_toluene_mapping): +def test_vaccuum_PME_error(benzene_vacuum_system, benzene_modifications, benzene_to_toluene_mapping): # state B doesn't have a solvent component (i.e. its vacuum) - stateB = openfe.ChemicalSystem({'ligand': benzene_modifications['toluene']}) + stateB = openfe.ChemicalSystem({"ligand": benzene_modifications["toluene"]}) p = openmm_rfe.RelativeHybridTopologyProtocol( settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), @@ -972,13 +963,13 @@ def test_vaccuum_PME_error(benzene_vacuum_system, benzene_modifications, ) -def test_incompatible_solvent(benzene_system, benzene_modifications, - benzene_to_toluene_mapping): +def test_incompatible_solvent(benzene_system, benzene_modifications, benzene_to_toluene_mapping): # the solvents are different stateB = openfe.ChemicalSystem( - {'ligand': benzene_modifications['toluene'], - 'solvent': openfe.SolventComponent( - positive_ion='K', negative_ion='Cl')} + { + "ligand": benzene_modifications["toluene"], + "solvent": openfe.SolventComponent(positive_ion="K", negative_ion="Cl"), + }, ) p = openmm_rfe.RelativeHybridTopologyProtocol( @@ -995,19 +986,18 @@ def test_incompatible_solvent(benzene_system, benzene_modifications, ) -def test_mapping_mismatch_A(benzene_system, toluene_system, - benzene_modifications): +def test_mapping_mismatch_A(benzene_system, toluene_system, benzene_modifications): # the atom mapping doesn't refer to the ligands in the systems mapping = setup.LigandAtomMapping( - componentA=benzene_system.components['ligand'], - componentB=benzene_modifications['phenol'], - componentA_to_componentB=dict()) + componentA=benzene_system.components["ligand"], + componentB=benzene_modifications["phenol"], + componentA_to_componentB=dict(), + ) p = openmm_rfe.RelativeHybridTopologyProtocol( settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), ) - errmsg = (r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=toluene\)") + errmsg = r"Unmapped alchemical component " r"SmallMoleculeComponent\(name=toluene\)" with pytest.raises(ValueError, match=errmsg): _ = p.create( stateA=benzene_system, @@ -1016,18 +1006,17 @@ def test_mapping_mismatch_A(benzene_system, toluene_system, ) -def test_mapping_mismatch_B(benzene_system, toluene_system, - benzene_modifications): +def test_mapping_mismatch_B(benzene_system, toluene_system, benzene_modifications): mapping = setup.LigandAtomMapping( - componentA=benzene_modifications['phenol'], - componentB=toluene_system.components['ligand'], - componentA_to_componentB=dict()) + componentA=benzene_modifications["phenol"], + componentB=toluene_system.components["ligand"], + componentA_to_componentB=dict(), + ) p = openmm_rfe.RelativeHybridTopologyProtocol( settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), ) - errmsg = (r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=benzene\)") + errmsg = r"Unmapped alchemical component " r"SmallMoleculeComponent\(name=benzene\)" with pytest.raises(ValueError, match=errmsg): _ = p.create( stateA=benzene_system, @@ -1036,8 +1025,7 @@ def test_mapping_mismatch_B(benzene_system, toluene_system, ) -def test_complex_mismatch(benzene_system, toluene_complex_system, - benzene_to_toluene_mapping): +def test_complex_mismatch(benzene_system, toluene_complex_system, benzene_to_toluene_mapping): # only one complex p = openmm_rfe.RelativeHybridTopologyProtocol( settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), @@ -1050,8 +1038,7 @@ def test_complex_mismatch(benzene_system, toluene_complex_system, ) -def test_too_many_specified_mappings(benzene_system, toluene_system, - benzene_to_toluene_mapping): +def test_too_many_specified_mappings(benzene_system, toluene_system, benzene_to_toluene_mapping): # mapping dict requires 'ligand' key p = openmm_rfe.RelativeHybridTopologyProtocol( settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), @@ -1061,21 +1048,16 @@ def test_too_many_specified_mappings(benzene_system, toluene_system, _ = p.create( stateA=benzene_system, stateB=toluene_system, - mapping=[benzene_to_toluene_mapping, - benzene_to_toluene_mapping], + mapping=[benzene_to_toluene_mapping, benzene_to_toluene_mapping], ) -def test_protein_mismatch(benzene_complex_system, toluene_complex_system, - benzene_to_toluene_mapping): +def test_protein_mismatch(benzene_complex_system, toluene_complex_system, benzene_to_toluene_mapping): # hack one protein to be labelled differently - prot = toluene_complex_system['protein'] - alt_prot = openfe.ProteinComponent(prot.to_rdkit(), - name='Mickey Mouse') + prot = toluene_complex_system["protein"] + alt_prot = openfe.ProteinComponent(prot.to_rdkit(), name="Mickey Mouse") alt_toluene_complex_system = openfe.ChemicalSystem( - {'ligand': toluene_complex_system['ligand'], - 'solvent': toluene_complex_system['solvent'], - 'protein': alt_prot} + {"ligand": toluene_complex_system["ligand"], "solvent": toluene_complex_system["solvent"], "protein": alt_prot}, ) p = openmm_rfe.RelativeHybridTopologyProtocol( @@ -1091,17 +1073,17 @@ def test_protein_mismatch(benzene_complex_system, toluene_complex_system, def test_element_change_warning(atom_mapping_basic_test_files): # check a mapping with element change gets rejected early - l1 = atom_mapping_basic_test_files['2-methylnaphthalene'] - l2 = atom_mapping_basic_test_files['2-naftanol'] + l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] + l2 = atom_mapping_basic_test_files["2-naftanol"] mapper = setup.LomapAtomMapper() mapping = next(mapper.suggest_mappings(l1, l2)) sys1 = openfe.ChemicalSystem( - {'ligand': l1, 'solvent': openfe.SolventComponent()}, + {"ligand": l1, "solvent": openfe.SolventComponent()}, ) sys2 = openfe.ChemicalSystem( - {'ligand': l2, 'solvent': openfe.SolventComponent()}, + {"ligand": l2, "solvent": openfe.SolventComponent()}, ) p = openmm_rfe.RelativeHybridTopologyProtocol( @@ -1109,46 +1091,43 @@ def test_element_change_warning(atom_mapping_basic_test_files): ) with pytest.warns(UserWarning, match="Element change"): _ = p.create( - stateA=sys1, stateB=sys2, + stateA=sys1, + stateB=sys2, mapping=mapping, ) -def test_ligand_overlap_warning(benzene_vacuum_system, toluene_vacuum_system, - benzene_to_toluene_mapping, tmpdir): +def test_ligand_overlap_warning(benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, tmpdir): vac_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, + settings=vac_settings, ) # update atom positions sysA = benzene_vacuum_system - rdmol = benzene_vacuum_system['ligand'].to_rdkit() + rdmol = benzene_vacuum_system["ligand"].to_rdkit() conf = rdmol.GetConformer() for atm in range(rdmol.GetNumAtoms()): x, y, z = conf.GetAtomPosition(atm) - conf.SetAtomPosition(atm, Point3D(x+3, y, z)) + conf.SetAtomPosition(atm, Point3D(x + 3, y, z)) - new_ligand = openfe.SmallMoleculeComponent.from_rdkit( - rdmol, name=benzene_vacuum_system['ligand'].name - ) + new_ligand = openfe.SmallMoleculeComponent.from_rdkit(rdmol, name=benzene_vacuum_system["ligand"].name) components = dict(benzene_vacuum_system.components) - components['ligand'] = new_ligand + components["ligand"] = new_ligand sysA = openfe.ChemicalSystem(components) - mapping = benzene_to_toluene_mapping.copy_with_replacements( - componentA=new_ligand - ) + mapping = benzene_to_toluene_mapping.copy_with_replacements(componentA=new_ligand) # Specifically check that the first pair throws a warning - with pytest.warns(UserWarning, match='0 : 4 deviates'): + with pytest.warns(UserWarning, match="0 : 4 deviates"): dag = protocol.create( - stateA=sysA, stateB=toluene_vacuum_system, + stateA=sysA, + stateB=toluene_vacuum_system, mapping=mapping, - ) + ) dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): dag_unit.run(dry=True) @@ -1161,7 +1140,8 @@ def solvent_protocol_dag(benzene_system, toluene_system, benzene_to_toluene_mapp settings=settings, ) return protocol.create( - stateA=benzene_system, stateB=toluene_system, + stateA=benzene_system, + stateB=toluene_system, mapping=benzene_to_toluene_mapping, ) @@ -1169,8 +1149,10 @@ def solvent_protocol_dag(benzene_system, toluene_system, benzene_to_toluene_mapp def test_unit_tagging(solvent_protocol_dag, tmpdir): # test that executing the Units includes correct generation and repeat info dag_units = solvent_protocol_dag.protocol_units - with mock.patch('openfe.protocols.openmm_rfe.equil_rfe_methods.RelativeHybridTopologyProtocolUnit.run', - return_value={'nc': 'file.nc', 'last_checkpoint': 'chk.nc'}): + with mock.patch( + "openfe.protocols.openmm_rfe.equil_rfe_methods.RelativeHybridTopologyProtocolUnit.run", + return_value={"nc": "file.nc", "last_checkpoint": "chk.nc"}, + ): results = [] for u in dag_units: ret = u.execute(context=gufe.Context(tmpdir, tmpdir)) @@ -1178,23 +1160,27 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir): repeats = set() for ret in results: assert isinstance(ret, gufe.ProtocolUnitResult) - assert ret.outputs['generation'] == 0 - repeats.add(ret.outputs['repeat_id']) + assert ret.outputs["generation"] == 0 + repeats.add(ret.outputs["repeat_id"]) # repeats are random ints, so check we got 3 individual numbers assert len(repeats) == 3 def test_gather(solvent_protocol_dag, tmpdir): # check .gather behaves as expected - with mock.patch('openfe.protocols.openmm_rfe.equil_rfe_methods.RelativeHybridTopologyProtocolUnit.run', - return_value={'nc': 'file.nc', 'last_checkpoint': 'chk.nc'}): - dagres = gufe.protocols.execute_DAG(solvent_protocol_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, - keep_shared=True) + with mock.patch( + "openfe.protocols.openmm_rfe.equil_rfe_methods.RelativeHybridTopologyProtocolUnit.run", + return_value={"nc": "file.nc", "last_checkpoint": "chk.nc"}, + ): + dagres = gufe.protocols.execute_DAG( + solvent_protocol_dag, + shared_basedir=tmpdir, + scratch_basedir=tmpdir, + keep_shared=True, + ) prot = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), ) res = prot.gather([dagres]) @@ -1204,33 +1190,31 @@ def test_gather(solvent_protocol_dag, tmpdir): class TestConstraintRemoval: @staticmethod - def make_systems(ligA: openfe.SmallMoleculeComponent, - ligB: openfe.SmallMoleculeComponent, - constraints): + def make_systems(ligA: openfe.SmallMoleculeComponent, ligB: openfe.SmallMoleculeComponent, constraints): """Make vacuum system for each, return Topology and System for each""" - omm_forcefield_A = app.ForceField('tip3p.xml') + omm_forcefield_A = app.ForceField("tip3p.xml") smirnoff_A = SMIRNOFFTemplateGenerator( - forcefield='openff-2.0.0.offxml', + forcefield="openff-2.0.0.offxml", molecules=[ligA.to_openff()], ) omm_forcefield_A.registerTemplateGenerator(smirnoff_A.generator) - omm_forcefield_B = app.ForceField('tip3p.xml') + omm_forcefield_B = app.ForceField("tip3p.xml") smirnoff_B = SMIRNOFFTemplateGenerator( - forcefield='openff-2.0.0.offxml', + forcefield="openff-2.0.0.offxml", molecules=[ligB.to_openff()], ) omm_forcefield_B.registerTemplateGenerator(smirnoff_B.generator) stateA_modeller = app.Modeller( ligA.to_openff().to_topology().to_openmm(), - ensure_quantity(ligA.to_openff().conformers[0], 'openmm') + ensure_quantity(ligA.to_openff().conformers[0], "openmm"), ) stateA_topology = stateA_modeller.getTopology() stateA_system = omm_forcefield_A.createSystem( stateA_topology, nonbondedMethod=app.CutoffNonPeriodic, - nonbondedCutoff=ensure_quantity(1.1 * unit.nm, 'openmm'), + nonbondedCutoff=ensure_quantity(1.1 * unit.nm, "openmm"), constraints=constraints, rigidWater=True, hydrogenMass=None, @@ -1240,7 +1224,7 @@ def make_systems(ligA: openfe.SmallMoleculeComponent, stateB_topology, _ = openmm_rfe._rfe_utils.topologyhelpers.combined_topology( stateA_topology, ligB.to_openff().to_topology().to_openmm(), - exclude_resids=np.array([r.index for r in list(stateA_topology.residues())]) + exclude_resids=np.array([r.index for r in list(stateA_topology.residues())]), ) # since we're doing a swap of the only molecule, this is equivalent: # stateB_topology = app.Modeller( @@ -1251,7 +1235,7 @@ def make_systems(ligA: openfe.SmallMoleculeComponent, stateB_system = omm_forcefield_B.createSystem( stateB_topology, nonbondedMethod=app.CutoffNonPeriodic, - nonbondedCutoff=ensure_quantity(1.1 * unit.nm, 'openmm'), + nonbondedCutoff=ensure_quantity(1.1 * unit.nm, "openmm"), constraints=constraints, rigidWater=True, hydrogenMass=None, @@ -1260,17 +1244,15 @@ def make_systems(ligA: openfe.SmallMoleculeComponent, return stateA_topology, stateA_system, stateB_topology, stateB_system - @pytest.mark.parametrize('reverse', [False, True]) - def test_remove_constraints_lengthchange(self, benzene_modifications, - reverse): + @pytest.mark.parametrize("reverse", [False, True]) + def test_remove_constraints_lengthchange(self, benzene_modifications, reverse): # check that mappings are correctly corrected to avoid changes in # constraint length # use a phenol->toluene transform to test - ligA = benzene_modifications['phenol'] - ligB = benzene_modifications['toluene'] + ligA = benzene_modifications["phenol"] + ligB = benzene_modifications["toluene"] - mapping = {0: 4, 1: 5, 2: 6, 3: 7, 4: 8, 5: 9, 6: 10, - 7: 11, 8: 12, 9: 13, 10: 1, 11: 14, 12: 2} + mapping = {0: 4, 1: 5, 2: 6, 3: 7, 4: 8, 5: 9, 6: 10, 7: 11, 8: 12, 9: 13, 10: 1, 11: 14, 12: 2} expected = 10 # this should get removed from mapping @@ -1289,14 +1271,19 @@ def test_remove_constraints_lengthchange(self, benzene_modifications, ) stateA_topology, stateA_system, stateB_topology, stateB_system = self.make_systems( - ligA, ligB, constraints=app.HBonds) + ligA, + ligB, + constraints=app.HBonds, + ) # this normally requires global indices, however as ligandA/B is only thing # in system, this mapping is still correct ret = openmm_rfe._rfe_utils.topologyhelpers._remove_constraints( mapping.componentA_to_componentB, - stateA_system, stateA_topology, - stateB_system, stateB_topology, + stateA_system, + stateA_topology, + stateB_system, + stateB_topology, ) # all of this just to check that an entry was removed from the mapping @@ -1305,13 +1292,12 @@ def test_remove_constraints_lengthchange(self, benzene_modifications, # but only one constraint should be removed assert len(ret) == len(mapping.componentA_to_componentB) - 1 - @pytest.mark.parametrize('reverse', [False, True]) + @pytest.mark.parametrize("reverse", [False, True]) def test_constraint_to_harmonic(self, benzene_modifications, reverse): - ligA = benzene_modifications['benzene'] - ligB = benzene_modifications['toluene'] + ligA = benzene_modifications["benzene"] + ligB = benzene_modifications["toluene"] expected = 10 - mapping = {0: 4, 1: 5, 2: 6, 3: 7, 4: 8, 5: 9, - 6: 10, 7: 11, 8: 12, 9: 13, 10: 2, 11: 14} + mapping = {0: 4, 1: 5, 2: 6, 3: 7, 4: 8, 5: 9, 6: 10, 7: 11, 8: 12, 9: 13, 10: 2, 11: 14} if reverse: ligA, ligB = ligB, ligA expected = mapping[expected] @@ -1319,56 +1305,62 @@ def test_constraint_to_harmonic(self, benzene_modifications, reverse): # this maps a -H to a -C, so the constraint on -H turns into a C-C bond # H constraint is A(4, 10) and C-C is B(8, 2) - mapping = setup.LigandAtomMapping( - componentA=ligA, componentB=ligB, - componentA_to_componentB=mapping - ) + mapping = setup.LigandAtomMapping(componentA=ligA, componentB=ligB, componentA_to_componentB=mapping) stateA_topology, stateA_system, stateB_topology, stateB_system = self.make_systems( - ligA, ligB, constraints=app.HBonds) + ligA, + ligB, + constraints=app.HBonds, + ) ret = openmm_rfe._rfe_utils.topologyhelpers._remove_constraints( mapping.componentA_to_componentB, - stateA_system, stateA_topology, - stateB_system, stateB_topology, + stateA_system, + stateA_topology, + stateB_system, + stateB_topology, ) assert expected not in ret assert len(ret) == len(mapping.componentA_to_componentB) - 1 - @pytest.mark.parametrize('reverse', [False, True]) - def test_constraint_to_harmonic_nitrile(self, benzene_modifications, - reverse): + @pytest.mark.parametrize("reverse", [False, True]) + def test_constraint_to_harmonic_nitrile(self, benzene_modifications, reverse): # same as previous test, but ligands are swapped # this follows a slightly different code path - ligA = benzene_modifications['toluene'] - ligB = benzene_modifications['benzonitrile'] + ligA = benzene_modifications["toluene"] + ligB = benzene_modifications["benzonitrile"] if reverse: ligA, ligB = ligB, ligA - mapping = {0: 0, 2: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 7, 10: 8, - 11: 9, 12: 10, 13: 11, 14: 12} + mapping = {0: 0, 2: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 7, 10: 8, 11: 9, 12: 10, 13: 11, 14: 12} if reverse: mapping = {v: k for k, v in mapping.items()} mapping = setup.LigandAtomMapping( - componentA=ligA, componentB=ligB, + componentA=ligA, + componentB=ligB, componentA_to_componentB=mapping, ) stateA_topology, stateA_system, stateB_topology, stateB_system = self.make_systems( - ligA, ligB, constraints=app.HBonds) + ligA, + ligB, + constraints=app.HBonds, + ) ret = openmm_rfe._rfe_utils.topologyhelpers._remove_constraints( mapping.componentA_to_componentB, - stateA_system, stateA_topology, - stateB_system, stateB_topology, + stateA_system, + stateA_topology, + stateB_system, + stateB_topology, ) assert 0 not in ret assert len(ret) == len(mapping.componentA_to_componentB) - 1 - @pytest.mark.parametrize('reverse', [False, True]) + @pytest.mark.parametrize("reverse", [False, True]) def test_non_H_constraint_fail(self, benzene_modifications, reverse): # here we specify app.AllBonds constraints # in this transform, the C-C[#N] to C-C[=O] constraint changes length @@ -1376,83 +1368,116 @@ def test_non_H_constraint_fail(self, benzene_modifications, reverse): # there's no Hydrogen involved so we can't trivially figure out the # best atom to remove from mapping # (but it would be 2 [& 1] in this case..) - ligA = benzene_modifications['toluene'] - ligB = benzene_modifications['benzonitrile'] + ligA = benzene_modifications["toluene"] + ligB = benzene_modifications["benzonitrile"] - mapping = {0: 0, 2: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 7, 10: 8, - 11: 9, 12: 10, 13: 11, 14: 12} + mapping = {0: 0, 2: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 7, 10: 8, 11: 9, 12: 10, 13: 11, 14: 12} if reverse: ligA, ligB = ligB, ligA mapping = {v: k for k, v in mapping.items()} mapping = setup.LigandAtomMapping( - componentA=ligA, componentB=ligB, + componentA=ligA, + componentB=ligB, componentA_to_componentB=mapping, ) stateA_topology, stateA_system, stateB_topology, stateB_system = self.make_systems( - ligA, ligB, constraints=app.AllBonds) + ligA, + ligB, + constraints=app.AllBonds, + ) - with pytest.raises(ValueError, match='resolve constraint') as e: + with pytest.raises(ValueError, match="resolve constraint") as e: _ = openmm_rfe._rfe_utils.topologyhelpers._remove_constraints( mapping.componentA_to_componentB, - stateA_system, stateA_topology, - stateB_system, stateB_topology, + stateA_system, + stateA_topology, + stateB_system, + stateB_topology, ) if not reverse: - assert 'A: 2-8 B: 1-6' in str(e) + assert "A: 2-8 B: 1-6" in str(e) else: - assert 'A: 1-6 B: 2-8' in str(e) + assert "A: 1-6 B: 2-8" in str(e) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def tyk2_xml(tmp_path_factory): - with resources.files('openfe.tests.data.openmm_rfe') as d: - fn1 = str(d / 'ligand_23.sdf') - fn2 = str(d / 'ligand_55.sdf') + with resources.files("openfe.tests.data.openmm_rfe") as d: + fn1 = str(d / "ligand_23.sdf") + fn2 = str(d / "ligand_55.sdf") lig23 = openfe.SmallMoleculeComponent.from_sdf_file(fn1) lig55 = openfe.SmallMoleculeComponent.from_sdf_file(fn2) mapping = setup.LigandAtomMapping( - componentA=lig23, componentB=lig55, + componentA=lig23, + componentB=lig55, # perses mapper output - componentA_to_componentB={0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, - 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, - 13: 13, 14: 14, 15: 15, 16: 16, 17: 17, - 18: 18, 23: 19, 26: 20, 27: 21, 28: 22, - 29: 23, 30: 24, 31: 25, 32: 26, 33: 27} - ) - - settings: openmm_rfe.RelativeHybridTopologyProtocolSettings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - settings.forcefield_settings.small_molecule_forcefield = 'openff-2.0.0' - settings.forcefield_settings.nonbonded_method = 'nocutoff' + componentA_to_componentB={ + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 10: 10, + 11: 11, + 12: 12, + 13: 13, + 14: 14, + 15: 15, + 16: 16, + 17: 17, + 18: 18, + 23: 19, + 26: 20, + 27: 21, + 28: 22, + 29: 23, + 30: 24, + 31: 25, + 32: 26, + 33: 27, + }, + ) + + settings: openmm_rfe.RelativeHybridTopologyProtocolSettings = ( + openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + ) + settings.forcefield_settings.small_molecule_forcefield = "openff-2.0.0" + settings.forcefield_settings.nonbonded_method = "nocutoff" settings.forcefield_settings.hydrogen_mass = 3.0 settings.protocol_repeats = 1 protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings) dag = protocol.create( - stateA=openfe.ChemicalSystem({'ligand': lig23}), - stateB=openfe.ChemicalSystem({'ligand': lig55}), + stateA=openfe.ChemicalSystem({"ligand": lig23}), + stateB=openfe.ChemicalSystem({"ligand": lig55}), mapping=mapping, ) pu = list(dag.protocol_units)[0] - tmp = tmp_path_factory.mktemp('xml_reg') + tmp = tmp_path_factory.mktemp("xml_reg") dryrun = pu.run(dry=True, shared_basepath=tmp) - system = dryrun['debug']['sampler']._hybrid_factory.hybrid_system + system = dryrun["debug"]["sampler"]._hybrid_factory.hybrid_system return ET.fromstring(XmlSerializer.serialize(system)) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def tyk2_reference_xml(): - with resources.files('openfe.tests.data.openmm_rfe') as d: - f = d / 'reference.xml' - with open(f, 'r') as i: + with resources.files("openfe.tests.data.openmm_rfe") as d: + f = d / "reference.xml" + with open(f) as i: xmldata = i.read() return ET.fromstring(xmldata) @@ -1460,45 +1485,44 @@ def tyk2_reference_xml(): @pytest.mark.slow class TestTyk2XmlRegression: """Generates Hybrid system XML and performs regression test""" + @staticmethod def test_particles(tyk2_xml, tyk2_reference_xml): # < Particle mass = "10.018727" / > - particles = tyk2_xml.find('Particles') + particles = tyk2_xml.find("Particles") assert particles - ref_particles = tyk2_reference_xml.find('Particles') + ref_particles = tyk2_reference_xml.find("Particles") for a, b in zip(particles, ref_particles): - assert float(a.get('mass')) == pytest.approx(float(b.get('mass'))) + assert float(a.get("mass")) == pytest.approx(float(b.get("mass"))) @staticmethod def test_constraints(tyk2_xml, tyk2_reference_xml): # - constraints = tyk2_xml.find('Constraints') + constraints = tyk2_xml.find("Constraints") assert constraints - ref_constraints = tyk2_reference_xml.find('Constraints') + ref_constraints = tyk2_reference_xml.find("Constraints") for a, b in zip(constraints, ref_constraints): - assert a.get('p1') == b.get('p1') - assert a.get('p2') == b.get('p2') - assert float(a.get('d')) == pytest.approx(float(b.get('d'))) + assert a.get("p1") == b.get("p1") + assert a.get("p2") == b.get("p2") + assert float(a.get("d")) == pytest.approx(float(b.get("d"))) class TestProtocolResult: @pytest.fixture() def protocolresult(self, rfe_transformation_json): - d = json.loads(rfe_transformation_json, - cls=gufe.tokenization.JSON_HANDLER.decoder) + d = json.loads(rfe_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) - pr = openfe.ProtocolResult.from_dict(d['protocol_result']) + pr = openfe.ProtocolResult.from_dict(d["protocol_result"]) return pr def test_reload_protocol_result(self, rfe_transformation_json): - d = json.loads(rfe_transformation_json, - cls=gufe.tokenization.JSON_HANDLER.decoder) + d = json.loads(rfe_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) - pr = openmm_rfe.RelativeHybridTopologyProtocolResult.from_dict(d['protocol_result']) + pr = openmm_rfe.RelativeHybridTopologyProtocolResult.from_dict(d["protocol_result"]) assert pr @@ -1533,11 +1557,10 @@ def test_get_forwards_etc(self, protocolresult): assert isinstance(far, list) far1 = far[0] assert isinstance(far1, dict) - for k in ['fractions', 'forward_DGs', 'forward_dDGs', - 'reverse_DGs', 'reverse_dDGs']: + for k in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]: assert k in far1 - if k == 'fractions': + if k == "fractions": assert isinstance(far1[k], np.ndarray) else: assert isinstance(far1[k], unit.Quantity) @@ -1550,8 +1573,8 @@ def test_get_overlap_matrices(self, protocolresult): assert len(ovp) == 3 ovp1 = ovp[0] - assert isinstance(ovp1['matrix'], np.ndarray) - assert ovp1['matrix'].shape == (11,11) + assert isinstance(ovp1["matrix"], np.ndarray) + assert ovp1["matrix"].shape == (11, 11) def test_get_replica_transition_statistics(self, protocolresult): rpx = protocolresult.get_replica_transition_statistics() @@ -1559,10 +1582,10 @@ def test_get_replica_transition_statistics(self, protocolresult): assert isinstance(rpx, list) assert len(rpx) == 3 rpx1 = rpx[0] - assert 'eigenvalues' in rpx1 - assert 'matrix' in rpx1 - assert rpx1['eigenvalues'].shape == (11,) - assert rpx1['matrix'].shape == (11, 11) + assert "eigenvalues" in rpx1 + assert "matrix" in rpx1 + assert rpx1["eigenvalues"].shape == (11,) + assert rpx1["matrix"].shape == (11, 11) def test_equilibration_iterations(self, protocolresult): eq = protocolresult.equilibration_iterations() @@ -1584,28 +1607,30 @@ def test_filenotfound_replica_states(self, protocolresult): with pytest.raises(ValueError, match=errmsg): protocolresult.get_replica_states() -@pytest.mark.parametrize('mapping_name,result', [ - ["benzene_to_toluene_mapping", 0], - ["benzene_to_benzoic_mapping", 1], - ["benzene_to_aniline_mapping", -1], - ["aniline_to_benzene_mapping", 1], -]) + +@pytest.mark.parametrize( + "mapping_name,result", + [ + ["benzene_to_toluene_mapping", 0], + ["benzene_to_benzoic_mapping", 1], + ["benzene_to_aniline_mapping", -1], + ["aniline_to_benzene_mapping", 1], + ], +) def test_get_charge_difference(mapping_name, result, request): mapping = request.getfixturevalue(mapping_name) if result != 0: - ion = 'Na\+' if result == -1 else 'Cl\-' - wmsg = (f"A charge difference of {result} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion") + ion = r"Na\+" if result == -1 else r"Cl\-" + wmsg = ( + f"A charge difference of {result} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) with pytest.warns(UserWarning, match=wmsg): - val = _get_alchemical_charge_difference( - mapping, 'pme', True, openfe.SolventComponent() - ) + val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) assert result == pytest.approx(result) else: - val = _get_alchemical_charge_difference( - mapping, 'pme', True, openfe.SolventComponent() - ) + val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) assert result == pytest.approx(result) @@ -1614,17 +1639,20 @@ def test_get_charge_difference_no_pme(benzene_to_benzoic_mapping): with pytest.raises(ValueError, match=errmsg): _get_alchemical_charge_difference( benzene_to_benzoic_mapping, - 'nocutoff', True, openfe.SolventComponent(), + "nocutoff", + True, + openfe.SolventComponent(), ) def test_get_charge_difference_no_corr(benzene_to_benzoic_mapping): - wmsg = ("A charge difference of 1 is observed between the end states. " - "No charge correction has been requested") + wmsg = "A charge difference of 1 is observed between the end states. " "No charge correction has been requested" with pytest.warns(UserWarning, match=wmsg): _get_alchemical_charge_difference( benzene_to_benzoic_mapping, - 'pme', False, openfe.SolventComponent(), + "pme", + False, + openfe.SolventComponent(), ) @@ -1633,13 +1661,15 @@ def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): with pytest.raises(ValueError, match=errmsg): _get_alchemical_charge_difference( aniline_to_benzoic_mapping, - 'pme', True, openfe.SolventComponent(), + "pme", + True, + openfe.SolventComponent(), ) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_solvent_openmm_system(benzene_modifications): - smc = benzene_modifications['benzene'] + smc = benzene_modifications["benzene"] offmol = smc.to_openff() settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() @@ -1666,23 +1696,18 @@ def benzene_solvent_openmm_system(benzene_modifications): topology = modeller.getTopology() positions = to_openmm(from_openmm(modeller.getPositions())) - system = system_generator.create_system( - topology, - molecules=[offmol] - ) + system = system_generator.create_system(topology, molecules=[offmol]) return system, topology, positions -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_tip4p_solvent_openmm_system(benzene_modifications): - smc = benzene_modifications['benzene'] + smc = benzene_modifications["benzene"] offmol = smc.to_openff() settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - settings.forcefield_settings.forcefields = [ - 'amber/ff14SB.xml', 'amber/tip4pew_standard.xml', 'amber/phosaa10.xml' - ] - settings.solvation_settings.solvent_model = 'tip4pew' + settings.forcefield_settings.forcefields = ["amber/ff14SB.xml", "amber/tip4pew_standard.xml", "amber/phosaa10.xml"] + settings.solvation_settings.solvent_model = "tip4pew" system_generator = system_creation.get_system_generator( forcefield_settings=settings.forcefield_settings, @@ -1707,10 +1732,7 @@ def benzene_tip4p_solvent_openmm_system(benzene_modifications): topology = modeller.getTopology() positions = to_openmm(from_openmm(modeller.getPositions())) - system = system_generator.create_system( - topology, - molecules=[offmol] - ) + system = system_generator.create_system(topology, molecules=[offmol]) return system, topology, positions @@ -1725,41 +1747,39 @@ def benzene_self_system_mapping(benzene_solvent_openmm_system): system, topology, positions = benzene_solvent_openmm_system res = [r for r in topology.residues()] - benzene_res = [r for r in res if r.name == 'UNK'][0] + benzene_res = [r for r in res if r.name == "UNK"][0] benzene_ids = [a.index for a in benzene_res.atoms()] env_ids = [a.index for a in topology.atoms() if a.index not in benzene_ids] all_ids = [a.index for a in topology.atoms()] system_mapping = { - 'new_to_old_atom_map': {i: i for i in all_ids}, - 'old_to_new_atom_map': {i: i for i in all_ids}, - 'new_to_old_core_atom_map': {i: i for i in benzene_ids}, - 'old_to_new_core_atom_map': {i: i for i in benzene_ids}, - 'old_to_new_env_atom_map': {i: i for i in env_ids}, - 'new_to_old_env_atom_map': {i: i for i in env_ids}, - 'old_mol_indices': benzene_ids, - 'new_mol_indices': benzene_ids, + "new_to_old_atom_map": {i: i for i in all_ids}, + "old_to_new_atom_map": {i: i for i in all_ids}, + "new_to_old_core_atom_map": {i: i for i in benzene_ids}, + "old_to_new_core_atom_map": {i: i for i in benzene_ids}, + "old_to_new_env_atom_map": {i: i for i in env_ids}, + "new_to_old_env_atom_map": {i: i for i in env_ids}, + "old_mol_indices": benzene_ids, + "new_mol_indices": benzene_ids, } return system_mapping -@pytest.mark.parametrize('ion, water', [ - ['NA', 'SOL'], - ['NX', 'WAT'], -]) -def test_get_ion_water_parameters_unknownresname( - ion, water, benzene_solvent_openmm_system -): +@pytest.mark.parametrize( + "ion, water", + [ + ["NA", "SOL"], + ["NX", "WAT"], + ], +) +def test_get_ion_water_parameters_unknownresname(ion, water, benzene_solvent_openmm_system): system, topology, positions = benzene_solvent_openmm_system errmsg = "Error encountered when attempting to explicitly handle" with pytest.raises(ValueError, match=errmsg): - topologyhelpers._get_ion_and_water_parameters( - topology, system, - ion_resname=ion, water_resname=water - ) + topologyhelpers._get_ion_and_water_parameters(topology, system, ion_resname=ion, water_resname=water) def test_get_alchemical_waters_no_waters( @@ -1771,8 +1791,10 @@ def test_get_alchemical_waters_no_waters( with pytest.raises(ValueError, match=errmsg): topologyhelpers.get_alchemical_waters( - topology, positions, charge_difference=1, - distance_cutoff=2.0 * unit.nanometer + topology, + positions, + charge_difference=1, + distance_cutoff=2.0 * unit.nanometer, ) @@ -1812,7 +1834,9 @@ def test_handle_alchemwats_too_many_nbf( with pytest.raises(ValueError, match=errmsg): topologyhelpers.handle_alchemical_waters( - water_resids=[1,], + water_resids=[ + 1, + ], topology=topology, system=new_system, system_mapping={}, @@ -1834,7 +1858,9 @@ def test_handle_alchemwats_vsite_water( with pytest.raises(ValueError, match=errmsg): topologyhelpers.handle_alchemical_waters( - water_resids=[1,], + water_resids=[ + 1, + ], topology=topology, system=system, system_mapping={}, @@ -1862,7 +1888,9 @@ def test_handle_alchemwats_incorrect_atom( with pytest.raises(ValueError, match=errmsg): topologyhelpers.handle_alchemical_waters( - water_resids=[5,], + water_resids=[ + 5, + ], topology=topology, system=new_system, system_mapping=benzene_self_system_mapping, @@ -1877,11 +1905,13 @@ def test_handle_alchemical_wats( ): system, topology, positions = benzene_solvent_openmm_system - n_env = len(benzene_self_system_mapping['old_to_new_env_atom_map']) - n_core = len(benzene_self_system_mapping['old_to_new_core_atom_map']) + n_env = len(benzene_self_system_mapping["old_to_new_env_atom_map"]) + n_core = len(benzene_self_system_mapping["old_to_new_core_atom_map"]) topologyhelpers.handle_alchemical_waters( - water_resids=[5,], + water_resids=[ + 5, + ], topology=topology, system=system, system_mapping=benzene_self_system_mapping, @@ -1890,12 +1920,12 @@ def test_handle_alchemical_wats( ) # check the mappings - old_new_env = benzene_self_system_mapping['old_to_new_env_atom_map'] - old_new_core = benzene_self_system_mapping['old_to_new_core_atom_map'] + old_new_env = benzene_self_system_mapping["old_to_new_env_atom_map"] + old_new_core = benzene_self_system_mapping["old_to_new_core_atom_map"] assert len(old_new_env) == n_env - 3 - assert old_new_env == benzene_self_system_mapping['new_to_old_env_atom_map'] + assert old_new_env == benzene_self_system_mapping["new_to_old_env_atom_map"] assert len(old_new_core) == n_core + 3 - assert old_new_core == benzene_self_system_mapping['new_to_old_core_atom_map'] + assert old_new_core == benzene_self_system_mapping["new_to_old_core_atom_map"] expected_old_new_core = {i: i for i in range(12)} | {24: 24, 25: 25, 26: 26} assert old_new_core == expected_old_new_core @@ -1903,7 +1933,10 @@ def test_handle_alchemical_wats( nbf = [i for i in system.getForces() if isinstance(i, NonbondedForce)][0] # check the oxygen parameters i_chg, i_sig, i_eps, o_chg, h_chg = topologyhelpers._get_ion_and_water_parameters( - topology, system, 'NA', 'HOH', + topology, + system, + "NA", + "HOH", ) charge, sigma, epsilon = nbf.getParticleParameters(24) @@ -1918,15 +1951,13 @@ def test_handle_alchemical_wats( def _assert_total_charge(system, atom_classes, chgA, chgB): - nonbond = [ - f for f in system.getForces() if isinstance(f, NonbondedForce) - ] + nonbond = [f for f in system.getForces() if isinstance(f, NonbondedForce)] offsets = {} for i in range(nonbond[0].getNumParticleParameterOffsets()): offset = nonbond[0].getParticleParameterOffset(i) assert len(offset) == 5 - offsets[offset[1]] = ensure_quantity(offset[2], 'openff') + offsets[offset[1]] = ensure_quantity(offset[2], "openff") stateA_charges = np.zeros(system.getNumParticles()) stateB_charges = np.zeros(system.getNumParticles()) @@ -1934,24 +1965,24 @@ def _assert_total_charge(system, atom_classes, chgA, chgB): for i in range(system.getNumParticles()): # get the particle charge (c) and the chargeScale offset (c_offset) c, s, e = nonbond[0].getParticleParameters(i) - c = ensure_quantity(c, 'openff') + c = ensure_quantity(c, "openff") # particle charge (c) is equal to molA particle charge # offset (c_offset) is equal to -(molA particle charge) - if i in atom_classes['unique_old_atoms']: + if i in atom_classes["unique_old_atoms"]: stateA_charges[i] = c.m # particle charge (c) is equal to 0 # offset (c_offset) is equal to molB particle charge - elif i in atom_classes['unique_new_atoms']: + elif i in atom_classes["unique_new_atoms"]: stateB_charges[i] = offsets[i].m # particle charge (c) is equal to molA particle charge # offset (c_offset) is equal to difference between molB and molA - elif i in atom_classes['core_atoms']: + elif i in atom_classes["core_atoms"]: stateA_charges[i] = c.m stateB_charges[i] = c.m + offsets[i].m # an environment atom else: - assert i in atom_classes['environment_atoms'] + assert i in atom_classes["environment_atoms"] stateA_charges[i] = c.m stateB_charges[i] = c.m @@ -1961,17 +1992,15 @@ def _assert_total_charge(system, atom_classes, chgA, chgB): def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, tmpdir): stateA_system = openfe.ChemicalSystem( - {'ligand': benzene_to_benzoic_mapping.componentA, - 'solvent': openfe.SolventComponent()} + {"ligand": benzene_to_benzoic_mapping.componentA, "solvent": openfe.SolventComponent()}, ) stateB_system = openfe.ChemicalSystem( - {'ligand': benzene_to_benzoic_mapping.componentB, - 'solvent': openfe.SolventComponent()} + {"ligand": benzene_to_benzoic_mapping.componentB, "solvent": openfe.SolventComponent()}, ) solv_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() solv_settings.alchemical_settings.explicit_charge_correction = True protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=solv_settings, + settings=solv_settings, ) # create DAG from protocol and take first (and only) work unit from within @@ -1983,48 +2012,54 @@ def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, tmpdir): unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = unit.run(dry=True)['debug']['sampler'] + sampler = unit.run(dry=True)["debug"]["sampler"] htf = sampler._factory - _assert_total_charge(htf.hybrid_system, - htf._atom_classes, 0, 0) + _assert_total_charge(htf.hybrid_system, htf._atom_classes, 0, 0) - assert len(htf._atom_classes['core_atoms']) == 14 - assert len(htf._atom_classes['unique_new_atoms']) == 3 - assert len(htf._atom_classes['unique_old_atoms']) == 1 + assert len(htf._atom_classes["core_atoms"]) == 14 + assert len(htf._atom_classes["unique_new_atoms"]) == 3 + assert len(htf._atom_classes["unique_old_atoms"]) == 1 @pytest.mark.slow -@pytest.mark.parametrize('mapping_name,chgA,chgB,correction,core_atoms,new_uniq,old_uniq', [ - ['benzene_to_aniline_mapping', 0, 1, False, 11, 4, 1], - ['aniline_to_benzene_mapping', 0, 0, True, 14, 1, 4], - ['aniline_to_benzene_mapping', 0, -1, False, 11, 1, 4], - ['benzene_to_benzoic_mapping', 0, 0, True, 14, 3, 1], - ['benzene_to_benzoic_mapping', 0, -1, False, 11, 3, 1], - ['benzoic_to_benzene_mapping', 0, 0, True, 14, 1, 3], - ['benzoic_to_benzene_mapping', 0, 1, False, 11, 1, 3], -]) +@pytest.mark.parametrize( + "mapping_name,chgA,chgB,correction,core_atoms,new_uniq,old_uniq", + [ + ["benzene_to_aniline_mapping", 0, 1, False, 11, 4, 1], + ["aniline_to_benzene_mapping", 0, 0, True, 14, 1, 4], + ["aniline_to_benzene_mapping", 0, -1, False, 11, 1, 4], + ["benzene_to_benzoic_mapping", 0, 0, True, 14, 3, 1], + ["benzene_to_benzoic_mapping", 0, -1, False, 11, 3, 1], + ["benzoic_to_benzene_mapping", 0, 0, True, 14, 1, 3], + ["benzoic_to_benzene_mapping", 0, 1, False, 11, 1, 3], + ], +) def test_dry_run_complex_alchemwater_totcharge( - mapping_name, chgA, chgB, correction, core_atoms, - new_uniq, old_uniq, tmpdir, request, T4_protein_component, + mapping_name, + chgA, + chgB, + correction, + core_atoms, + new_uniq, + old_uniq, + tmpdir, + request, + T4_protein_component, ): mapping = request.getfixturevalue(mapping_name) stateA_system = openfe.ChemicalSystem( - {'ligand': mapping.componentA, - 'solvent': openfe.SolventComponent(), - 'protein': T4_protein_component} + {"ligand": mapping.componentA, "solvent": openfe.SolventComponent(), "protein": T4_protein_component}, ) stateB_system = openfe.ChemicalSystem( - {'ligand': mapping.componentB, - 'solvent': openfe.SolventComponent(), - 'protein': T4_protein_component} + {"ligand": mapping.componentB, "solvent": openfe.SolventComponent(), "protein": T4_protein_component}, ) solv_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() solv_settings.alchemical_settings.explicit_charge_correction = correction protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=solv_settings, + settings=solv_settings, ) # create DAG from protocol and take first (and only) work unit from within @@ -2036,11 +2071,10 @@ def test_dry_run_complex_alchemwater_totcharge( unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sampler = unit.run(dry=True)['debug']['sampler'] + sampler = unit.run(dry=True)["debug"]["sampler"] htf = sampler._factory - _assert_total_charge(htf.hybrid_system, - htf._atom_classes, chgA, chgB) + _assert_total_charge(htf.hybrid_system, htf._atom_classes, chgA, chgB) - assert len(htf._atom_classes['core_atoms']) == core_atoms - assert len(htf._atom_classes['unique_new_atoms']) == new_uniq - assert len(htf._atom_classes['unique_old_atoms']) == old_uniq + assert len(htf._atom_classes["core_atoms"]) == core_atoms + assert len(htf._atom_classes["unique_new_atoms"]) == new_uniq + assert len(htf._atom_classes["unique_old_atoms"]) == old_uniq diff --git a/openfe/tests/protocols/test_openmm_plain_md_protocols.py b/openfe/tests/protocols/test_openmm_plain_md_protocols.py index 0fb03a762..de6bb7474 100644 --- a/openfe/tests/protocols/test_openmm_plain_md_protocols.py +++ b/openfe/tests/protocols/test_openmm_plain_md_protocols.py @@ -1,27 +1,24 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe +import json +import logging +import pathlib import sys +from unittest import mock + import gufe import pytest -from unittest import mock from numpy.testing import assert_allclose from openff.units import unit +from openff.units.openmm import from_openmm, to_openmm +from openmm import MonteCarloBarostat, NonbondedForce from openmm import unit as omm_unit -from openmm import NonbondedForce -from openff.units.openmm import to_openmm, from_openmm from openmmtools.states import ThermodynamicState -from openmm import MonteCarloBarostat -from openfe.protocols.openmm_md.plain_md_methods import ( - PlainMDProtocol, PlainMDProtocolUnit, PlainMDProtocolResult, -) -from openfe.protocols.openmm_utils.charge_generation import ( - HAS_NAGL, HAS_OPENEYE, HAS_ESPALOMA -) -import json + import openfe from openfe.protocols import openmm_md -import pathlib -import logging +from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocol, PlainMDProtocolResult, PlainMDProtocolUnit +from openfe.protocols.openmm_utils.charge_generation import HAS_ESPALOMA, HAS_NAGL, HAS_OPENEYE def test_create_default_settings(): @@ -58,7 +55,7 @@ def test_create_independent_repeat_ids(benzene_system): # Default protocol is 1 repeat, change to 3 repeats settings.protocol_repeats = 3 protocol = PlainMDProtocol( - settings=settings, + settings=settings, ) dag1 = protocol.create( stateA=benzene_system, @@ -74,9 +71,9 @@ def test_create_independent_repeat_ids(benzene_system): repeat_ids = set() u: PlainMDProtocolUnit for u in dag1.protocol_units: - repeat_ids.add(u.inputs['repeat_id']) + repeat_ids.add(u.inputs["repeat_id"]) for u in dag2.protocol_units: - repeat_ids.add(u.inputs['repeat_id']) + repeat_ids.add(u.inputs["repeat_id"]) assert len(repeat_ids) == 6 @@ -84,10 +81,10 @@ def test_create_independent_repeat_ids(benzene_system): def test_dry_run_default_vacuum(benzene_vacuum_system, tmpdir): vac_settings = PlainMDProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" protocol = PlainMDProtocol( - settings=vac_settings, + settings=vac_settings, ) # create DAG from protocol and take first (and only) work unit from within @@ -99,23 +96,27 @@ def test_dry_run_default_vacuum(benzene_vacuum_system, tmpdir): dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sim = dag_unit.run(dry=True, verbose=True)['debug']['system'] - assert not ThermodynamicState(sim, temperature=to_openmm( - protocol.settings.thermo_settings.temperature)).is_periodic - assert ThermodynamicState(sim, temperature=to_openmm( - protocol.settings.thermo_settings.temperature)).barostat is None + sim = dag_unit.run(dry=True, verbose=True)["debug"]["system"] + assert not ThermodynamicState( + sim, + temperature=to_openmm(protocol.settings.thermo_settings.temperature), + ).is_periodic + assert ( + ThermodynamicState(sim, temperature=to_openmm(protocol.settings.thermo_settings.temperature)).barostat + is None + ) def test_dry_run_logger_output(benzene_vacuum_system, tmpdir, caplog): vac_settings = PlainMDProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" vac_settings.simulation_settings.equilibration_length_nvt = 1 * unit.picosecond vac_settings.simulation_settings.equilibration_length = 1 * unit.picosecond vac_settings.simulation_settings.production_length = 1 * unit.picosecond protocol = PlainMDProtocol( - settings=vac_settings, + settings=vac_settings, ) # create DAG from protocol and take first (and only) work unit from within @@ -140,11 +141,11 @@ def test_dry_run_logger_output(benzene_vacuum_system, tmpdir, caplog): def test_dry_run_ffcache_none_vacuum(benzene_vacuum_system, tmpdir): vac_settings = PlainMDProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" vac_settings.output_settings.forcefield_cache = None protocol = PlainMDProtocol( - settings=vac_settings, + settings=vac_settings, ) assert protocol.settings.output_settings.forcefield_cache is None @@ -157,16 +158,16 @@ def test_dry_run_ffcache_none_vacuum(benzene_vacuum_system, tmpdir): dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - dag_unit.run(dry=True)['debug']['system'] + dag_unit.run(dry=True)["debug"]["system"] def test_dry_run_gaff_vacuum(benzene_vacuum_system, tmpdir): vac_settings = PlainMDProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' - vac_settings.forcefield_settings.small_molecule_forcefield = 'gaff-2.11' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" + vac_settings.forcefield_settings.small_molecule_forcefield = "gaff-2.11" protocol = PlainMDProtocol( - settings=vac_settings, + settings=vac_settings, ) # create DAG from protocol and take first (and only) work unit from within @@ -178,52 +179,60 @@ def test_dry_run_gaff_vacuum(benzene_vacuum_system, tmpdir): unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - system = unit.run(dry=True)['debug']['system'] - - -@pytest.mark.parametrize('method, backend, ref_key', [ - ('am1bcc', 'ambertools', 'ambertools'), - pytest.param( - 'am1bcc', 'openeye', 'openeye', - marks=pytest.mark.skipif( - not HAS_OPENEYE, reason='needs oechem', + system = unit.run(dry=True)["debug"]["system"] + + +@pytest.mark.parametrize( + "method, backend, ref_key", + [ + ("am1bcc", "ambertools", "ambertools"), + pytest.param( + "am1bcc", + "openeye", + "openeye", + marks=pytest.mark.skipif( + not HAS_OPENEYE, + reason="needs oechem", + ), ), - ), - pytest.param( - 'nagl', 'rdkit', 'nagl', - marks=pytest.mark.skipif( - not HAS_NAGL or sys.platform.startswith('darwin'), - reason='needs NAGL and/or on macos', + pytest.param( + "nagl", + "rdkit", + "nagl", + marks=pytest.mark.skipif( + not HAS_NAGL or sys.platform.startswith("darwin"), + reason="needs NAGL and/or on macos", + ), ), - ), - pytest.param( - 'espaloma', 'rdkit', 'espaloma', - marks=pytest.mark.skipif( - not HAS_ESPALOMA, reason='needs espaloma', + pytest.param( + "espaloma", + "rdkit", + "espaloma", + marks=pytest.mark.skipif( + not HAS_ESPALOMA, + reason="needs espaloma", + ), ), - ), -]) -def test_dry_run_charge_backends( - CN_molecule, tmpdir, method, backend, ref_key, am1bcc_ref_charges -): + ], +) +def test_dry_run_charge_backends(CN_molecule, tmpdir, method, backend, ref_key, am1bcc_ref_charges): vac_settings = PlainMDProtocol.default_settings() - vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.forcefield_settings.nonbonded_method = "nocutoff" vac_settings.partial_charge_settings.partial_charge_method = method vac_settings.partial_charge_settings.off_toolkit_backend = backend vac_settings.partial_charge_settings.nagl_model = "openff-gnn-am1bcc-0.1.0-rc.1.pt" protocol = PlainMDProtocol(settings=vac_settings) - csystem = openfe.ChemicalSystem({'ligand': CN_molecule}) + csystem = openfe.ChemicalSystem({"ligand": CN_molecule}) dag = protocol.create(stateA=csystem, stateB=csystem, mapping=None) md_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - system = md_unit.run(dry=True)['debug']['system'] + system = md_unit.run(dry=True)["debug"]["system"] - nonbond = [f for f in system.getForces() - if isinstance(f, NonbondedForce)][0] + nonbond = [f for f in system.getForces() if isinstance(f, NonbondedForce)][0] charges = [] for i in range(system.getNumParticles()): @@ -235,16 +244,14 @@ def test_dry_run_charge_backends( assert_allclose(am1bcc_ref_charges[ref_key], charges, rtol=1e-4) -def test_dry_many_molecules_solvent( - benzene_many_solv_system, tmpdir -): +def test_dry_many_molecules_solvent(benzene_many_solv_system, tmpdir): """ A basic test flushing "will it work if you pass multiple molecules" """ settings = PlainMDProtocol.default_settings() protocol = PlainMDProtocol( - settings=settings, + settings=settings, ) # create DAG from protocol and take first (and only) work unit from within @@ -256,7 +263,7 @@ def test_dry_many_molecules_solvent( unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - system = unit.run(dry=True)['debug']['system'] + system = unit.run(dry=True)["debug"]["system"] BENZ = """\ @@ -332,17 +339,17 @@ def test_dry_run_ligand_tip4p(benzene_system, tmpdir): """ settings = PlainMDProtocol.default_settings() settings.forcefield_settings.forcefields = [ - "amber/ff14SB.xml", # ff14SB protein force field + "amber/ff14SB.xml", # ff14SB protein force field "amber/tip4pew_standard.xml", # FF we are testsing with the fun VS "amber/phosaa10.xml", # Handles THE TPO ] settings.solvation_settings.solvent_padding = 1.0 * unit.nanometer settings.forcefield_settings.nonbonded_cutoff = 0.9 * unit.nanometer - settings.solvation_settings.solvent_model = 'tip4pew' + settings.solvation_settings.solvent_model = "tip4pew" settings.integrator_settings.reassign_velocities = True protocol = PlainMDProtocol( - settings=settings, + settings=settings, ) dag = protocol.create( stateA=benzene_system, @@ -352,7 +359,7 @@ def test_dry_run_ligand_tip4p(benzene_system, tmpdir): dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - system = dag_unit.run(dry=True)['debug']['system'] + system = dag_unit.run(dry=True)["debug"]["system"] assert system @@ -362,7 +369,7 @@ def test_dry_run_complex(benzene_complex_system, tmpdir): settings = PlainMDProtocol.default_settings() protocol = PlainMDProtocol( - settings=settings, + settings=settings, ) dag = protocol.create( stateA=benzene_complex_system, @@ -372,23 +379,25 @@ def test_dry_run_complex(benzene_complex_system, tmpdir): dag_unit = list(dag.protocol_units)[0] with tmpdir.as_cwd(): - sim = dag_unit.run(dry=True)['debug']['system'] - assert ThermodynamicState(sim, temperature= - to_openmm(protocol.settings.thermo_settings.temperature)).is_periodic - assert isinstance(ThermodynamicState(sim, temperature= - to_openmm(protocol.settings.thermo_settings.temperature)).barostat, - MonteCarloBarostat) - assert ThermodynamicState(sim, temperature= - to_openmm(protocol.settings.thermo_settings.temperature)).pressure == 1 * omm_unit.bar + sim = dag_unit.run(dry=True)["debug"]["system"] + assert ThermodynamicState(sim, temperature=to_openmm(protocol.settings.thermo_settings.temperature)).is_periodic + assert isinstance( + ThermodynamicState(sim, temperature=to_openmm(protocol.settings.thermo_settings.temperature)).barostat, + MonteCarloBarostat, + ) + assert ( + ThermodynamicState(sim, temperature=to_openmm(protocol.settings.thermo_settings.temperature)).pressure + == 1 * omm_unit.bar + ) def test_hightimestep(benzene_vacuum_system, tmpdir): settings = PlainMDProtocol.default_settings() settings.forcefield_settings.hydrogen_mass = 1.0 - settings.forcefield_settings.nonbonded_method = 'nocutoff' + settings.forcefield_settings.nonbonded_method = "nocutoff" p = PlainMDProtocol( - settings=settings, + settings=settings, ) dag = p.create( @@ -427,7 +436,8 @@ def solvent_protocol_dag(benzene_system): ) return protocol.create( - stateA=benzene_system, stateB=benzene_system, + stateA=benzene_system, + stateB=benzene_system, mapping=None, ) @@ -435,8 +445,10 @@ def solvent_protocol_dag(benzene_system): def test_unit_tagging(solvent_protocol_dag, tmpdir): # test that executing the Units includes correct generation and repeat info dag_units = solvent_protocol_dag.protocol_units - with mock.patch('openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run', - return_value={'nc': 'file.nc', 'last_checkpoint': 'chk.nc'}): + with mock.patch( + "openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run", + return_value={"nc": "file.nc", "last_checkpoint": "chk.nc"}, + ): results = [] for u in dag_units: ret = u.execute(context=gufe.Context(tmpdir, tmpdir)) @@ -445,26 +457,28 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir): repeats = set() for ret in results: assert isinstance(ret, gufe.ProtocolUnitResult) - assert ret.outputs['generation'] == 0 - repeats.add(ret.outputs['repeat_id']) + assert ret.outputs["generation"] == 0 + repeats.add(ret.outputs["repeat_id"]) # repeats are random ints, so check we got 3 individual numbers assert len(repeats) == 3 def test_gather(solvent_protocol_dag, tmpdir): # check .gather behaves as expected - with mock.patch('openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run', - return_value={'nc': 'file.nc', 'last_checkpoint': 'chk.nc'}): - dagres = gufe.protocols.execute_DAG(solvent_protocol_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, - keep_shared=True) + with mock.patch( + "openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run", + return_value={"nc": "file.nc", "last_checkpoint": "chk.nc"}, + ): + dagres = gufe.protocols.execute_DAG( + solvent_protocol_dag, + shared_basedir=tmpdir, + scratch_basedir=tmpdir, + keep_shared=True, + ) settings = PlainMDProtocol.default_settings() settings.protocol_repeats = 3 - prot = PlainMDProtocol( - settings=settings - ) + prot = PlainMDProtocol(settings=settings) res = prot.gather([dagres]) @@ -476,14 +490,14 @@ class TestProtocolResult: def protocolresult(self, md_json): d = json.loads(md_json, cls=gufe.tokenization.JSON_HANDLER.decoder) - pr = openfe.ProtocolResult.from_dict(d['protocol_result']) + pr = openfe.ProtocolResult.from_dict(d["protocol_result"]) return pr def test_reload_protocol_result(self, md_json): d = json.loads(md_json, cls=gufe.tokenization.JSON_HANDLER.decoder) - pr = openmm_md.plain_md_methods.PlainMDProtocolResult.from_dict(d['protocol_result']) + pr = openmm_md.plain_md_methods.PlainMDProtocolResult.from_dict(d["protocol_result"]) assert pr @@ -492,7 +506,6 @@ def test_get_estimate(self, protocolresult): assert est is None - def test_get_uncertainty(self, protocolresult): est = protocolresult.get_uncertainty() diff --git a/openfe/tests/protocols/test_openmm_rfe_slow.py b/openfe/tests/protocols/test_openmm_rfe_slow.py index 54c1d6f62..04fe6d74f 100644 --- a/openfe/tests/protocols/test_openmm_rfe_slow.py +++ b/openfe/tests/protocols/test_openmm_rfe_slow.py @@ -1,12 +1,13 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from gufe.protocols import execute_DAG +import os +import pathlib + import pytest +from gufe.protocols import execute_DAG from openff.units import unit from openmm import Platform -import os -import pathlib import openfe from openfe.protocols import openmm_rfe @@ -21,24 +22,29 @@ def available_platforms() -> set[str]: def set_openmm_threads_1(): # for vacuum sims, we want to limit threads to one # this fixture sets OPENMM_CPU_THREADS='1' for a single test, then reverts to previously held value - previous: str | None = os.environ.get('OPENMM_CPU_THREADS') + previous: str | None = os.environ.get("OPENMM_CPU_THREADS") try: - os.environ['OPENMM_CPU_THREADS'] = '1' + os.environ["OPENMM_CPU_THREADS"] = "1" yield finally: if previous is None: - del os.environ['OPENMM_CPU_THREADS'] + del os.environ["OPENMM_CPU_THREADS"] else: - os.environ['OPENMM_CPU_THREADS'] = previous + os.environ["OPENMM_CPU_THREADS"] = previous @pytest.mark.slow @pytest.mark.flaky(reruns=3) # pytest-rerunfailures; we can get bad minimisation -@pytest.mark.parametrize('platform', ['CPU', 'CUDA']) -def test_openmm_run_engine(benzene_vacuum_system, platform, - available_platforms, benzene_modifications, - set_openmm_threads_1, tmpdir): +@pytest.mark.parametrize("platform", ["CPU", "CUDA"]) +def test_openmm_run_engine( + benzene_vacuum_system, + platform, + available_platforms, + benzene_modifications, + set_openmm_threads_1, + tmpdir, +): if platform not in available_platforms: pytest.skip(f"OpenMM Platform: {platform} not available") # this test actually runs MD @@ -48,40 +54,35 @@ def test_openmm_run_engine(benzene_vacuum_system, platform, s.simulation_settings.equilibration_length = 0.1 * unit.picosecond s.simulation_settings.production_length = 0.1 * unit.picosecond s.simulation_settings.time_per_iteration = 20 * unit.femtosecond - s.forcefield_settings.nonbonded_method = 'nocutoff' + s.forcefield_settings.nonbonded_method = "nocutoff" s.protocol_repeats = 1 s.engine_settings.compute_platform = platform s.output_settings.checkpoint_interval = 20 * unit.femtosecond p = openmm_rfe.RelativeHybridTopologyProtocol(s) - b = benzene_vacuum_system['ligand'] + b = benzene_vacuum_system["ligand"] # make a copy with a different name - rdmol = benzene_modifications['benzene'].to_rdkit() - b_alt = openfe.SmallMoleculeComponent.from_rdkit(rdmol, name='alt') - benzene_vacuum_alt_system = openfe.ChemicalSystem({ - 'ligand': b_alt - }) + rdmol = benzene_modifications["benzene"].to_rdkit() + b_alt = openfe.SmallMoleculeComponent.from_rdkit(rdmol, name="alt") + benzene_vacuum_alt_system = openfe.ChemicalSystem({"ligand": b_alt}) - m = openfe.LigandAtomMapping(componentA=b, componentB=b_alt, - componentA_to_componentB={i: i for i in range(12)}) - dag = p.create(stateA=benzene_vacuum_system, stateB=benzene_vacuum_alt_system, - mapping=[m]) + m = openfe.LigandAtomMapping(componentA=b, componentB=b_alt, componentA_to_componentB={i: i for i in range(12)}) + dag = p.create(stateA=benzene_vacuum_system, stateB=benzene_vacuum_alt_system, mapping=[m]) cwd = pathlib.Path(str(tmpdir)) - r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, - keep_shared=True) + r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True) assert r.ok() for pur in r.protocol_unit_results: unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" assert unit_shared.exists() assert pathlib.Path(unit_shared).is_dir() - checkpoint = pur.outputs['last_checkpoint'] + checkpoint = pur.outputs["last_checkpoint"] assert checkpoint == "checkpoint.chk" assert (unit_shared / checkpoint).exists() - nc = pur.outputs['nc'] + nc = pur.outputs["nc"] assert nc == unit_shared / "simulation.nc" assert nc.exists() assert (unit_shared / "structural_analysis.json").exists() @@ -112,26 +113,25 @@ def test_run_eg5_sim(eg5_protein, eg5_ligands, eg5_cofactor, tmpdir): p = openmm_rfe.RelativeHybridTopologyProtocol(s) base_sys = { - 'protein': eg5_protein, - 'cofactor': eg5_cofactor, - 'solvent': openfe.SolventComponent(), + "protein": eg5_protein, + "cofactor": eg5_cofactor, + "solvent": openfe.SolventComponent(), } # this is just a simple (unmapped) *-H -> *-F switch l1, l2 = eg5_ligands[0], eg5_ligands[1] m = openfe.LigandAtomMapping( - componentA=l1, componentB=l2, + componentA=l1, + componentB=l2, # a bit lucky, first 51 atoms map to each other, H->F swap is at 52 - componentA_to_componentB={i: i for i in range(51)} + componentA_to_componentB={i: i for i in range(51)}, ) - sys1 = openfe.ChemicalSystem(components={**base_sys, 'ligand': l1}) - sys2 = openfe.ChemicalSystem(components={**base_sys, 'ligand': l2}) + sys1 = openfe.ChemicalSystem(components={**base_sys, "ligand": l1}) + sys2 = openfe.ChemicalSystem(components={**base_sys, "ligand": l2}) - dag = p.create(stateA=sys1, stateB=sys2, - mapping=[m]) + dag = p.create(stateA=sys1, stateB=sys2, mapping=[m]) cwd = pathlib.Path(str(tmpdir)) - r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, - keep_shared=True) + r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True) assert r.ok() diff --git a/openfe/tests/protocols/test_openmm_settings.py b/openfe/tests/protocols/test_openmm_settings.py index c1937ab4a..e210f44b3 100644 --- a/openfe/tests/protocols/test_openmm_settings.py +++ b/openfe/tests/protocols/test_openmm_settings.py @@ -5,6 +5,7 @@ from openff.units import unit from openfe.protocols.openmm_rfe import equil_rfe_settings + # afe settings currently have no FloatQuantity values from openfe.protocols.openmm_utils import omm_settings @@ -14,14 +15,14 @@ class TestOMMSettingsFromStrings: def test_system_settings(self): s = omm_settings.OpenMMSystemGeneratorFFSettings() - s.nonbonded_cutoff = '1.1 nm' + s.nonbonded_cutoff = "1.1 nm" assert s.nonbonded_cutoff == 1.1 * unit.nanometer def test_solvation_settings(self): s = omm_settings.OpenMMSolvationSettings() - s.solvent_padding = '1.1 nm' + s.solvent_padding = "1.1 nm" assert s.solvent_padding == 1.1 * unit.nanometer @@ -32,11 +33,11 @@ def test_alchemical_sampler_settings(self): def test_integator_settings(self): s = omm_settings.IntegratorSettings() - s.timestep = '3 fs' + s.timestep = "3 fs" assert s.timestep == 3.0 * unit.femtosecond - s.langevin_collision_rate = '1.1 / ps' + s.langevin_collision_rate = "1.1 / ps" assert s.langevin_collision_rate == 1.1 / unit.picosecond @@ -48,8 +49,8 @@ def test_simulation_settings(self): production_length=5.0 * unit.nanosecond, ) - s.equilibration_length = '2.5 ns' - s.production_length = '10 ns' + s.equilibration_length = "2.5 ns" + s.production_length = "10 ns" assert s.equilibration_length == 2.5 * unit.nanosecond assert s.production_length == 10.0 * unit.nanosecond @@ -57,10 +58,8 @@ def test_simulation_settings(self): class TestEquilRFESettingsFromString: def test_alchemical_settings(self): - s = equil_rfe_settings.AlchemicalSettings(softcore_LJ='gapsys') + s = equil_rfe_settings.AlchemicalSettings(softcore_LJ="gapsys") - s.explicit_charge_correction_cutoff = '0.85 nm' + s.explicit_charge_correction_cutoff = "0.85 nm" assert s.explicit_charge_correction_cutoff == 0.85 * unit.nanometer - - diff --git a/openfe/tests/protocols/test_openmmutils.py b/openfe/tests/protocols/test_openmmutils.py index 2ae5bd172..9e08fa76e 100644 --- a/openfe/tests/protocols/test_openmmutils.py +++ b/openfe/tests/protocols/test_openmmutils.py @@ -1,35 +1,36 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from importlib import resources import copy -from pathlib import Path -import pytest import sys -from pymbar.utils import ParameterError +from importlib import resources +from pathlib import Path +from unittest import mock + import numpy as np -from numpy.testing import assert_equal, assert_allclose -from openmm import app, MonteCarloBarostat, NonbondedForce -from openmm import unit as ommunit -from openmmtools import multistate +import pytest +from gufe.settings import OpenMMSystemGeneratorFFSettings, ThermoSettings +from numpy.testing import assert_allclose, assert_equal from openff.toolkit import Molecule as OFFMol -from openff.toolkit.utils.toolkits import RDKitToolkitWrapper from openff.toolkit.utils.toolkit_registry import ToolkitRegistry +from openff.toolkit.utils.toolkits import RDKitToolkitWrapper from openff.units import unit from openff.units.openmm import ensure_quantity -from gufe.settings import OpenMMSystemGeneratorFFSettings, ThermoSettings +from openmm import MonteCarloBarostat, NonbondedForce, app +from openmm import unit as ommunit +from openmmtools import multistate +from pymbar.utils import ParameterError + import openfe +from openfe.protocols.openmm_rfe.equil_rfe_settings import IntegratorSettings, OpenMMSolvationSettings from openfe.protocols.openmm_utils import ( - settings_validation, system_validation, system_creation, - multistate_analysis, omm_settings, charge_generation -) -from openfe.protocols.openmm_utils.charge_generation import ( - HAS_NAGL, HAS_ESPALOMA, HAS_OPENEYE -) -from openfe.protocols.openmm_rfe.equil_rfe_settings import ( - IntegratorSettings, - OpenMMSolvationSettings, + charge_generation, + multistate_analysis, + omm_settings, + settings_validation, + system_creation, + system_validation, ) -from unittest import mock +from openfe.protocols.openmm_utils.charge_generation import HAS_ESPALOMA, HAS_NAGL, HAS_OPENEYE def test_validate_timestep(): @@ -37,11 +38,14 @@ def test_validate_timestep(): settings_validation.validate_timestep(2.0, 4.0 * unit.femtoseconds) -@pytest.mark.parametrize('s,ts,mc,es', [ - [5 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 1250000], - [1 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 250000], - [1 * unit.picoseconds, 2 * unit.femtoseconds, 250, 500], -]) +@pytest.mark.parametrize( + "s,ts,mc,es", + [ + [5 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 1250000], + [1 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 250000], + [1 * unit.picoseconds, 2 * unit.femtoseconds, 250, 500], + ], +) def test_get_simsteps(s, ts, mc, es): sim_steps = settings_validation.get_simsteps(s, ts, mc) @@ -59,36 +63,46 @@ def test_mc_indivisible(): errmsg = "Simulation time 1.0 ps should contain" timelength = 1 * unit.picoseconds with pytest.raises(ValueError, match=errmsg): - settings_validation.get_simsteps( - timelength, 2 * unit.femtoseconds, 1000) + settings_validation.get_simsteps(timelength, 2 * unit.femtoseconds, 1000) -def test_get_alchemical_components(benzene_modifications, - T4_protein_component): +def test_get_alchemical_components(benzene_modifications, T4_protein_component): - stateA = openfe.ChemicalSystem({'A': benzene_modifications['benzene'], - 'B': benzene_modifications['toluene'], - 'P': T4_protein_component, - 'S': openfe.SolventComponent(smiles='C')}) - stateB = openfe.ChemicalSystem({'A': benzene_modifications['benzene'], - 'B': benzene_modifications['benzonitrile'], - 'P': T4_protein_component, - 'S': openfe.SolventComponent()}) + stateA = openfe.ChemicalSystem( + { + "A": benzene_modifications["benzene"], + "B": benzene_modifications["toluene"], + "P": T4_protein_component, + "S": openfe.SolventComponent(smiles="C"), + }, + ) + stateB = openfe.ChemicalSystem( + { + "A": benzene_modifications["benzene"], + "B": benzene_modifications["benzonitrile"], + "P": T4_protein_component, + "S": openfe.SolventComponent(), + }, + ) alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - assert len(alchem_comps['stateA']) == 2 - assert benzene_modifications['toluene'] in alchem_comps['stateA'] - assert openfe.SolventComponent(smiles='C') in alchem_comps['stateA'] - assert len(alchem_comps['stateB']) == 2 - assert benzene_modifications['benzonitrile'] in alchem_comps['stateB'] - assert openfe.SolventComponent() in alchem_comps['stateB'] + assert len(alchem_comps["stateA"]) == 2 + assert benzene_modifications["toluene"] in alchem_comps["stateA"] + assert openfe.SolventComponent(smiles="C") in alchem_comps["stateA"] + assert len(alchem_comps["stateB"]) == 2 + assert benzene_modifications["benzonitrile"] in alchem_comps["stateB"] + assert openfe.SolventComponent() in alchem_comps["stateB"] def test_duplicate_chemical_components(benzene_modifications): - stateA = openfe.ChemicalSystem({'A': benzene_modifications['toluene'], - 'B': benzene_modifications['toluene'], }) - stateB = openfe.ChemicalSystem({'A': benzene_modifications['toluene']}) + stateA = openfe.ChemicalSystem( + { + "A": benzene_modifications["toluene"], + "B": benzene_modifications["toluene"], + }, + ) + stateB = openfe.ChemicalSystem({"A": benzene_modifications["toluene"]}) errmsg = "state A components B:" @@ -98,36 +112,33 @@ def test_duplicate_chemical_components(benzene_modifications): def test_validate_solvent_nocutoff(benzene_modifications): - state = openfe.ChemicalSystem({'A': benzene_modifications['toluene'], - 'S': openfe.SolventComponent()}) + state = openfe.ChemicalSystem({"A": benzene_modifications["toluene"], "S": openfe.SolventComponent()}) with pytest.raises(ValueError, match="nocutoff cannot be used"): - system_validation.validate_solvent(state, 'nocutoff') + system_validation.validate_solvent(state, "nocutoff") def test_validate_solvent_multiple_solvent(benzene_modifications): - state = openfe.ChemicalSystem({'A': benzene_modifications['toluene'], - 'S': openfe.SolventComponent(), - 'S2': openfe.SolventComponent()}) + state = openfe.ChemicalSystem( + {"A": benzene_modifications["toluene"], "S": openfe.SolventComponent(), "S2": openfe.SolventComponent()}, + ) with pytest.raises(ValueError, match="Multiple SolventComponent"): - system_validation.validate_solvent(state, 'pme') + system_validation.validate_solvent(state, "pme") def test_not_water_solvent(benzene_modifications): - state = openfe.ChemicalSystem({'A': benzene_modifications['toluene'], - 'S': openfe.SolventComponent(smiles='C')}) + state = openfe.ChemicalSystem({"A": benzene_modifications["toluene"], "S": openfe.SolventComponent(smiles="C")}) with pytest.raises(ValueError, match="Non water solvent"): - system_validation.validate_solvent(state, 'pme') + system_validation.validate_solvent(state, "pme") def test_multiple_proteins(T4_protein_component): - state = openfe.ChemicalSystem({'A': T4_protein_component, - 'B': T4_protein_component}) + state = openfe.ChemicalSystem({"A": T4_protein_component, "B": T4_protein_component}) with pytest.raises(ValueError, match="Multiple ProteinComponent"): system_validation.validate_protein(state) @@ -135,8 +146,12 @@ def test_multiple_proteins(T4_protein_component): def test_get_components_gas(benzene_modifications): - state = openfe.ChemicalSystem({'A': benzene_modifications['benzene'], - 'B': benzene_modifications['toluene'], }) + state = openfe.ChemicalSystem( + { + "A": benzene_modifications["benzene"], + "B": benzene_modifications["toluene"], + }, + ) s, p, mols = system_validation.get_components(state) @@ -147,9 +162,13 @@ def test_get_components_gas(benzene_modifications): def test_components_solvent(benzene_modifications): - state = openfe.ChemicalSystem({'S': openfe.SolventComponent(), - 'A': benzene_modifications['benzene'], - 'B': benzene_modifications['toluene'], }) + state = openfe.ChemicalSystem( + { + "S": openfe.SolventComponent(), + "A": benzene_modifications["benzene"], + "B": benzene_modifications["toluene"], + }, + ) s, p, mols = system_validation.get_components(state) @@ -160,10 +179,14 @@ def test_components_solvent(benzene_modifications): def test_components_complex(T4_protein_component, benzene_modifications): - state = openfe.ChemicalSystem({'S': openfe.SolventComponent(), - 'A': benzene_modifications['benzene'], - 'B': benzene_modifications['toluene'], - 'P': T4_protein_component,}) + state = openfe.ChemicalSystem( + { + "S": openfe.SolventComponent(), + "A": benzene_modifications["benzene"], + "B": benzene_modifications["toluene"], + "P": T4_protein_component, + }, + ) s, p, mols = system_validation.get_components(state) @@ -172,7 +195,7 @@ def test_components_complex(T4_protein_component, benzene_modifications): assert len(mols) == 2 -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def get_settings(): forcefield_settings = OpenMMSystemGeneratorFFSettings() integrator_settings = IntegratorSettings() @@ -187,77 +210,101 @@ def get_settings(): class TestFEAnalysis: # Note: class scope _will_ cause this to segfault - the reporter has to close - @pytest.fixture(scope='function') + @pytest.fixture(scope="function") def reporter(self): - with resources.files('openfe.tests.data.openmm_rfe') as d: - ncfile = str(d / 'vacuum_nocoord.nc') + with resources.files("openfe.tests.data.openmm_rfe") as d: + ncfile = str(d / "vacuum_nocoord.nc") - with resources.files('openfe.tests.data.openmm_rfe') as d: - chkfile = str(d / 'vacuum_nocoord_checkpoint.nc') + with resources.files("openfe.tests.data.openmm_rfe") as d: + chkfile = str(d / "vacuum_nocoord_checkpoint.nc") - r = multistate.MultiStateReporter( - storage=ncfile, checkpoint_storage=chkfile - ) + r = multistate.MultiStateReporter(storage=ncfile, checkpoint_storage=chkfile) try: yield r finally: r.close() - + @pytest.fixture() def analyzer(self, reporter): return multistate_analysis.MultistateEquilFEAnalysis( - reporter, sampling_method='repex', + reporter, + sampling_method="repex", result_units=unit.kilocalorie_per_mole, ) - + def test_free_energies(self, analyzer): ret_dict = analyzer.unit_results_dict assert len(ret_dict.items()) == 7 - assert pytest.approx(ret_dict['unit_estimate'].m) == -47.9606 - assert pytest.approx(ret_dict['unit_estimate_error'].m) == 0.02396789 + assert pytest.approx(ret_dict["unit_estimate"].m) == -47.9606 + assert pytest.approx(ret_dict["unit_estimate_error"].m) == 0.02396789 # forward and reverse (since we do this ourselves) assert_allclose( - ret_dict['forward_and_reverse_energies']['fractions'], - np.array([0.08988764, 0.191011, 0.292135, 0.393258, 0.494382, - 0.595506, 0.696629, 0.797753, 0.898876, 1.0]), + ret_dict["forward_and_reverse_energies"]["fractions"], + np.array([0.08988764, 0.191011, 0.292135, 0.393258, 0.494382, 0.595506, 0.696629, 0.797753, 0.898876, 1.0]), rtol=1e-04, ) assert_allclose( - ret_dict['forward_and_reverse_energies']['forward_DGs'].m, - np.array([-48.057326, -48.038367, -48.033994, -48.0228, -48.028532, - -48.025258, -48.006349, -47.986304, -47.972138, -47.960623]), + ret_dict["forward_and_reverse_energies"]["forward_DGs"].m, + np.array( + [ + -48.057326, + -48.038367, + -48.033994, + -48.0228, + -48.028532, + -48.025258, + -48.006349, + -47.986304, + -47.972138, + -47.960623, + ], + ), rtol=1e-04, ) assert_allclose( - ret_dict['forward_and_reverse_energies']['forward_dDGs'].m, - np.array([0.07471 , 0.052914, 0.041508, 0.036613, 0.032827, 0.030489, - 0.028154, 0.026529, 0.025284, 0.023968]), + ret_dict["forward_and_reverse_energies"]["forward_dDGs"].m, + np.array( + [0.07471, 0.052914, 0.041508, 0.036613, 0.032827, 0.030489, 0.028154, 0.026529, 0.025284, 0.023968], + ), rtol=1e-04, ) assert_allclose( - ret_dict['forward_and_reverse_energies']['reverse_DGs'].m, - np.array([-47.823839, -47.833107, -47.845866, -47.858173, -47.883887, - -47.915963, -47.93319, -47.939125, -47.949016, -47.960623]), + ret_dict["forward_and_reverse_energies"]["reverse_DGs"].m, + np.array( + [ + -47.823839, + -47.833107, + -47.845866, + -47.858173, + -47.883887, + -47.915963, + -47.93319, + -47.939125, + -47.949016, + -47.960623, + ], + ), rtol=1e-04, ) assert_allclose( - ret_dict['forward_and_reverse_energies']['reverse_dDGs'].m, - np.array([0.081209, 0.055975, 0.044693, 0.038691, 0.034603, 0.031894, - 0.029417, 0.027082, 0.025316, 0.023968]), + ret_dict["forward_and_reverse_energies"]["reverse_dDGs"].m, + np.array( + [0.081209, 0.055975, 0.044693, 0.038691, 0.034603, 0.031894, 0.029417, 0.027082, 0.025316, 0.023968], + ), rtol=1e-04, ) def test_plots(self, analyzer, tmpdir): with tmpdir.as_cwd(): - analyzer.plot(filepath=Path('.'), filename_prefix='') - assert Path('forward_reverse_convergence.png').is_file() - assert Path('mbar_overlap_matrix.png').is_file() - assert Path('replica_exchange_matrix.png').is_file() - assert Path('replica_state_timeseries.png').is_file() + analyzer.plot(filepath=Path("."), filename_prefix="") + assert Path("forward_reverse_convergence.png").is_file() + assert Path("mbar_overlap_matrix.png").is_file() + assert Path("replica_exchange_matrix.png").is_file() + assert Path("replica_state_timeseries.png").is_file() def test_plot_convergence_bad_units(self, analyzer): - - with pytest.raises(ValueError, match='Unknown plotting units'): + + with pytest.raises(ValueError, match="Unknown plotting units"): openfe.analysis.plotting.plot_convergence( analyzer.forward_and_reverse_free_energies, unit.nanometer, @@ -265,11 +312,12 @@ def test_plot_convergence_bad_units(self, analyzer): def test_analyze_unknown_method_warning_and_error(self, reporter): - with pytest.warns(UserWarning, match='Unknown sampling method'): + with pytest.warns(UserWarning, match="Unknown sampling method"): ana = multistate_analysis.MultistateEquilFEAnalysis( - reporter, sampling_method='replex', - result_units=unit.kilocalorie_per_mole, - ) + reporter, + sampling_method="replex", + result_units=unit.kilocalorie_per_mole, + ) with pytest.raises(ValueError, match="Exchange matrix"): ana.replica_exchange_statistics @@ -278,25 +326,22 @@ def test_analyze_unknown_method_warning_and_error(self, reporter): class TestSystemCreation: def test_system_generator_nosolv_nocache(self, get_settings): ffsets, intsets, thermosets = get_settings - generator = system_creation.get_system_generator( - ffsets, thermosets, intsets, None, False - ) + generator = system_creation.get_system_generator(ffsets, thermosets, intsets, None, False) assert generator.barostat is None assert generator.template_generator._cache is None assert not generator.postprocess_system forcefield_kwargs = { - 'constraints': app.HBonds, - 'rigidWater': True, - 'removeCMMotion': False, - 'hydrogenMass': 3.0 * ommunit.amu + "constraints": app.HBonds, + "rigidWater": True, + "removeCMMotion": False, + "hydrogenMass": 3.0 * ommunit.amu, } assert generator.forcefield_kwargs == forcefield_kwargs - periodic_kwargs = { - 'nonbondedMethod': app.PME, - 'nonbondedCutoff': 1.0 * ommunit.nanometer + periodic_kwargs = {"nonbondedMethod": app.PME, "nonbondedCutoff": 1.0 * ommunit.nanometer} + nonperiodic_kwargs = { + "nonbondedMethod": app.NoCutoff, } - nonperiodic_kwargs = {'nonbondedMethod': app.NoCutoff,} assert generator.nonperiodic_forcefield_kwargs == nonperiodic_kwargs assert generator.periodic_forcefield_kwargs == periodic_kwargs @@ -306,18 +351,18 @@ def test_system_generator_solv_cache(self, get_settings): thermosets.temperature = 320 * unit.kelvin thermosets.pressure = 1.25 * unit.bar intsets.barostat_frequency = 200 * unit.timestep - generator = system_creation.get_system_generator( - ffsets, thermosets, intsets, Path('./db.json'), True - ) + generator = system_creation.get_system_generator(ffsets, thermosets, intsets, Path("./db.json"), True) # Check barostat conditions assert isinstance(generator.barostat, MonteCarloBarostat) pressure = ensure_quantity( - generator.barostat.getDefaultPressure(), 'openff', + generator.barostat.getDefaultPressure(), + "openff", ) temperature = ensure_quantity( - generator.barostat.getDefaultTemperature(), 'openff', + generator.barostat.getDefaultTemperature(), + "openff", ) assert pressure.m == pytest.approx(1.25) assert pressure.units == unit.bar @@ -326,48 +371,41 @@ def test_system_generator_solv_cache(self, get_settings): assert generator.barostat.getFrequency() == 200 # Check cache file - assert generator.template_generator._cache == 'db.json' + assert generator.template_generator._cache == "db.json" - def test_get_omm_modeller_complex(self, T4_protein_component, - benzene_modifications, - get_settings): + def test_get_omm_modeller_complex(self, T4_protein_component, benzene_modifications, get_settings): ffsets, intsets, thermosets = get_settings - generator = system_creation.get_system_generator( - ffsets, thermosets, intsets, None, True - ) + generator = system_creation.get_system_generator(ffsets, thermosets, intsets, None, True) - smc = benzene_modifications['toluene'] + smc = benzene_modifications["toluene"] mol = smc.to_openff() - generator.create_system(mol.to_topology().to_openmm(), - molecules=[mol]) + generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) model, comp_resids = system_creation.get_omm_modeller( - T4_protein_component, openfe.SolventComponent(), - {smc: mol}, - generator.forcefield, - OpenMMSolvationSettings()) + T4_protein_component, + openfe.SolventComponent(), + {smc: mol}, + generator.forcefield, + OpenMMSolvationSettings(), + ) resids = [r for r in model.topology.residues()] - assert resids[163].name == 'NME' - assert resids[164].name == 'UNK' - assert resids[165].name == 'HOH' + assert resids[163].name == "NME" + assert resids[164].name == "UNK" + assert resids[165].name == "HOH" assert_equal(comp_resids[T4_protein_component], np.linspace(0, 163, 164)) assert_equal(comp_resids[smc], np.array([164])) - assert_equal(comp_resids[openfe.SolventComponent()], - np.linspace(165, len(resids)-1, len(resids)-165)) + assert_equal(comp_resids[openfe.SolventComponent()], np.linspace(165, len(resids) - 1, len(resids) - 165)) def test_get_omm_modeller_ligand_no_neutralize(self, get_settings): ffsets, intsets, thermosets = get_settings - generator = system_creation.get_system_generator( - ffsets, thermosets, intsets, None, True - ) + generator = system_creation.get_system_generator(ffsets, thermosets, intsets, None, True) - offmol = OFFMol.from_smiles('[O-]C=O') + offmol = OFFMol.from_smiles("[O-]C=O") offmol.generate_conformers() smc = openfe.SmallMoleculeComponent.from_openff(offmol) - generator.create_system(offmol.to_topology().to_openmm(), - molecules=[offmol]) + generator.create_system(offmol.to_topology().to_openmm(), molecules=[offmol]) model, comp_resids = system_creation.get_omm_modeller( None, openfe.SolventComponent(neutralize=False), @@ -376,14 +414,10 @@ def test_get_omm_modeller_ligand_no_neutralize(self, get_settings): OpenMMSolvationSettings(), ) - system = generator.create_system( - model.topology, - molecules=[offmol] - ) + system = generator.create_system(model.topology, molecules=[offmol]) # Now let's check the total charge - nonbonded = [f for f in system.getForces() - if isinstance(f, NonbondedForce)][0] + nonbonded = [f for f in system.getForces() if isinstance(f, NonbondedForce)][0] charge = 0 * ommunit.elementary_charge @@ -391,20 +425,18 @@ def test_get_omm_modeller_ligand_no_neutralize(self, get_settings): c, s, e = nonbonded.getParticleParameters(i) charge += c - charge = ensure_quantity(charge, 'openff') + charge = ensure_quantity(charge, "openff") assert pytest.approx(charge.m) == -1.0 def test_convert_steps_per_iteration(): sim = omm_settings.MultiStateSimulationSettings( - equilibration_length='10 ps', - production_length='10 ps', - time_per_iteration='1.0 ps', - ) - inty = omm_settings.IntegratorSettings( - timestep='4 fs' + equilibration_length="10 ps", + production_length="10 ps", + time_per_iteration="1.0 ps", ) + inty = omm_settings.IntegratorSettings(timestep="4 fs") spi = settings_validation.convert_steps_per_iteration(sim, inty) @@ -413,13 +445,11 @@ def test_convert_steps_per_iteration(): def test_convert_steps_per_iteration_failure(): sim = omm_settings.MultiStateSimulationSettings( - equilibration_length='10 ps', - production_length='10 ps', - time_per_iteration='1.0 ps', - ) - inty = omm_settings.IntegratorSettings( - timestep='3 fs' + equilibration_length="10 ps", + production_length="10 ps", + time_per_iteration="1.0 ps", ) + inty = omm_settings.IntegratorSettings(timestep="3 fs") with pytest.raises(ValueError, match="does not evenly divide"): settings_validation.convert_steps_per_iteration(sim, inty) @@ -427,11 +457,11 @@ def test_convert_steps_per_iteration_failure(): def test_convert_real_time_analysis_iterations(): sim = omm_settings.MultiStateSimulationSettings( - equilibration_length='10 ps', - production_length='10 ps', - time_per_iteration='1.0 ps', - real_time_analysis_interval='250 ps', - real_time_analysis_minimum_time='500 ps', + equilibration_length="10 ps", + production_length="10 ps", + time_per_iteration="1.0 ps", + real_time_analysis_interval="250 ps", + real_time_analysis_minimum_time="500 ps", ) rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations(sim) @@ -442,38 +472,38 @@ def test_convert_real_time_analysis_iterations(): def test_convert_real_time_analysis_iterations_interval_fail(): # shouldn't like 250.5 ps / 1.0 ps sim = omm_settings.MultiStateSimulationSettings( - equilibration_length='10 ps', - production_length='10 ps', - time_per_iteration='1.0 ps', - real_time_analysis_interval='250.5 ps', - real_time_analysis_minimum_time='500 ps', + equilibration_length="10 ps", + production_length="10 ps", + time_per_iteration="1.0 ps", + real_time_analysis_interval="250.5 ps", + real_time_analysis_minimum_time="500 ps", ) - with pytest.raises(ValueError, match='does not evenly divide'): + with pytest.raises(ValueError, match="does not evenly divide"): settings_validation.convert_real_time_analysis_iterations(sim) def test_convert_real_time_analysis_iterations_min_interval_fail(): # shouldn't like 500.5 ps / 1 ps sim = omm_settings.MultiStateSimulationSettings( - equilibration_length='10 ps', - production_length='10 ps', - time_per_iteration='1.0 ps', - real_time_analysis_interval='250 ps', - real_time_analysis_minimum_time='500.5 ps', + equilibration_length="10 ps", + production_length="10 ps", + time_per_iteration="1.0 ps", + real_time_analysis_interval="250 ps", + real_time_analysis_minimum_time="500.5 ps", ) - with pytest.raises(ValueError, match='does not evenly divide'): + with pytest.raises(ValueError, match="does not evenly divide"): settings_validation.convert_real_time_analysis_iterations(sim) def test_convert_real_time_analysis_iterations_None(): sim = omm_settings.MultiStateSimulationSettings( - equilibration_length='10 ps', - production_length='10 ps', - time_per_iteration='1.0 ps', + equilibration_length="10 ps", + production_length="10 ps", + time_per_iteration="1.0 ps", real_time_analysis_interval=None, - real_time_analysis_minimum_time='500 ps', + real_time_analysis_minimum_time="500 ps", ) rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations(sim) @@ -502,29 +532,25 @@ def test_convert_target_error_from_kcal_per_mole_to_kT_zero(): class TestOFFPartialCharge: - @pytest.fixture(scope='function') + @pytest.fixture(scope="function") def uncharged_mol(self, CN_molecule): return CN_molecule.to_openff() - @pytest.mark.parametrize('overwrite', [True, False]) - def test_offmol_chg_gen_charged_overwrite( - self, overwrite, uncharged_mol - ): - chg = [ - 1 for _ in range(len(uncharged_mol.atoms)) - ] * unit.elementary_charge + @pytest.mark.parametrize("overwrite", [True, False]) + def test_offmol_chg_gen_charged_overwrite(self, overwrite, uncharged_mol): + chg = [1 for _ in range(len(uncharged_mol.atoms))] * unit.elementary_charge uncharged_mol.partial_charges = copy.deepcopy(chg) - + charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=overwrite, - method='am1bcc', - toolkit_backend='ambertools', + method="am1bcc", + toolkit_backend="ambertools", generate_n_conformers=None, nagl_model=None, ) - + assert np.allclose(uncharged_mol.partial_charges, chg) != overwrite def test_unknown_method(self, uncharged_mol): @@ -532,40 +558,41 @@ def test_unknown_method(self, uncharged_mol): charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=False, - method='foo', - toolkit_backend='ambertools', + method="foo", + toolkit_backend="ambertools", generate_n_conformers=None, nagl_model=None, ) - @pytest.mark.parametrize('method, backend', [ - ['am1bcc', 'rdkit'], - ['am1bccelf10', 'ambertools'], - ['nagl', 'bar'], - ['espaloma', 'openeye'], - ]) - def test_incompatible_backend_am1bcc( - self, method, backend, uncharged_mol - ): - with pytest.raises(ValueError, match='Selected toolkit_backend'): + @pytest.mark.parametrize( + "method, backend", + [ + ["am1bcc", "rdkit"], + ["am1bccelf10", "ambertools"], + ["nagl", "bar"], + ["espaloma", "openeye"], + ], + ) + def test_incompatible_backend_am1bcc(self, method, backend, uncharged_mol): + with pytest.raises(ValueError, match="Selected toolkit_backend"): charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=False, method=method, toolkit_backend=backend, generate_n_conformers=None, - nagl_model=None + nagl_model=None, ) def test_no_conformers(self, uncharged_mol): uncharged_mol._conformers = None - with pytest.raises(ValueError, match='No conformers'): + with pytest.raises(ValueError, match="No conformers"): charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=False, - method='am1bcc', - toolkit_backend='ambertools', + method="am1bcc", + toolkit_backend="ambertools", generate_n_conformers=None, nagl_model=None, ) @@ -581,20 +608,20 @@ def test_too_many_existing_conformers(self, uncharged_mol): charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=False, - method='am1bcc', - toolkit_backend='ambertools', + method="am1bcc", + toolkit_backend="ambertools", generate_n_conformers=None, nagl_model=None, ) def test_too_many_requested_conformers(self, uncharged_mol): - + with pytest.raises(ValueError, match="5 conformers were requested"): charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=False, - method='am1bcc', - toolkit_backend='ambertools', + method="am1bcc", + toolkit_backend="ambertools", generate_n_conformers=5, nagl_model=None, ) @@ -603,11 +630,11 @@ def test_am1bcc_no_conformer(self, uncharged_mol): uncharged_mol._conformers = None - with pytest.raises(ValueError, match='at least one conformer'): + with pytest.raises(ValueError, match="at least one conformer"): charge_generation.assign_offmol_am1bcc_charges( uncharged_mol, - partial_charge_method='am1bcc', - toolkit_registry=ToolkitRegistry([RDKitToolkitWrapper()]) + partial_charge_method="am1bcc", + toolkit_registry=ToolkitRegistry([RDKitToolkitWrapper()]), ) @pytest.mark.slow @@ -621,8 +648,8 @@ def test_am1bcc_conformer_nochange(self, eg5_ligands): charge_generation.assign_offmol_partial_charges( lig, overwrite=False, - method='am1bcc', - toolkit_backend='ambertools', + method="am1bcc", + toolkit_backend="ambertools", generate_n_conformers=None, nagl_model=None, ) @@ -637,10 +664,10 @@ def test_am1bcc_conformer_nochange(self, eg5_ligands): charge_generation.assign_offmol_partial_charges( lig, overwrite=True, - method='am1bcc', - toolkit_backend='ambertools', + method="am1bcc", + toolkit_backend="ambertools", generate_n_conformers=1, - nagl_model=None + nagl_model=None, ) # conformer shouldn't have changed @@ -649,71 +676,104 @@ def test_am1bcc_conformer_nochange(self, eg5_ligands): # but the charges should have assert not np.allclose(charges, lig.partial_charges) - @pytest.mark.skipif(not HAS_NAGL, reason='NAGL is not available') + @pytest.mark.skipif(not HAS_NAGL, reason="NAGL is not available") def test_no_production_nagl(self, uncharged_mol): - - with pytest.raises(ValueError, match='No production am1bcc NAGL'): + + with pytest.raises(ValueError, match="No production am1bcc NAGL"): charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=False, - method='nagl', - toolkit_backend='rdkit', + method="nagl", + toolkit_backend="rdkit", generate_n_conformers=None, nagl_model=None, ) # Note: skipping nagl tests on macos/darwin due to known issues # see: https://github.com/openforcefield/openff-nagl/issues/78 - @pytest.mark.parametrize('method, backend, ref_key, confs', [ - ('am1bcc', 'ambertools', 'ambertools', None), - pytest.param( - 'am1bcc', 'openeye', 'openeye', None, - marks=pytest.mark.skipif( - not HAS_OPENEYE, reason='needs oechem', + @pytest.mark.parametrize( + "method, backend, ref_key, confs", + [ + ("am1bcc", "ambertools", "ambertools", None), + pytest.param( + "am1bcc", + "openeye", + "openeye", + None, + marks=pytest.mark.skipif( + not HAS_OPENEYE, + reason="needs oechem", + ), ), - ), - pytest.param( - 'am1bccelf10', 'openeye', 'openeye', 500, - marks=pytest.mark.skipif( - not HAS_OPENEYE, reason='needs oechem', + pytest.param( + "am1bccelf10", + "openeye", + "openeye", + 500, + marks=pytest.mark.skipif( + not HAS_OPENEYE, + reason="needs oechem", + ), ), - ), - pytest.param( - 'nagl', 'rdkit', 'nagl', None, - marks=pytest.mark.skipif( - not HAS_NAGL or sys.platform.startswith('darwin'), - reason='needs NAGL and/or on macos', + pytest.param( + "nagl", + "rdkit", + "nagl", + None, + marks=pytest.mark.skipif( + not HAS_NAGL or sys.platform.startswith("darwin"), + reason="needs NAGL and/or on macos", + ), ), - ), - pytest.param( - 'nagl', 'ambertools', 'nagl', None, - marks=pytest.mark.skipif( - not HAS_NAGL or sys.platform.startswith('darwin') - , reason='needs NAGL and/or on macos', + pytest.param( + "nagl", + "ambertools", + "nagl", + None, + marks=pytest.mark.skipif( + not HAS_NAGL or sys.platform.startswith("darwin"), + reason="needs NAGL and/or on macos", + ), ), - ), - pytest.param( - 'nagl', 'openeye', 'nagl', None, - marks=pytest.mark.skipif( - not HAS_NAGL or not HAS_OPENEYE or sys.platform.startswith('darwin'), - reason='needs NAGL and oechem and not on macos', + pytest.param( + "nagl", + "openeye", + "nagl", + None, + marks=pytest.mark.skipif( + not HAS_NAGL or not HAS_OPENEYE or sys.platform.startswith("darwin"), + reason="needs NAGL and oechem and not on macos", + ), ), - ), - pytest.param( - 'espaloma', 'rdkit', 'espaloma', None, - marks=pytest.mark.skipif( - not HAS_ESPALOMA, reason='needs espaloma', + pytest.param( + "espaloma", + "rdkit", + "espaloma", + None, + marks=pytest.mark.skipif( + not HAS_ESPALOMA, + reason="needs espaloma", + ), ), - ), - pytest.param( - 'espaloma', 'ambertools', 'espaloma', None, - marks=pytest.mark.skipif( - not HAS_ESPALOMA, reason='needs espaloma', + pytest.param( + "espaloma", + "ambertools", + "espaloma", + None, + marks=pytest.mark.skipif( + not HAS_ESPALOMA, + reason="needs espaloma", + ), ), - ), - ]) + ], + ) def test_am1bcc_reference( - self, uncharged_mol, method, backend, ref_key, confs, + self, + uncharged_mol, + method, + backend, + ref_key, + confs, am1bcc_ref_charges, ): """ @@ -729,59 +789,43 @@ def test_am1bcc_reference( nagl_model="openff-gnn-am1bcc-0.1.0-rc.1.pt", ) - assert_allclose( - am1bcc_ref_charges[ref_key], - uncharged_mol.partial_charges, - rtol=1e-4 - ) + assert_allclose(am1bcc_ref_charges[ref_key], uncharged_mol.partial_charges, rtol=1e-4) def test_nagl_import_error(self, monkeypatch, uncharged_mol): - monkeypatch.setattr( - sys.modules['openfe.protocols.openmm_utils.charge_generation'], - 'HAS_NAGL', - False - ) + monkeypatch.setattr(sys.modules["openfe.protocols.openmm_utils.charge_generation"], "HAS_NAGL", False) - with pytest.raises(ImportError, match='NAGL toolkit is not available'): + with pytest.raises(ImportError, match="NAGL toolkit is not available"): charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=False, - method='nagl', - toolkit_backend='rdkit', + method="nagl", + toolkit_backend="rdkit", generate_n_conformers=None, - nagl_model=None + nagl_model=None, ) def test_espaloma_import_error(self, monkeypatch, uncharged_mol): - monkeypatch.setattr( - sys.modules['openfe.protocols.openmm_utils.charge_generation'], - 'HAS_ESPALOMA', - False - ) + monkeypatch.setattr(sys.modules["openfe.protocols.openmm_utils.charge_generation"], "HAS_ESPALOMA", False) - with pytest.raises(ImportError, match='Espaloma'): + with pytest.raises(ImportError, match="Espaloma"): charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=False, - method='espaloma', - toolkit_backend='rdkit', + method="espaloma", + toolkit_backend="rdkit", generate_n_conformers=None, nagl_model=None, ) def test_openeye_import_error(self, monkeypatch, uncharged_mol): - monkeypatch.setattr( - sys.modules['openfe.protocols.openmm_utils.charge_generation'], - 'HAS_OPENEYE', - False - ) + monkeypatch.setattr(sys.modules["openfe.protocols.openmm_utils.charge_generation"], "HAS_OPENEYE", False) - with pytest.raises(ImportError, match='OpenEye is not available'): + with pytest.raises(ImportError, match="OpenEye is not available"): charge_generation.assign_offmol_partial_charges( uncharged_mol, overwrite=False, - method='am1bcc', - toolkit_backend='openeye', + method="am1bcc", + toolkit_backend="openeye", generate_n_conformers=None, nagl_model=None, ) @@ -790,18 +834,17 @@ def test_openeye_import_error(self, monkeypatch, uncharged_mol): @pytest.mark.slow @pytest.mark.download def test_forward_backwards_failure(simulation_nc): - rep = multistate.multistatereporter.MultiStateReporter( - simulation_nc, - open_mode='r' - ) + rep = multistate.multistatereporter.MultiStateReporter(simulation_nc, open_mode="r") ana = multistate_analysis.MultistateEquilFEAnalysis( rep, - sampling_method='repex', + sampling_method="repex", result_units=unit.kilocalorie_per_mole, ) - with mock.patch('openfe.protocols.openmm_utils.multistate_analysis.MultistateEquilFEAnalysis._get_free_energy', - side_effect=ParameterError): + with mock.patch( + "openfe.protocols.openmm_utils.multistate_analysis.MultistateEquilFEAnalysis._get_free_energy", + side_effect=ParameterError, + ): ret = ana.get_forward_and_reverse_analysis() assert ret is None diff --git a/openfe/tests/protocols/test_rfe_tokenization.py b/openfe/tests/protocols/test_rfe_tokenization.py index 92081e18b..f5b3439c6 100644 --- a/openfe/tests/protocols/test_rfe_tokenization.py +++ b/openfe/tests/protocols/test_rfe_tokenization.py @@ -1,9 +1,10 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from openfe.protocols import openmm_rfe -from gufe.tests.test_tokenization import GufeTokenizableTestsMixin import pytest +from gufe.tests.test_tokenization import GufeTokenizableTestsMixin + +from openfe.protocols import openmm_rfe """ todo: @@ -12,6 +13,7 @@ - RelativeHybridTopologyProtocolUnit """ + @pytest.fixture def protocol(): return openmm_rfe.RelativeHybridTopologyProtocol(openmm_rfe.RelativeHybridTopologyProtocol.default_settings()) @@ -20,7 +22,8 @@ def protocol(): @pytest.fixture def protocol_unit(protocol, benzene_system, toluene_system, benzene_to_toluene_mapping): pus = protocol.create( - stateA=benzene_system, stateB=toluene_system, + stateA=benzene_system, + stateB=toluene_system, mapping=[benzene_to_toluene_mapping], ) return list(pus.protocol_units)[0] diff --git a/openfe/tests/protocols/test_solvation_afe_tokenization.py b/openfe/tests/protocols/test_solvation_afe_tokenization.py index 436930e10..f8f998726 100644 --- a/openfe/tests/protocols/test_solvation_afe_tokenization.py +++ b/openfe/tests/protocols/test_solvation_afe_tokenization.py @@ -1,25 +1,25 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import json -import openfe -from openfe.protocols import openmm_afe + import gufe -from gufe.tests.test_tokenization import GufeTokenizableTestsMixin import pytest +from gufe.tests.test_tokenization import GufeTokenizableTestsMixin + +import openfe +from openfe.protocols import openmm_afe @pytest.fixture def protocol(): - return openmm_afe.AbsoluteSolvationProtocol( - openmm_afe.AbsoluteSolvationProtocol.default_settings() - ) + return openmm_afe.AbsoluteSolvationProtocol(openmm_afe.AbsoluteSolvationProtocol.default_settings()) @pytest.fixture def protocol_units(protocol, benzene_system): pus = protocol.create( stateA=benzene_system, - stateB=openfe.ChemicalSystem({'solvent': openfe.SolventComponent()}), + stateB=openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}), mapping=None, ) return list(pus.protocol_units) @@ -41,9 +41,8 @@ def vacuum_protocol_unit(protocol_units): @pytest.fixture def protocol_result(afe_solv_transformation_json): - d = json.loads(afe_solv_transformation_json, - cls=gufe.tokenization.JSON_HANDLER.decoder) - pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d['protocol_result']) + d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) + pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"]) return pr diff --git a/openfe/tests/setup/alchemical_network_planner/edge_types.py b/openfe/tests/setup/alchemical_network_planner/edge_types.py index bc17a54e4..0b7f55b2f 100644 --- a/openfe/tests/setup/alchemical_network_planner/edge_types.py +++ b/openfe/tests/setup/alchemical_network_planner/edge_types.py @@ -1,6 +1,6 @@ from gufe import Transformation -from ..chemicalsystem_generator.component_checks import proteinC_in_chem_sys, solventC_in_chem_sys, ligandC_in_chem_sys +from ..chemicalsystem_generator.component_checks import ligandC_in_chem_sys, proteinC_in_chem_sys, solventC_in_chem_sys def both_states_proteinC_edge(edge: Transformation) -> bool: @@ -16,7 +16,9 @@ def both_states_ligandC_edge(edge: Transformation) -> bool: def r_vacuum_edge(edge: Transformation) -> bool: - return both_states_ligandC_edge(edge) and not both_states_solventC_edge(edge) and not both_states_proteinC_edge(edge) + return ( + both_states_ligandC_edge(edge) and not both_states_solventC_edge(edge) and not both_states_proteinC_edge(edge) + ) def r_solvent_edge(edge: Transformation) -> bool: diff --git a/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py b/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py index 187ea87eb..49fd4a7fa 100644 --- a/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py +++ b/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py @@ -1,23 +1,23 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import pytest +from gufe import AlchemicalNetwork, SolventComponent -from ...conftest import atom_mapping_basic_test_files, T4_protein_component +from openfe.setup.alchemical_network_planner import RBFEAlchemicalNetworkPlanner, RHFEAlchemicalNetworkPlanner -from gufe import SolventComponent, AlchemicalNetwork -from openfe.setup.alchemical_network_planner import RHFEAlchemicalNetworkPlanner, RBFEAlchemicalNetworkPlanner +from ...conftest import T4_protein_component, atom_mapping_basic_test_files from .edge_types import r_complex_edge, r_solvent_edge, r_vacuum_edge def test_rhfe_alchemical_network_planner_init(): alchem_planner = RHFEAlchemicalNetworkPlanner() - + assert alchem_planner.name == "easy_rhfe" def test_rbfe_alchemical_network_planner_init(): alchem_planner = RBFEAlchemicalNetworkPlanner() - + assert alchem_planner.name == "easy_rbfe" @@ -28,16 +28,18 @@ def test_rbfe_alchemical_network_planner_call(atom_mapping_basic_test_files, T4_ solvent=SolventComponent(), protein=T4_protein_component, ) - + assert isinstance(alchem_network, AlchemicalNetwork) - + edges = alchem_network.edges - assert len(edges) == 14 # we build 2envs*8ligands-2startLigands = 14 relative edges. + assert len(edges) == 14 # we build 2envs*8ligands-2startLigands = 14 relative edges. print(edges) - assert sum([r_complex_edge(e) for e in edges]) == 7 # half of the transformations should be complex (they always are)! - assert sum([r_solvent_edge(e) for e in edges]) == 7 # half of the transformations should be solvent! - assert sum([r_vacuum_edge(e) for e in edges]) == 0 # no vacuum here! + assert ( + sum([r_complex_edge(e) for e in edges]) == 7 + ) # half of the transformations should be complex (they always are)! + assert sum([r_solvent_edge(e) for e in edges]) == 7 # half of the transformations should be solvent! + assert sum([r_vacuum_edge(e) for e in edges]) == 0 # no vacuum here! def test_rhfe_alchemical_network_planner_call_multigraph(atom_mapping_basic_test_files): @@ -47,26 +49,28 @@ def test_rhfe_alchemical_network_planner_call_multigraph(atom_mapping_basic_test ligand_network_edges = list(ligand_network.edges) ligand_network_edges.extend(list(ligand_network.edges)) - chemical_system_generator = alchem_planner._chemical_system_generator_type( - solvent=SolventComponent() - ) + chemical_system_generator = alchem_planner._chemical_system_generator_type(solvent=SolventComponent()) - with pytest.raises(ValueError, match="There were multiple transformations with the same edge label! This might lead to overwritting your files."): + with pytest.raises( + ValueError, + match="There were multiple transformations with the same edge label! This might lead to overwritting your files.", + ): alchem_network = alchem_planner._build_transformations( ligand_network_edges=ligand_network_edges, protocol=alchem_planner.transformation_protocol, - chemical_system_generator=chemical_system_generator) + chemical_system_generator=chemical_system_generator, + ) def test_rhfe_alchemical_network_planner_call(atom_mapping_basic_test_files): alchem_planner = RHFEAlchemicalNetworkPlanner() alchem_network = alchem_planner(ligands=atom_mapping_basic_test_files.values(), solvent=SolventComponent()) - + assert isinstance(alchem_network, AlchemicalNetwork) - + edges = alchem_network.edges - assert len(edges) == 14 # we build 2envs*8ligands-2startLigands = 14 relative edges. + assert len(edges) == 14 # we build 2envs*8ligands-2startLigands = 14 relative edges. - assert sum([r_complex_edge(e) for e in edges]) == 0 # no complex! - assert sum([r_solvent_edge(e) for e in edges]) == 7 # half of the transformations should be solvent! - assert sum([r_vacuum_edge(e) for e in edges]) == 7 # half of the transformations should be vacuum! + assert sum([r_complex_edge(e) for e in edges]) == 0 # no complex! + assert sum([r_solvent_edge(e) for e in edges]) == 7 # half of the transformations should be solvent! + assert sum([r_vacuum_edge(e) for e in edges]) == 7 # half of the transformations should be vacuum! diff --git a/openfe/tests/setup/atom_mapping/conftest.py b/openfe/tests/setup/atom_mapping/conftest.py index 105133732..e21b9215d 100644 --- a/openfe/tests/setup/atom_mapping/conftest.py +++ b/openfe/tests/setup/atom_mapping/conftest.py @@ -3,34 +3,34 @@ from typing import Dict, Tuple -from rdkit import Chem -from gufe import SmallMoleculeComponent import lomap import pytest +from gufe import SmallMoleculeComponent +from rdkit import Chem from openfe import LigandAtomMapping from ...conftest import mol_from_smiles -def _translate_lomap_mapping(atom_mapping_str: str) -> Dict[int, int]: - mapped_atom_tuples = map(lambda x: tuple(map(int, x.split(":"))), - atom_mapping_str.split(",")) +def _translate_lomap_mapping(atom_mapping_str: str) -> dict[int, int]: + mapped_atom_tuples = map(lambda x: tuple(map(int, x.split(":"))), atom_mapping_str.split(",")) return {i: j for i, j in mapped_atom_tuples} -def _get_atom_mapping_dict(lomap_atom_mappings) -> Dict[Tuple[int, int], - Dict[int, int]]: - return {mol_pair: _translate_lomap_mapping(atom_mapping_str) for - mol_pair, atom_mapping_str in - lomap_atom_mappings.mcs_map_store.items()} +def _get_atom_mapping_dict(lomap_atom_mappings) -> dict[tuple[int, int], dict[int, int]]: + return { + mol_pair: _translate_lomap_mapping(atom_mapping_str) + for mol_pair, atom_mapping_str in lomap_atom_mappings.mcs_map_store.items() + } @pytest.fixture() -def gufe_atom_mapping_matrix(lomap_basic_test_files_dir, - atom_mapping_basic_test_files - ) -> Dict[Tuple[int, int], LigandAtomMapping]: - dbmols = lomap.DBMolecules(lomap_basic_test_files_dir, verbose='off') +def gufe_atom_mapping_matrix( + lomap_basic_test_files_dir, + atom_mapping_basic_test_files, +) -> dict[tuple[int, int], LigandAtomMapping]: + dbmols = lomap.DBMolecules(lomap_basic_test_files_dir, verbose="off") _, _ = dbmols.build_matrices() molecule_pair_atom_mappings = _get_atom_mapping_dict(dbmols) @@ -41,14 +41,14 @@ def gufe_atom_mapping_matrix(lomap_basic_test_files_dir, ligand_atom_mappings[(i, j)] = LigandAtomMapping( componentA=atom_mapping_basic_test_files[nm1], componentB=atom_mapping_basic_test_files[nm2], - componentA_to_componentB=val) + componentA_to_componentB=val, + ) return ligand_atom_mappings @pytest.fixture() -def mol_pair_to_shock_perses_mapper() -> Tuple[SmallMoleculeComponent, - SmallMoleculeComponent]: +def mol_pair_to_shock_perses_mapper() -> tuple[SmallMoleculeComponent, SmallMoleculeComponent]: """ This pair of Molecules leads to an empty Atom mapping in Perses Mapper with certain settings. @@ -56,6 +56,6 @@ def mol_pair_to_shock_perses_mapper() -> Tuple[SmallMoleculeComponent, Returns: Tuple[SmallMoleculeComponent]: two molecule objs for the test """ - molA = SmallMoleculeComponent(mol_from_smiles('c1ccccc1'), 'benzene') - molB = SmallMoleculeComponent(mol_from_smiles('C1CCCCC1'), 'cyclohexane') + molA = SmallMoleculeComponent(mol_from_smiles("c1ccccc1"), "benzene") + molB = SmallMoleculeComponent(mol_from_smiles("C1CCCCC1"), "cyclohexane") return molA, molB diff --git a/openfe/tests/setup/atom_mapping/test_atommapper.py b/openfe/tests/setup/atom_mapping/test_atommapper.py index 99729ce7c..7f3f0f72e 100644 --- a/openfe/tests/setup/atom_mapping/test_atommapper.py +++ b/openfe/tests/setup/atom_mapping/test_atommapper.py @@ -28,7 +28,7 @@ def _defaults(cls): return {} def _to_dict(self): - return {'mappings': self.mappings} + return {"mappings": self.mappings} @classmethod def _from_dict(cls, dct): diff --git a/openfe/tests/setup/atom_mapping/test_lomap_atommapper.py b/openfe/tests/setup/atom_mapping/test_lomap_atommapper.py index c5d32e272..b2a35a70e 100644 --- a/openfe/tests/setup/atom_mapping/test_lomap_atommapper.py +++ b/openfe/tests/setup/atom_mapping/test_lomap_atommapper.py @@ -1,11 +1,11 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import pytest +from gufe import SmallMoleculeComponent from rdkit import Chem from rdkit.Chem import AllChem import openfe -from gufe import SmallMoleculeComponent from openfe.setup.atom_mapping import LomapAtomMapper from .conftest import mol_from_smiles @@ -13,8 +13,8 @@ def test_simple(atom_mapping_basic_test_files): # basic sanity check on the LigandAtomMapper - mol1 = atom_mapping_basic_test_files['methylcyclohexane'] - mol2 = atom_mapping_basic_test_files['toluene'] + mol1 = atom_mapping_basic_test_files["methylcyclohexane"] + mol2 = atom_mapping_basic_test_files["toluene"] mapper = LomapAtomMapper() @@ -28,8 +28,8 @@ def test_simple(atom_mapping_basic_test_files): def test_distances(atom_mapping_basic_test_files): # basic sanity check on the LigandAtomMapper - mol1 = atom_mapping_basic_test_files['methylcyclohexane'] - mol2 = atom_mapping_basic_test_files['toluene'] + mol1 = atom_mapping_basic_test_files["methylcyclohexane"] + mol2 = atom_mapping_basic_test_files["toluene"] mapper = LomapAtomMapper() mapping = next(mapper.suggest_mappings(mol1, mol2)) @@ -38,8 +38,8 @@ def test_distances(atom_mapping_basic_test_files): assert len(dists) == len(mapping.componentA_to_componentB) i, j = next(iter(mapping.componentA_to_componentB.items())) - ref_d = mol1.to_rdkit().GetConformer().GetAtomPosition(i).Distance( - mol2.to_rdkit().GetConformer().GetAtomPosition(j) + ref_d = ( + mol1.to_rdkit().GetConformer().GetAtomPosition(i).Distance(mol2.to_rdkit().GetConformer().GetAtomPosition(j)) ) assert pytest.approx(dists[0], rel=1e-5) == ref_d assert pytest.approx(dists[0], rel=1e-5) == 0.07249779 @@ -48,8 +48,8 @@ def test_distances(atom_mapping_basic_test_files): def test_generator_length(atom_mapping_basic_test_files): # check that we get one mapping back from Lomap LigandAtomMapper then the # generator stops correctly - mol1 = atom_mapping_basic_test_files['methylcyclohexane'] - mol2 = atom_mapping_basic_test_files['toluene'] + mol1 = atom_mapping_basic_test_files["methylcyclohexane"] + mol2 = atom_mapping_basic_test_files["toluene"] mapper = LomapAtomMapper() @@ -61,9 +61,8 @@ def test_generator_length(atom_mapping_basic_test_files): def test_bad_mapping(atom_mapping_basic_test_files): - toluene = atom_mapping_basic_test_files['toluene'] - NigelTheNitrogen = SmallMoleculeComponent(mol_from_smiles('N'), - name='Nigel') + toluene = atom_mapping_basic_test_files["toluene"] + NigelTheNitrogen = SmallMoleculeComponent(mol_from_smiles("N"), name="Nigel") mapper = LomapAtomMapper() @@ -75,8 +74,8 @@ def test_bad_mapping(atom_mapping_basic_test_files): # TODO: Remvoe these test when element changes are allowed - START def test_simple_no_element_changes(atom_mapping_basic_test_files): # basic sanity check on the LigandAtomMapper - mol1 = atom_mapping_basic_test_files['methylcyclohexane'] - mol2 = atom_mapping_basic_test_files['toluene'] + mol1 = atom_mapping_basic_test_files["methylcyclohexane"] + mol2 = atom_mapping_basic_test_files["toluene"] mapper = LomapAtomMapper() mapper._no_element_changes = True @@ -87,17 +86,16 @@ def test_simple_no_element_changes(atom_mapping_basic_test_files): # maps (CH3) off methyl and (6C + 5H) on ring assert len(mapping.componentA_to_componentB) == 15 - + def test_bas_mapping_no_element_changes(atom_mapping_basic_test_files): - toluene = atom_mapping_basic_test_files['toluene'] - NigelTheNitrogen = SmallMoleculeComponent(mol_from_smiles('N'), - name='Nigel') + toluene = atom_mapping_basic_test_files["toluene"] + NigelTheNitrogen = SmallMoleculeComponent(mol_from_smiles("N"), name="Nigel") mapper = LomapAtomMapper() mapper._no_element_changes = True mapping_gen = mapper.suggest_mappings(toluene, NigelTheNitrogen) with pytest.raises(StopIteration): next(mapping_gen) - -# TODO: Remvoe these test when element changes are allowed - END + +# TODO: Remvoe these test when element changes are allowed - END diff --git a/openfe/tests/setup/atom_mapping/test_lomap_scorers.py b/openfe/tests/setup/atom_mapping/test_lomap_scorers.py index 556dc05ca..3d7db7702 100644 --- a/openfe/tests/setup/atom_mapping/test_lomap_scorers.py +++ b/openfe/tests/setup/atom_mapping/test_lomap_scorers.py @@ -2,59 +2,56 @@ # For details, see https://github.com/OpenFreeEnergy/openfe import itertools -import lomap import math -import numpy as np -from numpy.testing import assert_allclose -import openfe -from openfe.setup import lomap_scorers, LigandAtomMapping +import lomap +import numpy as np import pytest +from numpy.testing import assert_allclose from rdkit import Chem from rdkit.Chem.AllChem import Compute2DCoords +import openfe +from openfe.setup import LigandAtomMapping, lomap_scorers + from .conftest import mol_from_smiles @pytest.fixture() def toluene_to_cyclohexane(atom_mapping_basic_test_files): - meth = atom_mapping_basic_test_files['methylcyclohexane'] - tolu = atom_mapping_basic_test_files['toluene'] + meth = atom_mapping_basic_test_files["methylcyclohexane"] + tolu = atom_mapping_basic_test_files["toluene"] mapping = [(0, 0), (1, 1), (2, 6), (3, 5), (4, 4), (5, 3), (6, 2)] - return LigandAtomMapping(tolu, meth, - componentA_to_componentB=dict(mapping)) + return LigandAtomMapping(tolu, meth, componentA_to_componentB=dict(mapping)) @pytest.fixture() def toluene_to_methylnaphthalene(atom_mapping_basic_test_files): - tolu = atom_mapping_basic_test_files['toluene'] - naph = atom_mapping_basic_test_files['2-methylnaphthalene'] + tolu = atom_mapping_basic_test_files["toluene"] + naph = atom_mapping_basic_test_files["2-methylnaphthalene"] mapping = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 8), (5, 9), (6, 10)] - return LigandAtomMapping(tolu, naph, - componentA_to_componentB=dict(mapping)) + return LigandAtomMapping(tolu, naph, componentA_to_componentB=dict(mapping)) @pytest.fixture() def toluene_to_heptane(atom_mapping_basic_test_files): - tolu = atom_mapping_basic_test_files['toluene'] - hept = Chem.MolFromSmiles('CCCCCCC') + tolu = atom_mapping_basic_test_files["toluene"] + hept = Chem.MolFromSmiles("CCCCCCC") Chem.rdDepictor.Compute2DCoords(hept) hept = openfe.SmallMoleculeComponent(hept) mapping = [(6, 0)] - return LigandAtomMapping(tolu, hept, - componentA_to_componentB=dict(mapping)) + return LigandAtomMapping(tolu, hept, componentA_to_componentB=dict(mapping)) @pytest.fixture() def methylnaphthalene_to_naphthol(atom_mapping_basic_test_files): - m1 = atom_mapping_basic_test_files['2-methylnaphthalene'] - m2 = atom_mapping_basic_test_files['2-naftanol'] - mapping = [(0, 0), (1, 1), (2, 10), (3, 9), (4, 8), (5, 7), (6, 6), (7, 5), - (8, 4), (9, 3), (10, 2)] + m1 = atom_mapping_basic_test_files["2-methylnaphthalene"] + m2 = atom_mapping_basic_test_files["2-naftanol"] + mapping = [(0, 0), (1, 1), (2, 10), (3, 9), (4, 8), (5, 7), (6, 6), (7, 5), (8, 4), (9, 3), (10, 2)] return LigandAtomMapping(m1, m2, componentA_to_componentB=dict(mapping)) @@ -97,8 +94,7 @@ def test_atomic_number_score_pass(toluene_to_cyclohexane): def test_atomic_number_score_fail(methylnaphthalene_to_naphthol): - score = lomap_scorers.atomic_number_score( - methylnaphthalene_to_naphthol) + score = lomap_scorers.atomic_number_score(methylnaphthalene_to_naphthol) # single mismatch @ 0.5 assert score == pytest.approx(math.exp(-0.1 * 0.5)) @@ -109,8 +105,7 @@ def test_atomic_number_score_weights(methylnaphthalene_to_naphthol): 8: {6: 0.75}, # oxygen to carbon @ 12 } - score = lomap_scorers.atomic_number_score( - methylnaphthalene_to_naphthol, difficulty=difficulty) + score = lomap_scorers.atomic_number_score(methylnaphthalene_to_naphthol, difficulty=difficulty) # single mismatch @ (1 - 0.75) assert score == pytest.approx(math.exp(-0.1 * 0.25)) @@ -120,7 +115,7 @@ class TestSulfonamideRule: @staticmethod @pytest.fixture def ethylbenzene(): - m = Chem.AddHs(mol_from_smiles('c1ccccc1CCC')) + m = Chem.AddHs(mol_from_smiles("c1ccccc1CCC")) return openfe.SmallMoleculeComponent.from_rdkit(m) @@ -128,7 +123,7 @@ def ethylbenzene(): @pytest.fixture def sulfonamide(): # technically 3-phenylbutane-1-sulfonamide - m = Chem.AddHs(mol_from_smiles('c1ccccc1C(C)CCS(=O)(=O)N')) + m = Chem.AddHs(mol_from_smiles("c1ccccc1C(C)CCS(=O)(=O)N")) return openfe.SmallMoleculeComponent.from_rdkit(m) @@ -136,13 +131,32 @@ def sulfonamide(): @pytest.fixture def from_sulf_mapping(): # this is the standard output from lomap_scorers - return {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 14, - 8: 7, 9: 8, 10: 18, 14: 9, 15: 10, 16: 11, 17: 12, - 18: 13, 19: 15, 23: 16, 24: 17, 25: 19, 26: 20} + return { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 14, + 8: 7, + 9: 8, + 10: 18, + 14: 9, + 15: 10, + 16: 11, + 17: 12, + 18: 13, + 19: 15, + 23: 16, + 24: 17, + 25: 19, + 26: 20, + } @staticmethod - def test_sulfonamide_hit_backwards(ethylbenzene, sulfonamide, - from_sulf_mapping): + def test_sulfonamide_hit_backwards(ethylbenzene, sulfonamide, from_sulf_mapping): # a sulfonamide completely disappears on the RHS, so should trigger # the sulfonamide score to try and forbid this @@ -155,30 +169,30 @@ def test_sulfonamide_hit_backwards(ethylbenzene, sulfonamide, assert lomap_scorers.sulfonamides_score(mapping) == expected @staticmethod - def test_sulfonamide_hit_forwards(ethylbenzene, sulfonamide, - from_sulf_mapping): + def test_sulfonamide_hit_forwards(ethylbenzene, sulfonamide, from_sulf_mapping): AtoB = {v: k for k, v in from_sulf_mapping.items()} # this is the standard output from lomap_scorers - mapping = LigandAtomMapping(componentA=ethylbenzene, - componentB=sulfonamide, - componentA_to_componentB=AtoB) + mapping = LigandAtomMapping(componentA=ethylbenzene, componentB=sulfonamide, componentA_to_componentB=AtoB) expected = math.exp(-1 * 0.4) assert lomap_scorers.sulfonamides_score(mapping) == expected -@pytest.mark.parametrize('base,other,name,hit', [ - ('CCc1ccccc1', 'CCc1ccc(-c2ccco2)cc1', 'phenylfuran', False), - ('CCc1ccccc1', 'CCc1ccc(-c2cnc[nH]2)cc1', 'phenylimidazole', True), - ('CCc1ccccc1', 'CCc1ccc(-c2ccno2)cc1', 'phenylisoxazole', True), - ('CCc1ccccc1', 'CCc1ccc(-c2cnco2)cc1', 'phenyloxazole', True), - ('CCc1ccccc1', 'CCc1ccc(-c2cccnc2)cc1', 'phenylpyridine1', True), - ('CCc1ccccc1', 'CCc1ccc(-c2ccccn2)cc1', 'phenylpyridine2', True), - ('CCc1ccccc1', 'CCc1ccc(-c2cncnc2)cc1', 'phenylpyrimidine', True), - ('CCc1ccccc1', 'CCc1ccc(-c2ccc[nH]2)cc1', 'phenylpyrrole', False), - ('CCc1ccccc1', 'CCc1ccc(-c2ccccc2)cc1', 'phenylphenyl', False), -]) +@pytest.mark.parametrize( + "base,other,name,hit", + [ + ("CCc1ccccc1", "CCc1ccc(-c2ccco2)cc1", "phenylfuran", False), + ("CCc1ccccc1", "CCc1ccc(-c2cnc[nH]2)cc1", "phenylimidazole", True), + ("CCc1ccccc1", "CCc1ccc(-c2ccno2)cc1", "phenylisoxazole", True), + ("CCc1ccccc1", "CCc1ccc(-c2cnco2)cc1", "phenyloxazole", True), + ("CCc1ccccc1", "CCc1ccc(-c2cccnc2)cc1", "phenylpyridine1", True), + ("CCc1ccccc1", "CCc1ccc(-c2ccccn2)cc1", "phenylpyridine2", True), + ("CCc1ccccc1", "CCc1ccc(-c2cncnc2)cc1", "phenylpyrimidine", True), + ("CCc1ccccc1", "CCc1ccc(-c2ccc[nH]2)cc1", "phenylpyrrole", False), + ("CCc1ccccc1", "CCc1ccc(-c2ccccc2)cc1", "phenylphenyl", False), + ], +) def test_heterocycle_score(base, other, name, hit): # base -> other transform, if *hit* a forbidden heterocycle is created r1 = Chem.AddHs(mol_from_smiles(base)) @@ -198,42 +212,38 @@ def test_heterocycle_score(base, other, name, hit): # test individual scoring functions against lomap SCORE_NAMES = { - 'mcsr': 'mcsr_score', - 'mncar': 'mncar_score', - 'atomic_number_rule': 'atomic_number_score', - 'hybridization_rule': 'hybridization_score', - 'sulfonamides_rule': 'sulfonamides_score', - 'heterocycles_rule': 'heterocycles_score', - 'transmuting_methyl_into_ring_rule': 'transmuting_methyl_into_ring_score', - 'transmuting_ring_sizes_rule': 'transmuting_ring_sizes_score' + "mcsr": "mcsr_score", + "mncar": "mncar_score", + "atomic_number_rule": "atomic_number_score", + "hybridization_rule": "hybridization_score", + "sulfonamides_rule": "sulfonamides_score", + "heterocycles_rule": "heterocycles_score", + "transmuting_methyl_into_ring_rule": "transmuting_methyl_into_ring_score", + "transmuting_ring_sizes_rule": "transmuting_ring_sizes_score", } IX = itertools.combinations(range(8), 2) -@pytest.mark.parametrize('params', itertools.product(SCORE_NAMES, IX)) -def test_lomap_individual_scores(params, - atom_mapping_basic_test_files): +@pytest.mark.parametrize("params", itertools.product(SCORE_NAMES, IX)) +def test_lomap_individual_scores(params, atom_mapping_basic_test_files): scorename, (i, j) = params mols = sorted(atom_mapping_basic_test_files.items()) _, molA = mols[i] _, molB = mols[j] # reference value - lomap_version = getattr(lomap.MCS(molA.to_rdkit(), - molB.to_rdkit()), scorename)() + lomap_version = getattr(lomap.MCS(molA.to_rdkit(), molB.to_rdkit()), scorename)() # longer way mapper = openfe.setup.atom_mapping.LomapAtomMapper(threed=False) mapping = next(mapper.suggest_mappings(molA, molB)) openfe_version = getattr(lomap_scorers, SCORE_NAMES[scorename])(mapping) - assert lomap_version == pytest.approx(openfe_version), \ - f"{molA.name} {molB.name} {scorename}" + assert lomap_version == pytest.approx(openfe_version), f"{molA.name} {molB.name} {scorename}" # full back to back test again lomap -def test_lomap_regression(lomap_basic_test_files_dir, # in a dir for lomap - atom_mapping_basic_test_files): +def test_lomap_regression(lomap_basic_test_files_dir, atom_mapping_basic_test_files): # in a dir for lomap # run lomap dbmols = lomap.DBMolecules(lomap_basic_test_files_dir) matrix, _ = dbmols.build_matrices() @@ -275,6 +285,7 @@ def test_transmuting_methyl_into_ring_score(): The first mapping should trigger this rule, the second shouldn't """ + def makemol(smi): m = Chem.MolFromSmiles(smi) m = Chem.AddHs(m) @@ -282,10 +293,10 @@ def makemol(smi): return openfe.SmallMoleculeComponent(m) - core = 'CCC{}' - RC = makemol(core.format('C')) - RPh = makemol(core.format('c1ccccc1')) - RH = makemol(core.format('[H]')) + core = "CCC{}" + RC = makemol(core.format("C")) + RPh = makemol(core.format("c1ccccc1")) + RH = makemol(core.format("[H]")) RC_to_RPh = openfe.LigandAtomMapping(RC, RPh, {i: i for i in range(3)}) RH_to_RPh = openfe.LigandAtomMapping(RH, RPh, {i: i for i in range(3)}) diff --git a/openfe/tests/setup/atom_mapping/test_perses_atommapper.py b/openfe/tests/setup/atom_mapping/test_perses_atommapper.py index 58e0def43..133f7c667 100644 --- a/openfe/tests/setup/atom_mapping/test_perses_atommapper.py +++ b/openfe/tests/setup/atom_mapping/test_perses_atommapper.py @@ -1,10 +1,11 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import pytest -from openfe.setup.atom_mapping import PersesAtomMapper, LigandAtomMapping -pytest.importorskip('perses') -pytest.importorskip('openeye') +from openfe.setup.atom_mapping import LigandAtomMapping, PersesAtomMapper + +pytest.importorskip("perses") +pytest.importorskip("openeye") USING_NEW_OFF = True # by default we are now @@ -12,8 +13,8 @@ @pytest.mark.xfail(USING_NEW_OFF, reason="Perses #1108") def test_simple(atom_mapping_basic_test_files): # basic sanity check on the LigandAtomMapper - mol1 = atom_mapping_basic_test_files['methylcyclohexane'] - mol2 = atom_mapping_basic_test_files['toluene'] + mol1 = atom_mapping_basic_test_files["methylcyclohexane"] + mol2 = atom_mapping_basic_test_files["toluene"] mapper = PersesAtomMapper() @@ -29,8 +30,8 @@ def test_simple(atom_mapping_basic_test_files): def test_generator_length(atom_mapping_basic_test_files): # check that we get one mapping back from Lomap LigandAtomMapper then the # generator stops correctly - mol1 = atom_mapping_basic_test_files['methylcyclohexane'] - mol2 = atom_mapping_basic_test_files['toluene'] + mol1 = atom_mapping_basic_test_files["methylcyclohexane"] + mol2 = atom_mapping_basic_test_files["toluene"] mapper = PersesAtomMapper() diff --git a/openfe/tests/setup/atom_mapping/test_perses_scorers.py b/openfe/tests/setup/atom_mapping/test_perses_scorers.py index 19f117731..44a0d8c0f 100644 --- a/openfe/tests/setup/atom_mapping/test_perses_scorers.py +++ b/openfe/tests/setup/atom_mapping/test_perses_scorers.py @@ -1,58 +1,51 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import pytest -from numpy.testing import assert_allclose, assert_ - import numpy as np +import pytest +from numpy.testing import assert_, assert_allclose from openfe.setup import perses_scorers -pytest.importorskip('perses') -pytest.importorskip('openeye') +pytest.importorskip("perses") +pytest.importorskip("openeye") from ....utils.silence_root_logging import silence_root_logging + with silence_root_logging(): from perses.rjmc.atom_mapping import AtomMapper, AtomMapping USING_OLD_OFF = False -@pytest.mark.xfail(not USING_OLD_OFF, reason='perses #1108') +@pytest.mark.xfail(not USING_OLD_OFF, reason="perses #1108") def test_perses_normalization_not_using_positions(gufe_atom_mapping_matrix): # now run the openfe equivalent with the same ligand atom _mappings scorer = perses_scorers.default_perses_scorer - molecule_row = np.max(list(gufe_atom_mapping_matrix.keys()))+1 + molecule_row = np.max(list(gufe_atom_mapping_matrix.keys())) + 1 norm_scores = np.zeros([molecule_row, molecule_row]) for (i, j), ligand_atom_mapping in gufe_atom_mapping_matrix.items(): - norm_score = scorer( - ligand_atom_mapping, - use_positions=False, - normalize=True) + norm_score = scorer(ligand_atom_mapping, use_positions=False, normalize=True) norm_scores[i, j] = norm_scores[j, i] = norm_score assert norm_scores.shape == (8, 8) - assert_(np.all((norm_scores <= 1) & (norm_scores >= 0.0)), - msg="OpenFE norm value larger than 1 or smaller than 0") + assert_(np.all((norm_scores <= 1) & (norm_scores >= 0.0)), msg="OpenFE norm value larger than 1 or smaller than 0") -@pytest.mark.xfail(not USING_OLD_OFF, reason='perses #1108') +@pytest.mark.xfail(not USING_OLD_OFF, reason="perses #1108") def test_perses_not_implemented_position_using(gufe_atom_mapping_matrix): scorer = perses_scorers.default_perses_scorer first_key = list(gufe_atom_mapping_matrix.keys())[0] match_re = "normalizing using positions is not currently implemented" with pytest.raises(NotImplementedError, match=match_re): - norm_score = scorer( - gufe_atom_mapping_matrix[first_key], - use_positions=True, - normalize=True) + norm_score = scorer(gufe_atom_mapping_matrix[first_key], use_positions=True, normalize=True) -@pytest.mark.xfail(not USING_OLD_OFF, reason='perses #1108') +@pytest.mark.xfail(not USING_OLD_OFF, reason="perses #1108") def test_perses_regression(gufe_atom_mapping_matrix): # This is the way how perses does scoring - molecule_row = np.max(list(gufe_atom_mapping_matrix.keys()))+1 + molecule_row = np.max(list(gufe_atom_mapping_matrix.keys())) + 1 matrix = np.zeros([molecule_row, molecule_row]) for x in gufe_atom_mapping_matrix.items(): (i, j), ligand_atom_mapping = x @@ -60,11 +53,10 @@ def test_perses_regression(gufe_atom_mapping_matrix): perses_atom_mapping = AtomMapping( old_mol=ligand_atom_mapping.componentA.to_openff(), new_mol=ligand_atom_mapping.componentB.to_openff(), - old_to_new_atom_map=ligand_atom_mapping.componentA_to_componentB + old_to_new_atom_map=ligand_atom_mapping.componentA_to_componentB, ) # score Perses Mapping - Perses Style - matrix[i, j] = matrix[j, i] = AtomMapper( - ).score_mapping(perses_atom_mapping) + matrix[i, j] = matrix[j, i] = AtomMapper().score_mapping(perses_atom_mapping) assert matrix.shape == (8, 8) @@ -72,14 +64,8 @@ def test_perses_regression(gufe_atom_mapping_matrix): scorer = perses_scorers.default_perses_scorer scores = np.zeros_like(matrix) for (i, j), ligand_atom_mapping in gufe_atom_mapping_matrix.items(): - score = scorer( - ligand_atom_mapping, - use_positions=True, - normalize=False) + score = scorer(ligand_atom_mapping, use_positions=True, normalize=False) scores[i, j] = scores[j, i] = score - assert_allclose( - actual=matrix, - desired=scores, - err_msg="openFE was not close to perses") + assert_allclose(actual=matrix, desired=scores, err_msg="openFE was not close to perses") diff --git a/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py b/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py index bf706923f..5bc519a94 100644 --- a/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py +++ b/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py @@ -1,34 +1,33 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from rdkit import Chem import pytest +from gufe import ChemicalSystem, SolventComponent +from rdkit import Chem -from gufe import ChemicalSystem -from openfe.setup.chemicalsystem_generator.easy_chemicalsystem_generator import ( - EasyChemicalSystemGenerator, -) - +from openfe.setup.chemicalsystem_generator.easy_chemicalsystem_generator import EasyChemicalSystemGenerator from ...conftest import T4_protein_component -from gufe import SolventComponent -from .component_checks import proteinC_in_chem_sys, solventC_in_chem_sys, ligandC_in_chem_sys +from .component_checks import ligandC_in_chem_sys, proteinC_in_chem_sys, solventC_in_chem_sys def test_easy_chemical_system_generator_init(T4_protein_component): chem_sys_generator = EasyChemicalSystemGenerator(do_vacuum=True) - + chem_sys_generator = EasyChemicalSystemGenerator(solvent=SolventComponent()) - - chem_sys_generator = EasyChemicalSystemGenerator( - solvent=SolventComponent(), protein=T4_protein_component - ) - + + chem_sys_generator = EasyChemicalSystemGenerator(solvent=SolventComponent(), protein=T4_protein_component) + chem_sys_generator = EasyChemicalSystemGenerator( - solvent=SolventComponent(), protein=T4_protein_component, do_vacuum=True + solvent=SolventComponent(), + protein=T4_protein_component, + do_vacuum=True, ) - with pytest.raises(ValueError, match='Chemical system generator is unable to generate any chemical systems with neither protein nor solvent nor do_vacuum'): + with pytest.raises( + ValueError, + match="Chemical system generator is unable to generate any chemical systems with neither protein nor solvent nor do_vacuum", + ): chem_sys_generator = EasyChemicalSystemGenerator() @@ -69,9 +68,7 @@ def test_build_protein_chemical_system(ethane, T4_protein_component): def test_build_hydr_scenario_chemical_systems(ethane): - chem_sys_generator = EasyChemicalSystemGenerator( - do_vacuum=True, solvent=SolventComponent() - ) + chem_sys_generator = EasyChemicalSystemGenerator(do_vacuum=True, solvent=SolventComponent()) chem_sys_gen = chem_sys_generator(ethane) chem_syss = [chem_sys for chem_sys in chem_sys_gen] @@ -84,7 +81,8 @@ def test_build_hydr_scenario_chemical_systems(ethane): def test_build_binding_scenario_chemical_systems(ethane, T4_protein_component): chem_sys_generator = EasyChemicalSystemGenerator( - solvent=SolventComponent(), protein=T4_protein_component, + solvent=SolventComponent(), + protein=T4_protein_component, ) chem_sys_gen = chem_sys_generator(ethane) chem_syss = [chem_sys for chem_sys in chem_sys_gen] @@ -99,7 +97,9 @@ def test_build_binding_scenario_chemical_systems(ethane, T4_protein_component): def test_build_hbinding_scenario_chemical_systems(ethane, T4_protein_component): chem_sys_generator = EasyChemicalSystemGenerator( - do_vacuum=True, solvent=SolventComponent(), protein=T4_protein_component, + do_vacuum=True, + solvent=SolventComponent(), + protein=T4_protein_component, ) chem_sys_gen = chem_sys_generator(ethane) chem_syss = [chem_sys for chem_sys in chem_sys_gen] diff --git a/openfe/tests/setup/test_network_planning.py b/openfe/tests/setup/test_network_planning.py index 5d194282a..03d6da0c3 100644 --- a/openfe/tests/setup/test_network_planning.py +++ b/openfe/tests/setup/test_network_planning.py @@ -1,8 +1,8 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from rdkit import Chem -import pytest import networkx as nx +import pytest +from rdkit import Chem import openfe.setup @@ -25,27 +25,27 @@ def _mappings_generator(self, molA, molB): yield {0: 0} -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def toluene_vs_others(atom_mapping_basic_test_files): - central_ligand_name = 'toluene' - others = [v for (k, v) in atom_mapping_basic_test_files.items() - if k != central_ligand_name] + central_ligand_name = "toluene" + others = [v for (k, v) in atom_mapping_basic_test_files.items() if k != central_ligand_name] toluene = atom_mapping_basic_test_files[central_ligand_name] return toluene, others -@pytest.mark.parametrize('as_list', [False, True]) -def test_radial_network(atom_mapping_basic_test_files, toluene_vs_others, - as_list): +@pytest.mark.parametrize("as_list", [False, True]) +def test_radial_network(atom_mapping_basic_test_files, toluene_vs_others, as_list): toluene, others = toluene_vs_others - central_ligand_name = 'toluene' + central_ligand_name = "toluene" mapper = openfe.setup.atom_mapping.LomapAtomMapper() if as_list: mapper = [mapper] network = openfe.setup.ligand_network_planning.generate_radial_network( - ligands=others, central_ligand=toluene, - mappers=mapper, scorer=None, + ligands=others, + central_ligand=toluene, + mappers=mapper, + scorer=None, ) # couple sanity checks assert len(network.nodes) == len(atom_mapping_basic_test_files) @@ -54,8 +54,7 @@ def test_radial_network(atom_mapping_basic_test_files, toluene_vs_others, ligands_in_network = {mol.name for mol in network.nodes} assert ligands_in_network == set(atom_mapping_basic_test_files.keys()) # check that every edge has the central ligand within - assert all((central_ligand_name in {mapping.componentA.name, mapping.componentB.name}) - for mapping in network.edges) + assert all((central_ligand_name in {mapping.componentA.name, mapping.componentB.name}) for mapping in network.edges) def test_radial_network_self_central(toluene_vs_others): @@ -65,8 +64,10 @@ def test_radial_network_self_central(toluene_vs_others): with pytest.warns(UserWarning, match="The central_ligand"): network = openfe.setup.ligand_network_planning.generate_radial_network( - ligands=ligs, central_ligand=ligs[0], - mappers=openfe.setup.atom_mapping.LomapAtomMapper(), scorer=None + ligands=ligs, + central_ligand=ligs[0], + mappers=openfe.setup.atom_mapping.LomapAtomMapper(), + scorer=None, ) assert len(network.edges) == len(ligs) - 1 @@ -82,15 +83,15 @@ def scorer(mapping): ligands=others, central_ligand=toluene, mappers=[BadMapper(), openfe.setup.atom_mapping.LomapAtomMapper()], - scorer=scorer + scorer=scorer, ) assert len(network.edges) == len(others) for edge in network.edges: # we didn't take the bad mapper assert len(edge.componentA_to_componentB) > 1 - assert 'score' in edge.annotations - assert edge.annotations['score'] == len(edge.componentA_to_componentB) + assert "score" in edge.annotations + assert edge.annotations["score"] == len(edge.componentA_to_componentB) def test_radial_network_multiple_mappers_no_scorer(toluene_vs_others): @@ -100,7 +101,7 @@ def test_radial_network_multiple_mappers_no_scorer(toluene_vs_others): network = openfe.setup.ligand_network_planning.generate_radial_network( ligands=others, central_ligand=toluene, - mappers=[BadMapper(), openfe.setup.atom_mapping.LomapAtomMapper()] + mappers=[BadMapper(), openfe.setup.atom_mapping.LomapAtomMapper()], ) assert len(network.edges) == len(others) @@ -109,28 +110,24 @@ def test_radial_network_multiple_mappers_no_scorer(toluene_vs_others): def test_radial_network_failure(atom_mapping_basic_test_files): - nigel = openfe.SmallMoleculeComponent(mol_from_smiles('N')) + nigel = openfe.SmallMoleculeComponent(mol_from_smiles("N")) - with pytest.raises(ValueError, match='No mapping found for'): + with pytest.raises(ValueError, match="No mapping found for"): network = openfe.setup.ligand_network_planning.generate_radial_network( ligands=[nigel], - central_ligand=atom_mapping_basic_test_files['toluene'], + central_ligand=atom_mapping_basic_test_files["toluene"], mappers=[openfe.setup.atom_mapping.LomapAtomMapper()], - scorer=None + scorer=None, ) -@pytest.mark.parametrize('with_progress', [True, False]) -@pytest.mark.parametrize('with_scorer', [True, False]) -@pytest.mark.parametrize('extra_mapper', [True, False]) -def test_generate_maximal_network(toluene_vs_others, with_progress, - with_scorer, extra_mapper): +@pytest.mark.parametrize("with_progress", [True, False]) +@pytest.mark.parametrize("with_scorer", [True, False]) +@pytest.mark.parametrize("extra_mapper", [True, False]) +def test_generate_maximal_network(toluene_vs_others, with_progress, with_scorer, extra_mapper): toluene, others = toluene_vs_others if extra_mapper: - mappers = [ - openfe.setup.atom_mapping.LomapAtomMapper(), - BadMapper() - ] + mappers = [openfe.setup.atom_mapping.LomapAtomMapper(), BadMapper()] else: mappers = openfe.setup.atom_mapping.LomapAtomMapper() @@ -157,18 +154,19 @@ def scoring_func(mapping): if scorer: for edge in network.edges: - score = edge.annotations['score'] + score = edge.annotations["score"] assert score == len(edge.componentA_to_componentB) else: for edge in network.edges: - assert 'score' not in edge.annotations + assert "score" not in edge.annotations -@pytest.mark.parametrize('multi_mappers', [False, True]) +@pytest.mark.parametrize("multi_mappers", [False, True]) def test_minimal_spanning_network_mappers(atom_mapping_basic_test_files, multi_mappers): - ligands = [atom_mapping_basic_test_files['toluene'], - atom_mapping_basic_test_files['2-naftanol'], - ] + ligands = [ + atom_mapping_basic_test_files["toluene"], + atom_mapping_basic_test_files["2-naftanol"], + ] if multi_mappers: mappers = [BadMapper(), openfe.setup.atom_mapping.LomapAtomMapper()] @@ -188,7 +186,7 @@ def scorer(mapping): assert list(network.edges) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def minimal_spanning_network(toluene_vs_others): toluene, others = toluene_vs_others mappers = [BadMapper(), openfe.setup.atom_mapping.LomapAtomMapper()] @@ -199,7 +197,7 @@ def scorer(mapping): network = openfe.setup.ligand_network_planning.generate_minimal_spanning_network( ligands=others + [toluene], mappers=mappers, - scorer=scorer + scorer=scorer, ) return network @@ -208,8 +206,7 @@ def test_minimal_spanning_network(minimal_spanning_network, toluene_vs_others): tol, others = toluene_vs_others assert len(minimal_spanning_network.nodes) == len(others) + 1 for edge in minimal_spanning_network.edges: - assert edge.componentA_to_componentB != { - 0: 0} # lomap should find something + assert edge.componentA_to_componentB != {0: 0} # lomap should find something def test_minimal_spanning_network_connectedness(minimal_spanning_network): @@ -225,19 +222,18 @@ def test_minimal_spanning_network_connectedness(minimal_spanning_network): def test_minimal_spanning_network_regression(minimal_spanning_network): # issue #244, this was previously giving non-reproducible (yet valid) # networks when scores were tied. - edge_ids = sorted( - (edge.componentA.name, edge.componentB.name) - for edge in minimal_spanning_network.edges + edge_ids = sorted((edge.componentA.name, edge.componentB.name) for edge in minimal_spanning_network.edges) + ref = sorted( + [ + ("1,3,7-trimethylnaphthalene", "2,6-dimethylnaphthalene"), + ("1-butyl-4-methylbenzene", "2-methyl-6-propylnaphthalene"), + ("2,6-dimethylnaphthalene", "2-methyl-6-propylnaphthalene"), + ("2,6-dimethylnaphthalene", "2-methylnaphthalene"), + ("2,6-dimethylnaphthalene", "2-naftanol"), + ("2,6-dimethylnaphthalene", "methylcyclohexane"), + ("2,6-dimethylnaphthalene", "toluene"), + ], ) - ref = sorted([ - ('1,3,7-trimethylnaphthalene', '2,6-dimethylnaphthalene'), - ('1-butyl-4-methylbenzene', '2-methyl-6-propylnaphthalene'), - ('2,6-dimethylnaphthalene', '2-methyl-6-propylnaphthalene'), - ('2,6-dimethylnaphthalene', '2-methylnaphthalene'), - ('2,6-dimethylnaphthalene', '2-naftanol'), - ('2,6-dimethylnaphthalene', 'methylcyclohexane'), - ('2,6-dimethylnaphthalene', 'toluene'), - ]) assert len(edge_ids) == len(ref) assert edge_ids == ref @@ -254,11 +250,11 @@ def scorer(mapping): network = openfe.setup.ligand_network_planning.generate_minimal_spanning_network( ligands=others + [toluene, nimrod], mappers=[openfe.setup.atom_mapping.LomapAtomMapper()], - scorer=scorer + scorer=scorer, ) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def minimal_redundant_network(toluene_vs_others): toluene, others = toluene_vs_others mappers = [BadMapper(), openfe.setup.atom_mapping.LomapAtomMapper()] @@ -270,7 +266,7 @@ def scorer(mapping): ligands=others + [toluene], mappers=mappers, scorer=scorer, - mst_num=2 + mst_num=2, ) return network @@ -282,12 +278,10 @@ def test_minimal_redundant_network(minimal_redundant_network, toluene_vs_others) assert len(minimal_redundant_network.nodes) == len(others) + 1 # test for correct number of edges - assert len(minimal_redundant_network.edges) == 2 * \ - (len(minimal_redundant_network.nodes) - 1) + assert len(minimal_redundant_network.edges) == 2 * (len(minimal_redundant_network.nodes) - 1) for edge in minimal_redundant_network.edges: - assert edge.componentA_to_componentB != { - 0: 0} # lomap should find something + assert edge.componentA_to_componentB != {0: 0} # lomap should find something def test_minimal_redundant_network_connectedness(minimal_redundant_network): @@ -303,33 +297,31 @@ def test_minimal_redundant_network_connectedness(minimal_redundant_network): def test_redundant_vs_spanning_network(minimal_redundant_network, minimal_spanning_network): # when setting minimal redundant network to only take one MST, it should have as many # edges as the regular minimum spanning network - assert 2 * len(minimal_spanning_network.edges) == len( - minimal_redundant_network.edges) + assert 2 * len(minimal_spanning_network.edges) == len(minimal_redundant_network.edges) def test_minimal_redundant_network_edges(minimal_redundant_network): # issue #244, this was previously giving non-reproducible (yet valid) # networks when scores were tied. - edge_ids = sorted( - (edge.componentA.name, edge.componentB.name) - for edge in minimal_redundant_network.edges + edge_ids = sorted((edge.componentA.name, edge.componentB.name) for edge in minimal_redundant_network.edges) + ref = sorted( + [ + ("1,3,7-trimethylnaphthalene", "2,6-dimethylnaphthalene"), + ("1,3,7-trimethylnaphthalene", "2-methyl-6-propylnaphthalene"), + ("1-butyl-4-methylbenzene", "2,6-dimethylnaphthalene"), + ("1-butyl-4-methylbenzene", "2-methyl-6-propylnaphthalene"), + ("1-butyl-4-methylbenzene", "toluene"), + ("2,6-dimethylnaphthalene", "2-methyl-6-propylnaphthalene"), + ("2,6-dimethylnaphthalene", "2-methylnaphthalene"), + ("2,6-dimethylnaphthalene", "2-naftanol"), + ("2,6-dimethylnaphthalene", "methylcyclohexane"), + ("2,6-dimethylnaphthalene", "toluene"), + ("2-methyl-6-propylnaphthalene", "2-methylnaphthalene"), + ("2-methylnaphthalene", "2-naftanol"), + ("2-methylnaphthalene", "methylcyclohexane"), + ("2-methylnaphthalene", "toluene"), + ], ) - ref = sorted([ - ('1,3,7-trimethylnaphthalene', '2,6-dimethylnaphthalene'), - ('1,3,7-trimethylnaphthalene', '2-methyl-6-propylnaphthalene'), - ('1-butyl-4-methylbenzene', '2,6-dimethylnaphthalene'), - ('1-butyl-4-methylbenzene', '2-methyl-6-propylnaphthalene'), - ('1-butyl-4-methylbenzene', 'toluene'), - ('2,6-dimethylnaphthalene', '2-methyl-6-propylnaphthalene'), - ('2,6-dimethylnaphthalene', '2-methylnaphthalene'), - ('2,6-dimethylnaphthalene', '2-naftanol'), - ('2,6-dimethylnaphthalene', 'methylcyclohexane'), - ('2,6-dimethylnaphthalene', 'toluene'), - ('2-methyl-6-propylnaphthalene', '2-methylnaphthalene'), - ('2-methylnaphthalene', '2-naftanol'), - ('2-methylnaphthalene', 'methylcyclohexane'), - ('2-methylnaphthalene', 'toluene') - ]) assert len(edge_ids) == len(ref) assert edge_ids == ref @@ -339,8 +331,7 @@ def test_minimal_redundant_network_redundant(minimal_redundant_network): # test that each node is connected to 2 edges. network = minimal_redundant_network for node in network.nodes: - assert len(network.graph.in_edges(node)) + \ - len(network.graph.out_edges(node)) >= 2 + assert len(network.graph.in_edges(node)) + len(network.graph.out_edges(node)) >= 2 def test_minimal_redundant_network_unreachable(toluene_vs_others): @@ -354,7 +345,7 @@ def scorer(mapping): network = openfe.setup.ligand_network_planning.generate_minimal_redundant_network( ligands=others + [toluene, nimrod], mappers=[openfe.setup.atom_mapping.LomapAtomMapper()], - scorer=scorer + scorer=scorer, ) @@ -362,8 +353,8 @@ def test_network_from_names(atom_mapping_basic_test_files): ligs = list(atom_mapping_basic_test_files.values()) requested = [ - ('toluene', '2-naftanol'), - ('2-methylnaphthalene', '2-naftanol'), + ("toluene", "2-naftanol"), + ("2-methylnaphthalene", "2-naftanol"), ] network = openfe.setup.ligand_network_planning.generate_network_from_names( @@ -374,8 +365,7 @@ def test_network_from_names(atom_mapping_basic_test_files): assert len(network.nodes) == len(ligs) assert len(network.edges) == 2 - actual_edges = [(e.componentA.name, e.componentB.name) - for e in network.edges] + actual_edges = [(e.componentA.name, e.componentB.name) for e in network.edges] assert set(requested) == set(actual_edges) @@ -383,8 +373,8 @@ def test_network_from_names_bad_name(atom_mapping_basic_test_files): ligs = list(atom_mapping_basic_test_files.values()) requested = [ - ('hank', '2-naftanol'), - ('2-methylnaphthalene', '2-naftanol'), + ("hank", "2-naftanol"), + ("2-methylnaphthalene", "2-naftanol"), ] with pytest.raises(KeyError, match="Invalid name"): @@ -400,8 +390,8 @@ def test_network_from_names_duplicate_name(atom_mapping_basic_test_files): ligs = ligs + [ligs[0]] requested = [ - ('toluene', '2-naftanol'), - ('2-methylnaphthalene', '2-naftanol'), + ("toluene", "2-naftanol"), + ("2-methylnaphthalene", "2-naftanol"), ] with pytest.raises(ValueError, match="Duplicate names"): @@ -428,8 +418,7 @@ def test_network_from_indices(atom_mapping_basic_test_files): edges = list(network.edges) expected_edges = {(ligs[0], ligs[1]), (ligs[2], ligs[3])} - actual_edges = {(edges[0].componentA, edges[0].componentB), - (edges[1].componentA, edges[1].componentB)} + actual_edges = {(edges[0].componentA, edges[0].componentB), (edges[1].componentA, edges[1].componentB)} assert actual_edges == expected_edges @@ -459,14 +448,14 @@ def test_network_from_indices_disconnected_warning(atom_mapping_basic_test_files ) -@pytest.mark.parametrize('file_fixture, loader', [ - ['orion_network', - openfe.setup.ligand_network_planning.load_orion_network], - ['fepplus_network', - openfe.setup.ligand_network_planning.load_fepplus_network], -]) -def test_network_from_external(file_fixture, loader, request, - benzene_modifications): +@pytest.mark.parametrize( + "file_fixture, loader", + [ + ["orion_network", openfe.setup.ligand_network_planning.load_orion_network], + ["fepplus_network", openfe.setup.ligand_network_planning.load_fepplus_network], + ], +) +def test_network_from_external(file_fixture, loader, request, benzene_modifications): network_file = request.getfixturevalue(file_fixture) @@ -477,14 +466,12 @@ def test_network_from_external(file_fixture, loader, request, ) expected_edges = { - (benzene_modifications['benzene'], benzene_modifications['toluene']), - (benzene_modifications['benzene'], benzene_modifications['phenol']), - (benzene_modifications['benzene'], - benzene_modifications['benzonitrile']), - (benzene_modifications['benzene'], benzene_modifications['anisole']), - (benzene_modifications['benzene'], benzene_modifications['styrene']), - (benzene_modifications['benzene'], - benzene_modifications['benzaldehyde']), + (benzene_modifications["benzene"], benzene_modifications["toluene"]), + (benzene_modifications["benzene"], benzene_modifications["phenol"]), + (benzene_modifications["benzene"], benzene_modifications["benzonitrile"]), + (benzene_modifications["benzene"], benzene_modifications["anisole"]), + (benzene_modifications["benzene"], benzene_modifications["styrene"]), + (benzene_modifications["benzene"], benzene_modifications["benzaldehyde"]), } actual_edges = {(e.componentA, e.componentB) for e in list(network.edges)} @@ -494,16 +481,16 @@ def test_network_from_external(file_fixture, loader, request, assert actual_edges == expected_edges -@pytest.mark.parametrize('file_fixture, loader', [ - ['orion_network', - openfe.setup.ligand_network_planning.load_orion_network], - ['fepplus_network', - openfe.setup.ligand_network_planning.load_fepplus_network], -]) -def test_network_from_external_unknown_edge(file_fixture, loader, request, - benzene_modifications): +@pytest.mark.parametrize( + "file_fixture, loader", + [ + ["orion_network", openfe.setup.ligand_network_planning.load_orion_network], + ["fepplus_network", openfe.setup.ligand_network_planning.load_fepplus_network], + ], +) +def test_network_from_external_unknown_edge(file_fixture, loader, request, benzene_modifications): network_file = request.getfixturevalue(file_fixture) - ligs = [l for l in benzene_modifications.values() if l.name != 'phenol'] + ligs = [l for l in benzene_modifications.values() if l.name != "phenol"] with pytest.raises(KeyError, match="Invalid name"): network = loader( @@ -527,14 +514,14 @@ def test_network_from_external_unknown_edge(file_fixture, loader, request, def test_bad_orion_network(benzene_modifications, tmpdir): with tmpdir.as_cwd(): - with open('bad_orion_net.dat', 'w') as f: + with open("bad_orion_net.dat", "w") as f: f.write(BAD_ORION_NETWORK) with pytest.raises(KeyError, match="line does not match"): network = openfe.setup.ligand_network_planning.load_orion_network( ligands=[l for l in benzene_modifications.values()], mapper=openfe.LomapAtomMapper(), - network_file='bad_orion_net.dat', + network_file="bad_orion_net.dat", ) @@ -550,12 +537,12 @@ def test_bad_orion_network(benzene_modifications, tmpdir): def test_bad_edges_network(benzene_modifications, tmpdir): with tmpdir.as_cwd(): - with open('bad_edges.edges', 'w') as f: + with open("bad_edges.edges", "w") as f: f.write(BAD_EDGES) with pytest.raises(KeyError, match="line does not match"): network = openfe.setup.ligand_network_planning.load_fepplus_network( ligands=[l for l in benzene_modifications.values()], mapper=openfe.LomapAtomMapper(), - network_file='bad_edges.edges', + network_file="bad_edges.edges", ) diff --git a/openfe/tests/storage/conftest.py b/openfe/tests/storage/conftest.py index 13bb35f3d..9bf196612 100644 --- a/openfe/tests/storage/conftest.py +++ b/openfe/tests/storage/conftest.py @@ -1,28 +1,33 @@ -import pytest -from openff.units import unit import gufe -from gufe import SolventComponent, ChemicalSystem +import pytest +from gufe import ChemicalSystem, SolventComponent from gufe.tests.test_protocol import DummyProtocol +from openff.units import unit @pytest.fixture def solv_comp(): - yield SolventComponent(positive_ion="K", negative_ion="Cl", - ion_concentration=0.0 * unit.molar) + yield SolventComponent(positive_ion="K", negative_ion="Cl", ion_concentration=0.0 * unit.molar) + @pytest.fixture def solvated_complex(T4_protein_component, benzene_transforms, solv_comp): return ChemicalSystem( - {"ligand": benzene_transforms['toluene'], - "protein": T4_protein_component, "solvent": solv_comp,} + { + "ligand": benzene_transforms["toluene"], + "protein": T4_protein_component, + "solvent": solv_comp, + }, ) @pytest.fixture def solvated_ligand(benzene_transforms, solv_comp): return ChemicalSystem( - {"ligand": benzene_transforms['toluene'], - "solvent": solv_comp,} + { + "ligand": benzene_transforms["toluene"], + "solvent": solv_comp, + }, ) @@ -45,29 +50,34 @@ def complex_equilibrium(solvated_complex): @pytest.fixture -def benzene_variants_star_map(benzene_transforms, solv_comp, - T4_protein_component): - variants = ['toluene', 'phenol', 'benzonitrile', 'anisole', - 'benzaldehyde', 'styrene'] +def benzene_variants_star_map(benzene_transforms, solv_comp, T4_protein_component): + variants = ["toluene", "phenol", "benzonitrile", "anisole", "benzaldehyde", "styrene"] # define the solvent chemical systems and transformations between # benzene and the others solvated_ligands = {} solvated_ligand_transformations = {} - solvated_ligands['benzene'] = ChemicalSystem( - {"solvent": solv_comp, "ligand": benzene_transforms['benzene'],}, + solvated_ligands["benzene"] = ChemicalSystem( + { + "solvent": solv_comp, + "ligand": benzene_transforms["benzene"], + }, name="benzene-solvent", ) for ligand in variants: solvated_ligands[ligand] = ChemicalSystem( - {"solvent": solv_comp, "ligand": benzene_transforms[ligand],}, - name=f"{ligand}-solvent" + { + "solvent": solv_comp, + "ligand": benzene_transforms[ligand], + }, + name=f"{ligand}-solvent", ) solvated_ligand_transformations[("benzene", ligand)] = gufe.Transformation( - solvated_ligands['benzene'], solvated_ligands[ligand], + solvated_ligands["benzene"], + solvated_ligands[ligand], protocol=DummyProtocol(settings=DummyProtocol.default_settings()), mapping=None, ) @@ -78,15 +88,13 @@ def benzene_variants_star_map(benzene_transforms, solv_comp, solvated_complex_transformations = {} solvated_complexes["benzene"] = gufe.ChemicalSystem( - {"protein": T4_protein_component, "solvent": solv_comp, - "ligand": benzene_transforms['benzene']}, + {"protein": T4_protein_component, "solvent": solv_comp, "ligand": benzene_transforms["benzene"]}, name="benzene-complex", ) for ligand in variants: solvated_complexes[ligand] = gufe.ChemicalSystem( - {"protein": T4_protein_component, "solvent": solv_comp, - "ligand": benzene_transforms[ligand]}, + {"protein": T4_protein_component, "solvent": solv_comp, "ligand": benzene_transforms[ligand]}, name=f"{ligand}-complex", ) solvated_complex_transformations[("benzene", ligand)] = gufe.Transformation( @@ -97,6 +105,5 @@ def benzene_variants_star_map(benzene_transforms, solv_comp, ) return gufe.AlchemicalNetwork( - list(solvated_ligand_transformations.values()) - + list(solvated_complex_transformations.values()) + list(solvated_ligand_transformations.values()) + list(solvated_complex_transformations.values()), ) diff --git a/openfe/tests/storage/test_metadatastore.py b/openfe/tests/storage/test_metadatastore.py index 395a46e50..4ca116d68 100644 --- a/openfe/tests/storage/test_metadatastore.py +++ b/openfe/tests/storage/test_metadatastore.py @@ -1,37 +1,33 @@ -import pytest import json import pathlib -from openfe.storage.metadatastore import ( - JSONMetadataStore, PerFileJSONMetadataStore -) +import pytest +from gufe.storage.errors import ChangedExternalResourceError, MissingExternalResourceError from gufe.storage.externalresource import FileStorage from gufe.storage.externalresource.base import Metadata -from gufe.storage.errors import ( - MissingExternalResourceError, ChangedExternalResourceError -) + +from openfe.storage.metadatastore import JSONMetadataStore, PerFileJSONMetadataStore @pytest.fixture def json_metadata(tmpdir): - metadata_dict = {'path/to/foo.txt': {'md5': 'bar'}} + metadata_dict = {"path/to/foo.txt": {"md5": "bar"}} external_store = FileStorage(str(tmpdir)) - with open(tmpdir / 'metadata.json', mode='wb') as f: - f.write(json.dumps(metadata_dict).encode('utf-8')) + with open(tmpdir / "metadata.json", mode="wb") as f: + f.write(json.dumps(metadata_dict).encode("utf-8")) json_metadata = JSONMetadataStore(external_store) return json_metadata @pytest.fixture def per_file_metadata(tmp_path): - metadata_dict = {'path': 'path/to/foo.txt', - 'metadata': {'md5': 'bar'}} + metadata_dict = {"path": "path/to/foo.txt", "metadata": {"md5": "bar"}} external_store = FileStorage(str(tmp_path)) - metadata_loc = 'metadata/path/to/foo.txt.json' + metadata_loc = "metadata/path/to/foo.txt.json" metadata_path = tmp_path / pathlib.Path(metadata_loc) metadata_path.parent.mkdir(parents=True, exist_ok=True) - with open(metadata_path, mode='wb') as f: - f.write(json.dumps(metadata_dict).encode('utf-8')) + with open(metadata_path, mode="wb") as f: + f.write(json.dumps(metadata_dict).encode("utf-8")) per_file_metadata = PerFileJSONMetadataStore(external_store) return per_file_metadata @@ -39,27 +35,27 @@ def per_file_metadata(tmp_path): class MetadataTests: """Mixin with a few tests for any subclass of MetadataStore""" + def test_store_metadata(self, metadata): raise NotImplementedError() def test_load_all_metadata(self): - raise NotImplementedError("This should call " - "self._test_load_all_metadata") + raise NotImplementedError("This should call " "self._test_load_all_metadata") def test_delete(self): raise NotImplementedError("This should call self._test_delete") def _test_load_all_metadata(self, metadata): - expected = {'path/to/foo.txt': Metadata(md5='bar')} + expected = {"path/to/foo.txt": Metadata(md5="bar")} metadata._metadata_cache = {} loaded = metadata.load_all_metadata() assert loaded == expected def _test_delete(self, metadata): - assert 'path/to/foo.txt' in metadata + assert "path/to/foo.txt" in metadata assert len(metadata) == 1 - del metadata['path/to/foo.txt'] - assert 'path/to/foo.txt' not in metadata + del metadata["path/to/foo.txt"] + assert "path/to/foo.txt" not in metadata assert len(metadata) == 0 def _test_iter(self, metadata): @@ -77,16 +73,15 @@ def test_store_metadata(self, json_metadata): meta = Metadata(md5="other") json_metadata.store_metadata("path/to/other.txt", meta) base_path = json_metadata.external_store.root_dir - metadata_json = base_path / 'metadata.json' + metadata_json = base_path / "metadata.json" assert metadata_json.exists() - with open(metadata_json, mode='r') as f: + with open(metadata_json) as f: metadata_dict = json.load(f) - metadata = {key: Metadata(**val) - for key, val in metadata_dict.items()} + metadata = {key: Metadata(**val) for key, val in metadata_dict.items()} assert metadata == json_metadata._metadata_cache - assert json_metadata['path/to/other.txt'] == meta + assert json_metadata["path/to/other.txt"] == meta assert len(metadata) == 2 def test_load_all_metadata(self, json_metadata): @@ -121,9 +116,8 @@ def test_store_metadata(self, per_file_metadata): meta = Metadata(md5="other") per_file_metadata.store_metadata("path/to/other.txt", meta) assert expected_path.exists() - expected = {'path': "path/to/other.txt", - 'metadata': {"md5": "other"}} - with open(expected_path, mode='r')as f: + expected = {"path": "path/to/other.txt", "metadata": {"md5": "other"}} + with open(expected_path) as f: assert json.load(f) == expected def test_load_all_metadata(self, per_file_metadata): @@ -145,10 +139,9 @@ def test_getitem(self, per_file_metadata): def test_bad_metadata_contents(self, tmp_path): loc = tmp_path / "metadata/foo.txt.json" loc.parent.mkdir(parents=True, exist_ok=True) - bad_dict = {'foo': 'bar'} - with open(loc, mode='wb') as f: - f.write(json.dumps(bad_dict).encode('utf-8')) + bad_dict = {"foo": "bar"} + with open(loc, mode="wb") as f: + f.write(json.dumps(bad_dict).encode("utf-8")) - with pytest.raises(ChangedExternalResourceError, - match="Bad metadata"): + with pytest.raises(ChangedExternalResourceError, match="Bad metadata"): PerFileJSONMetadataStore(FileStorage(tmp_path)) diff --git a/openfe/tests/storage/test_resultclient.py b/openfe/tests/storage/test_resultclient.py index 99ec77bc1..470a3844a 100644 --- a/openfe/tests/storage/test_resultclient.py +++ b/openfe/tests/storage/test_resultclient.py @@ -1,13 +1,11 @@ import os - -import pytest from unittest import mock +import pytest from gufe.storage.externalresource import MemoryStorage from gufe.tokenization import TOKENIZABLE_REGISTRY -from openfe.storage.resultclient import ( - ResultClient, TransformationResult, CloneResult, ExtensionResult -) + +from openfe.storage.resultclient import CloneResult, ExtensionResult, ResultClient, TransformationResult @pytest.fixture @@ -16,10 +14,7 @@ def result_client(tmpdir): result_client = ResultClient(external) # store one file with contents "foo" - result_client.result_server.store_bytes( - "transformations/MAIN_TRANS/0/0/file.txt", - "foo".encode('utf-8') - ) + result_client.result_server.store_bytes("transformations/MAIN_TRANS/0/0/file.txt", b"foo") # create some empty files as well empty_files = [ @@ -45,7 +40,7 @@ def _make_mock_transformation(hash_str): def test_load_file(result_client): file_handler = result_client / "MAIN_TRANS" / "0" / 0 / "file.txt" with file_handler as f: - assert f.read().decode('utf-8') == "foo" + assert f.read().decode("utf-8") == "foo" class _ResultContainerTest: @@ -70,27 +65,27 @@ def _get_key(self, as_object, container): key = obj if as_object else obj._path_component return key, obj - @pytest.mark.parametrize('as_object', [True, False]) + @pytest.mark.parametrize("as_object", [True, False]) def test_getitem(self, as_object, result_client): container = self.get_container(result_client) key, obj = self._get_key(as_object, container) assert container[key] == obj - @pytest.mark.parametrize('as_object', [True, False]) + @pytest.mark.parametrize("as_object", [True, False]) def test_div(self, as_object, result_client): container = self.get_container(result_client) key, obj = self._get_key(as_object, container) assert container / key == obj - @pytest.mark.parametrize('load_with', ['div', 'getitem']) + @pytest.mark.parametrize("load_with", ["div", "getitem"]) def test_caching(self, result_client, load_with): # used to test caching regardless of how first loaded was loaded container = self.get_container(result_client) key, obj = self._get_key(False, container) - if load_with == 'div': + if load_with == "div": loaded = container / key - elif load_with == 'getitem': + elif load_with == "getitem": loaded = container[key] else: # -no-cov- raise RuntimeError(f"Bad input: can't load with '{load_with}'") @@ -107,12 +102,12 @@ def test_load_stream(self, result_client): container = self.get_container(result_client) loc = "transformations/MAIN_TRANS/0/0/file.txt" with container.load_stream(loc) as f: - assert f.read().decode('utf-8') == "foo" + assert f.read().decode("utf-8") == "foo" def test_load_bytes(self, result_client): container = self.get_container(result_client) loc = "transformations/MAIN_TRANS/0/0/file.txt" - assert container.load_bytes(loc).decode('utf-8') == "foo" + assert container.load_bytes(loc).decode("utf-8") == "foo" def test_path(self, result_client): container = self.get_container(result_client) @@ -138,10 +133,7 @@ def get_container(result_client): return result_client def _getitem_object(self, container): - return TransformationResult( - parent=container, - transformation=_make_mock_transformation("MAIN_TRANS") - ) + return TransformationResult(parent=container, transformation=_make_mock_transformation("MAIN_TRANS")) def test_store_protocol_dag_result(self): pytest.skip("Not implemented yet") @@ -159,8 +151,7 @@ def _test_store_load_same_process(obj, store_func_name, load_func_name): assert reloaded is obj @staticmethod - def _test_store_load_different_process(obj, store_func_name, - load_func_name): + def _test_store_load_different_process(obj, store_func_name, load_func_name): store = MemoryStorage() client = ResultClient(store) store_func = getattr(client, store_func_name) @@ -177,41 +168,37 @@ def _test_store_load_different_process(obj, store_func_name, assert reload == obj assert reload is not obj - - @pytest.mark.parametrize("fixture", [ - "absolute_transformation", - "complex_equilibrium", - ]) + @pytest.mark.parametrize( + "fixture", + [ + "absolute_transformation", + "complex_equilibrium", + ], + ) def test_store_load_transformation_same_process(self, request, fixture): transformation = request.getfixturevalue(fixture) - self._test_store_load_same_process(transformation, - "store_transformation", - "load_transformation") - - @pytest.mark.parametrize('fixture', [ - "absolute_transformation", - "complex_equilibrium", - ]) - def test_store_load_transformation_different_process(self, request, - fixture): + self._test_store_load_same_process(transformation, "store_transformation", "load_transformation") + + @pytest.mark.parametrize( + "fixture", + [ + "absolute_transformation", + "complex_equilibrium", + ], + ) + def test_store_load_transformation_different_process(self, request, fixture): transformation = request.getfixturevalue(fixture) - self._test_store_load_different_process(transformation, - "store_transformation", - "load_transformation") + self._test_store_load_different_process(transformation, "store_transformation", "load_transformation") @pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) def test_store_load_network_same_process(self, request, fixture): network = request.getfixturevalue(fixture) - self._test_store_load_same_process(network, - "store_network", - "load_network") + self._test_store_load_same_process(network, "store_network", "load_network") @pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) def test_store_load_network_different_process(self, request, fixture): network = request.getfixturevalue(fixture) - self._test_store_load_different_process(network, - "store_network", - "load_network") + self._test_store_load_different_process(network, "store_network", "load_network") def test_delete(self, result_client): file_to_delete = self.expected_files[0] @@ -234,7 +221,7 @@ class TestTransformationResults(_ResultContainerTest): def get_container(result_client): container = TransformationResult( parent=TestResultClient.get_container(result_client), - transformation=_make_mock_transformation("MAIN_TRANS") + transformation=_make_mock_transformation("MAIN_TRANS"), ) container._path_component = "MAIN_TRANS" return container @@ -253,10 +240,7 @@ class TestCloneResults(_ResultContainerTest): @staticmethod def get_container(result_client): - return CloneResult( - parent=TestTransformationResults.get_container(result_client), - clone=0 - ) + return CloneResult(parent=TestTransformationResults.get_container(result_client), clone=0) def _getitem_object(self, container): return ExtensionResult(parent=container, extension=0) @@ -271,15 +255,11 @@ class TestExtensionResults(_ResultContainerTest): @staticmethod def get_container(result_client): - return ExtensionResult( - parent=TestCloneResults.get_container(result_client), - extension=0 - ) + return ExtensionResult(parent=TestCloneResults.get_container(result_client), extension=0) def _get_key(self, as_object, container): if self.as_object: # -no-cov- - raise RuntimeError("TestExtensionResults does not support " - "as_object=True") + raise RuntimeError("TestExtensionResults does not support " "as_object=True") path = "transformations/MAIN_TRANS/0/0/" fname = "file.txt" return fname, container.result_server.load_stream(path + fname) @@ -288,12 +268,12 @@ def _get_key(self, as_object, container): def test_div(self, result_client): container = self.get_container(result_client) with container / "file.txt" as f: - assert f.read().decode('utf-8') == "foo" + assert f.read().decode("utf-8") == "foo" def test_getitem(self, result_client): container = self.get_container(result_client) with container["file.txt"] as f: - assert f.read().decode('utf-8') == "foo" + assert f.read().decode("utf-8") == "foo" def test_caching(self, result_client): # this one does not cache results; the cache should remain empty diff --git a/openfe/tests/storage/test_resultserver.py b/openfe/tests/storage/test_resultserver.py index 96e9018a7..b3a2ef075 100644 --- a/openfe/tests/storage/test_resultserver.py +++ b/openfe/tests/storage/test_resultserver.py @@ -1,16 +1,13 @@ -import pytest -from unittest import mock - import pathlib +from unittest import mock -from openfe.storage.resultserver import ResultServer +import pytest +from gufe.storage.errors import ChangedExternalResourceError, MissingExternalResourceError +from gufe.storage.externalresource import FileStorage from gufe.storage.externalresource.base import Metadata -from gufe.storage.externalresource import FileStorage from openfe.storage.metadatastore import JSONMetadataStore -from gufe.storage.errors import ( - MissingExternalResourceError, ChangedExternalResourceError -) +from openfe.storage.resultserver import ResultServer @pytest.fixture @@ -18,7 +15,7 @@ def result_server(tmpdir): external = FileStorage(tmpdir) metadata = JSONMetadataStore(external) result_server = ResultServer(external, metadata) - result_server.store_bytes("path/to/foo.txt", "foo".encode('utf-8')) + result_server.store_bytes("path/to/foo.txt", b"foo") return result_server @@ -32,14 +29,10 @@ def test_store_bytes(self, result_server): assert result_server.external_store.exists(foo_loc) # also explicitly test storing here - mock_hash = mock.Mock( - return_value=mock.Mock( - hexdigest=mock.Mock(return_value="deadbeef") - ) - ) + mock_hash = mock.Mock(return_value=mock.Mock(hexdigest=mock.Mock(return_value="deadbeef"))) bar_loc = "path/to/bar.txt" - with mock.patch('hashlib.md5', mock_hash): - result_server.store_bytes(bar_loc, "bar".encode('utf-8')) + with mock.patch("hashlib.md5", mock_hash): + result_server.store_bytes(bar_loc, b"bar") assert len(metadata_store) == 2 assert bar_loc in metadata_store @@ -47,25 +40,21 @@ def test_store_bytes(self, result_server): assert metadata_store[bar_loc].to_dict() == {"md5": "deadbeef"} external = result_server.external_store with external.load_stream(bar_loc) as f: - assert f.read().decode('utf-8') == "bar" + assert f.read().decode("utf-8") == "bar" def test_store_path(self, result_server, tmp_path): orig_file = tmp_path / ".hidden" / "bar.txt" orig_file.parent.mkdir(parents=True, exist_ok=True) - with open(orig_file, mode='wb') as f: - f.write("bar".encode('utf-8')) - - mock_hash = mock.Mock( - return_value=mock.Mock( - hexdigest=mock.Mock(return_value="deadc0de") - ) - ) + with open(orig_file, mode="wb") as f: + f.write(b"bar") + + mock_hash = mock.Mock(return_value=mock.Mock(hexdigest=mock.Mock(return_value="deadc0de"))) bar_loc = "path/to/bar.txt" assert len(result_server.metadata_store) == 1 assert bar_loc not in result_server.metadata_store - with mock.patch('hashlib.md5', mock_hash): + with mock.patch("hashlib.md5", mock_hash): result_server.store_path(bar_loc, orig_file) assert len(result_server.metadata_store) == 2 @@ -74,7 +63,7 @@ def test_store_path(self, result_server, tmp_path): assert metadata_dict == {"md5": "deadc0de"} external = result_server.external_store with external.load_stream(bar_loc) as f: - assert f.read().decode('utf-8') == "bar" + assert f.read().decode("utf-8") == "bar" def test_iter(self, result_server): assert list(result_server) == ["path/to/foo.txt"] @@ -86,10 +75,10 @@ def test_find_missing_files(self, result_server): assert result_server.find_missing_files() == ["fake/file.txt"] def test_load_stream(self, result_server): - with result_server.load_stream('path/to/foo.txt') as f: + with result_server.load_stream("path/to/foo.txt") as f: contents = f.read() - assert contents.decode('utf-8') == "foo" + assert contents.decode("utf-8") == "foo" def test_delete(self, result_server, tmpdir): location = "path/to/foo.txt" @@ -105,17 +94,16 @@ def test_load_stream_missing(self, result_server): result_server.load_stream("path/does/not/exist.txt") def test_load_stream_error_bad_hash(self, result_server): - meta = Metadata(md5='1badc0de') - result_server.metadata_store.store_metadata('path/to/foo.txt', meta) + meta = Metadata(md5="1badc0de") + result_server.metadata_store.store_metadata("path/to/foo.txt", meta) with pytest.raises(ChangedExternalResourceError): - result_server.load_stream('path/to/foo.txt') + result_server.load_stream("path/to/foo.txt") def test_load_stream_allow_bad_hash(self, result_server): - meta = Metadata(md5='1badc0de') - result_server.metadata_store.store_metadata('path/to/foo.txt', meta) + meta = Metadata(md5="1badc0de") + result_server.metadata_store.store_metadata("path/to/foo.txt", meta) with pytest.warns(UserWarning, match="Metadata mismatch"): - file = result_server.load_stream("path/to/foo.txt", - allow_changed=True) + file = result_server.load_stream("path/to/foo.txt", allow_changed=True) with file as f: assert f.read().decode("utf-8") == "foo" diff --git a/openfe/tests/utils/conftest.py b/openfe/tests/utils/conftest.py index 7029dba63..90be909d3 100644 --- a/openfe/tests/utils/conftest.py +++ b/openfe/tests/utils/conftest.py @@ -1,17 +1,20 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe +from collections.abc import Iterable +from importlib import resources +from typing import NamedTuple + import pytest from rdkit import Chem -from importlib import resources -from openfe import SmallMoleculeComponent, LigandAtomMapping, LigandNetwork -from typing import Iterable, NamedTuple +from openfe import LigandAtomMapping, LigandNetwork, SmallMoleculeComponent from ..conftest import mol_from_smiles class _NetworkTestContainer(NamedTuple): """Container to facilitate network testing""" + network: LigandNetwork nodes: Iterable[SmallMoleculeComponent] edges: Iterable[LigandAtomMapping] @@ -49,13 +52,13 @@ def simple_network(mols, std_edges): ) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def benzene_transforms(): # a dict of Molecules for benzene transformations mols = {} - with resources.files('openfe.tests.data') as d: - fn = str(d / 'benzene_modifications.sdf') + with resources.files("openfe.tests.data") as d: + fn = str(d / "benzene_modifications.sdf") supplier = Chem.SDMolSupplier(fn, removeHs=False) for mol in supplier: - mols[mol.GetProp('_Name')] = SmallMoleculeComponent(mol) + mols[mol.GetProp("_Name")] = SmallMoleculeComponent(mol) return mols diff --git a/openfe/tests/utils/test_atommapping_network_plotting.py b/openfe/tests/utils/test_atommapping_network_plotting.py index 8049c545d..1ce12d91d 100644 --- a/openfe/tests/utils/test_atommapping_network_plotting.py +++ b/openfe/tests/utils/test_atommapping_network_plotting.py @@ -1,17 +1,14 @@ +import importlib.resources import inspect -import pytest from unittest import mock -from matplotlib import pyplot as plt + import matplotlib import matplotlib.figure -import importlib.resources - -from openfe.utils.atommapping_network_plotting import ( - AtomMappingNetworkDrawing, plot_atommapping_network, - LigandNode, -) +import pytest +from matplotlib import pyplot as plt from openfe.tests.utils.test_network_plotting import mock_event +from openfe.utils.atommapping_network_plotting import AtomMappingNetworkDrawing, LigandNode, plot_atommapping_network def bound_args(func, args, kwargs): @@ -41,11 +38,7 @@ def bound_args(func, args, kwargs): def network_drawing(simple_network): nx_graph = simple_network.network.graph node_dict = {node.smiles: node for node in nx_graph.nodes} - positions = { - node_dict["CC"]: (0.0, 0.0), - node_dict["CO"]: (0.5, 0.0), - node_dict["CCO"]: (0.25, 0.25) - } + positions = {node_dict["CC"]: (0.0, 0.0), node_dict["CO"]: (0.5, 0.0), node_dict["CCO"]: (0.25, 0.25)} graph = AtomMappingNetworkDrawing(nx_graph, positions) graph.ax.set_xlim(0, 1) graph.ax.set_ylim(0, 1) @@ -65,7 +58,6 @@ def default_node(network_drawing): yield LigandNode(node_dict["CC"], 0.5, 0.5, 0.1, 0.1) - class TestAtomMappingEdge: def test_draw_mapped_molecule(self, default_edge): assert len(default_edge.artist.axes.images) == 0 @@ -73,7 +65,7 @@ def test_draw_mapped_molecule(self, default_edge): (0.05, 0.45, 0.5, 0.9), default_edge.node_artists[0].node, default_edge.node_artists[1].node, - {0: 0} + {0: 0}, ) # maybe add something about im itself? not sure what to test here assert len(default_edge.artist.axes.images) == 1 @@ -88,51 +80,48 @@ def test_select(self, default_edge, network_drawing): assert not default_edge.picked assert len(default_edge.artist.axes.images) == 0 - event = mock_event('mouseup', 0.25, 0.0, network_drawing.fig) + event = mock_event("mouseup", 0.25, 0.0, network_drawing.fig) default_edge.select(event, network_drawing) assert default_edge.picked assert len(default_edge.artist.axes.images) == 2 - @pytest.mark.parametrize('edge_str,left_right,molA_to_molB', [ - (("CCO", "CC"), ("CC", "CCO"), {0: 0, 1: 1}), - (("CC", "CO"), ("CC", "CO"), {0: 0}), - (("CCO", "CO"), ("CCO", "CO"), {0: 0, 2: 1}), - ]) - def test_select_mock_drawing(self, edge_str, left_right, molA_to_molB, - network_drawing): + @pytest.mark.parametrize( + "edge_str,left_right,molA_to_molB", + [ + (("CCO", "CC"), ("CC", "CCO"), {0: 0, 1: 1}), + (("CC", "CO"), ("CC", "CO"), {0: 0}), + (("CCO", "CO"), ("CCO", "CO"), {0: 0, 2: 1}), + ], + ) + def test_select_mock_drawing(self, edge_str, left_right, molA_to_molB, network_drawing): # this tests that we call _draw_mapped_molecule with the correct # kwargs -- in particular, it ensures that we get the left and right # molecules correctly - node_dict = {node.smiles: node - for node in network_drawing.graph.nodes} + node_dict = {node.smiles: node for node in network_drawing.graph.nodes} edge_tuple = tuple(node_dict[node] for node in edge_str) edge = network_drawing.edges[edge_tuple] - left, right = [network_drawing.nodes[node_dict[node]] - for node in left_right] + left, right = (network_drawing.nodes[node_dict[node]] for node in left_right) # ensure that we have them labelled correctly assert left.xy[0] < right.xy[0] func = edge._draw_mapped_molecule # save for bound_args edge._draw_mapped_molecule = mock.Mock() - event = mock_event('mouseup', 0.25, 0.0, network_drawing.fig) + event = mock_event("mouseup", 0.25, 0.0, network_drawing.fig) edge.select(event, network_drawing) - arg_dicts = [ - bound_args(func, call.args, call.kwargs) - for call in edge._draw_mapped_molecule.mock_calls - ] + arg_dicts = [bound_args(func, call.args, call.kwargs) for call in edge._draw_mapped_molecule.mock_calls] expected_left = { - 'extent': (0.05, 0.45, 0.5, 0.9), - 'molA': left.node, - 'molB': right.node, - 'molA_to_molB': molA_to_molB, + "extent": (0.05, 0.45, 0.5, 0.9), + "molA": left.node, + "molB": right.node, + "molA_to_molB": molA_to_molB, } expected_right = { - 'extent': (0.55, 0.95, 0.5, 0.9), - 'molA': right.node, - 'molB': left.node, - 'molA_to_molB': {v: k for k, v in molA_to_molB.items()}, + "extent": (0.55, 0.95, 0.5, 0.9), + "molA": right.node, + "molB": left.node, + "molA_to_molB": {v: k for k, v in molA_to_molB.items()}, } assert len(arg_dicts) == 2 assert expected_left in arg_dicts @@ -141,7 +130,7 @@ def test_select_mock_drawing(self, edge_str, left_right, molA_to_molB, def test_unselect(self, default_edge, network_drawing): # start by selecting; hard to be sure we mocked all the side effects # of select - event = mock_event('mouseup', 0.25, 0.0, network_drawing.fig) + event = mock_event("mouseup", 0.25, 0.0, network_drawing.fig) default_edge.select(event, network_drawing) assert default_edge.picked assert len(default_edge.artist.axes.images) == 2 diff --git a/openfe/tests/utils/test_duecredit.py b/openfe/tests/utils/test_duecredit.py index dad7b9966..0d63a9b73 100644 --- a/openfe/tests/utils/test_duecredit.py +++ b/openfe/tests/utils/test_duecredit.py @@ -1,26 +1,37 @@ -import os import importlib +import os + import pytest import openfe - -pytest.importorskip('duecredit') +pytest.importorskip("duecredit") -@pytest.mark.skipif((os.environ.get('DUECREDIT_ENABLE', 'no').lower() - in ('no', '0', 'false')), - reason="duecredit is disabled") +@pytest.mark.skipif( + (os.environ.get("DUECREDIT_ENABLE", "no").lower() in ("no", "0", "false")), + reason="duecredit is disabled", +) class TestDuecredit: - @pytest.mark.parametrize('module, dois', [ - ['openfe.protocols.openmm_afe.equil_solvation_afe_method', - ['10.5281/zenodo.596504', '10.48550/arxiv.2302.06758', - '10.5281/zenodo.596622', '10.1371/journal.pcbi.1005659']], - ['openfe.protocols.openmm_rfe.equil_rfe_methods', - ['10.5281/zenodo.1297683', '10.5281/zenodo.596622', - '10.1371/journal.pcbi.1005659']], - ]) + @pytest.mark.parametrize( + "module, dois", + [ + [ + "openfe.protocols.openmm_afe.equil_solvation_afe_method", + [ + "10.5281/zenodo.596504", + "10.48550/arxiv.2302.06758", + "10.5281/zenodo.596622", + "10.1371/journal.pcbi.1005659", + ], + ], + [ + "openfe.protocols.openmm_rfe.equil_rfe_methods", + ["10.5281/zenodo.1297683", "10.5281/zenodo.596622", "10.1371/journal.pcbi.1005659"], + ], + ], + ) def test_duecredit_protocol_collection(self, module, dois): importlib.import_module(module) for doi in dois: diff --git a/openfe/tests/utils/test_network_plotting.py b/openfe/tests/utils/test_network_plotting.py index c99ff1e73..18bef583b 100644 --- a/openfe/tests/utils/test_network_plotting.py +++ b/openfe/tests/utils/test_network_plotting.py @@ -1,14 +1,12 @@ -import pytest from unittest import mock -from numpy import testing as npt -from matplotlib import pyplot as plt import networkx as nx -from openfe.utils.network_plotting import ( - Node, Edge, EventHandler, GraphDrawing -) +import pytest +from matplotlib import pyplot as plt +from matplotlib.backend_bases import MouseButton, MouseEvent +from numpy import testing as npt -from matplotlib.backend_bases import MouseEvent, MouseButton +from openfe.utils.network_plotting import Edge, EventHandler, GraphDrawing, Node def _get_fig_ax(fig): @@ -16,8 +14,7 @@ def _get_fig_ax(fig): fig, _ = plt.subplots() if len(fig.axes) != 1: # -no-cov- - raise RuntimeError("Error in test setup: figure must have exactly " - "one Axes object associated") + raise RuntimeError("Error in test setup: figure must have exactly " "one Axes object associated") return fig, fig.axes[0] @@ -25,15 +22,15 @@ def _get_fig_ax(fig): def mock_event(event_name, xdata, ydata, fig=None): fig, ax = _get_fig_ax(fig) name = { - 'mousedown': 'button_press_event', - 'mouseup': 'button_release_event', - 'drag': 'motion_notify_event', + "mousedown": "button_press_event", + "mouseup": "button_release_event", + "drag": "motion_notify_event", }[event_name] matplotlib_buttons = { - 'mousedown': MouseButton.LEFT, - 'mouseup': MouseButton.LEFT, - 'drag': MouseButton.LEFT, + "mousedown": MouseButton.LEFT, + "mouseup": MouseButton.LEFT, + "drag": MouseButton.LEFT, } button = matplotlib_buttons.get(event_name, None) x, y = ax.transData.transform((xdata, ydata)) @@ -53,14 +50,13 @@ def make_mock_edge(node1, node2, data): node_B = make_mock_node("B", 0.5, 0.0) node_C = make_mock_node("C", 0.5, 0.5) node_D = make_mock_node("D", 0.0, 0.5) - edge_AB = make_mock_edge(node_A, node_B, {'data': "AB"}) - edge_BC = make_mock_edge(node_B, node_C, {'data': "BC"}) - edge_BD = make_mock_edge(node_B, node_D, {'data': "BD"}) + edge_AB = make_mock_edge(node_A, node_B, {"data": "AB"}) + edge_BC = make_mock_edge(node_B, node_C, {"data": "BC"}) + edge_BD = make_mock_edge(node_B, node_D, {"data": "BD"}) mock_graph = mock.Mock( nodes={node.node: node for node in [node_A, node_B, node_C, node_D]}, - edges={tuple(edge.node_artists): edge - for edge in [edge_AB, edge_BC, edge_BD]}, + edges={tuple(edge.node_artists): edge for edge in [edge_AB, edge_BC, edge_BD]}, ) return mock_graph @@ -110,17 +106,20 @@ def test_update_location(self): assert self.node.artist.xy == (0.7, 0.5) assert self.node.xy == (0.7, 0.5) - @pytest.mark.parametrize('point,expected', [ - ((0.55, 0.05), True), - ((0.5, 0.5), False), - ((-10, -10), False), - ]) + @pytest.mark.parametrize( + "point,expected", + [ + ((0.55, 0.05), True), + ((0.5, 0.5), False), + ((-10, -10), False), + ], + ) def test_contains(self, point, expected): - event = mock_event('drag', *point, fig=self.fig) + event = mock_event("drag", *point, fig=self.fig) assert self.node.contains(event) == expected def test_on_mousedown_in_rect(self): - event = mock_event('mousedown', 0.55, 0.05, self.fig) + event = mock_event("mousedown", 0.55, 0.05, self.fig) drawing_graph = make_mock_graph(self.fig) assert Node.lock is None assert self.node.press is None @@ -131,7 +130,7 @@ def test_on_mousedown_in_rect(self): Node.lock = None def test_on_mousedown_in_axes(self): - event = mock_event('mousedown', 0.25, 0.25, self.fig) + event = mock_event("mousedown", 0.25, 0.25, self.fig) drawing_graph = make_mock_graph(self.fig) assert Node.lock is None @@ -142,7 +141,7 @@ def test_on_mousedown_in_axes(self): def test_on_mousedown_out_axes(self): node = Node("B", 0.5, 0.6) - event = mock_event('mousedown', 0.55, 0.05, self.fig) + event = mock_event("mousedown", 0.55, 0.05, self.fig) drawing_graph = make_mock_graph(self.fig) fig2, ax2 = plt.subplots() @@ -156,12 +155,11 @@ def test_on_mousedown_out_axes(self): plt.close(fig2) def test_on_drag(self): - event = mock_event('drag', 0.7, 0.7, self.fig) + event = mock_event("drag", 0.7, 0.7, self.fig) # this test some integration, so we need more than a mock drawing_graph = GraphDrawing( - nx.MultiDiGraph(([("A", "B"), ("B", "C"), ("B", "D")])), - positions={"A": (0.0, 0.0), "B": (0.5, 0.0), - "C": (0.5, 0.5), "D": (0.0, 0.5)} + nx.MultiDiGraph([("A", "B"), ("B", "C"), ("B", "D")]), + positions={"A": (0.0, 0.0), "B": (0.5, 0.0), "C": (0.5, 0.5), "D": (0.0, 0.5)}, ) # set up things that should happen on mousedown Node.lock = self.node @@ -175,7 +173,7 @@ def test_on_drag(self): Node.lock = None def test_on_drag_do_nothing(self): - event = mock_event('drag', 0.7, 0.7, self.fig) + event = mock_event("drag", 0.7, 0.7, self.fig) drawing_graph = make_mock_graph(self.fig) # don't set lock -- early exit @@ -184,7 +182,7 @@ def test_on_drag_do_nothing(self): assert self.node.xy == original def test_on_drag_no_mousedown(self): - event = mock_event('drag', 0.7, 0.7, self.fig) + event = mock_event("drag", 0.7, 0.7, self.fig) drawing_graph = make_mock_graph(self.fig) Node.lock = self.node @@ -194,7 +192,7 @@ def test_on_drag_no_mousedown(self): Node.lock = None def test_on_mouseup(self): - event = mock_event('drag', 0.7, 0.7, self.fig) + event = mock_event("drag", 0.7, 0.7, self.fig) drawing_graph = make_mock_graph(self.fig) Node.lock = self.node self.node.press = (0.5, 0.0), (0.55, 0.05) @@ -229,21 +227,22 @@ def test_register_artist(self): assert ax.get_lines()[0] == edge.artist plt.close(fig) - @pytest.mark.parametrize('point,expected', [ - ((0.25, 0.05), True), - ((0.6, 0.1), False), - ]) + @pytest.mark.parametrize( + "point,expected", + [ + ((0.25, 0.05), True), + ((0.6, 0.1), False), + ], + ) def test_contains(self, point, expected): - event = mock_event('drag', *point, fig=self.fig) + event = mock_event("drag", *point, fig=self.fig) assert self.edge.contains(event) == expected def test_edge_xs_ys(self): - npt.assert_allclose(self.edge._edge_xs_ys(*self.nodes), - ((0.05, 0.55), (0.05, 0.05))) + npt.assert_allclose(self.edge._edge_xs_ys(*self.nodes), ((0.05, 0.55), (0.05, 0.05))) def _get_colors(self): - colors = {node: node.artist.get_facecolor() - for node in self.nodes} + colors = {node: node.artist.get_facecolor() for node in self.nodes} colors[self.edge] = self.edge.artist.get_color() return colors @@ -251,9 +250,9 @@ def test_unselect(self): original = self._get_colors() for node in self.nodes: - node.artist.set(color='red') + node.artist.set(color="red") - self.edge.artist.set(color='red') + self.edge.artist.set(color="red") # ensure that we have changed from the original values changed = self._get_colors() @@ -265,7 +264,7 @@ def test_unselect(self): assert after == original def test_select(self): - event = mock_event('mouseup', 0.25, 0.05, self.fig) + event = mock_event("mouseup", 0.25, 0.05, self.fig) drawing_graph = make_mock_graph(self.fig) original = self._get_colors() self.edge.select(event, drawing_graph) @@ -312,15 +311,15 @@ def _mock_for_connections(self): self.event_handler.on_mouseup = mock.Mock() self.event_handler.on_drag = mock.Mock() - @pytest.mark.parametrize('event_type', ['mousedown', 'mouseup', 'drag']) + @pytest.mark.parametrize("event_type", ["mousedown", "mouseup", "drag"]) def test_connect(self, event_type): self._mock_for_connections() event = mock_event(event_type, 0.2, 0.2, self.fig) methods = { - 'mousedown': self.event_handler.on_mousedown, - 'mouseup': self.event_handler.on_mouseup, - 'drag': self.event_handler.on_drag, + "mousedown": self.event_handler.on_mousedown, + "mouseup": self.event_handler.on_mouseup, + "drag": self.event_handler.on_drag, } should_call = methods[event_type] should_not_call = set(methods.values()) - {should_call} @@ -335,7 +334,7 @@ def test_connect(self, event_type): for method in should_not_call: assert not method.called - @pytest.mark.parametrize('event_type', ['mousedown', 'mouseup', 'drag']) + @pytest.mark.parametrize("event_type", ["mousedown", "mouseup", "drag"]) def test_disconnect(self, event_type): self._mock_for_connections() fig, _ = plt.subplots() @@ -346,9 +345,7 @@ def test_disconnect(self, event_type): self.event_handler.disconnect(fig.canvas) assert len(self.event_handler.connections) == 0 - methods = [self.event_handler.on_mousedown, - self.event_handler.on_mousedown, - self.event_handler.on_drag] + methods = [self.event_handler.on_mousedown, self.event_handler.on_mousedown, self.event_handler.on_drag] fig.canvas.callbacks.process(event.name, event) for method in methods: @@ -365,7 +362,7 @@ def _mock_contains(self, mock_objs): else: obj.contains = mock.Mock(return_value=False) - @pytest.mark.parametrize('hit', ['node', 'edge', 'node+edge', 'miss']) + @pytest.mark.parametrize("hit", ["node", "edge", "node+edge", "miss"]) def test_get_event_container_select_node(self, hit): expected, contains_event = self.setup_contains[hit] expected_count = { @@ -387,11 +384,11 @@ def test_get_event_container_select_node(self, hit): contains_count = sum(obj.contains.called for obj in all_objs) assert contains_count == expected_count - @pytest.mark.parametrize('hit', ['node', 'edge', 'node+edge', 'miss']) + @pytest.mark.parametrize("hit", ["node", "edge", "node+edge", "miss"]) def test_on_mousedown(self, hit): expected, contains_event = self.setup_contains[hit] self._mock_contains(contains_event) - event = mock_event('mousedown', 0.5, 0.5) + event = mock_event("mousedown", 0.5, 0.5) assert self.event_handler.click_location is None assert self.event_handler.active is None @@ -403,11 +400,11 @@ def test_on_mousedown(self, hit): plt.close(event.canvas.figure) - @pytest.mark.parametrize('is_active', [True, False]) + @pytest.mark.parametrize("is_active", [True, False]) def test_on_drag(self, is_active): node = self.event_handler.graph.nodes["C"] node.artist.axes = self.ax - event = mock_event('drag', 0.25, 0.25, self.fig) + event = mock_event("drag", 0.25, 0.25, self.fig) if is_active: self.event_handler.active = node @@ -418,7 +415,7 @@ def test_on_drag(self, is_active): else: assert not node.on_drag.called - @pytest.mark.parametrize('has_selected', [True, False]) + @pytest.mark.parametrize("has_selected", [True, False]) def test_on_mouseup_click_select(self, has_selected): # start: mouse hasn't moved, and something is active graph = self.event_handler.graph @@ -428,7 +425,7 @@ def test_on_mouseup_click_select(self, has_selected): self.event_handler.selected = old_selected self._mock_contains([edge]) - event = mock_event('mouseup', 0.25, 0.25) + event = mock_event("mouseup", 0.25, 0.25) self.event_handler.click_location = (event.xdata, event.ydata) self.event_handler.active = edge @@ -447,7 +444,7 @@ def test_on_mouseup_click_select(self, has_selected): plt.close(event.canvas.figure) - @pytest.mark.parametrize('has_selected', [True, False]) + @pytest.mark.parametrize("has_selected", [True, False]) def test_on_mouseup_click_not_select(self, has_selected): # start: mouse hasn't moved, nothing is active graph = self.event_handler.graph @@ -455,7 +452,7 @@ def test_on_mouseup_click_not_select(self, has_selected): old_selected = graph.edges[graph.nodes["A"], graph.nodes["B"]] self.event_handler.selected = old_selected - event = mock_event('mouseup', 0.25, 0.25) + event = mock_event("mouseup", 0.25, 0.25) self.event_handler.click_location = (event.xdata, event.ydata) self.event_handler.on_mouseup(event) @@ -469,7 +466,7 @@ def test_on_mouseup_click_not_select(self, has_selected): graph.draw.assert_called_once() plt.close(event.canvas.figure) - @pytest.mark.parametrize('has_selected', [True, False]) + @pytest.mark.parametrize("has_selected", [True, False]) def test_on_mouseup_drag(self, has_selected): # start: mouse has moved, something is active graph = self.event_handler.graph @@ -478,7 +475,7 @@ def test_on_mouseup_drag(self, has_selected): old_selected = graph.edges[graph.nodes["A"], graph.nodes["B"]] self.event_handler.selected = old_selected - event = mock_event('mouseup', 0.25, 0.25) + event = mock_event("mouseup", 0.25, 0.25) self.event_handler.click_location = (0.5, 0.5) self.event_handler.active = edge @@ -498,15 +495,14 @@ def test_on_mouseup_drag(self, has_selected): class TestGraphDrawing: def setup_method(self): self.nx_graph = nx.MultiDiGraph() - self.nx_graph.add_edges_from([ - ("A", "B", {'data': "AB"}), - ("B", "C", {'data': "BC"}), - ("B", "D", {'data': "BD"}), - ]) - self.positions = { - "A": (0.0, 0.0), "B": (0.5, 0.0), "C": (0.5, 0.5), - "D": (-0.1, 0.6) - } + self.nx_graph.add_edges_from( + [ + ("A", "B", {"data": "AB"}), + ("B", "C", {"data": "BC"}), + ("B", "D", {"data": "BD"}), + ], + ) + self.positions = {"A": (0.0, 0.0), "B": (0.5, 0.0), "C": (0.5, 0.5), "D": (-0.1, 0.6)} self.graph = GraphDrawing(self.nx_graph, positions=self.positions) def test_init(self): @@ -520,24 +516,23 @@ def test_init(self): def test_init_custom_ax(self): fig, ax = plt.subplots() - graph = GraphDrawing(self.nx_graph, positions=self.positions, - ax=ax) + graph = GraphDrawing(self.nx_graph, positions=self.positions, ax=ax) assert graph.fig is fig assert graph.ax is ax plt.close(fig) def test_register_node_error(self): with pytest.raises(RuntimeError, match="multiple times"): - self.graph._register_node( - node=list(self.nx_graph.nodes)[0], - position=(0, 0) - ) - - @pytest.mark.parametrize('node,edges', [ - ("A", [("A", "B")]), - ("B", [("A", "B"), ("B", "C"), ("B", "D")]), - ("C", [("B", "C")]), - ]) + self.graph._register_node(node=list(self.nx_graph.nodes)[0], position=(0, 0)) + + @pytest.mark.parametrize( + "node,edges", + [ + ("A", [("A", "B")]), + ("B", [("A", "B"), ("B", "C"), ("B", "D")]), + ("C", [("B", "C")]), + ], + ) def test_edges_for_node(self, node, edges): expected_edges = {self.graph.edges[n1, n2] for n1, n2 in edges} assert set(self.graph.edges_for_node(node)) == expected_edges diff --git a/openfe/tests/utils/test_optional_imports.py b/openfe/tests/utils/test_optional_imports.py index 96ca6b98a..f6ee1f1f0 100644 --- a/openfe/tests/utils/test_optional_imports.py +++ b/openfe/tests/utils/test_optional_imports.py @@ -1,11 +1,12 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from openfe.utils import requires_package import pytest +from openfe.utils import requires_package + -@requires_package('no_such_package_hopefully') +@requires_package("no_such_package_hopefully") def the_answer(): return 42 diff --git a/openfe/tests/utils/test_remove_oechem.py b/openfe/tests/utils/test_remove_oechem.py index 0bd208f07..d1467fb16 100644 --- a/openfe/tests/utils/test_remove_oechem.py +++ b/openfe/tests/utils/test_remove_oechem.py @@ -1,8 +1,9 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from openfe.utils import without_oechem_backend from openff.toolkit import GLOBAL_TOOLKIT_REGISTRY, OpenEyeToolkitWrapper +from openfe.utils import without_oechem_backend + def test_remove_oechem(): original_tks = GLOBAL_TOOLKIT_REGISTRY.registered_toolkits @@ -12,6 +13,5 @@ def test_remove_oechem(): for tk in GLOBAL_TOOLKIT_REGISTRY.registered_toolkits: assert not isinstance(tk, OpenEyeToolkitWrapper) assert len(GLOBAL_TOOLKIT_REGISTRY.registered_toolkits) == original_n_tks - for ref_tk, tk in zip(original_tks, - GLOBAL_TOOLKIT_REGISTRY.registered_toolkits): + for ref_tk, tk in zip(original_tks, GLOBAL_TOOLKIT_REGISTRY.registered_toolkits): assert isinstance(tk, type(ref_tk)) diff --git a/openfe/tests/utils/test_system_probe.py b/openfe/tests/utils/test_system_probe.py index 744eab5a8..234a8307a 100644 --- a/openfe/tests/utils/test_system_probe.py +++ b/openfe/tests/utils/test_system_probe.py @@ -1,13 +1,13 @@ import contextlib -from collections import namedtuple import logging import pathlib import sys +from collections import namedtuple from unittest.mock import Mock, patch import psutil -from psutil._common import sdiskusage import pytest +from psutil._common import sdiskusage from openfe.utils.system_probe import ( _get_disk_usage, @@ -18,7 +18,6 @@ log_system_probe, ) - # Named tuples from https://github.com/giampaolo/psutil/blob/master/psutil/_pslinux.py svmem = namedtuple( "svmem", @@ -116,19 +115,15 @@ "mount_point": "/mnt/data", }, }, - } + }, } def fake_disk_usage(path): if str(path) == "/foo": - return sdiskusage( - total=1958854045696, used=1232985415680, free=626288726016, percent=66.3 - ) + return sdiskusage(total=1958854045696, used=1232985415680, free=626288726016, percent=66.3) if str(path) == "/bar": - return sdiskusage( - total=4000770252800, used=1678226952192, free=2322615496704, percent=41.9 - ) + return sdiskusage(total=4000770252800, used=1678226952192, free=2322615496704, percent=41.9) @contextlib.contextmanager @@ -159,13 +154,11 @@ def patch_system(): pss=10777600, swap=0, ), - } + }, ), ) # Since this attribute doesn't exist on OSX, we have to create it - patch_psutil_Process_rlimit = patch( - "psutil.Process.rlimit", Mock(return_value=(-1, -1)) - ) + patch_psutil_Process_rlimit = patch("psutil.Process.rlimit", Mock(return_value=(-1, -1))) patch_psutil_virtual_memory = patch( "psutil.virtual_memory", Mock( @@ -181,12 +174,10 @@ def patch_system(): cached=34111094784, shared=1021571072, slab=1518297088, - ) + ), ), ) - patch_psutil_disk_usage = patch( - "psutil.disk_usage", Mock(side_effect=fake_disk_usage) - ) + patch_psutil_disk_usage = patch("psutil.disk_usage", Mock(side_effect=fake_disk_usage)) # assumes that each shell command is called in only one way cmd_to_output = { @@ -219,18 +210,14 @@ def patch_system(): yield stack -@pytest.mark.skipif( - sys.platform == "darwin", reason="test requires psutil.Process.rlimit" -) +@pytest.mark.skipif(sys.platform == "darwin", reason="test requires psutil.Process.rlimit") def test_get_hostname(): with patch_system(): hostname = _get_hostname() assert hostname == "mock-hostname" -@pytest.mark.skipif( - sys.platform == "darwin", reason="test requires psutil.Process.rlimit" -) +@pytest.mark.skipif(sys.platform == "darwin", reason="test requires psutil.Process.rlimit") def test_get_gpu_info(): with patch_system(): gpu_info = _get_gpu_info() @@ -257,9 +244,7 @@ def test_get_gpu_info(): assert gpu_info == expected_gpu_info -@pytest.mark.skipif( - sys.platform == "darwin", reason="test requires psutil.Process.rlimit" -) +@pytest.mark.skipif(sys.platform == "darwin", reason="test requires psutil.Process.rlimit") def test_get_psutil_info(): with patch_system(): psutil_info = _get_psutil_info() @@ -301,9 +286,7 @@ def test_get_psutil_info(): assert psutil_info == expected_psutil_info -@pytest.mark.skipif( - sys.platform == "darwin", reason="test requires psutil.Process.rlimit" -) +@pytest.mark.skipif(sys.platform == "darwin", reason="test requires psutil.Process.rlimit") def test_get_disk_usage(): with patch_system(): disk_info = _get_disk_usage() @@ -326,9 +309,7 @@ def test_get_disk_usage(): assert disk_info == expected_disk_info -@pytest.mark.skipif( - sys.platform == "darwin", reason="test requires psutil.Process.rlimit" -) +@pytest.mark.skipif(sys.platform == "darwin", reason="test requires psutil.Process.rlimit") def test_get_disk_usage_with_path(): with patch_system(): disk_info = _get_disk_usage(paths=[pathlib.Path("/foo"), pathlib.Path("/bar")]) @@ -350,9 +331,7 @@ def test_get_disk_usage_with_path(): assert disk_info == expected_disk_info -@pytest.mark.skipif( - sys.platform == "darwin", reason="test requires psutil.Process.rlimit" -) +@pytest.mark.skipif(sys.platform == "darwin", reason="test requires psutil.Process.rlimit") def test_probe_system(): with patch_system(): system_info = _probe_system() @@ -430,7 +409,7 @@ def test_probe_system(): "mount_point": "/mnt/data", }, }, - } + }, } assert system_info == expected_system_info @@ -445,9 +424,9 @@ def test_log_system_probe_unconfigured(): # if probe loggers aren't configured to run, then we shouldn't even call # _probe_system() logger_names = [ - 'openfe.utils.system_probe.log', - 'openfe.utils.system_probe.log.gpu', - 'openfe.utils.system_probe.log.hostname', + "openfe.utils.system_probe.log", + "openfe.utils.system_probe.log.gpu", + "openfe.utils.system_probe.log.hostname", ] # check that initial conditions are as expected for logger_name in logger_names: @@ -455,14 +434,14 @@ def test_log_system_probe_unconfigured(): assert not logger.isEnabledFor(logging.DEBUG) sysprobe_mock = Mock(return_value=EXPECTED_SYSTEM_INFO) - with patch('openfe.utils.system_probe._probe_system', sysprobe_mock): + with patch("openfe.utils.system_probe._probe_system", sysprobe_mock): log_system_probe(logging.DEBUG) assert sysprobe_mock.call_count == 0 # now check that it does get called if we use a level that will emit # (this is effectively tests that the previous assert isn't a false # positive) - with patch('openfe.utils.system_probe._probe_system', sysprobe_mock): + with patch("openfe.utils.system_probe._probe_system", sysprobe_mock): log_system_probe(logging.WARNING) assert sysprobe_mock.call_count == 1 @@ -470,7 +449,7 @@ def test_log_system_probe_unconfigured(): def test_log_system_probe(caplog): # this checks that the expected contents show up in log_system_probe sysprobe_mock = Mock(return_value=EXPECTED_SYSTEM_INFO) - with patch('openfe.utils.system_probe._probe_system', sysprobe_mock): + with patch("openfe.utils.system_probe._probe_system", sysprobe_mock): with caplog.at_level(logging.DEBUG): log_system_probe() @@ -480,7 +459,7 @@ def test_log_system_probe(caplog): "GPU: uuid='GPU-UUID-2' NVIDIA GeForce RTX 2060 mode=Default", "Memory used: 27.8G (52.8%)", "/dev/mapper/data-root: 37% full (1.1T free)", - "/dev/dm-3: 42% full (2.2T free)" + "/dev/dm-3: 42% full (2.2T free)", ] for line in expected: assert line in caplog.text diff --git a/openfe/tests/utils/test_visualization_3D.py b/openfe/tests/utils/test_visualization_3D.py index 3bea583df..102534344 100644 --- a/openfe/tests/utils/test_visualization_3D.py +++ b/openfe/tests/utils/test_visualization_3D.py @@ -1,9 +1,9 @@ import pytest -from openfe.setup import LigandAtomMapping +from openfe.setup import LigandAtomMapping -pytest.importorskip('py3Dmol') -from openfe.utils.visualization_3D import view_mapping_3d, view_components_3d +pytest.importorskip("py3Dmol") +from openfe.utils.visualization_3D import view_components_3d, view_mapping_3d @pytest.fixture(scope="module") @@ -38,6 +38,7 @@ def test_visualize_component_coords_give_iterable_shift(benzene_transforms): components = [benzene_transforms["benzene"], benzene_transforms["phenol"]] view_components_3d(components, shift=(1, 1, 1)) + def test_visualize_component_coords_reuse_view(benzene_transforms): """ smoke test just checking if nothing goes horribly wrong diff --git a/openfe/utils/atommapping_network_plotting.py b/openfe/utils/atommapping_network_plotting.py index 8e119b432..cd26299b4 100644 --- a/openfe/utils/atommapping_network_plotting.py +++ b/openfe/utils/atommapping_network_plotting.py @@ -1,16 +1,15 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import io +from typing import Dict, Tuple + import matplotlib +from gufe.visualization.mapping_visualization import draw_one_molecule_mapping from rdkit import Chem -from typing import Dict, Tuple -from openfe.utils.network_plotting import GraphDrawing, Node, Edge -from gufe.visualization.mapping_visualization import ( - draw_one_molecule_mapping, -) +from openfe import LigandNetwork, SmallMoleculeComponent from openfe.utils.custom_typing import MPL_MouseEvent -from openfe import SmallMoleculeComponent, LigandNetwork +from openfe.utils.network_plotting import Edge, GraphDrawing, Node class AtomMappingEdge(Edge): @@ -27,26 +26,24 @@ class AtomMappingEdge(Edge): Data dictionary for this edge. Must have key ``object``, which maps to an :class:`.AtomMapping`. """ - def __init__(self, node_artist1: Node, node_artist2: Node, data: Dict): + + def __init__(self, node_artist1: Node, node_artist2: Node, data: dict): super().__init__(node_artist1, node_artist2, data) self.left_image = None self.right_image = None def _draw_mapped_molecule( self, - extent: Tuple[float, float, float, float], + extent: tuple[float, float, float, float], molA: SmallMoleculeComponent, molB: SmallMoleculeComponent, - molA_to_molB: Dict[int, int] + molA_to_molB: dict[int, int], ): # create the image in a format matplotlib can handle d2d = Chem.Draw.rdMolDraw2D.MolDraw2DCairo(300, 300, 300, 300) d2d.drawOptions().setBackgroundColour((1, 1, 1, 0.7)) # TODO: use a custom draw2d object; figure size from transforms - img_bytes = draw_one_molecule_mapping(molA_to_molB, - molA.to_rdkit(), - molB.to_rdkit(), - d2d=d2d) + img_bytes = draw_one_molecule_mapping(molA_to_molB, molA.to_rdkit(), molB.to_rdkit(), d2d=d2d) img_filelike = io.BytesIO(img_bytes) # imread needs filelike img_data = matplotlib.pyplot.imread(img_filelike) @@ -85,7 +82,7 @@ def _get_image_extents(self): def select(self, event, graph): super().select(event, graph) - mapping = self.data['object'] + mapping = self.data["object"] # figure out which node is to the left and which to the right xs = [node.xy[0] for node in self.node_artists] @@ -102,14 +99,8 @@ def select(self, event, graph): left_extent, right_extent = self._get_image_extents() - self.left_image = self._draw_mapped_molecule(left_extent, - left, - right, - left_to_right) - self.right_image = self._draw_mapped_molecule(right_extent, - right, - left, - right_to_left) + self.left_image = self._draw_mapped_molecule(left_extent, left, right, left_to_right) + self.right_image = self._draw_mapped_molecule(right_extent, right, left, right_to_left) graph.fig.canvas.draw() def unselect(self): @@ -124,8 +115,7 @@ def unselect(self): class LigandNode(Node): def _make_artist(self, x, y, dx, dy): - artist = matplotlib.text.Text(x, y, self.node.name, color='blue', - backgroundcolor='white') + artist = matplotlib.text.Text(x, y, self.node.name, color="blue", backgroundcolor="white") return artist def register_artist(self, ax): @@ -154,6 +144,7 @@ class AtomMappingNetworkDrawing(GraphDrawing): positions : Optional[Dict[SmallMoleculeComponent, Tuple[float, float]]] mapping of node to position """ + NodeCls = LigandNode EdgeCls = AtomMappingEdge diff --git a/openfe/utils/custom_typing.py b/openfe/utils/custom_typing.py index 6cb28f696..ebed80c40 100644 --- a/openfe/utils/custom_typing.py +++ b/openfe/utils/custom_typing.py @@ -2,9 +2,10 @@ # For details, see https://github.com/OpenFreeEnergy/openfe from typing import TypeVar -from rdkit import Chem + import matplotlib.axes import matplotlib.backend_bases +from rdkit import Chem try: from typing import TypeAlias # type: ignore @@ -13,7 +14,7 @@ RDKitMol: TypeAlias = Chem.rdchem.Mol -OEMol = TypeVar('OEMol') +OEMol = TypeVar("OEMol") MPL_FigureCanvasBase: TypeAlias = matplotlib.backend_bases.FigureCanvasBase MPL_MouseEvent: TypeAlias = matplotlib.backend_bases.MouseEvent MPL_Axes: TypeAlias = matplotlib.axes.Axes diff --git a/openfe/utils/handle_trajectories.py b/openfe/utils/handle_trajectories.py index a10d376ee..2624c7791 100644 --- a/openfe/utils/handle_trajectories.py +++ b/openfe/utils/handle_trajectories.py @@ -1,15 +1,16 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe +from pathlib import Path +from typing import Optional + import netCDF4 as nc import numpy as np -from pathlib import Path from openff.units import unit + from openfe import __version__ -from typing import Optional -def _state_to_replica(dataset: nc.Dataset, state_num: int, - frame_num: int) -> int: +def _state_to_replica(dataset: nc.Dataset, state_num: int, frame_num: int) -> int: """Convert a state index to replica index at a given frame Parameters @@ -28,13 +29,11 @@ def _state_to_replica(dataset: nc.Dataset, state_num: int, Index of the replica which represents that thermodynamic state for that frame. """ - state_distribution = dataset.variables['states'][frame_num].data + state_distribution = dataset.variables["states"][frame_num].data return np.where(state_distribution == state_num)[0][0] -def _replica_positions_at_frame(dataset: nc.Dataset, - replica_index: int, - frame_num: int) -> unit.Quantity: +def _replica_positions_at_frame(dataset: nc.Dataset, replica_index: int, frame_num: int) -> unit.Quantity: """ Helper method to extract atom positions of a state at a given frame. @@ -52,13 +51,12 @@ def _replica_positions_at_frame(dataset: nc.Dataset, unit.Quantity n_atoms * 3 position Quantity array """ - pos = dataset.variables['positions'][frame_num][replica_index].data - pos_units = dataset.variables['positions'].units + pos = dataset.variables["positions"][frame_num][replica_index].data + pos_units = dataset.variables["positions"].units return pos * unit(pos_units) -def _create_new_dataset(filename: Path, n_atoms: int, - title: str) -> nc.Dataset: +def _create_new_dataset(filename: Path, n_atoms: int, title: str) -> nc.Dataset: """ Helper method to create a new NetCDF dataset which follows the AMBER convention (see: https://ambermd.org/netcdf/nctraj.xhtml) @@ -78,44 +76,40 @@ def _create_new_dataset(filename: Path, n_atoms: int, AMBER Conventions compliant NetCDF dataset to store information contained in MultiState reporter generated NetCDF file. """ - ncfile = nc.Dataset(filename, 'w', format='NETCDF3_64BIT') - ncfile.Conventions = 'AMBER' + ncfile = nc.Dataset(filename, "w", format="NETCDF3_64BIT") + ncfile.Conventions = "AMBER" ncfile.ConventionVersion = "1.0" ncfile.application = "openfe" ncfile.program = f"openfe {__version__}" ncfile.programVersion = f"{__version__}" ncfile.title = title - + # Set the dimensions - ncfile.createDimension('frame', None) - ncfile.createDimension('spatial', 3) - ncfile.createDimension('atom', n_atoms) - ncfile.createDimension('cell_spatial', 3) - ncfile.createDimension('cell_angular', 3) - ncfile.createDimension('label', 5) - + ncfile.createDimension("frame", None) + ncfile.createDimension("spatial", 3) + ncfile.createDimension("atom", n_atoms) + ncfile.createDimension("cell_spatial", 3) + ncfile.createDimension("cell_angular", 3) + ncfile.createDimension("label", 5) + # Set the variables # positions - pos = ncfile.createVariable('coordinates', 'f4', ('frame', 'atom', 'spatial')) - pos.units = 'angstrom' + pos = ncfile.createVariable("coordinates", "f4", ("frame", "atom", "spatial")) + pos.units = "angstrom" # we could also set this to 0.1 and do no nm to angstrom scaling on write - pos.scale_factor = 1.0 + pos.scale_factor = 1.0 # Note: OpenMMTools NetCDF files store velocities # but honestly it's rather useless, so we don't populate them - # Note 2: NetCDF file doesn't contain any time information... + # Note 2: NetCDF file doesn't contain any time information... # so we can't populate that either, this might trip up some readers.. # Note 3: We'll need to convert box vectors (in nm) to # unitcell (in angstrom & degrees) - cell_lengths = ncfile.createVariable( - 'cell_lengths', 'f8', ('frame', 'cell_spatial') - ) - cell_lengths.units = 'angstrom' - cell_angles = ncfile.createVariable( - 'cell_angles', 'f8', ('frame', 'cell_spatial') - ) - cell_angles.units = 'degree' - + cell_lengths = ncfile.createVariable("cell_lengths", "f8", ("frame", "cell_spatial")) + cell_lengths.units = "angstrom" + cell_angles = ncfile.createVariable("cell_angles", "f8", ("frame", "cell_spatial")) + cell_angles.units = "degree" + return ncfile @@ -139,9 +133,9 @@ def _get_unitcell(dataset: nc.Dataset, replica_index: int, frame_num: int): Tuple[lx, ly, lz, alpha, beta, gamma] Unit cell lengths and angles in angstroms and degrees. """ - vecs = dataset.variables['box_vectors'][frame_num][replica_index].data - vecs_units = dataset.variables['box_vectors'].units - x, y, z = (vecs * unit(vecs_units)).to('angstrom').m + vecs = dataset.variables["box_vectors"][frame_num][replica_index].data + vecs_units = dataset.variables["box_vectors"].units + x, y, z = (vecs * unit(vecs_units)).to("angstrom").m lx = np.linalg.norm(x) ly = np.linalg.norm(y) lz = np.linalg.norm(z) @@ -155,9 +149,12 @@ def _get_unitcell(dataset: nc.Dataset, replica_index: int, frame_num: int): return lx, ly, lz, np.rad2deg(alpha), np.rad2deg(beta), np.rad2deg(gamma) -def trajectory_from_multistate(input_file: Path, output_file: Path, - state_number: Optional[int] = None, - replica_number: Optional[int] = None) -> None: +def trajectory_from_multistate( + input_file: Path, + output_file: Path, + state_number: Optional[int] = None, + replica_number: Optional[int] = None, +) -> None: """ Extract a state's trajectory (in an AMBER compliant format) from a MultiState sampler generated NetCDF file. @@ -176,28 +173,27 @@ def trajectory_from_multistate(input_file: Path, output_file: Path, Index of the replica to write out """ if not ((state_number is None) ^ (replica_number is None)): - raise ValueError("Supply either state or replica number, " - f"got state_number={state_number} " - f"and replica_number={replica_number}") + raise ValueError( + "Supply either state or replica number, " + f"got state_number={state_number} " + f"and replica_number={replica_number}", + ) # Open MultiState NC file and get number of atoms and frames - multistate = nc.Dataset(input_file, 'r') - n_atoms = len(multistate.variables['positions'][0][0]) - n_replicas = len(multistate.variables['positions'][0]) - n_frames = len(multistate.variables['positions']) - + multistate = nc.Dataset(input_file, "r") + n_atoms = len(multistate.variables["positions"][0][0]) + n_replicas = len(multistate.variables["positions"][0]) + n_frames = len(multistate.variables["positions"]) + # Sanity check if state_number is not None and (state_number + 1 > n_replicas): # Note this works for now, but when we have more states # than replicas (e.g. SAMS) this won't really work errmsg = "State does not exist" raise ValueError(errmsg) - + # Create output AMBER NetCDF convention file - traj = _create_new_dataset( - output_file, n_atoms, - title=f"state {state_number} trajectory from {input_file}" - ) + traj = _create_new_dataset(output_file, n_atoms, title=f"state {state_number} trajectory from {input_file}") replica_id: int = -1 if replica_number is not None: @@ -208,12 +204,12 @@ def trajectory_from_multistate(input_file: Path, output_file: Path, if state_number is not None: replica_id = _state_to_replica(multistate, state_number, frame) - traj.variables['coordinates'][frame] = _replica_positions_at_frame( - multistate, replica_id, frame - ).to('angstrom').m + traj.variables["coordinates"][frame] = ( + _replica_positions_at_frame(multistate, replica_id, frame).to("angstrom").m + ) unitcell = _get_unitcell(multistate, replica_id, frame) - traj.variables['cell_lengths'][frame] = unitcell[:3] - traj.variables['cell_angles'][frame] = unitcell[3:] + traj.variables["cell_lengths"][frame] = unitcell[:3] + traj.variables["cell_angles"][frame] = unitcell[3:] # Make sure to clean up when you are done multistate.close() diff --git a/openfe/utils/logging_filter.py b/openfe/utils/logging_filter.py index 78d00de57..9336433d6 100644 --- a/openfe/utils/logging_filter.py +++ b/openfe/utils/logging_filter.py @@ -1,5 +1,6 @@ import logging + class MsgIncludesStringFilter: """Logging filter to silence specfic log messages. @@ -11,6 +12,7 @@ class MsgIncludesStringFilter: if an exact for this is included in the log message, the log record is suppressed """ + def __init__(self, string): self.string = string diff --git a/openfe/utils/network_plotting.py b/openfe/utils/network_plotting.py index ad767dc73..11a4f38f1 100644 --- a/openfe/utils/network_plotting.py +++ b/openfe/utils/network_plotting.py @@ -11,15 +11,14 @@ from __future__ import annotations import itertools +from typing import Any, Optional, Union, cast + import networkx as nx from matplotlib import pyplot as plt -from matplotlib.patches import Rectangle from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle -from typing import Optional, Any, Union, cast -from openfe.utils.custom_typing import ( - MPL_MouseEvent, MPL_FigureCanvasBase, MPL_Axes, TypeAlias -) +from openfe.utils.custom_typing import MPL_Axes, MPL_FigureCanvasBase, MPL_MouseEvent, TypeAlias ClickLocation: TypeAlias = tuple[tuple[float, float], tuple[Any, Any]] @@ -32,6 +31,7 @@ class Node: for this node. This acts as an adapter class, allowing different artists to be used, as well as enabling different functionalities. """ + # TODO: someday it might be good to separate the artist adapter from the # functionality on select, etc. draggable = True @@ -44,10 +44,10 @@ def __init__(self, node, x: float, y: float, dx=0.1, dy=0.1): self.dy = dx self.artist = self._make_artist(x, y, dx, dy) self.picked = False - self.press: Optional[ClickLocation] = None + self.press: ClickLocation | None = None def _make_artist(self, x, y, dx, dy): - return Rectangle((x, y), dx, dy, color='blue') + return Rectangle((x, y), dx, dy, color="blue") def register_artist(self, ax: MPL_Axes): """Register this node's artist with the matplotlib Axes""" @@ -57,8 +57,7 @@ def register_artist(self, ax: MPL_Axes): def extent(self) -> tuple[float, float, float, float]: """extent of this node in matplotlib data coordinates""" bounds = self.artist.get_bbox().bounds - return (bounds[0], bounds[0] + bounds[2], - bounds[1], bounds[1] + bounds[3]) + return (bounds[0], bounds[0] + bounds[2], bounds[1], bounds[1] + bounds[3]) @property def xy(self) -> tuple[float, float]: @@ -71,11 +70,11 @@ def select(self, event: MPL_MouseEvent, graph: GraphDrawing): # -no-cov- def unselect(self): """Reset this node to its standard, unselected visualization""" - self.artist.set(color='blue') + self.artist.set(color="blue") def edge_select(self, edge: Edge): """Change node visualization when one of its edges is selected""" - self.artist.set(color='red') + self.artist.set(color="red") def update_location(self, x: float, y: float): """Update the location of the underlying artist""" @@ -151,6 +150,7 @@ class Edge: data : Dict data dictionary for this edge """ + pickable = True def __init__(self, node_artist1: Node, node_artist2: Node, data: dict): @@ -159,10 +159,9 @@ def __init__(self, node_artist1: Node, node_artist2: Node, data: dict): self.artist = self._make_artist(node_artist1, node_artist2, data) self.picked = False - def _make_artist(self, node_artist1: Node, node_artist2: Node, - data: dict) -> Any: + def _make_artist(self, node_artist1: Node, node_artist2: Node, data: dict) -> Any: xs, ys = self._edge_xs_ys(node_artist1, node_artist2) - return Line2D(xs, ys, color='black', picker=True, zorder=-1) + return Line2D(xs, ys, color="black", picker=True, zorder=-1) def register_artist(self, ax: MPL_Axes): """Register this edge's artist with the matplotlib Axes""" @@ -198,14 +197,14 @@ def on_mouseup(self, event: MPL_MouseEvent, graph: GraphDrawing): def unselect(self): """Reset this edge to its standard, unselected visualization""" - self.artist.set(color='black') + self.artist.set(color="black") for node_artist in self.node_artists: node_artist.unselect() self.picked = False def select(self, event: MPL_MouseEvent, graph: GraphDrawing): """Mark this edge as selected, update visualization""" - self.artist.set(color='red') + self.artist.set(color="red") for artist in self.node_artists: artist.edge_select(self) self.picked = True @@ -243,20 +242,23 @@ class EventHandler: connections : List[int] list of IDs for connections to matplotlib canvas """ + def __init__(self, graph: GraphDrawing): self.graph = graph - self.active: Optional[Union[Node, Edge]] = None - self.selected: Optional[Union[Node, Edge]] = None - self.click_location: Optional[tuple[Optional[float], Optional[float]]] = None + self.active: Node | Edge | None = None + self.selected: Node | Edge | None = None + self.click_location: tuple[float | None, float | None] | None = None self.connections: list[int] = [] def connect(self, canvas: MPL_FigureCanvasBase): """Connect our methods to events in the matplotlib canvas""" - self.connections.extend([ - canvas.mpl_connect('button_press_event', self.on_mousedown), # type: ignore - canvas.mpl_connect('motion_notify_event', self.on_drag), # type: ignore - canvas.mpl_connect('button_release_event', self.on_mouseup), # type: ignore - ]) + self.connections.extend( + [ + canvas.mpl_connect("button_press_event", self.on_mousedown), # type: ignore + canvas.mpl_connect("motion_notify_event", self.on_drag), # type: ignore + canvas.mpl_connect("button_release_event", self.on_mouseup), # type: ignore + ], + ) def disconnect(self, canvas: MPL_FigureCanvasBase): """Disconnect all connections to the canvas.""" @@ -271,8 +273,7 @@ def _get_event_container(self, event: MPL_MouseEvent): could be a node or an edge, it is interpreted as clicking on the node. """ - containers = itertools.chain(self.graph.nodes.values(), - self.graph.edges.values()) + containers = itertools.chain(self.graph.nodes.values(), self.graph.edges.values()) for container in containers: if container.contains(event): break @@ -339,6 +340,7 @@ class GraphDrawing: positions : Optional[Dict[Any, Tuple[float, float]]] mapping of node to position """ + NodeCls = Node EdgeCls = Edge @@ -350,7 +352,7 @@ def __init__(self, graph: nx.Graph, positions=None, ax=None): self.edges: dict[tuple[Node, Node], Any] = {} if positions is None: - positions = nx.nx_agraph.graphviz_layout(self.graph, prog='neato') + positions = nx.nx_agraph.graphviz_layout(self.graph, prog="neato") was_interactive = plt.isinteractive() plt.ioff() @@ -380,15 +382,12 @@ def _ipython_display_(self): # -no-cov- def edges_for_node(self, node: Node) -> list[Edge]: """List of edges for the given node""" - edges = (list(self.graph.in_edges(node)) - + list(self.graph.out_edges(node))) + edges = list(self.graph.in_edges(node)) + list(self.graph.out_edges(node)) return [self.edges[edge] for edge in edges] def _get_nodes_extent(self): """Find the extent of all nodes (used in setting bounds)""" - min_xs, max_xs, min_ys, max_ys = zip(*( - node.extent for node in self.nodes.values() - )) + min_xs, max_xs, min_ys, max_ys = zip(*(node.extent for node in self.nodes.values())) return min(min_xs), max(max_xs), min(min_ys), max(max_ys) def reset_bounds(self): diff --git a/openfe/utils/optional_imports.py b/openfe/utils/optional_imports.py index 768beb243..f6c3c1110 100644 --- a/openfe/utils/optional_imports.py +++ b/openfe/utils/optional_imports.py @@ -2,6 +2,7 @@ Tools for integration with miscellaneous non-required packages. shamelessly borrowed from openff.toolkit """ + import functools from typing import Callable @@ -29,8 +30,7 @@ def wrapper(*args, **kwargs): try: importlib.import_module(package_name) except (ImportError, ModuleNotFoundError): - raise ImportError(function.__name__ + " requires package: " + - package_name) + raise ImportError(function.__name__ + " requires package: " + package_name) except Exception as e: raise e diff --git a/openfe/utils/remove_oechem.py b/openfe/utils/remove_oechem.py index f489f04cc..d42f6f6f8 100644 --- a/openfe/utils/remove_oechem.py +++ b/openfe/utils/remove_oechem.py @@ -1,16 +1,15 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe +from contextlib import contextmanager + from openff.toolkit import GLOBAL_TOOLKIT_REGISTRY, OpenEyeToolkitWrapper from openff.toolkit.utils.toolkit_registry import ToolkitUnavailableException -from contextlib import contextmanager - @contextmanager def without_oechem_backend(): """For temporarily removing oechem from openff's toolkit registry""" - current_toolkits = [type(tk) - for tk in GLOBAL_TOOLKIT_REGISTRY.registered_toolkits] + current_toolkits = [type(tk) for tk in GLOBAL_TOOLKIT_REGISTRY.registered_toolkits] try: GLOBAL_TOOLKIT_REGISTRY.deregister_toolkit(OpenEyeToolkitWrapper()) @@ -22,7 +21,6 @@ def without_oechem_backend(): finally: # this is order dependent; we want to prepend OEChem back to first while GLOBAL_TOOLKIT_REGISTRY.registered_toolkits: - GLOBAL_TOOLKIT_REGISTRY.deregister_toolkit( - GLOBAL_TOOLKIT_REGISTRY.registered_toolkits[0]) + GLOBAL_TOOLKIT_REGISTRY.deregister_toolkit(GLOBAL_TOOLKIT_REGISTRY.registered_toolkits[0]) for tk in current_toolkits: GLOBAL_TOOLKIT_REGISTRY.register_toolkit(tk) diff --git a/openfe/utils/silence_root_logging.py b/openfe/utils/silence_root_logging.py index a0222927b..c37b0da3c 100644 --- a/openfe/utils/silence_root_logging.py +++ b/openfe/utils/silence_root_logging.py @@ -1,6 +1,7 @@ import contextlib import logging + @contextlib.contextmanager def silence_root_logging(): """Context manager to silence logging from root logging handlers. @@ -21,6 +22,3 @@ def silence_root_logging(): root.removeHandler(null) for handler in old_handlers: root.addHandler(handler) - - - diff --git a/openfe/utils/system_probe.py b/openfe/utils/system_probe.py index 483d3f2dd..dee600e59 100644 --- a/openfe/utils/system_probe.py +++ b/openfe/utils/system_probe.py @@ -4,7 +4,8 @@ import socket import subprocess import sys -from typing import Iterable, Optional +from collections.abc import Iterable +from typing import Optional import psutil from psutil._common import bytes2human @@ -223,7 +224,7 @@ def _get_psutil_info() -> dict[str, dict[str, str]]: "num_fds", "pid", "status", - ] + ], ) # OSX doesn't have rlimit for Process if sys.platform != "darwin": @@ -337,13 +338,11 @@ def _get_gpu_info() -> dict[str, dict[str, str]]: ) try: - nvidia_smi_output = subprocess.check_output( - ["nvidia-smi", GPU_QUERY, "--format=csv"] - ).decode("utf-8") + nvidia_smi_output = subprocess.check_output(["nvidia-smi", GPU_QUERY, "--format=csv"]).decode("utf-8") except FileNotFoundError: logging.debug( "Error: nvidia-smi command not found. Make sure NVIDIA drivers are" - " installed, this is expected if there is no GPU available" + " installed, this is expected if there is no GPU available", ) return {} @@ -478,12 +477,11 @@ def _probe_system(paths: Optional[Iterable[pathlib.Path]] = None) -> dict: "gpu information": gpu_info, "psutil information": psutil_info, "disk usage information": disk_usage_info, - } + }, } -def log_system_probe(level=logging.DEBUG, - paths: Optional[Iterable[os.PathLike]] = None): +def log_system_probe(level=logging.DEBUG, paths: Optional[Iterable[os.PathLike]] = None): """Print the system information via configurable logging. This creates a logger tree under "{__name__}.log", allowing one to turn @@ -499,29 +497,26 @@ def log_system_probe(level=logging.DEBUG, hostname = logging.getLogger(basename + ".hostname") loggers = [base, gpu, hostname] if any(l.isEnabledFor(level) for l in loggers): - sysinfo = _probe_system(pl_paths)['system information'] + sysinfo = _probe_system(pl_paths)["system information"] base.log(level, "SYSTEM CONFIG DETAILS:") hostname.log(level, f"hostname: '{sysinfo['hostname']}'") - if gpuinfo := sysinfo['gpu information']: + if gpuinfo := sysinfo["gpu information"]: for uuid, gpu_card in gpuinfo.items(): - gpu.log(level, f"GPU: {uuid=} {gpu_card['name']} " - f"mode={gpu_card['compute_mode']}") + gpu.log(level, f"GPU: {uuid=} {gpu_card['name']} " f"mode={gpu_card['compute_mode']}") # gpu.log(level, f"CUDA driver: {...}") # gpu.log(level, f"CUDA toolkit: {...}") else: # -no-cov- gpu.log(level, f"CUDA-based GPU not found") psutilinfo = sysinfo["psutil information"] - memused = psutilinfo['virtual_memory']['used'] - mempct = psutilinfo['virtual_memory']['percent'] + memused = psutilinfo["virtual_memory"]["used"] + mempct = psutilinfo["virtual_memory"]["percent"] base.log(level, f"Memory used: {bytes2human(memused)} ({mempct}%)") - for diskdev, disk in sysinfo['disk usage information'].items(): - base.log(level, f"{diskdev}: {disk['percent_used']} full " - f"({disk['available']} free)") + for diskdev, disk in sysinfo["disk usage information"].items(): + base.log(level, f"{diskdev}: {disk['percent_used']} full " f"({disk['available']} free)") if __name__ == "__main__": from pprint import pprint pprint(_probe_system()) - diff --git a/openfe/utils/visualization_3D.py b/openfe/utils/visualization_3D.py index fc7551df1..0b86003bf 100644 --- a/openfe/utils/visualization_3D.py +++ b/openfe/utils/visualization_3D.py @@ -1,19 +1,20 @@ +from collections.abc import Iterable +from typing import Dict, Optional, Tuple, Union + import numpy as np +from matplotlib import pyplot as plt +from matplotlib.colors import rgb2hex from numpy.typing import NDArray -from typing import Tuple, Union, Optional, Dict, Iterable - from rdkit import Chem from rdkit.Geometry.rdGeometry import Point3D -from matplotlib import pyplot as plt -from matplotlib.colors import rgb2hex try: import py3Dmol except ImportError: - pass # Don't throw error, will happen later + pass # Don't throw error, will happen later -from gufe.mapping import AtomMapping from gufe.components.explicitmoleculecomponent import ExplicitMoleculeComponent +from gufe.mapping import AtomMapping from openfe.utils import requires_package @@ -43,7 +44,7 @@ def _get_max_dist_in_x(atom_mapping: AtomMapping) -> float: return estm if (estm > 5) else 5 -def _translate(mol, shift:Union[Tuple[float, float, float], NDArray[np.float64]]): +def _translate(mol, shift: Union[tuple[float, float, float], NDArray[np.float64]]): """ shifts the molecule by the shift vector @@ -68,7 +69,7 @@ def _translate(mol, shift:Union[Tuple[float, float, float], NDArray[np.float64]] return mol -def _add_spheres(view:py3Dmol.view, mol1:Chem.Mol, mol2:Chem.Mol, mapping:Dict[int, int]): +def _add_spheres(view: py3Dmol.view, mol1: Chem.Mol, mol2: Chem.Mol, mapping: dict[int, int]): """ will add spheres according to mapping to the view. (inplace!) @@ -95,7 +96,7 @@ def _add_spheres(view:py3Dmol.view, mol1:Chem.Mol, mol2:Chem.Mol, mapping:Dict[i "radius": 0.6, "color": color, "alpha": 0.8, - } + }, ) view.addSphere( { @@ -103,16 +104,17 @@ def _add_spheres(view:py3Dmol.view, mol1:Chem.Mol, mol2:Chem.Mol, mapping:Dict[i "radius": 0.6, "color": color, "alpha": 0.8, - } + }, ) @requires_package("py3Dmol") -def view_components_3d(mols: Iterable[ExplicitMoleculeComponent], - style: Optional[str] ="stick", - shift: Optional[Tuple[float, float, float]] = None, - view: py3Dmol.view = None - ) -> py3Dmol.view: +def view_components_3d( + mols: Iterable[ExplicitMoleculeComponent], + style: Optional[str] = "stick", + shift: Optional[tuple[float, float, float]] = None, + view: py3Dmol.view = None, +) -> py3Dmol.view: """visualize multiple component coordinates in one interactive view. It helps to understand how the components are aligned in the system to each other. @@ -129,24 +131,24 @@ def view_components_3d(mols: Iterable[ExplicitMoleculeComponent], Amount to i*shift each mols_i in order to allow inspection of them in heavy overlap cases. view : py3Dmol, optional Allows to pass an already existing view, by default None - + Returns ------- py3Dmol.view view containing all component coordinates """ - if(view is None): + if view is None: view = py3Dmol.view(width=600, height=600) - + for i, component in enumerate(mols): mol = Chem.Mol(component.to_rdkit()) - if(shift is not None): - tmp_shift = np.array(shift, dtype=np.float64)*i + if shift is not None: + tmp_shift = np.array(shift, dtype=np.float64) * i mol = _translate(mol, tmp_shift) view.addModel(Chem.MolToMolBlock(mol)) - + view.setStyle({style: {}}) view.zoomTo() @@ -159,9 +161,8 @@ def view_mapping_3d( spheres: Optional[bool] = True, show_atomIDs: Optional[bool] = False, style: Optional[str] = "stick", - shift: Optional[Union[Tuple[float, float, float], NDArray[np.float64]]] = None, + shift: Optional[Union[tuple[float, float, float], NDArray[np.float64]]] = None, ) -> py3Dmol.view: - """ Render relative transformation edge in 3D using py3Dmol. diff --git a/openfecli/__init__.py b/openfecli/__init__.py index 74e718cd2..76a0f2f57 100644 --- a/openfecli/__init__.py +++ b/openfecli/__init__.py @@ -1,8 +1,9 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from .plugins import OFECommandPlugin +from importlib.metadata import version + from . import commands +from .plugins import OFECommandPlugin -from importlib.metadata import version __version__ = version("openfe") diff --git a/openfecli/cli.py b/openfecli/cli.py index 014d85f7c..6c36b36c9 100644 --- a/openfecli/cli.py +++ b/openfecli/cli.py @@ -1,23 +1,21 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import pathlib import logging import logging.config +import pathlib import click from plugcli.cli import CLI, CONTEXT_SETTINGS from plugcli.plugin_management import FilePluginLoader import openfecli - from openfecli.plugins import OFECommandPlugin class OpenFECLI(CLI): # COMMAND_SECTIONS = ["Setup", "Simulation", "Orchestration", "Analysis"] - COMMAND_SECTIONS = ["Network Planning", "Quickrun Executor", - "Miscellaneous"] + COMMAND_SECTIONS = ["Network Planning", "Quickrun Executor", "Miscellaneous"] def get_loaders(self): commands = str(pathlib.Path(__file__).parent.resolve() / "commands") @@ -35,11 +33,9 @@ def get_installed_plugins(self): """ -@click.command(cls=OpenFECLI, name="openfe", help=_MAIN_HELP, - context_settings=CONTEXT_SETTINGS) +@click.command(cls=OpenFECLI, name="openfe", help=_MAIN_HELP, context_settings=CONTEXT_SETTINGS) @click.version_option(version=openfecli.__version__) -@click.option('--log', type=click.Path(exists=True, readable=True), - help="logging configuration file") +@click.option("--log", type=click.Path(exists=True, readable=True), help="logging configuration file") def main(log): # Subcommand runs after this is processed. # set logging if provided diff --git a/openfecli/clicktypes/hyphenchoice.py b/openfecli/clicktypes/hyphenchoice.py index 70ac77501..bc63856eb 100644 --- a/openfecli/clicktypes/hyphenchoice.py +++ b/openfecli/clicktypes/hyphenchoice.py @@ -1,8 +1,10 @@ import click + def _normalize_to_hyphen(string): return string.replace("_", "-") + class HyphenAwareChoice(click.Choice): def __init__(self, choices, case_sensitive=True): choices = [_normalize_to_hyphen(choice) for choice in choices] diff --git a/openfecli/commands/atommapping.py b/openfecli/commands/atommapping.py index 142061db5..989cca071 100644 --- a/openfecli/commands/atommapping.py +++ b/openfecli/commands/atommapping.py @@ -2,8 +2,9 @@ # For details, see https://github.com/OpenFreeEnergy/openfe import click + from openfecli import OFECommandPlugin -from openfecli.parameters import MOL, MAPPER, OUTPUT_FILE_AND_EXT +from openfecli.parameters import MAPPER, MOL, OUTPUT_FILE_AND_EXT def allow_two_molecules(ctx, param, value): @@ -13,16 +14,15 @@ def allow_two_molecules(ctx, param, value): return value -@click.command( - "atommapping", - short_help="Check the atom mapping of a given pair of ligands" +@click.command("atommapping", short_help="Check the atom mapping of a given pair of ligands") +@MOL.parameter( + multiple=True, + callback=allow_two_molecules, + required=True, + help=MOL.kwargs["help"] + " Must be specified twice.", ) -@MOL.parameter(multiple=True, callback=allow_two_molecules, required=True, - help=MOL.kwargs['help'] + " Must be specified twice.") @MAPPER.parameter(required=True) -@OUTPUT_FILE_AND_EXT.parameter( - help=OUTPUT_FILE_AND_EXT.kwargs['help'] + " (PNG format)" -) +@OUTPUT_FILE_AND_EXT.parameter(help=OUTPUT_FILE_AND_EXT.kwargs["help"] + " (PNG format)") def atommapping(mol, mapper, output): """ This provides tools for looking at a specific atommapping. @@ -60,8 +60,7 @@ def generate_mapping(mapper, molA, molB): mappings = list(mapper.suggest_mappings(molA, molB)) if len(mappings) != 1: raise click.UsageError( - f"Found {len(mappings)} mappings; this command requires a mapper " - "to provide exactly 1 mapping" + f"Found {len(mappings)} mappings; this command requires a mapper " "to provide exactly 1 mapping", ) return mappings[0] @@ -73,8 +72,8 @@ def atommapping_print_dict_main(mapper, molA, molB): def atommapping_visualize_main(mapper, molA, molB, file, ext): - from rdkit.Chem import Draw from gufe.visualization import mapping_visualization as vis + from rdkit.Chem import Draw mapping = generate_mapping(mapper, molA, molB) ext_to_artist = { @@ -85,12 +84,15 @@ def atommapping_visualize_main(mapper, molA, molB, file, ext): except KeyError: raise click.BadParameter( f"Unknown file format: '{ext}'. The following formats are " - "supported: " + ", ".join([f"'{ext}'" for ext in ext_to_artist]) + "supported: " + ", ".join([f"'{ext}'" for ext in ext_to_artist]), ) - contents = vis.draw_mapping(mapping.componentA_to_componentB, - mapping.componentA.to_rdkit(), - mapping.componentB.to_rdkit(), d2d=artist) + contents = vis.draw_mapping( + mapping.componentA_to_componentB, + mapping.componentA.to_rdkit(), + mapping.componentB.to_rdkit(), + d2d=artist, + ) file.write(contents) diff --git a/openfecli/commands/fetch.py b/openfecli/commands/fetch.py index 3159559aa..69913c371 100644 --- a/openfecli/commands/fetch.py +++ b/openfecli/commands/fetch.py @@ -1,23 +1,24 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import click -import urllib +# MOVE SINGLEMODULEPLUGINLOADER UPSTREAM TO PLUGCLI +import importlib import shutil +import urllib + +import click from plugcli.cli import CLI, CONTEXT_SETTINGS -from openfecli.fetching import FetchablePlugin +from plugcli.plugin_management import CLIPluginLoader + from openfecli import OFECommandPlugin +from openfecli.fetching import FetchablePlugin + -# MOVE SINGLEMODULEPLUGINLOADER UPSTREAM TO PLUGCLI -import importlib -from plugcli.plugin_management import CLIPluginLoader class SingleModulePluginLoader(CLIPluginLoader): - """Load plugins from a specific module - """ + """Load plugins from a specific module""" + def __init__(self, module_name, plugin_class): - super().__init__(plugin_type="single_module", - search_path=module_name, - plugin_class=plugin_class) + super().__init__(plugin_type="single_module", search_path=module_name, plugin_class=plugin_class) def _find_candidates(self): return [importlib.import_module(self.search_path)] @@ -33,28 +34,25 @@ class FetchCLI(CLI): This provides the command sections used in help and defines where plugins should be kept. """ + COMMAND_SECTIONS = ["Tutorials"] def get_loaders(self): - return [ - SingleModulePluginLoader('openfecli.fetchables', - FetchablePlugin) - ] + return [SingleModulePluginLoader("openfecli.fetchables", FetchablePlugin)] def get_installed_plugins(self): loader = self.get_loaders()[0] return list(loader()) -@click.command( - cls=FetchCLI, - short_help="Fetch tutorial or other resource." -) + +@click.command(cls=FetchCLI, short_help="Fetch tutorial or other resource.") def fetch(): """ Fetch the given resource. Some resources require internet; others are built-in. """ + PLUGIN = OFECommandPlugin( command=fetch, section="Miscellaneous", diff --git a/openfecli/commands/gather.py b/openfecli/commands/gather.py index 9e6aa42f3..a8a33dd0a 100644 --- a/openfecli/commands/gather.py +++ b/openfecli/commands/gather.py @@ -1,15 +1,18 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe +import pathlib +import warnings + import click + from openfecli import OFECommandPlugin from openfecli.clicktypes import HyphenAwareChoice -import pathlib -import warnings def _get_column(val): import numpy as np + if val == 0: return 0 @@ -28,6 +31,7 @@ def format_estimate_uncertainty( unc_prec: int = 1, ) -> tuple[str, str]: import numpy as np + # get the last column needed for uncertainty unc_col = _get_column(unc) - (unc_prec - 1) @@ -43,60 +47,61 @@ def format_estimate_uncertainty( def is_results_json(f): # sanity check on files before we try and deserialize - return 'estimate' in open(f, 'r').read(20) + return "estimate" in open(f).read(20) def load_results(f): # path to deserialized results import json + from gufe.tokenization import JSON_HANDLER - return json.load(open(f, 'r'), cls=JSON_HANDLER.decoder) + return json.load(open(f), cls=JSON_HANDLER.decoder) def get_names(result) -> tuple[str, str]: # Result to tuple of ligand names - nm = list(result['unit_results'].values())[0]['name'] + nm = list(result["unit_results"].values())[0]["name"] toks = nm.split() - if toks[2] == 'repeat': + if toks[2] == "repeat": return toks[0], toks[1] else: return toks[0], toks[2] def get_type(res): - list_of_pur = list(res['protocol_result']['data'].values())[0] + list_of_pur = list(res["protocol_result"]["data"].values())[0] pur = list_of_pur[0] - components = pur['inputs']['stateA']['components'] + components = pur["inputs"]["stateA"]["components"] - if 'solvent' not in components: - return 'vacuum' - elif 'protein' in components: - return 'complex' + if "solvent" not in components: + return "vacuum" + elif "protein" in components: + return "complex" else: - return 'solvent' + return "solvent" def legacy_get_type(res_fn): - if 'solvent' in res_fn: - return 'solvent' - elif 'vacuum' in res_fn: - return 'vacuum' + if "solvent" in res_fn: + return "solvent" + elif "vacuum" in res_fn: + return "vacuum" else: - return 'complex' + return "complex" def _generate_bad_legs_error_message(set_vals, ligpair): - expected_rbfe = {'complex', 'solvent'} - expected_rhfe = {'solvent', 'vacuum'} + expected_rbfe = {"complex", "solvent"} + expected_rhfe = {"solvent", "vacuum"} maybe_rhfe = bool(set_vals & expected_rhfe) maybe_rbfe = bool(set_vals & expected_rbfe) if maybe_rhfe and not maybe_rbfe: msg = ( - "This appears to be an RHFE calculation, but we're " - f"missing {expected_rhfe - set_vals} runs for the " - f"edge with ligands {ligpair}." - ) + "This appears to be an RHFE calculation, but we're " + f"missing {expected_rhfe - set_vals} runs for the " + f"edge with ligands {ligpair}." + ) elif maybe_rbfe and not maybe_rhfe: msg = ( "This appears to be an RBFE calculation, but we're " @@ -131,14 +136,13 @@ def _generate_bad_legs_error_message(set_vals, ligpair): def _parse_raw_units(results: dict) -> list[tuple]: # grab individual unit results from master results dict # returns list of (estimate, uncertainty) tuples - pus = list(results['unit_results'].values()) - return [(pu['outputs']['unit_estimate'], - pu['outputs']['unit_estimate_error']) - for pu in pus] + pus = list(results["unit_results"].values()) + return [(pu["outputs"]["unit_estimate"], pu["outputs"]["unit_estimate_error"]) for pu in pus] def _get_ddgs(legs, error_on_missing=True): import numpy as np + DDGs = [] for ligpair, vals in sorted(legs.items()): set_vals = set(vals) @@ -147,20 +151,20 @@ def _get_ddgs(legs, error_on_missing=True): bind_unc = None hyd_unc = None - do_rbfe = (len(set_vals & {'complex', 'solvent'}) == 2) - do_rhfe = (len(set_vals & {'vacuum', 'solvent'}) == 2) + do_rbfe = len(set_vals & {"complex", "solvent"}) == 2 + do_rhfe = len(set_vals & {"vacuum", "solvent"}) == 2 if do_rbfe: - DG1_mag, DG1_unc = vals['complex'] - DG2_mag, DG2_unc = vals['solvent'] + DG1_mag, DG1_unc = vals["complex"] + DG2_mag, DG2_unc = vals["solvent"] if not ((DG1_mag is None) or (DG2_mag is None)): # DDG(2,1)bind = DG(1->2)complex - DG(1->2)solvent DDGbind = (DG1_mag - DG2_mag).m bind_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m]))) if do_rhfe: - DG1_mag, DG1_unc = vals['solvent'] - DG2_mag, DG2_unc = vals['vacuum'] + DG1_mag, DG1_unc = vals["solvent"] + DG2_mag, DG2_unc = vals["vacuum"] if not ((DG1_mag is None) or (DG2_mag is None)): DDGhyd = (DG1_mag - DG2_mag).m hyd_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m]))) @@ -179,8 +183,7 @@ def _get_ddgs(legs, error_on_missing=True): def _write_ddg(legs, writer, allow_partial): DDGs = _get_ddgs(legs, error_on_missing=not allow_partial) - writer.writerow(["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)", - "uncertainty (kcal/mol)"]) + writer.writerow(["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)", "uncertainty (kcal/mol)"]) for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs: if DDGbind is not None: DDGbind, bind_unc = format_estimate_uncertainty(DDGbind, bind_unc) @@ -191,14 +194,13 @@ def _write_ddg(legs, writer, allow_partial): def _write_raw(legs, writer, allow_partial=True): - writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)", - "MBAR uncertainty (kcal/mol)"]) + writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)", "MBAR uncertainty (kcal/mol)"]) for ligpair, vals in sorted(legs.items()): for simtype, repeats in sorted(vals.items()): for m, u in repeats: if m is None: - m, u = 'NaN', 'NaN' + m, u = "NaN", "NaN" else: m, u = format_estimate_uncertainty(m.m, u.m) @@ -206,12 +208,11 @@ def _write_raw(legs, writer, allow_partial=True): def _write_dg_raw(legs, writer, allow_partial): # pragma: no-cover - writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)", - "uncertainty (kcal/mol)"]) + writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)", "uncertainty (kcal/mol)"]) for ligpair, vals in sorted(legs.items()): for simtype, (m, u) in sorted(vals.items()): if m is None: - m, u = 'NaN', 'NaN' + m, u = "NaN", "NaN" else: m, u = format_estimate_uncertainty(m.m, u.m) writer.writerow([simtype, *ligpair, m, u]) @@ -221,6 +222,7 @@ def _write_dg_mle(legs, writer, allow_partial): import networkx as nx import numpy as np from cinnabar.stats import mle + DDGs = _get_ddgs(legs, error_on_missing=not allow_partial) MLEs = [] # 4b) perform MLE @@ -250,53 +252,49 @@ def _write_dg_mle(legs, writer, allow_partial): nm_to_idx[ligB] = idB g.add_edge( - idA, idB, calc_DDG=DDGbind, calc_dDDG=bind_unc, + idA, + idB, + calc_DDG=DDGbind, + calc_dDDG=bind_unc, ) if DDGbind_count > 2: idx_to_nm = {v: k for k, v in nm_to_idx.items()} - f_i, df_i = mle(g, factor='calc_DDG') + f_i, df_i = mle(g, factor="calc_DDG") df_i = np.diagonal(df_i) ** 0.5 for node, f, df in zip(g.nodes, f_i, df_i): ligname = idx_to_nm[node] MLEs.append((ligname, f, df)) - writer.writerow(["ligand", "DG(MLE) (kcal/mol)", - "uncertainty (kcal/mol)"]) + writer.writerow(["ligand", "DG(MLE) (kcal/mol)", "uncertainty (kcal/mol)"]) for ligA, DG, unc_DG in MLEs: DG, unc_DG = format_estimate_uncertainty(DG, unc_DG) writer.writerow([ligA, DG, unc_DG]) -@click.command( - 'gather', - short_help="Gather result jsons for network of RFE results into a TSV file" -) -@click.argument('rootdir', - type=click.Path(dir_okay=True, file_okay=False, - path_type=pathlib.Path), - required=True) +@click.command("gather", short_help="Gather result jsons for network of RFE results into a TSV file") +@click.argument("rootdir", type=click.Path(dir_okay=True, file_okay=False, path_type=pathlib.Path), required=True) @click.option( - '--report', - type=HyphenAwareChoice(['dg', 'ddg', 'raw'], - case_sensitive=False), - default="dg", show_default=True, + "--report", + type=HyphenAwareChoice(["dg", "ddg", "raw"], case_sensitive=False), + default="dg", + show_default=True, help=( "What data to report. 'dg' gives maximum-likelihood estimate of " "absolute deltaG, 'ddg' gives delta-delta-G, and 'dg-raw' gives " "the raw result of the deltaG for a leg." - ) + ), ) -@click.option('output', '-o', - type=click.File(mode='w'), - default='-') +@click.option("output", "-o", type=click.File(mode="w"), default="-") @click.option( - '--allow-partial', is_flag=True, default=False, + "--allow-partial", + is_flag=True, + default=False, help=( "Do not raise errors is results are missing parts for some edges. " "(Skip those edges and issue warning instead.)" - ) + ), ) def gather(rootdir, output, report, allow_partial): """Gather simulation result jsons of relative calculations to a tsv file @@ -320,12 +318,12 @@ def gather(rootdir, output, report, allow_partial): The output is a table of **tab** separated values. By default, this outputs to stdout, use the -o option to choose an output file. """ - from collections import defaultdict - import glob import csv + import glob + from collections import defaultdict # 1) find all possible jsons - json_fns = glob.glob(str(rootdir) + '/**/*json', recursive=True) + json_fns = glob.glob(str(rootdir) + "/**/*json", recursive=True) # 2) filter only result jsons result_fns = filter(is_results_json, json_fns) @@ -337,9 +335,8 @@ def gather(rootdir, output, report, allow_partial): result = load_results(result_fn) if result is None: continue - elif result['estimate'] is None or result['uncertainty'] is None: - click.echo(f"WARNING: Calculations for {result_fn} did not finish successfully!", - err=True) + elif result["estimate"] is None or result["uncertainty"] is None: + click.echo(f"WARNING: Calculations for {result_fn} did not finish successfully!", err=True) try: names = get_names(result) @@ -350,10 +347,10 @@ def gather(rootdir, output, report, allow_partial): except KeyError: simtype = legacy_get_type(result_fn) - if report.lower() == 'raw': + if report.lower() == "raw": legs[names][simtype] = _parse_raw_units(result) else: - legs[names][simtype] = result['estimate'], result['uncertainty'] + legs[names][simtype] = result["estimate"], result["uncertainty"] writer = csv.writer( output, @@ -365,17 +362,17 @@ def gather(rootdir, output, report, allow_partial): # 5b) write out DDG values # 5c) write out each leg writing_func = { - 'dg': _write_dg_mle, - 'ddg': _write_ddg, + "dg": _write_dg_mle, + "ddg": _write_ddg, # 'dg-raw': _write_dg_raw, - 'raw': _write_raw, + "raw": _write_raw, }[report.lower()] writing_func(legs, writer, allow_partial) PLUGIN = OFECommandPlugin( command=gather, - section='Quickrun Executor', + section="Quickrun Executor", requires_ofe=(0, 6), ) diff --git a/openfecli/commands/plan_rbfe_network.py b/openfecli/commands/plan_rbfe_network.py index 997f4db23..d712053f4 100644 --- a/openfecli/commands/plan_rbfe_network.py +++ b/openfecli/commands/plan_rbfe_network.py @@ -1,14 +1,14 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import click from typing import List -from openfecli.utils import write, print_duration + +import click + from openfecli import OFECommandPlugin -from openfecli.parameters import ( - MOL_DIR, PROTEIN, MAPPER, OUTPUT_DIR, COFACTORS, YAML_OPTIONS, -) +from openfecli.parameters import COFACTORS, MAPPER, MOL_DIR, OUTPUT_DIR, PROTEIN, YAML_OPTIONS from openfecli.plan_alchemical_networks_utils import plan_alchemical_network_output +from openfecli.utils import print_duration, write def plan_rbfe_network_main( @@ -46,9 +46,7 @@ def plan_rbfe_network_main( associated ligand network """ - from openfe.setup.alchemical_network_planner.relative_alchemical_network_planner import ( - RBFEAlchemicalNetworkPlanner, - ) + from openfe.setup.alchemical_network_planner.relative_alchemical_network_planner import RBFEAlchemicalNetworkPlanner network_planner = RBFEAlchemicalNetworkPlanner( mappers=mapper, @@ -56,7 +54,9 @@ def plan_rbfe_network_main( ligand_network_planner=ligand_network_planner, ) alchemical_network = network_planner( - ligands=small_molecules, solvent=solvent, protein=protein, + ligands=small_molecules, + solvent=solvent, + protein=protein, cofactors=cofactors, ) return alchemical_network, network_planner._ligand_network @@ -64,22 +64,15 @@ def plan_rbfe_network_main( @click.command( "plan-rbfe-network", - short_help=( - "Plan a relative binding free energy network, saved as JSON files " - "for the quickrun command." - ) -) -@MOL_DIR.parameter( - required=True, help=MOL_DIR.kwargs["help"] + " Any number of sdf paths." -) -@PROTEIN.parameter( - multiple=False, required=True, default=None, help=PROTEIN.kwargs["help"] -) -@COFACTORS.parameter( - multiple=True, required=False, default=None, help=COFACTORS.kwargs["help"] + short_help=("Plan a relative binding free energy network, saved as JSON files " "for the quickrun command."), ) +@MOL_DIR.parameter(required=True, help=MOL_DIR.kwargs["help"] + " Any number of sdf paths.") +@PROTEIN.parameter(multiple=False, required=True, default=None, help=PROTEIN.kwargs["help"]) +@COFACTORS.parameter(multiple=True, required=False, default=None, help=COFACTORS.kwargs["help"]) @YAML_OPTIONS.parameter( - multiple=False, required=False, default=None, + multiple=False, + required=False, + default=None, help=YAML_OPTIONS.kwargs["help"], ) @OUTPUT_DIR.parameter( @@ -88,9 +81,11 @@ def plan_rbfe_network_main( ) @print_duration def plan_rbfe_network( - molecules: List[str], protein: str, cofactors: tuple[str], - yaml_settings: str, - output_dir: str, + molecules: list[str], + protein: str, + cofactors: tuple[str], + yaml_settings: str, + output_dir: str, ): """ Plan a relative binding free energy network, saved as JSON files for @@ -125,10 +120,7 @@ def plan_rbfe_network( write("\tGot input: ") small_molecules = MOL_DIR.get(molecules) - write( - "\t\tSmall Molecules: " - + " ".join([str(sm) for sm in small_molecules]) - ) + write("\t\tSmall Molecules: " + " ".join([str(sm) for sm in small_molecules])) protein = PROTEIN.get(protein) write("\t\tProtein: " + str(protein)) @@ -182,6 +174,4 @@ def plan_rbfe_network( ) -PLUGIN = OFECommandPlugin( - command=plan_rbfe_network, section="Network Planning", requires_ofe=(0, 3) -) +PLUGIN = OFECommandPlugin(command=plan_rbfe_network, section="Network Planning", requires_ofe=(0, 3)) diff --git a/openfecli/commands/plan_rhfe_network.py b/openfecli/commands/plan_rhfe_network.py index fe5bd3114..539114338 100644 --- a/openfecli/commands/plan_rhfe_network.py +++ b/openfecli/commands/plan_rhfe_network.py @@ -2,19 +2,21 @@ # For details, see https://github.com/OpenFreeEnergy/openfe -import click from typing import List -from openfecli.utils import write, print_duration +import click + from openfecli import OFECommandPlugin -from openfecli.parameters import ( - MOL_DIR, MAPPER, OUTPUT_DIR, YAML_OPTIONS, -) +from openfecli.parameters import MAPPER, MOL_DIR, OUTPUT_DIR, YAML_OPTIONS from openfecli.plan_alchemical_networks_utils import plan_alchemical_network_output +from openfecli.utils import print_duration, write def plan_rhfe_network_main( - mapper, mapping_scorer, ligand_network_planner, small_molecules, + mapper, + mapping_scorer, + ligand_network_planner, + small_molecules, solvent, ): """Utility method to plan a relative hydration free energy network. @@ -38,34 +40,27 @@ def plan_rhfe_network_main( Alchemical network with protocol for executing simulations, and the associated ligand network """ - from openfe.setup.alchemical_network_planner.relative_alchemical_network_planner import ( - RHFEAlchemicalNetworkPlanner - ) + from openfe.setup.alchemical_network_planner.relative_alchemical_network_planner import RHFEAlchemicalNetworkPlanner network_planner = RHFEAlchemicalNetworkPlanner( mappers=mapper, mapping_scorer=mapping_scorer, ligand_network_planner=ligand_network_planner, ) - alchemical_network = network_planner( - ligands=small_molecules, solvent=solvent - ) + alchemical_network = network_planner(ligands=small_molecules, solvent=solvent) return alchemical_network, network_planner._ligand_network @click.command( "plan-rhfe-network", - short_help=( - "Plan a relative hydration free energy network, saved as JSON files " - "for the quickrun command." - ), -) -@MOL_DIR.parameter( - required=True, help=MOL_DIR.kwargs["help"] + " Any number of sdf paths." + short_help=("Plan a relative hydration free energy network, saved as JSON files " "for the quickrun command."), ) +@MOL_DIR.parameter(required=True, help=MOL_DIR.kwargs["help"] + " Any number of sdf paths.") @YAML_OPTIONS.parameter( - multiple=False, required=False, default=None, + multiple=False, + required=False, + default=None, help=YAML_OPTIONS.kwargs["help"], ) @OUTPUT_DIR.parameter( @@ -73,7 +68,7 @@ def plan_rhfe_network_main( default="alchemicalNetwork", ) @print_duration -def plan_rhfe_network(molecules: List[str], yaml_settings: str, output_dir: str): +def plan_rhfe_network(molecules: list[str], yaml_settings: str, output_dir: str): """ Plan a relative hydration free energy network, saved as JSON files for the quickrun command. @@ -110,10 +105,7 @@ def plan_rhfe_network(molecules: List[str], yaml_settings: str, output_dir: str) write("\tGot input: ") small_molecules = MOL_DIR.get(molecules) - write( - "\t\tSmall Molecules: " - + " ".join([str(sm) for sm in small_molecules]) - ) + write("\t\tSmall Molecules: " + " ".join([str(sm) for sm in small_molecules])) yaml_options = YAML_OPTIONS.get(yaml_settings) mapper_obj = yaml_options.mapper @@ -156,6 +148,4 @@ def plan_rhfe_network(molecules: List[str], yaml_settings: str, output_dir: str) ) -PLUGIN = OFECommandPlugin( - command=plan_rhfe_network, section="Network Planning", requires_ofe=(0, 3) -) +PLUGIN = OFECommandPlugin(command=plan_rhfe_network, section="Network Planning", requires_ofe=(0, 3)) diff --git a/openfecli/commands/quickrun.py b/openfecli/commands/quickrun.py index b8f6dc5e0..9ef2f615a 100644 --- a/openfecli/commands/quickrun.py +++ b/openfecli/commands/quickrun.py @@ -1,40 +1,35 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import click import json import pathlib +import click + from openfecli import OFECommandPlugin from openfecli.parameters.output import ensure_file_does_not_exist -from openfecli.utils import write, print_duration, configure_logger +from openfecli.utils import configure_logger, print_duration, write def _format_exception(exception) -> str: - """Takes the exception as stored by Gufe and reformats it. - """ + """Takes the exception as stored by Gufe and reformats it.""" return f"{exception[0]}: {exception[1][0]}" - -@click.command( - 'quickrun', - short_help="Run a given transformation, saved as a JSON file" -) -@click.argument('transformation', type=click.File(mode='r'), - required=True) +@click.command("quickrun", short_help="Run a given transformation, saved as a JSON file") +@click.argument("transformation", type=click.File(mode="r"), required=True) @click.option( - '--work-dir', '-d', default=None, - type=click.Path(dir_okay=True, file_okay=False, writable=True, - path_type=pathlib.Path), - help=( - "directory to store files in (defaults to current directory)" - ), + "--work-dir", + "-d", + default=None, + type=click.Path(dir_okay=True, file_okay=False, writable=True, path_type=pathlib.Path), + help=("directory to store files in (defaults to current directory)"), ) @click.option( - 'output', '-o', default=None, - type=click.Path(dir_okay=False, file_okay=True, writable=True, - path_type=pathlib.Path), + "output", + "-o", + default=None, + type=click.Path(dir_okay=False, file_okay=True, writable=True, path_type=pathlib.Path), help="output file (JSON format) for the final results", callback=ensure_file_does_not_exist, ) @@ -49,31 +44,31 @@ def quickrun(transformation, work_dir, output): That will save a JSON file suitable to be input for this command. """ - import gufe + import logging import os import sys + + import gufe from gufe.protocols.protocoldag import execute_DAG from gufe.tokenization import JSON_HANDLER + from openfe.utils.logging_filter import MsgIncludesStringFilter - import logging # avoid problems with output not showing if queueing system kills a job sys.stdout.reconfigure(line_buffering=True) stdout_handler = logging.StreamHandler(sys.stdout) - configure_logger('gufekey', handler=stdout_handler) - configure_logger('gufe', handler=stdout_handler) - configure_logger('openfe', handler=stdout_handler) + configure_logger("gufekey", handler=stdout_handler) + configure_logger("gufe", handler=stdout_handler) + configure_logger("openfe", handler=stdout_handler) # silence the openmmtools.multistate API warning stfu = MsgIncludesStringFilter( - "The openmmtools.multistate API is experimental and may change in " - "future releases" + "The openmmtools.multistate API is experimental and may change in " "future releases", ) omm_multistate = "openmmtools.multistate" - modules = ["multistatereporter", "multistateanalyzer", - "multistatesampler"] + modules = ["multistatereporter", "multistateanalyzer", "multistatesampler"] for module in modules: ms_log = logging.getLogger(omm_multistate + "." + module) ms_log.addFilter(stfu) @@ -93,13 +88,14 @@ def quickrun(transformation, work_dir, output): write("Planning simulations for this edge...") dag = trans.create() write("Starting the simulations for this edge...") - dagresult = execute_DAG(dag, - shared_basedir=work_dir, - scratch_basedir=work_dir, - keep_shared=True, - raise_error=False, - n_retries=2, - ) + dagresult = execute_DAG( + dag, + shared_basedir=work_dir, + scratch_basedir=work_dir, + keep_shared=True, + raise_error=False, + n_retries=2, + ) write("Done with all simulations! Analyzing the results....") prot_result = trans.protocol.gather([dagresult]) @@ -110,19 +106,16 @@ def quickrun(transformation, work_dir, output): estimate = uncertainty = None # for output file out_dict = { - 'estimate': estimate, - 'uncertainty': uncertainty, - 'protocol_result': prot_result.to_dict(), - 'unit_results': { - unit.key: unit.to_keyed_dict() - for unit in dagresult.protocol_unit_results - } + "estimate": estimate, + "uncertainty": uncertainty, + "protocol_result": prot_result.to_dict(), + "unit_results": {unit.key: unit.to_keyed_dict() for unit in dagresult.protocol_unit_results}, } if output is None: - output = work_dir / (str(trans.key) + '_results.json') + output = work_dir / (str(trans.key) + "_results.json") - with open(output, mode='w') as outf: + with open(output, mode="w") as outf: json.dump(out_dict, outf, cls=JSON_HANDLER.encoder) write(f"Here is the result:\n\tdG = {estimate} ± {uncertainty}\n") @@ -140,15 +133,11 @@ def quickrun(transformation, work_dir, output): raise click.ClickException( f"The protocol unit '{failure.name}' failed with the error " f"message:\n{_format_exception(failure.exception)}\n\n" - "Details provided in output." + "Details provided in output.", ) -PLUGIN = OFECommandPlugin( - command=quickrun, - section="Quickrun Executor", - requires_ofe=(0, 3) -) +PLUGIN = OFECommandPlugin(command=quickrun, section="Quickrun Executor", requires_ofe=(0, 3)) if __name__ == "__main__": quickrun() diff --git a/openfecli/commands/test.py b/openfecli/commands/test.py index 3a536dc3c..fe896e339 100644 --- a/openfecli/commands/test.py +++ b/openfecli/commands/test.py @@ -1,17 +1,14 @@ -import click -from openfecli import OFECommandPlugin +import os +import click import pytest -import os +from openfecli import OFECommandPlugin from openfecli.utils import write -@click.command( - "test", - short_help="Run the OpenFE test suite" -) -@click.option('--long', is_flag=True, default=False, - help="Run additional tests (takes much longer)") + +@click.command("test", short_help="Run the OpenFE test suite") +@click.option("--long", is_flag=True, default=False, help="Run additional tests (takes much longer)") def test(long): """ Run the OpenFE test suite. This first checks that OpenFE is correctly @@ -29,14 +26,12 @@ def test(long): write("Testing can import....") import openfe + write("Running the main package tests") pytest.main(["-v", "--pyargs", "openfe", "--pyargs", "openfecli"]) os.environ.clear() os.environ.update(old_env) -PLUGIN = OFECommandPlugin( - test, - "Miscellaneous", - requires_ofe=(0, 7,5) -) + +PLUGIN = OFECommandPlugin(test, "Miscellaneous", requires_ofe=(0, 7, 5)) diff --git a/openfecli/commands/view_ligand_network.py b/openfecli/commands/view_ligand_network.py index ce27f13d1..743426e56 100644 --- a/openfecli/commands/view_ligand_network.py +++ b/openfecli/commands/view_ligand_network.py @@ -1,22 +1,19 @@ import click + from openfecli import OFECommandPlugin -@click.command( - "view-ligand-network", - short_help="Visualize a ligand network" -) + +@click.command("view-ligand-network", short_help="Visualize a ligand network") @click.argument( "ligand-network", - type=click.Path(exists=True, readable=True, dir_okay=False, - file_okay=True), + type=click.Path(exists=True, readable=True, dir_okay=False, file_okay=True), ) def view_ligand_network(ligand_network): - from openfe.utils.atommapping_network_plotting import ( - plot_atommapping_network - ) - from openfe.setup import LigandNetwork import matplotlib + from openfe.setup import LigandNetwork + from openfe.utils.atommapping_network_plotting import plot_atommapping_network + matplotlib.use("TkAgg") with open(ligand_network) as f: graphml = f.read() diff --git a/openfecli/fetchables.py b/openfecli/fetchables.py index 5b695182a..b4b505875 100644 --- a/openfecli/fetchables.py +++ b/openfecli/fetchables.py @@ -1,9 +1,8 @@ """Plugins for the ``fetch`` command""" -from openfecli.fetching import URLFetcher, PkgResourceFetcher +from openfecli.fetching import PkgResourceFetcher, URLFetcher -_EXAMPLE_NB_BASE = ("https://raw.githubusercontent.com/" - "OpenFreeEnergy/ExampleNotebooks/main/") +_EXAMPLE_NB_BASE = "https://raw.githubusercontent.com/" "OpenFreeEnergy/ExampleNotebooks/main/" RBFE_SHOWCASE = URLFetcher( resources=[ diff --git a/openfecli/fetching.py b/openfecli/fetching.py index 678046a03..637f904ee 100644 --- a/openfecli/fetching.py +++ b/openfecli/fetching.py @@ -1,12 +1,13 @@ +import importlib.resources +import pathlib +import shutil +import urllib.request + import click from plugcli.plugin_management import CommandPlugin -import urllib.request -import importlib.resources -import shutil from .utils import write -import pathlib class _Fetcher: """Base class for fetchers. Defines the API and plugin creation. @@ -24,16 +25,10 @@ class _Fetcher: requires_ofe: Tuple minimum version of OpenFE required """ + REQUIRES_INTERNET = None - def __init__( - self, - resources, - short_name, - short_help, - requires_ofe, - section=None, - long_help=None - ): + + def __init__(self, resources, short_name, short_help, requires_ofe, section=None, long_help=None): self._resources = resources self.short_name = short_name self.short_help = short_help @@ -73,7 +68,9 @@ def plugin(self): help=docs, ) @click.option( - '-d', '--directory', default='.', + "-d", + "--directory", + default=".", help="output directory, defaults to current directory", type=click.Path(file_okay=False, dir_okay=True, writable=True), ) @@ -82,12 +79,8 @@ def command(directory): directory.mkdir(parents=True, exist_ok=True) self(directory) - return FetchablePlugin( - command, - section=self.section, - requires_ofe=self.requires_ofe, - fetcher=self - ) + return FetchablePlugin(command, section=self.section, requires_ofe=self.requires_ofe, fetcher=self) + class URLFetcher(_Fetcher): """Fetcher for URLs. @@ -95,11 +88,13 @@ class URLFetcher(_Fetcher): Resources should be (base, filename), e.g., ("https://google.com/", "index.html). """ + REQUIRES_INTERNET = True + def __call__(self, dest_dir): for base, filename in self.resources: # let's just prevent one footgun here - if not base.endswith('/'): + if not base.endswith("/"): base += "/" write(f"Fetching {base}{filename}") @@ -107,7 +102,7 @@ def __call__(self, dest_dir): with urllib.request.urlopen(base + filename) as resp: contents = resp.read() - with open(dest_dir / filename, mode='wb') as f: + with open(dest_dir / filename, mode="wb") as f: f.write(contents) @@ -117,7 +112,9 @@ class PkgResourceFetcher(_Fetcher): Resources should be (package, filename), e.g., ("openfecli", "__init__.py"). """ + REQUIRES_INTERNET = False + def __call__(self, dest_dir): for package, filename in self.resources: ref = importlib.resources.files(package) / filename @@ -148,15 +145,11 @@ class FetchablePlugin(CommandPlugin): This includes the fetcher to simplify testing and introspection. """ + def __init__(self, command, section, requires_ofe, fetcher): - super().__init__(command=command, - section=section, - requires_lib=requires_ofe, - requires_cli=requires_ofe) + super().__init__(command=command, section=section, requires_lib=requires_ofe, requires_cli=requires_ofe) self.fetcher = fetcher @property def filenames(self): return [res[1] for res in self.fetcher.resources] - - diff --git a/openfecli/parameters/__init__.py b/openfecli/parameters/__init__.py index f5a1cfa6f..3f9a30122 100644 --- a/openfecli/parameters/__init__.py +++ b/openfecli/parameters/__init__.py @@ -1,10 +1,10 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from .mol import MOL from .mapper import MAPPER +from .mol import MOL +from .molecules import COFACTORS, MOL_DIR from .output import OUTPUT_FILE_AND_EXT from .output_dir import OUTPUT_DIR -from .protein import PROTEIN -from .molecules import MOL_DIR, COFACTORS from .plan_network_options import YAML_OPTIONS +from .protein import PROTEIN diff --git a/openfecli/parameters/mapper.py b/openfecli/parameters/mapper.py index d38b2f588..b060d3e16 100644 --- a/openfecli/parameters/mapper.py +++ b/openfecli/parameters/mapper.py @@ -2,8 +2,10 @@ # For details, see https://github.com/OpenFreeEnergy/openfe from plugcli.params import MultiStrategyGetter, Option + from openfecli.parameters.utils import import_parameter + def _atommapper_from_openfe_setup(user_input, context): return import_parameter("openfe.setup." + user_input) @@ -17,14 +19,13 @@ def _atommapper_from_qualname(user_input, context): _atommapper_from_qualname, _atommapper_from_openfe_setup, ], - error_message=("Unable to create atom mapper from user input " - "'{user_input}'. Please check spelling and " - "capitalization.") + error_message=( + "Unable to create atom mapper from user input " "'{user_input}'. Please check spelling and " "capitalization." + ), ) MAPPER = Option( "--mapper", getter=get_atommapper, - help=("Atom mapper; can either be a name in the openfe.setup namespace " - "or a custom fully-qualified name.") + help=("Atom mapper; can either be a name in the openfe.setup namespace " "or a custom fully-qualified name."), ) diff --git a/openfecli/parameters/mol.py b/openfecli/parameters/mol.py index 465d9e46b..bf2bfc725 100644 --- a/openfecli/parameters/mol.py +++ b/openfecli/parameters/mol.py @@ -1,12 +1,14 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from plugcli.params import MultiStrategyGetter, Option, NOT_PARSED +from plugcli.params import NOT_PARSED, MultiStrategyGetter, Option def _load_molecule_from_smiles(user_input, context): - from openfe import SmallMoleculeComponent from rdkit import Chem + + from openfe import SmallMoleculeComponent + # MolFromSmiles returns None if the string is not a molecule # TODO: find some way to redirect the error messages? Messages stayed # after either redirect_stdout or redirect_stderr. @@ -21,10 +23,11 @@ def _load_molecule_from_smiles(user_input, context): def _load_molecule_from_sdf(user_input, context): - if '.sdf' not in str(user_input): # this silences some stderr spam + if ".sdf" not in str(user_input): # this silences some stderr spam return NOT_PARSED from openfe import SmallMoleculeComponent + try: return SmallMoleculeComponent.from_sdf_file(user_input) except ValueError: # any exception should try other strategies @@ -32,10 +35,11 @@ def _load_molecule_from_sdf(user_input, context): def _load_molecule_from_mol2(user_input, context): - if '.mol2' not in str(user_input): + if ".mol2" not in str(user_input): return NOT_PARSED from rdkit import Chem + from openfe import SmallMoleculeComponent m = Chem.MolFromMol2File(user_input) @@ -53,12 +57,12 @@ def _load_molecule_from_mol2(user_input, context): # failure will give meaningless user-facing errors _load_molecule_from_smiles, ], - error_message="Unable to generate a molecule from '{user_input}'." + error_message="Unable to generate a molecule from '{user_input}'.", ) MOL = Option( - "-m", "--mol", - help=("SmallMoleculeComponent. Can be provided as an SDF file or as a SMILES " - " string."), - getter=get_molecule + "-m", + "--mol", + help=("SmallMoleculeComponent. Can be provided as an SDF file or as a SMILES " " string."), + getter=get_molecule, ) diff --git a/openfecli/parameters/molecules.py b/openfecli/parameters/molecules.py index 5c9d55ecb..6509979c2 100644 --- a/openfecli/parameters/molecules.py +++ b/openfecli/parameters/molecules.py @@ -1,26 +1,33 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import click import glob import itertools import pathlib -from plugcli.params import MultiStrategyGetter, Option, NOT_PARSED +import click +from plugcli.params import NOT_PARSED, MultiStrategyGetter, Option + # MOVE TO GUFE #################################################### def _smcs_from_sdf(sdf): - from openfe import SmallMoleculeComponent from rdkit import Chem + + from openfe import SmallMoleculeComponent + supp = Chem.SDMolSupplier(str(sdf), removeHs=False) mols = [SmallMoleculeComponent(mol) for mol in supp] return mols + def _smcs_from_mol2(mol2): - from openfe import SmallMoleculeComponent from rdkit import Chem + + from openfe import SmallMoleculeComponent + rdmol = Chem.MolFromMol2File(str(mol2), removeHs=False) return [SmallMoleculeComponent.from_rdkit(rdmol)] + def load_molecules(file_or_directory): """ Load SmallMoleculeComponents in the given file or directory. @@ -40,8 +47,7 @@ def load_molecules(file_or_directory): inp = pathlib.Path(file_or_directory) # for shorter lines if inp.is_dir(): sdf_files = [f for f in inp.iterdir() if f.suffix.lower() == ".sdf"] - mol2_files = [f for f in inp.iterdir() - if f.suffix.lower() == ".mol2"] + mol2_files = [f for f in inp.iterdir() if f.suffix.lower() == ".mol2"] else: sdf_files = [inp] if inp.suffix.lower() == ".sdf" else [] mol2_files = [inp] if inp.suffix.lower() == ".mol2" else [] @@ -53,6 +59,8 @@ def load_molecules(file_or_directory): mol2_mols = sum([_smcs_from_mol2(mol2) for mol2 in mol2_files], []) return sdf_mols + mol2_mols + + # END MOVE TO GUFE ################################################ @@ -65,15 +73,15 @@ def molecule_getter(user_input, context): "--molecules", type=click.Path(exists=True), help=( - "A directory or file containing all molecules to be loaded, either" - " as a single SDF or multiple MOL2/SDFs." + "A directory or file containing all molecules to be loaded, either" " as a single SDF or multiple MOL2/SDFs." ), getter=molecule_getter, ) COFACTORS = Option( - "-C", "--cofactors", + "-C", + "--cofactors", type=click.Path(exists=True), help="Path to cofactors sdf file. This may contain multiple molecules", getter=molecule_getter, -) \ No newline at end of file +) diff --git a/openfecli/parameters/output.py b/openfecli/parameters/output.py index a82652883..8398d2f3b 100644 --- a/openfecli/parameters/output.py +++ b/openfecli/parameters/output.py @@ -1,11 +1,12 @@ -import click import pathlib -from plugcli.params import MultiStrategyGetter, Option, NOT_PARSED + +import click +from plugcli.params import NOT_PARSED, MultiStrategyGetter, Option def get_file_and_extension(user_input, context): file = user_input - ext = file.name.split('.')[-1] if file else None + ext = file.name.split(".")[-1] if file else None return file, ext @@ -16,8 +17,9 @@ def ensure_file_does_not_exist(ctx, param, value): OUTPUT_FILE_AND_EXT = Option( - "-o", "--output", + "-o", + "--output", help="output file", getter=get_file_and_extension, - type=click.File(mode='wb'), + type=click.File(mode="wb"), ) diff --git a/openfecli/parameters/output_dir.py b/openfecli/parameters/output_dir.py index 95c271aa6..e9a6b41e3 100644 --- a/openfecli/parameters/output_dir.py +++ b/openfecli/parameters/output_dir.py @@ -1,7 +1,8 @@ import os -import click import pathlib -from plugcli.params import MultiStrategyGetter, Option, NOT_PARSED + +import click +from plugcli.params import NOT_PARSED, MultiStrategyGetter, Option def get_dir(user_input, context): diff --git a/openfecli/parameters/plan_network_options.py b/openfecli/parameters/plan_network_options.py index c89ec9ec9..e17371e6e 100644 --- a/openfecli/parameters/plan_network_options.py +++ b/openfecli/parameters/plan_network_options.py @@ -3,28 +3,29 @@ """Pydantic models for the definition of advanced CLI options """ -import click from collections import namedtuple + +import click + try: # todo; once we're fully v2, we can use ConfigDict not nested class from pydantic.v1 import BaseModel # , ConfigDict except ImportError: from pydantic import BaseModel -from plugcli.params import Option -from typing import Any, Optional -import yaml + import warnings +from typing import Any, Optional +import yaml +from plugcli.params import Option -PlanNetworkOptions = namedtuple('PlanNetworkOptions', - ['mapper', 'scorer', - 'ligand_network_planner', 'solvent']) +PlanNetworkOptions = namedtuple("PlanNetworkOptions", ["mapper", "scorer", "ligand_network_planner", "solvent"]) class MapperSelection(BaseModel): # model_config = ConfigDict(extra='allow', str_to_lower=True) class Config: - extra = 'allow' + extra = "allow" anystr_lower = True method: Optional[str] = None @@ -34,7 +35,7 @@ class Config: class NetworkSelection(BaseModel): # model_config = ConfigDict(extra='allow', str_to_lower=True) class Config: - extra = 'allow' + extra = "allow" anystr_lower = True method: Optional[str] = None @@ -44,7 +45,7 @@ class Config: class CliYaml(BaseModel): # model_config = ConfigDict(extra='allow') class Config: - extra = 'allow' + extra = "allow" mapper: Optional[MapperSelection] = None network: Optional[NetworkSelection] = None @@ -72,7 +73,7 @@ def parse_yaml_planner_options(contents: str) -> CliYaml: if False: # todo: warnings about extra fields we don't expect? - expected = {'mapper', 'network'} + expected = {"mapper", "network"} for field in raw: if field in expected: continue @@ -98,24 +99,21 @@ def load_yaml_planner_options(path: Optional[str], context) -> PlanNetworkOption and 'solvent' fields. these fields each hold appropriate objects ready for use """ + from functools import partial + from gufe import SolventComponent + + from openfe.setup import KartografAtomMapper, LomapAtomMapper + from openfe.setup.atom_mapping.lomap_scorers import default_lomap_score from openfe.setup.ligand_network_planning import ( - generate_radial_network, - generate_minimal_spanning_network, generate_maximal_network, generate_minimal_redundant_network, + generate_minimal_spanning_network, + generate_radial_network, ) - from openfe.setup import ( - LomapAtomMapper, - KartografAtomMapper, - ) - from openfe.setup.atom_mapping.lomap_scorers import ( - default_lomap_score, - ) - from functools import partial if path is not None: - with open(path, 'r') as f: + with open(path) as f: raw = f.read() # convert raw yaml to normalised pydantic model @@ -126,10 +124,10 @@ def load_yaml_planner_options(path: Optional[str], context) -> PlanNetworkOption # convert normalised inputs to objects if opt and opt.mapper: mapper_choices = { - 'lomap': LomapAtomMapper, - 'lomapatommapper': LomapAtomMapper, - 'kartograf': KartografAtomMapper, - 'kartografatommapper': KartografAtomMapper, + "lomap": LomapAtomMapper, + "lomapatommapper": LomapAtomMapper, + "kartograf": KartografAtomMapper, + "kartografatommapper": KartografAtomMapper, } try: @@ -138,20 +136,19 @@ def load_yaml_planner_options(path: Optional[str], context) -> PlanNetworkOption raise KeyError(f"Bad mapper choice: '{opt.mapper.method}'") mapper_obj = cls(**opt.mapper.settings) else: - mapper_obj = LomapAtomMapper(time=20, threed=True, element_change=False, - max3d=1) + mapper_obj = LomapAtomMapper(time=20, threed=True, element_change=False, max3d=1) # todo: choice of scorer goes here mapping_scorer = default_lomap_score if opt and opt.network: network_choices = { - 'generate_radial_network': generate_radial_network, - 'radial': generate_radial_network, - 'generate_minimal_spanning_network': generate_minimal_spanning_network, - 'mst': generate_minimal_spanning_network, - 'generate_minimal_redundant_network': generate_minimal_redundant_network, - 'generate_maximal_network': generate_maximal_network, + "generate_radial_network": generate_radial_network, + "radial": generate_radial_network, + "generate_minimal_spanning_network": generate_minimal_spanning_network, + "mst": generate_minimal_spanning_network, + "generate_minimal_redundant_network": generate_minimal_redundant_network, + "generate_maximal_network": generate_maximal_network, } try: @@ -175,7 +172,9 @@ def load_yaml_planner_options(path: Optional[str], context) -> PlanNetworkOption YAML_OPTIONS = Option( - '-s', "--settings", "yaml_settings", + "-s", + "--settings", + "yaml_settings", type=click.Path(exists=True, dir_okay=False), help="Path to planning settings yaml file.", getter=load_yaml_planner_options, diff --git a/openfecli/parameters/protein.py b/openfecli/parameters/protein.py index 76625161a..599ba4397 100644 --- a/openfecli/parameters/protein.py +++ b/openfecli/parameters/protein.py @@ -1,7 +1,8 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from plugcli.params import MultiStrategyGetter, Option, NOT_PARSED +from plugcli.params import NOT_PARSED, MultiStrategyGetter, Option + def _load_protein_from_pdb(user_input, context): if ".pdb" not in str(user_input): # this silences some stderr spam @@ -38,9 +39,6 @@ def _load_protein_from_pdbx(user_input, context): PROTEIN = Option( "-p", "--protein", - help=( - "ProteinComponent. Can be provided as an PDB or as a PDBx/mmCIF file. " - " string." - ), + help=("ProteinComponent. Can be provided as an PDB or as a PDBx/mmCIF file. " " string."), getter=get_molecule, ) diff --git a/openfecli/parameters/utils.py b/openfecli/parameters/utils.py index 8befac7ef..f7b8bc201 100644 --- a/openfecli/parameters/utils.py +++ b/openfecli/parameters/utils.py @@ -2,6 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/openfe from plugcli.params import NOT_PARSED + from openfecli.utils import import_thing diff --git a/openfecli/plan_alchemical_networks_utils.py b/openfecli/plan_alchemical_networks_utils.py index fd15bd754..d8eaac317 100644 --- a/openfecli/plan_alchemical_networks_utils.py +++ b/openfecli/plan_alchemical_networks_utils.py @@ -1,22 +1,23 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe from __future__ import annotations -import os -from collections import defaultdict + import json +import os import pathlib -from openfecli.utils import write import typing +from collections import defaultdict + from openfe import AlchemicalNetwork, LigandNetwork +from openfecli.utils import write def plan_alchemical_network_output( alchemical_network: AlchemicalNetwork, ligand_network: LigandNetwork, - folder_path: pathlib.Path + folder_path: pathlib.Path, ): - """Write the contents of an alchemical network into the structure - """ + """Write the contents of an alchemical network into the structure""" import gufe from gufe import tokenization @@ -30,7 +31,7 @@ def plan_alchemical_network_output( write("\t\t- " + base_name + ".json") ln_fname = "ligand_network.graphml" - with open(folder_path / ln_fname, mode='w') as f: + with open(folder_path / ln_fname, mode="w") as f: f.write(ligand_network.to_graphml()) write(f"\t\t- {ln_fname}") @@ -42,5 +43,3 @@ def plan_alchemical_network_output( filename = f"{transformation_name}.json" transformation.dump(transformations_dir / filename) write("\t\t\t\t- " + filename) - - diff --git a/openfecli/plugins.py b/openfecli/plugins.py index 16aa17d43..34bd55562 100644 --- a/openfecli/plugins.py +++ b/openfecli/plugins.py @@ -6,7 +6,4 @@ class OFECommandPlugin(CommandPlugin): def __init__(self, command, section, requires_ofe): - super().__init__(command=command, - section=section, - requires_lib=requires_ofe, - requires_cli=requires_ofe) + super().__init__(command=command, section=section, requires_lib=requires_ofe, requires_cli=requires_ofe) diff --git a/openfecli/tests/clicktypes/test_hyphenchoice.py b/openfecli/tests/clicktypes/test_hyphenchoice.py index c7c0bb5a2..dbd2ca53f 100644 --- a/openfecli/tests/clicktypes/test_hyphenchoice.py +++ b/openfecli/tests/clicktypes/test_hyphenchoice.py @@ -1,19 +1,17 @@ -import pytest import click +import pytest + from openfecli.clicktypes.hyphenchoice import HyphenAwareChoice + class TestHyphenAwareChoice: - @pytest.mark.parametrize('value', [ - "foo_bar_baz", "foo_bar-baz", "foo-bar_baz", "foo-bar-baz" - ]) + @pytest.mark.parametrize("value", ["foo_bar_baz", "foo_bar-baz", "foo-bar_baz", "foo-bar-baz"]) def test_init(self, value): ch = HyphenAwareChoice([value]) assert ch.choices == ["foo-bar-baz"] - @pytest.mark.parametrize('value', [ - "foo_bar_baz", "foo_bar-baz", "foo-bar_baz", "foo-bar-baz" - ]) + @pytest.mark.parametrize("value", ["foo_bar_baz", "foo_bar-baz", "foo-bar_baz", "foo-bar-baz"]) def test_convert(self, value): - ch = HyphenAwareChoice(['foo-bar-baz']) + ch = HyphenAwareChoice(["foo-bar-baz"]) # counting on __call__ to get to convert() assert ch(value) == "foo-bar-baz" diff --git a/openfecli/tests/commands/test_atommapping.py b/openfecli/tests/commands/test_atommapping.py index 0251113ae..d8c1d111a 100644 --- a/openfecli/tests/commands/test_atommapping.py +++ b/openfecli/tests/commands/test_atommapping.py @@ -1,16 +1,17 @@ from unittest import mock -import pytest import click +import pytest from click.testing import CliRunner from openfe.setup import LigandAtomMapping, LomapAtomMapper - -from openfecli.parameters import MOL from openfecli.commands.atommapping import ( - atommapping, generate_mapping, atommapping_print_dict_main, - atommapping_visualize_main + atommapping, + atommapping_print_dict_main, + atommapping_visualize_main, + generate_mapping, ) +from openfecli.parameters import MOL @pytest.fixture @@ -45,15 +46,14 @@ def print_test_with_file(mapper, molA, molB, file, ext): print(ext) -@pytest.mark.parametrize('with_file', [True, False]) +@pytest.mark.parametrize("with_file", [True, False]) def test_atommapping(molA_args, molB_args, mapper_args, with_file): # Patch out the main function with a simple function to output # information about the objects we pass to the main; test the output of # that using tools from click. This tests the creation of objects from # user input on the command line. args = molA_args + molB_args + mapper_args - expected_output = (f"{molA_args[1]}\n{molB_args[1]}\n" - f"{mapper_args[1]}\n") + expected_output = f"{molA_args[1]}\n{molB_args[1]}\n" f"{mapper_args[1]}\n" patch_base = "openfecli.commands.atommapping." if with_file: args += ["-o", "myfile.png"] @@ -88,16 +88,17 @@ def test_atommapping_missing_mapper(molA_args, molB_args): assert "Missing option '--mapper'" in result.output -@pytest.mark.parametrize('n_mappings', [0, 1, 2]) +@pytest.mark.parametrize("n_mappings", [0, 1, 2]) def test_generate_mapping(n_mappings, mols): - molA, molB, = mols + ( + molA, + molB, + ) = mols mappings = [ LigandAtomMapping(molA, molB, {i: i for i in range(7)}), LigandAtomMapping(molA, molB, {i: (i + 1) % 7 for i in range(7)}), ] - mapper = mock.Mock( - suggest_mappings=mock.Mock(return_value=mappings[:n_mappings]) - ) + mapper = mock.Mock(suggest_mappings=mock.Mock(return_value=mappings[:n_mappings])) if n_mappings == 1: assert generate_mapping(mapper, molA, molB) == mappings[0] @@ -110,8 +111,7 @@ def test_atommapping_print_dict_main(capsys, mols): molA, molB = mols mapper = LomapAtomMapper mapping = LigandAtomMapping(molA, molB, {i: i for i in range(7)}) - with mock.patch('openfecli.commands.atommapping.generate_mapping', - mock.Mock(return_value=mapping)): + with mock.patch("openfecli.commands.atommapping.generate_mapping", mock.Mock(return_value=mapping)): atommapping_print_dict_main(mapper, molA, molB) captured = capsys.readouterr() assert captured.out == str(mapping.componentA_to_componentB) + "\n" @@ -127,9 +127,7 @@ def test_atommapping_visualize_main_bad_extension(mols, tmpdir): molA, molB = mols mapper = LomapAtomMapper mapping = LigandAtomMapping(molA, molB, {i: i for i in range(7)}) - with mock.patch('openfecli.commands.atommapping.generate_mapping', - mock.Mock(return_value=mapping)): - with open(tmpdir / "foo.bar", mode='w') as f: - with pytest.raises(click.BadParameter, - match="Unknown file format"): + with mock.patch("openfecli.commands.atommapping.generate_mapping", mock.Mock(return_value=mapping)): + with open(tmpdir / "foo.bar", mode="w") as f: + with pytest.raises(click.BadParameter, match="Unknown file format"): atommapping_visualize_main(mapper, molA, molB, f, "bar") diff --git a/openfecli/tests/commands/test_gather.py b/openfecli/tests/commands/test_gather.py index 543e5a31a..48cae79ee 100644 --- a/openfecli/tests/commands/test_gather.py +++ b/openfecli/tests/commands/test_gather.py @@ -1,28 +1,42 @@ -from click.testing import CliRunner -from importlib import resources -import tarfile import os import pathlib -import pytest +import tarfile +from importlib import resources + import pooch +import pytest +from click.testing import CliRunner -from openfecli.commands.gather import ( - gather, format_estimate_uncertainty, _get_column, - _generate_bad_legs_error_message, -) +from openfecli.commands.gather import _generate_bad_legs_error_message, _get_column, format_estimate_uncertainty, gather -@pytest.mark.parametrize('est,unc,unc_prec,est_str,unc_str', [ - (12.432, 0.111, 2, "12.43", "0.11"), - (0.9999, 0.01, 2, "1.000", "0.010"), - (1234, 100, 2, "1230", "100"), -]) + +@pytest.mark.parametrize( + "est,unc,unc_prec,est_str,unc_str", + [ + (12.432, 0.111, 2, "12.43", "0.11"), + (0.9999, 0.01, 2, "1.000", "0.010"), + (1234, 100, 2, "1230", "100"), + ], +) def test_format_estimate_uncertainty(est, unc, unc_prec, est_str, unc_str): assert format_estimate_uncertainty(est, unc, unc_prec) == (est_str, unc_str) -@pytest.mark.parametrize('val, col', [ - (1.0, 1), (0.1, -1), (-0.0, 0), (0.0, 0), (0.2, -1), (0.9, -1), - (0.011, -2), (9, 1), (10, 2), (15, 2), -]) + +@pytest.mark.parametrize( + "val, col", + [ + (1.0, 1), + (0.1, -1), + (-0.0, 0), + (0.0, 0), + (0.2, -1), + (0.9, -1), + (0.011, -2), + (9, 1), + (10, 2), + (15, 2), + ], +) def test_get_column(val, col): assert _get_column(val) == col @@ -30,12 +44,13 @@ def test_get_column(val, col): @pytest.fixture def results_dir(tmpdir): with tmpdir.as_cwd(): - with resources.files('openfecli.tests.data') as d: - t = tarfile.open(d / 'rbfe_results.tar.gz', mode='r') - t.extractall('.') + with resources.files("openfecli.tests.data") as d: + t = tarfile.open(d / "rbfe_results.tar.gz", mode="r") + t.extractall(".") yield + _EXPECTED_DG = b""" ligand DG(MLE) (kcal/mol) uncertainty (kcal/mol) lig_ejm_31 -0.09 0.05 @@ -146,7 +161,7 @@ def results_dir(tmpdir): @pytest.mark.xfail -@pytest.mark.parametrize('report', ["", "dg", "ddg"]) +@pytest.mark.parametrize("report", ["", "dg", "ddg"]) def test_gather(results_dir, report): expected = { "": _EXPECTED_DG, @@ -161,25 +176,24 @@ def test_gather(results_dir, report): else: args = [] - result = runner.invoke(gather, ['results'] + args + ['-o', '-']) + result = runner.invoke(gather, ["results"] + args + ["-o", "-"]) assert result.exit_code == 0 - actual_lines = set(result.stdout_bytes.split(b'\n')) + actual_lines = set(result.stdout_bytes.split(b"\n")) - assert set(expected.split(b'\n')) == actual_lines + assert set(expected.split(b"\n")) == actual_lines -@pytest.mark.parametrize('include', ['complex', 'solvent', 'vacuum']) +@pytest.mark.parametrize("include", ["complex", "solvent", "vacuum"]) def test_generate_bad_legs_error_message(include): expected = { - 'complex': ("appears to be an RBFE", "missing {'solvent'}"), - 'vacuum': ("appears to be an RHFE", "missing {'solvent'}"), - 'solvent': ("whether this is an RBFE or an RHFE", - "'complex'", "'solvent'"), + "complex": ("appears to be an RBFE", "missing {'solvent'}"), + "vacuum": ("appears to be an RHFE", "missing {'solvent'}"), + "solvent": ("whether this is an RBFE or an RHFE", "'complex'", "'solvent'"), }[include] set_vals = {include} - ligpair = {'lig1', 'lig2'} + ligpair = {"lig1", "lig2"} msg = _generate_bad_legs_error_message(set_vals, ligpair) for string in expected: assert string in msg @@ -191,7 +205,7 @@ def test_missing_leg_error(results_dir): (pathlib.Path("results") / file_to_remove).unlink() runner = CliRunner() - result = runner.invoke(gather, ['results'] + ['-o', '-']) + result = runner.invoke(gather, ["results"] + ["-o", "-"]) assert result.exit_code == 1 assert isinstance(result.exception, RuntimeError) assert "Unable to determine" in str(result.exception) @@ -205,13 +219,12 @@ def test_missing_leg_allow_partial(results_dir): (pathlib.Path("results") / file_to_remove).unlink() runner = CliRunner() - result = runner.invoke(gather, - ['results'] + ['--allow-partial', '-o', '-']) + result = runner.invoke(gather, ["results"] + ["--allow-partial", "-o", "-"]) assert result.exit_code == 0 RBFE_RESULTS = pooch.create( - pooch.os_cache('openfe'), + pooch.os_cache("openfe"), base_url="doi:10.6084/m9.figshare.25148945", registry={"results.tar.gz": "bf27e728935b31360f95188f41807558156861f6d89b8a47854502a499481da3"}, ) @@ -221,9 +234,9 @@ def test_missing_leg_allow_partial(results_dir): def rbfe_results(): # fetches rbfe results from online # untars into local directory and returns path to this - d = RBFE_RESULTS.fetch('results.tar.gz', processor=pooch.Untar()) + d = RBFE_RESULTS.fetch("results.tar.gz", processor=pooch.Untar()) - return os.path.join(pooch.os_cache('openfe'), 'results.tar.gz.untar', 'results') + return os.path.join(pooch.os_cache("openfe"), "results.tar.gz.untar", "results") @pytest.mark.download @@ -231,7 +244,7 @@ def rbfe_results(): def test_rbfe_results(rbfe_results): runner = CliRunner() - result = runner.invoke(gather, ['--report', 'raw', rbfe_results]) + result = runner.invoke(gather, ["--report", "raw", rbfe_results]) assert result.exit_code == 0 assert result.stdout_bytes == _EXPECTED_RAW diff --git a/openfecli/tests/commands/test_ligand_network_viewer.py b/openfecli/tests/commands/test_ligand_network_viewer.py index 525447122..e76b92c32 100644 --- a/openfecli/tests/commands/test_ligand_network_viewer.py +++ b/openfecli/tests/commands/test_ligand_network_viewer.py @@ -1,15 +1,17 @@ -import pytest -from unittest import mock -from click.testing import CliRunner import importlib.resources +from unittest import mock + import matplotlib +import pytest +from click.testing import CliRunner from openfecli.commands.view_ligand_network import view_ligand_network + @pytest.mark.filterwarnings("ignore:.*non-GUI backend") def test_view_ligand_network(): # smoke test - resource = importlib.resources.files('openfe.tests.data.serialization') + resource = importlib.resources.files("openfe.tests.data.serialization") ref = resource / "network_template.graphml" runner = CliRunner() diff --git a/openfecli/tests/commands/test_plan_rbfe_network.py b/openfecli/tests/commands/test_plan_rbfe_network.py index 0def8bca9..e5c501643 100644 --- a/openfecli/tests/commands/test_plan_rbfe_network.py +++ b/openfecli/tests/commands/test_plan_rbfe_network.py @@ -1,23 +1,20 @@ +import os +import shutil +from importlib import resources from unittest import mock import pytest -from importlib import resources -import os -import shutil from click.testing import CliRunner -from openfecli.commands.plan_rbfe_network import ( - plan_rbfe_network, - plan_rbfe_network_main, -) +from openfecli.commands.plan_rbfe_network import plan_rbfe_network, plan_rbfe_network_main -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def mol_dir_args(tmpdir_factory): - ofe_dir_path = tmpdir_factory.mktemp('moldir') + ofe_dir_path = tmpdir_factory.mktemp("moldir") - with resources.files('openfe.tests.data.openmm_rfe') as d: - for f in ['ligand_23.sdf', 'ligand_55.sdf']: + with resources.files("openfe.tests.data.openmm_rfe") as d: + for f in ["ligand_23.sdf", "ligand_55.sdf"]: shutil.copyfile(d / f, ofe_dir_path / f) return ["--molecules", ofe_dir_path] @@ -44,26 +41,14 @@ def print_test_with_file( def test_plan_rbfe_network_main(): - from gufe import ( - ProteinComponent, - SmallMoleculeComponent, - SolventComponent, - ) - from openfe.setup import ( - LomapAtomMapper, - lomap_scorers, - ligand_network_planning, - ) + from gufe import ProteinComponent, SmallMoleculeComponent, SolventComponent + + from openfe.setup import LomapAtomMapper, ligand_network_planning, lomap_scorers with resources.files("openfe.tests.data.openmm_rfe") as d: - smallM_components = [ - SmallMoleculeComponent.from_sdf_file(d / f) - for f in ['ligand_23.sdf', 'ligand_55.sdf'] - ] + smallM_components = [SmallMoleculeComponent.from_sdf_file(d / f) for f in ["ligand_23.sdf", "ligand_55.sdf"]] with resources.files("openfe.tests.data") as d: - protein_compontent = ProteinComponent.from_pdb_file( - str(d / "181l_only.pdb") - ) + protein_compontent = ProteinComponent.from_pdb_file(str(d / "181l_only.pdb")) solvent_component = SolventComponent() alchemical_network, ligand_network = plan_rbfe_network_main( @@ -101,9 +86,7 @@ def test_plan_rbfe_network(mol_dir_args, protein_args): "- easy_rbfe_ligand_55_solvent_ligand_23_solvent.json", ] - patch_base = ( - "openfecli.commands.plan_rbfe_network." - ) + patch_base = "openfecli.commands.plan_rbfe_network." args += ["-o", "tmp_network"] patch_loc = patch_base + "plan_rbfe_network" @@ -124,10 +107,10 @@ def test_plan_rbfe_network(mol_dir_args, protein_args): @pytest.fixture def eg5_files(): - with resources.files('openfe.tests.data.eg5') as p: - pdb_path = str(p.joinpath('eg5_protein.pdb')) - lig_path = str(p.joinpath('eg5_ligands.sdf')) - cof_path = str(p.joinpath('eg5_cofactor.sdf')) + with resources.files("openfe.tests.data.eg5") as p: + pdb_path = str(p.joinpath("eg5_protein.pdb")) + lig_path = str(p.joinpath("eg5_ligands.sdf")) + cof_path = str(p.joinpath("eg5_cofactor.sdf")) yield pdb_path, lig_path, cof_path @@ -137,9 +120,12 @@ def test_plan_rbfe_network_cofactors(eg5_files): runner = CliRunner() args = [ - '-p', eg5_files[0], - '-M', eg5_files[1], - '-C', eg5_files[2], + "-p", + eg5_files[0], + "-M", + eg5_files[1], + "-C", + eg5_files[2], ] with runner.isolated_filesystem(): @@ -175,10 +161,14 @@ def test_custom_yaml_plan_rbfe_smoke_test(custom_yaml_settings, eg5_files, tmpdi assert settings_path.exists() args = [ - '-p', protein, - '-M', ligand, - '-C', cofactor, - '-s', settings_path, + "-p", + protein, + "-M", + ligand, + "-C", + cofactor, + "-s", + settings_path, ] runner = CliRunner() diff --git a/openfecli/tests/commands/test_plan_rhfe_network.py b/openfecli/tests/commands/test_plan_rhfe_network.py index f5a72431f..8eca785e7 100644 --- a/openfecli/tests/commands/test_plan_rhfe_network.py +++ b/openfecli/tests/commands/test_plan_rhfe_network.py @@ -1,31 +1,26 @@ +import os +import shutil +from importlib import resources from unittest import mock import pytest -from importlib import resources -import os -import shutil from click.testing import CliRunner -from openfecli.commands.plan_rhfe_network import ( - plan_rhfe_network, - plan_rhfe_network_main, -) +from openfecli.commands.plan_rhfe_network import plan_rhfe_network, plan_rhfe_network_main -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def mol_dir_args(tmpdir_factory): - ofe_dir_path = tmpdir_factory.mktemp('moldir') + ofe_dir_path = tmpdir_factory.mktemp("moldir") - with resources.files('openfe.tests.data.openmm_rfe') as d: - for f in ['ligand_23.sdf', 'ligand_55.sdf']: + with resources.files("openfe.tests.data.openmm_rfe") as d: + for f in ["ligand_23.sdf", "ligand_55.sdf"]: shutil.copyfile(d / f, ofe_dir_path / f) return ["--molecules", ofe_dir_path] -def print_test_with_file( - mapping_scorer, ligand_network_planner, small_molecules, solvent -): +def print_test_with_file(mapping_scorer, ligand_network_planner, small_molecules, solvent): print(mapping_scorer) print(ligand_network_planner) print(small_molecules) @@ -34,18 +29,13 @@ def print_test_with_file( def test_plan_rhfe_network_main(): import os + from gufe import SmallMoleculeComponent, SolventComponent - from openfe.setup import ( - LomapAtomMapper, - lomap_scorers, - ligand_network_planning, - ) + + from openfe.setup import LomapAtomMapper, ligand_network_planning, lomap_scorers with resources.files("openfe.tests.data.openmm_rfe") as d: - smallM_components = [ - SmallMoleculeComponent.from_sdf_file(d / f) - for f in ['ligand_23.sdf', 'ligand_55.sdf'] - ] + smallM_components = [SmallMoleculeComponent.from_sdf_file(d / f) for f in ["ligand_23.sdf", "ligand_55.sdf"]] solvent_component = SolventComponent() alchemical_network, ligand_network = plan_rhfe_network_main( @@ -82,9 +72,7 @@ def test_plan_rhfe_network(mol_dir_args): "- easy_rhfe_ligand_55_solvent_ligand_23_solvent.json", ] - patch_base = ( - "openfecli.commands.plan_rhfe_network." - ) + patch_base = "openfecli.commands.plan_rhfe_network." args += ["-o", "tmp_network"] patch_loc = patch_base + "plan_rhfe_network" @@ -126,7 +114,7 @@ def test_custom_yaml_plan_rhfe_smoke_test(custom_yaml_settings, mol_dir_args, tm assert settings_path.exists() - args = mol_dir_args + ['-s', settings_path] + args = mol_dir_args + ["-s", settings_path] runner = CliRunner() diff --git a/openfecli/tests/commands/test_quickrun.py b/openfecli/tests/commands/test_quickrun.py index c09ca7d76..71ffdd158 100644 --- a/openfecli/tests/commands/test_quickrun.py +++ b/openfecli/tests/commands/test_quickrun.py @@ -1,26 +1,24 @@ -import pytest -import click -from importlib import resources -import pathlib import json +import pathlib +from importlib import resources + +import click +import pytest from click.testing import CliRunner +from gufe.tokenization import JSON_HANDLER from openfecli.commands.quickrun import quickrun -from gufe.tokenization import JSON_HANDLER @pytest.fixture def json_file(): - with resources.files('openfecli.tests.data') as d: - json_file = str(d / 'transformation.json') + with resources.files("openfecli.tests.data") as d: + json_file = str(d / "transformation.json") return json_file -@pytest.mark.parametrize('extra_args', [ - {}, - {'-d': 'foo_dir', '-o': 'foo.json'} -]) +@pytest.mark.parametrize("extra_args", [{}, {"-d": "foo_dir", "-o": "foo.json"}]) def test_quickrun(extra_args, json_file): extras = sum([list(kv) for kv in extra_args.items()], []) @@ -31,13 +29,12 @@ def test_quickrun(extra_args, json_file): assert "Here is the result" in result.output assert "Additional information" in result.output - if outfile := extra_args.get('-o'): + if outfile := extra_args.get("-o"): assert pathlib.Path(outfile).exists() - with open(outfile, mode='r') as outf: + with open(outfile) as outf: dct = json.load(outf, cls=JSON_HANDLER.decoder) - assert set(dct) == {'estimate', 'uncertainty', - 'protocol_result', 'unit_results'} + assert set(dct) == {"estimate", "uncertainty", "protocol_result", "unit_results"} # TODO: need a protocol that drops files to actually do this! # if directory := extra_args.get('-d'): @@ -50,19 +47,19 @@ def test_quickrun(extra_args, json_file): def test_quickrun_output_file_exists(json_file): runner = CliRunner() with runner.isolated_filesystem(): - pathlib.Path('foo.json').touch() - result = runner.invoke(quickrun, [json_file, '-o', 'foo.json']) + pathlib.Path("foo.json").touch() + result = runner.invoke(quickrun, [json_file, "-o", "foo.json"]) assert result.exit_code == 2 # usage error assert "File 'foo.json' already exists." in result.output def test_quickrun_unit_error(): - with resources.files('openfecli.tests.data') as d: - json_file = str(d / 'bad_transformation.json') + with resources.files("openfecli.tests.data") as d: + json_file = str(d / "bad_transformation.json") runner = CliRunner() with runner.isolated_filesystem(): - result = runner.invoke(quickrun, [json_file, '-o', 'foo.json']) + result = runner.invoke(quickrun, [json_file, "-o", "foo.json"]) assert result.exit_code == 1 assert pathlib.Path("foo.json").exists() # TODO: I'm still not happy with this... failure result does not see diff --git a/openfecli/tests/commands/test_test.py b/openfecli/tests/commands/test_test.py index d10be06ef..e799537b3 100644 --- a/openfecli/tests/commands/test_test.py +++ b/openfecli/tests/commands/test_test.py @@ -1,17 +1,20 @@ -import pytest +import os from unittest import mock + +import pytest from click.testing import CliRunner -import os from openfecli.commands.test import test + def mock_func(args): print(os.environ.get("OFE_SLOW_TESTS")) -@pytest.mark.parametrize('slow', [True, False]) + +@pytest.mark.parametrize("slow", [True, False]) def test_test(slow): runner = CliRunner() - args = ['--long'] if slow else [] + args = ["--long"] if slow else [] patchloc = "openfecli.commands.test.pytest.main" ofe_slow_tests = os.environ.get("OFE_SLOW_TESTS") with mock.patch(patchloc, mock_func): diff --git a/openfecli/tests/conftest.py b/openfecli/tests/conftest.py index a1e10dfd9..d5ea7119d 100644 --- a/openfecli/tests/conftest.py +++ b/openfecli/tests/conftest.py @@ -1,7 +1,7 @@ import urllib.request try: - urllib.request.urlopen('https://www.google.com') + urllib.request.urlopen("https://www.google.com") except: # -no-cov- HAS_INTERNET = False else: diff --git a/openfecli/tests/data/bad_transformation.json b/openfecli/tests/data/bad_transformation.json index fe0bc82f0..ef3cb9569 100644 --- a/openfecli/tests/data/bad_transformation.json +++ b/openfecli/tests/data/bad_transformation.json @@ -1 +1 @@ -{":version:": 1, "__module__": "gufe.transformations.transformation", "__qualname__": "Transformation", "mapping": {":version:": 1, "__module__": "gufe.mapping.ligandatommapping", "__qualname__": "LigandAtomMapping", "annotations": "{}", "componentA": {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, true, 0, 0, {}], [6, 0, 0, true, 0, 0, {}], [6, 0, 0, true, 0, 0, {}], [6, 0, 0, true, 0, 0, {}], [6, 0, 0, true, 0, 0, {}], [6, 0, 0, true, 0, 0, {}], [1, 0, 0, false, 0, 0, {}], [1, 0, 0, false, 0, 0, {}], [1, 0, 0, false, 0, 0, {}], [1, 0, 0, false, 0, 0, {}], [1, 0, 0, false, 0, 0, {}], [1, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 12, 0, {}], [0, 5, 12, 0, {}], [0, 6, 1, 0, {}], [1, 2, 12, 0, {}], [1, 7, 1, 0, {}], [2, 3, 12, 0, {}], [2, 8, 1, 0, {}], [3, 4, 12, 0, {}], [3, 9, 1, 0, {}], [4, 5, 12, 0, {}], [4, 10, 1, 0, {}], [5, 11, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '\u00e8<@[\u00b1\u00bf\u00ec\u009e|!@\u00b0rh\u0091\u00ed|\u0014@\u00c3d\u00aa`T2<@\u009a\b\u001b\u009e^I @\u00caT\u00c1\u00a8\u00a4\u008e\u001a@?\u00c6\u00dc\u00b5\u0084\u00fc;@\u00da\u001b|a2\u00d5 @$(~\u008c\u00b9k\u0016@n\u00a3\u0001\u00bc\u0005B;@\u00c0\u00ec\u009e<,t\"@\u0084\u009e\u00cd\u00aa\u00cfU\u0016@\u00ee|?5^\u00fa9@\u0002+\u0087\u0016\u00d9N\u0015@\u0004V\u000e-\u00b2\u001d\u0013@\u0085\u00ebQ\u00b8\u001ee:@\u00b2\u009d\u00ef\u00a7\u00c6K\u0014@\u00cb\u00a1E\u00b6\u00f3\u00fd\u000b@\u00d7\u00a3p=\nW;@q=\n\u00d7\u00a3p\u0017@\u009e\u00ef\u00a7\u00c6K7\u0007@\u0083\u00c0\u00ca\u00a1E\u00d6;@\u00c9v\u00be\u009f\u001a\u00af\u001b@Zd;\u00dfO\u008d\f@\u00ecQ\u00b8\u001e\u0085k;@b\u0010X9\u00b4\u00c8\u001c@\u0006\u0081\u0095C\u008bl\u0013@sh\u0091\u00ed|\u007f:@j\u00bct\u0093\u0018\u0084\u0019@\u00c7K7\u0089A\u00e0\u0015@\u00ed\u009e<,\u00d4:9@\u00e8<@[\u00b1\u00bf\u00ec\u009e|!@\u00b0rh\u0091\u00ed|\u0014@\u00c3d\u00aa`T2<@\u009a\b\u001b\u009e^I @\u00caT\u00c1\u00a8\u00a4\u008e\u001a@?\u00c6\u00dc\u00b5\u0084\u00fc;@\u00da\u001b|a2\u00d5 @$(~\u008c\u00b9k\u0016@n\u00a3\u0001\u00bc\u0005B;@\u00c0\u00ec\u009e<,t\"@\u0084\u009e\u00cd\u00aa\u00cfU\u0016@\u00ee|?5^\u00fa9@\u0002+\u0087\u0016\u00d9N\u0015@\u0004V\u000e-\u00b2\u001d\u0013@\u0085\u00ebQ\u00b8\u001ee:@\u00b2\u009d\u00ef\u00a7\u00c6K\u0014@\u00cb\u00a1E\u00b6\u00f3\u00fd\u000b@\u00d7\u00a3p=\nW;@q=\n\u00d7\u00a3p\u0017@\u009e\u00ef\u00a7\u00c6K7\u0007@\u0083\u00c0\u00ca\u00a1E\u00d6;@\u00c9v\u00be\u009f\u001a\u00af\u001b@Zd;\u00dfO\u008d\f@\u00ecQ\u00b8\u001e\u0085k;@b\u0010X9\u00b4\u00c8\u001c@\u0006\u0081\u0095C\u008bl\u0013@sh\u0091\u00ed|\u007f:@j\u00bct\u0093\u0018\u0084\u0019@\u00c7K7\u0089A\u00e0\u0015@\u00ed\u009e<,\u00d4:9@\u00e8<@[\u00b1\u00bf\u00ec\u009e|!@\u00b0rh\u0091\u00ed|\u0014@\u00c3d\u00aa`T2<@\u009a\b\u001b\u009e^I @\u00caT\u00c1\u00a8\u00a4\u008e\u001a@?\u00c6\u00dc\u00b5\u0084\u00fc;@\u00da\u001b|a2\u00d5 @$(~\u008c\u00b9k\u0016@n\u00a3\u0001\u00bc\u0005B;@\u00c0\u00ec\u009e<,t\"@\u0084\u009e\u00cd\u00aa\u00cfU\u0016@\u00ee|?5^\u00fa9@\u0002+\u0087\u0016\u00d9N\u0015@\u0004V\u000e-\u00b2\u001d\u0013@\u0085\u00ebQ\u00b8\u001ee:@\u00b2\u009d\u00ef\u00a7\u00c6K\u0014@\u00cb\u00a1E\u00b6\u00f3\u00fd\u000b@\u00d7\u00a3p=\nW;@q=\n\u00d7\u00a3p\u0017@\u009e\u00ef\u00a7\u00c6K7\u0007@\u0083\u00c0\u00ca\u00a1E\u00d6;@\u00c9v\u00be\u009f\u001a\u00af\u001b@Zd;\u00dfO\u008d\f@\u00ecQ\u00b8\u001e\u0085k;@b\u0010X9\u00b4\u00c8\u001c@\u0006\u0081\u0095C\u008bl\u0013@sh\u0091\u00ed|\u007f:@j\u00bct\u0093\u0018\u0084\u0019@\u00c7K7\u0089A\u00e0\u0015@\u00ed\u009e<,\u00d4:9@\u00e8<@[\u00b1\u00bf\u00ec\u009e|!@\u00b0rh\u0091\u00ed|\u0014@\u00c3d\u00aa`T2<@\u009a\b\u001b\u009e^I @\u00caT\u00c1\u00a8\u00a4\u008e\u001a@?\u00c6\u00dc\u00b5\u0084\u00fc;@\u00da\u001b|a2\u00d5 @$(~\u008c\u00b9k\u0016@n\u00a3\u0001\u00bc\u0005B;@\u00c0\u00ec\u009e<,t\"@\u0084\u009e\u00cd\u00aa\u00cfU\u0016@\u00ee|?5^\u00fa9@\u0002+\u0087\u0016\u00d9N\u0015@\u0004V\u000e-\u00b2\u001d\u0013@\u0085\u00ebQ\u00b8\u001ee:@\u00b2\u009d\u00ef\u00a7\u00c6K\u0014@\u00cb\u00a1E\u00b6\u00f3\u00fd\u000b@\u00d7\u00a3p=\nW;@q=\n\u00d7\u00a3p\u0017@\u009e\u00ef\u00a7\u00c6K7\u0007@\u0083\u00c0\u00ca\u00a1E\u00d6;@\u00c9v\u00be\u009f\u001a\u00af\u001b@Zd;\u00dfO\u008d\f@\u00ecQ\u00b8\u001e\u0085k;@b\u0010X9\u00b4\u00c8\u001c@\u0006\u0081\u0095C\u008bl\u0013@sh\u0091\u00ed|\u007f:@j\u00bct\u0093\u0018\u0084\u0019@\u00c7K7\u0089A\u00e0\u0015@\u00ed\u009e<,\u00d4:9@\u00e8<@[\u00b1\u00bf\u00ec\u009e|!@\u00b0rh\u0091\u00ed|\u0014@\u00c3d\u00aa`T2<@\u009a\b\u001b\u009e^I @\u00caT\u00c1\u00a8\u00a4\u008e\u001a@?\u00c6\u00dc\u00b5\u0084\u00fc;@\u00da\u001b|a2\u00d5 @$(~\u008c\u00b9k\u0016@n\u00a3\u0001\u00bc\u0005B;@\u00c0\u00ec\u009e<,t\"@\u0084\u009e\u00cd\u00aa\u00cfU\u0016@\u00ee|?5^\u00fa9@\u0002+\u0087\u0016\u00d9N\u0015@\u0004V\u000e-\u00b2\u001d\u0013@\u0085\u00ebQ\u00b8\u001ee:@\u00b2\u009d\u00ef\u00a7\u00c6K\u0014@\u00cb\u00a1E\u00b6\u00f3\u00fd\u000b@\u00d7\u00a3p=\nW;@q=\n\u00d7\u00a3p\u0017@\u009e\u00ef\u00a7\u00c6K7\u0007@\u0083\u00c0\u00ca\u00a1E\u00d6;@\u00c9v\u00be\u009f\u001a\u00af\u001b@Zd;\u00dfO\u008d\f@\u00ecQ\u00b8\u001e\u0085k;@b\u0010X9\u00b4\u00c8\u001c@\u0006\u0081\u0095C\u008bl\u0013@sh\u0091\u00ed|\u007f:@j\u00bct\u0093\u0018\u0084\u0019@\u00c7K7\u0089A\u00e0\u0015@\u00ed\u009e<,\u00d4:9@\u00e8<@[\u00b1\u00bf\u00ec\u009e|!@\u00b0rh\u0091\u00ed|\u0014@\u00c3d\u00aa`T2<@\u009a\b\u001b\u009e^I @\u00caT\u00c1\u00a8\u00a4\u008e\u001a@?\u00c6\u00dc\u00b5\u0084\u00fc;@\u00da\u001b|a2\u00d5 @$(~\u008c\u00b9k\u0016@n\u00a3\u0001\u00bc\u0005B;@\u00c0\u00ec\u009e<,t\"@\u0084\u009e\u00cd\u00aa\u00cfU\u0016@\u00ee|?5^\u00fa9@\u0002+\u0087\u0016\u00d9N\u0015@\u0004V\u000e-\u00b2\u001d\u0013@\u0085\u00ebQ\u00b8\u001ee:@\u00b2\u009d\u00ef\u00a7\u00c6K\u0014@\u00cb\u00a1E\u00b6\u00f3\u00fd\u000b@\u00d7\u00a3p=\nW;@q=\n\u00d7\u00a3p\u0017@\u009e\u00ef\u00a7\u00c6K7\u0007@\u0083\u00c0\u00ca\u00a1E\u00d6;@\u00c9v\u00be\u009f\u001a\u00af\u001b@Zd;\u00dfO\u008d\f@\u00ecQ\u00b8\u001e\u0085k;@b\u0010X9\u00b4\u00c8\u001c@\u0006\u0081\u0095C\u008bl\u0013@sh\u0091\u00ed|\u007f:@j\u00bct\u0093\u0018\u0084\u0019@\u00c7K7\u0089A\u00e0\u0015@\u00ed\u009e<,\u00d4:9@\u00e8<@[\u00b1\u00bf\u00ec\u009e|!@\u00b0rh\u0091\u00ed|\u0014@\u00c3d\u00aa`T2<@\u009a\b\u001b\u009e^I @\u00caT\u00c1\u00a8\u00a4\u008e\u001a@?\u00c6\u00dc\u00b5\u0084\u00fc;@\u00da\u001b|a2\u00d5 @$(~\u008c\u00b9k\u0016@n\u00a3\u0001\u00bc\u0005B;@\u00c0\u00ec\u009e<,t\"@\u0084\u009e\u00cd\u00aa\u00cfU\u0016@\u00ee|?5^\u00fa9@\u0002+\u0087\u0016\u00d9N\u0015@\u0004V\u000e-\u00b2\u001d\u0013@\u0085\u00ebQ\u00b8\u001ee:@\u00b2\u009d\u00ef\u00a7\u00c6K\u0014@\u00cb\u00a1E\u00b6\u00f3\u00fd\u000b@\u00d7\u00a3p=\nW;@q=\n\u00d7\u00a3p\u0017@\u009e\u00ef\u00a7\u00c6K7\u0007@\u0083\u00c0\u00ca\u00a1E\u00d6;@\u00c9v\u00be\u009f\u001a\u00af\u001b@Zd;\u00dfO\u008d\f@\u00ecQ\u00b8\u001e\u0085k;@b\u0010X9\u00b4\u00c8\u001c@\u0006\u0081\u0095C\u008bl\u0013@sh\u0091\u00ed|\u007f:@j\u00bct\u0093\u0018\u0084\u0019@\u00c7K7\u0089A\u00e0\u0015@\u00ed\u009e<,\u00d4:9@\u00e8<@[\u00b1\u00bf\u00ec\u009e|!@\u00b0rh\u0091\u00ed|\u0014@\u00c3d\u00aa`T2<@\u009a\b\u001b\u009e^I @\u00caT\u00c1\u00a8\u00a4\u008e\u001a@?\u00c6\u00dc\u00b5\u0084\u00fc;@\u00da\u001b|a2\u00d5 @$(~\u008c\u00b9k\u0016@n\u00a3\u0001\u00bc\u0005B;@\u00c0\u00ec\u009e<,t\"@\u0084\u009e\u00cd\u00aa\u00cfU\u0016@\u00ee|?5^\u00fa9@\u0002+\u0087\u0016\u00d9N\u0015@\u0004V\u000e-\u00b2\u001d\u0013@\u0085\u00ebQ\u00b8\u001ee:@\u00b2\u009d\u00ef\u00a7\u00c6K\u0014@\u00cb\u00a1E\u00b6\u00f3\u00fd\u000b@\u00d7\u00a3p=\nW;@q=\n\u00d7\u00a3p\u0017@\u009e\u00ef\u00a7\u00c6K7\u0007@\u0083\u00c0\u00ca\u00a1E\u00d6;@\u00c9v\u00be\u009f\u001a\u00af\u001b@Zd;\u00dfO\u008d\f@\u00ecQ\u00b8\u001e\u0085k;@b\u0010X9\u00b4\u00c8\u001c@\u0006\u0081\u0095C\u008bl\u0013@sh\u0091\u00ed|\u007f:@j\u00bct\u0093\u0018\u0084\u0019@\u00c7K7\u0089A\u00e0\u0015@\u00ed\u009e<,\u00d4:9@ 0 -@pytest.mark.parametrize('with_log', [True, False]) +@pytest.mark.parametrize("with_log", [True, False]) def test_main_log(with_log): logged_text = "Running null command\n" - logfile_text = "\n".join([ - "[loggers]", "keys=root", "", - "[handlers]", "keys=std", "", - "[formatters]", "keys=default", "", - "[formatter_default]", "format=%(message)s", "", - "[handler_std]", "class=StreamHandler", "level=NOTSET", - "formatter=default", "args=(sys.stdout,)", "" - "[logger_root]", "level=DEBUG", "handlers=std" - ]) + logfile_text = "\n".join( + [ + "[loggers]", + "keys=root", + "", + "[handlers]", + "keys=std", + "", + "[formatters]", + "keys=default", + "", + "[formatter_default]", + "format=%(message)s", + "", + "[handler_std]", + "class=StreamHandler", + "level=NOTSET", + "formatter=default", + "args=(sys.stdout,)", + "" "[logger_root]", + "level=DEBUG", + "handlers=std", + ], + ) runner = click.testing.CliRunner() - invocation = ['null_command'] + invocation = ["null_command"] if with_log: - invocation = ['--log', 'logging.conf'] + invocation + invocation = ["--log", "logging.conf"] + invocation expected = logged_text if with_log else "" with runner.isolated_filesystem(): - with open("logging.conf", mode='w') as log_conf: + with open("logging.conf", mode="w") as log_conf: log_conf.write(logfile_text) with null_command_context(main): result = runner.invoke(main, invocation) found = result.stdout_bytes - assert found.decode('utf-8') == expected + assert found.decode("utf-8") == expected diff --git a/openfecli/tests/test_fetchables.py b/openfecli/tests/test_fetchables.py index a30e7e247..9a500a64a 100644 --- a/openfecli/tests/test_fetchables.py +++ b/openfecli/tests/test_fetchables.py @@ -1,13 +1,13 @@ +import pathlib + import pytest from click.testing import CliRunner -from .conftest import HAS_INTERNET -import pathlib +from openfecli.fetchables import RBFE_SHOWCASE, RBFE_TUTORIAL, RBFE_TUTORIAL_RESULTS from openfecli.fetching import FetchablePlugin -from openfecli.fetchables import ( - RBFE_TUTORIAL, RBFE_TUTORIAL_RESULTS, RBFE_SHOWCASE -) +from .conftest import HAS_INTERNET + def fetchable_test(fetchable): """Unit test to ensure that a given FetchablePlugin works""" @@ -17,7 +17,7 @@ def fetchable_test(fetchable): if fetchable.fetcher.REQUIRES_INTERNET and not HAS_INTERNET: # -no-cov- pytest.skip("Internet seems to be unavailable") with runner.isolated_filesystem(): - result = runner.invoke(fetchable.command, ['-d' 'output-dir']) + result = runner.invoke(fetchable.command, ["-d" "output-dir"]) assert result.exit_code == 0 for path in expected_paths: assert (pathlib.Path("output-dir") / path).exists() @@ -26,8 +26,10 @@ def fetchable_test(fetchable): def test_rhfe_tutorial(): fetchable_test(RBFE_TUTORIAL) + def test_rhfe_tutorial_results(): fetchable_test(RBFE_TUTORIAL_RESULTS) + def test_rhfe_showcase(): fetchable_test(RBFE_SHOWCASE) diff --git a/openfecli/tests/test_fetching.py b/openfecli/tests/test_fetching.py index 7aa5b8c1c..ee0ead093 100644 --- a/openfecli/tests/test_fetching.py +++ b/openfecli/tests/test_fetching.py @@ -1,9 +1,9 @@ import pytest +from openfecli.fetching import FetchablePlugin, PkgResourceFetcher, URLFetcher + from .conftest import HAS_INTERNET -from openfecli.fetching import URLFetcher, PkgResourceFetcher -from openfecli.fetching import FetchablePlugin class FetcherTester: @pytest.fixture @@ -40,27 +40,25 @@ def fetcher(self): short_name="google", short_help="The Goog", requires_ofe=(0, 7, 0), - long_help="Google, an Alphabet company" + long_help="Google, an Alphabet company", ) def test_resources(self, fetcher): expected = [("https://www.google.com/", "index.html")] assert list(fetcher.resources) == expected - @pytest.mark.skipif(not HAS_INTERNET, - reason="Internet seems to be unavailable") + @pytest.mark.skipif(not HAS_INTERNET, reason="Internet seems to be unavailable") def test_call(self, fetcher, tmp_path): super().test_call(fetcher, tmp_path) - @pytest.mark.skipif(not HAS_INTERNET, - reason="Internet seems to be unavailable") + @pytest.mark.skipif(not HAS_INTERNET, reason="Internet seems to be unavailable") def test_without_trailing_slash(self, tmp_path): fetcher = URLFetcher( resources=[("https://www.google.com", "index.html")], short_name="goog2", short_help="more goog", requires_ofe=(0, 7, 0), - long_help="What if you forget the trailing slash?" + long_help="What if you forget the trailing slash?", ) self.test_call(fetcher, tmp_path) @@ -70,14 +68,13 @@ class TestPkgResourceFetcher(FetcherTester): @pytest.fixture def fetcher(self): return PkgResourceFetcher( - resources=[('openfecli.tests', 'test_fetching.py')], + resources=[("openfecli.tests", "test_fetching.py")], short_name="me", short_help="download this file", requires_ofe=(0, 7, 4), - long_help="whoa, meta." + long_help="whoa, meta.", ) def test_resources(self, fetcher): - expected = [('openfecli.tests', 'test_fetching.py')] + expected = [("openfecli.tests", "test_fetching.py")] assert list(fetcher.resources) == expected - diff --git a/openfecli/tests/test_plugins.py b/openfecli/tests/test_plugins.py index e631d1f0a..cc13bb2cc 100644 --- a/openfecli/tests/test_plugins.py +++ b/openfecli/tests/test_plugins.py @@ -1,4 +1,5 @@ import click + from openfecli.plugins import OFECommandPlugin @@ -9,11 +10,7 @@ def fake(): class TestOFECommandPlugin: def setup_method(self): - self.plugin = OFECommandPlugin( - command=fake, - section="Some Section", - requires_ofe=(0, 0, 1) - ) + self.plugin = OFECommandPlugin(command=fake, section="Some Section", requires_ofe=(0, 0, 1)) def test_plugin_setup(self): assert self.plugin.command is fake @@ -22,4 +19,3 @@ def test_plugin_setup(self): assert self.plugin.requires_lib == self.plugin.requires_cli assert self.plugin.requires_lib == (0, 0, 1) assert self.plugin.requires_cli == (0, 0, 1) - diff --git a/openfecli/tests/test_rbfe_tutorial.py b/openfecli/tests/test_rbfe_tutorial.py index 89c55829c..98d667ebc 100644 --- a/openfecli/tests/test_rbfe_tutorial.py +++ b/openfecli/tests/test_rbfe_tutorial.py @@ -5,80 +5,82 @@ - mocks the calculations and performs gathers on the mocked outputs """ -import pytest from importlib import resources -from click.testing import CliRunner from os import path from unittest import mock + +import pytest +from click.testing import CliRunner from openff.units import unit +from openfecli.commands.gather import gather from openfecli.commands.plan_rbfe_network import plan_rbfe_network from openfecli.commands.quickrun import quickrun -from openfecli.commands.gather import gather @pytest.fixture def tyk2_ligands(): - with resources.files('openfecli.tests.data.rbfe_tutorial') as d: - yield str(d / 'tyk2_ligands.sdf') + with resources.files("openfecli.tests.data.rbfe_tutorial") as d: + yield str(d / "tyk2_ligands.sdf") @pytest.fixture def tyk2_protein(): - with resources.files('openfecli.tests.data.rbfe_tutorial') as d: - yield str(d / 'tyk2_protein.pdb') + with resources.files("openfecli.tests.data.rbfe_tutorial") as d: + yield str(d / "tyk2_protein.pdb") @pytest.fixture def expected_transformations(): - return ['easy_rbfe_lig_ejm_31_complex_lig_ejm_42_complex.json', - 'easy_rbfe_lig_ejm_31_solvent_lig_ejm_50_solvent.json', - 'easy_rbfe_lig_ejm_31_complex_lig_ejm_46_complex.json', - 'easy_rbfe_lig_ejm_42_complex_lig_ejm_43_complex.json', - 'easy_rbfe_lig_ejm_31_complex_lig_ejm_47_complex.json', - 'easy_rbfe_lig_ejm_42_solvent_lig_ejm_43_solvent.json', - 'easy_rbfe_lig_ejm_31_complex_lig_ejm_48_complex.json', - 'easy_rbfe_lig_ejm_46_complex_lig_jmc_23_complex.json', - 'easy_rbfe_lig_ejm_31_complex_lig_ejm_50_complex.json', - 'easy_rbfe_lig_ejm_46_complex_lig_jmc_27_complex.json', - 'easy_rbfe_lig_ejm_31_solvent_lig_ejm_42_solvent.json', - 'easy_rbfe_lig_ejm_46_complex_lig_jmc_28_complex.json', - 'easy_rbfe_lig_ejm_31_solvent_lig_ejm_46_solvent.json', - 'easy_rbfe_lig_ejm_46_solvent_lig_jmc_23_solvent.json', - 'easy_rbfe_lig_ejm_31_solvent_lig_ejm_47_solvent.json', - 'easy_rbfe_lig_ejm_46_solvent_lig_jmc_27_solvent.json', - 'easy_rbfe_lig_ejm_31_solvent_lig_ejm_48_solvent.json', - 'easy_rbfe_lig_ejm_46_solvent_lig_jmc_28_solvent.json'] + return [ + "easy_rbfe_lig_ejm_31_complex_lig_ejm_42_complex.json", + "easy_rbfe_lig_ejm_31_solvent_lig_ejm_50_solvent.json", + "easy_rbfe_lig_ejm_31_complex_lig_ejm_46_complex.json", + "easy_rbfe_lig_ejm_42_complex_lig_ejm_43_complex.json", + "easy_rbfe_lig_ejm_31_complex_lig_ejm_47_complex.json", + "easy_rbfe_lig_ejm_42_solvent_lig_ejm_43_solvent.json", + "easy_rbfe_lig_ejm_31_complex_lig_ejm_48_complex.json", + "easy_rbfe_lig_ejm_46_complex_lig_jmc_23_complex.json", + "easy_rbfe_lig_ejm_31_complex_lig_ejm_50_complex.json", + "easy_rbfe_lig_ejm_46_complex_lig_jmc_27_complex.json", + "easy_rbfe_lig_ejm_31_solvent_lig_ejm_42_solvent.json", + "easy_rbfe_lig_ejm_46_complex_lig_jmc_28_complex.json", + "easy_rbfe_lig_ejm_31_solvent_lig_ejm_46_solvent.json", + "easy_rbfe_lig_ejm_46_solvent_lig_jmc_23_solvent.json", + "easy_rbfe_lig_ejm_31_solvent_lig_ejm_47_solvent.json", + "easy_rbfe_lig_ejm_46_solvent_lig_jmc_27_solvent.json", + "easy_rbfe_lig_ejm_31_solvent_lig_ejm_48_solvent.json", + "easy_rbfe_lig_ejm_46_solvent_lig_jmc_28_solvent.json", + ] def test_plan_tyk2(tyk2_ligands, tyk2_protein, expected_transformations): runner = CliRunner() with runner.isolated_filesystem(): - result = runner.invoke(plan_rbfe_network, ['-M', tyk2_ligands, - '-p', tyk2_protein]) + result = runner.invoke(plan_rbfe_network, ["-M", tyk2_ligands, "-p", tyk2_protein]) assert result.exit_code == 0 - assert path.exists('alchemicalNetwork/transformations') + assert path.exists("alchemicalNetwork/transformations") for f in expected_transformations: - assert path.exists( - path.join('alchemicalNetwork/transformations', f)) + assert path.exists(path.join("alchemicalNetwork/transformations", f)) @pytest.fixture def mock_execute(expected_transformations): def fake_execute(*args, **kwargs): return { - 'repeat_id': kwargs['repeat_id'], - 'generation': kwargs['generation'], - 'nc': 'file.nc', - 'last_checkpoint': 'checkpoint.nc', - 'unit_estimate': 4.2 * unit.kilocalories_per_mole + "repeat_id": kwargs["repeat_id"], + "generation": kwargs["generation"], + "nc": "file.nc", + "last_checkpoint": "checkpoint.nc", + "unit_estimate": 4.2 * unit.kilocalories_per_mole, } - with mock.patch('openfe.protocols.openmm_rfe.equil_rfe_methods.' - 'RelativeHybridTopologyProtocolUnit._execute') as m: + with mock.patch( + "openfe.protocols.openmm_rfe.equil_rfe_methods." "RelativeHybridTopologyProtocolUnit._execute", + ) as m: m.side_effect = fake_execute yield m @@ -100,21 +102,19 @@ def ref_gather(): """ -def test_run_tyk2(tyk2_ligands, tyk2_protein, expected_transformations, - mock_execute, ref_gather): +def test_run_tyk2(tyk2_ligands, tyk2_protein, expected_transformations, mock_execute, ref_gather): runner = CliRunner() with runner.isolated_filesystem(): - result = runner.invoke(plan_rbfe_network, ['-M', tyk2_ligands, - '-p', tyk2_protein]) + result = runner.invoke(plan_rbfe_network, ["-M", tyk2_ligands, "-p", tyk2_protein]) assert result.exit_code == 0 for f in expected_transformations: - fn = path.join('alchemicalNetwork/transformations', f) + fn = path.join("alchemicalNetwork/transformations", f) result2 = runner.invoke(quickrun, [fn]) assert result2.exit_code == 0 - gather_result = runner.invoke(gather, ["--report", "ddg", '.']) + gather_result = runner.invoke(gather, ["--report", "ddg", "."]) assert gather_result.exit_code == 0 assert gather_result.stdout == ref_gather diff --git a/openfecli/tests/test_utils.py b/openfecli/tests/test_utils.py index dca30b873..00ee64ddd 100644 --- a/openfecli/tests/test_utils.py +++ b/openfecli/tests/test_utils.py @@ -1,13 +1,11 @@ +import contextlib +import logging import os -import pytest - from unittest.mock import patch -import logging -import contextlib -from openfecli.utils import ( - import_thing, _should_configure_logger, configure_logger -) +import pytest + +from openfecli.utils import _should_configure_logger, configure_logger, import_thing # looks like this can't be done as a fixture; related to @@ -27,32 +25,41 @@ def patch_root_logger(): logging.root = old_root -@pytest.mark.parametrize('import_string,expected', [ - ('os.path.exists', os.path.exists), - ('os.getcwd', os.getcwd), - ('os', os), -]) +@pytest.mark.parametrize( + "import_string,expected", + [ + ("os.path.exists", os.path.exists), + ("os.getcwd", os.getcwd), + ("os", os), + ], +) def test_import_thing(import_string, expected): assert import_thing(import_string) is expected def test_import_thing_import_error(): with pytest.raises(ImportError): - import_thing('foo.bar') + import_thing("foo.bar") def test_import_thing_attribute_error(): with pytest.raises(AttributeError): - import_thing('os.foo') - - -@pytest.mark.parametrize("logger_name, expected", [ - ("default", True), ("default.default", True), - ("level", False), ("level.default", False), - ("handler", False), ("handler.default", False), - ("default.noprop", False) -]) -@pytest.mark.parametrize('with_adapter', [True, False]) + import_thing("os.foo") + + +@pytest.mark.parametrize( + "logger_name, expected", + [ + ("default", True), + ("default.default", True), + ("level", False), + ("level.default", False), + ("handler", False), + ("handler.default", False), + ("default.noprop", False), + ], +) +@pytest.mark.parametrize("with_adapter", [True, False]) def test_should_configure_logger(logger_name, expected, with_adapter): with patch_root_logger(): logging.getLogger("level").setLevel(logging.INFO) @@ -73,17 +80,15 @@ def test_root_logger_level_configured(): assert not _should_configure_logger(logger) -@pytest.mark.parametrize('with_handler', [True, False]) +@pytest.mark.parametrize("with_handler", [True, False]) def test_configure_logger(with_handler): handler = logging.NullHandler() if with_handler else None expected_handlers = [handler] if handler else [] with patch_root_logger(): - configure_logger('default.default', handler=handler) - logger = logging.getLogger('default.default') - parent = logging.getLogger('default') + configure_logger("default.default", handler=handler) + logger = logging.getLogger("default.default") + parent = logging.getLogger("default") assert logger.isEnabledFor(logging.INFO) assert not parent.isEnabledFor(logging.INFO) assert logger.handlers == expected_handlers assert parent.handlers == [] - - diff --git a/openfecli/utils.py b/openfecli/utils.py index 0376019b2..d5ec0f5d3 100644 --- a/openfecli/utils.py +++ b/openfecli/utils.py @@ -1,12 +1,13 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import click -import importlib import functools -from typing import Callable, Optional -from datetime import datetime +import importlib import logging +from datetime import datetime +from typing import Callable, Optional + +import click def import_thing(import_string: str): @@ -22,7 +23,7 @@ def import_thing(import_string: str): Any : the object from that namespace """ - splitted = import_string.split('.') + splitted = import_string.split(".") if len(splitted) > 1: # if the string has a dot, import the module and getattr the object obj = splitted[-1] @@ -65,13 +66,12 @@ def _should_configure_logger(logger: logging.Logger): ): l = l.parent - is_default = (l == logging.root and l.level == logging.WARNING) + is_default = l == logging.root and l.level == logging.WARNING return is_default -def configure_logger(logger_name: str, level: int = logging.INFO, *, - handler: Optional[logging.Handler] = None): +def configure_logger(logger_name: str, level: int = logging.INFO, *, handler: Optional[logging.Handler] = None): """Configure the logger at ``logger_name`` to be at ``level``. This is used to prevent accidentally overwriting existing logging @@ -103,6 +103,7 @@ def print_duration(function: Callable) -> Callable: the decorated function. """ + @functools.wraps(function) def wrapper(*args, **kwargs): start_time = datetime.now() @@ -115,4 +116,3 @@ def wrapper(*args, **kwargs): return result return wrapper -