diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index cb0a8c8..88e5c17 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -10,8 +10,8 @@ dependencies: - openff-units - openmm - pymbar + - pydantic=1 # TODO: Modify when we support pydantic 2 - python - - pip # Testing - pytest diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 4d8e377..ee908a5 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -1,6 +1,7 @@ # Adapted from perses: https://github.com/choderalab/perses/blob/protocol-neqcyc/perses/protocols/nonequilibrium_cycling.py from typing import Optional, Iterable, List, Dict, Any +from itertools import chain import datetime import logging @@ -130,11 +131,16 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): phase = self._detect_phase(state_a, state_b) # infer phase from systems and components - # Get components from systems if found (None otherwise) -- NOTE: Uses hardcoded keys! + # Get receptor components from systems if found (None otherwise) -- NOTE: Uses hardcoded keys! receptor_a = state_a.components.get("protein") # receptor_b = state_b.components.get("protein") # Should not be needed - ligand_a = mapping.get("ligand").componentA - ligand_b = mapping.get("ligand").componentB + + # Get ligand/small-mol components + ligand_mapping = mapping["ligand"] + ligand_a = ligand_mapping.componentA + ligand_b = ligand_mapping.componentB + + # Get solvent components solvent_a = state_a.components.get("solvent") # solvent_b = state_b.components.get("solvent") # Should not be needed @@ -163,26 +169,37 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): # Note: by default this is cached to ctx.shared/db.json so shouldn't # incur too large a cost self.logger.info("Parameterizing molecules") - small_mols_a = [] + # The following creates a dictionary with all the small molecules in the states, with the structure: + # Dict[SmallMoleculeComponent, openff.toolkit.Molecule] + # Alchemical small mols + alchemical_small_mols_a = {ligand_a: ligand_a.to_openff()} + alchemical_small_mols_b = {ligand_b: ligand_b.to_openff()} + all_alchemical_mols = alchemical_small_mols_a | alchemical_small_mols_b + # non-alchemical common small mols + common_small_mols = {} for comp in state_a.components.values(): - if isinstance(comp, SmallMoleculeComponent): - small_mols_a.append(comp) - - for comp in small_mols_a: - offmol = comp.to_openff() - system_generator.create_system(offmol.to_topology().to_openmm(), - molecules=[offmol]) - if comp == ligand_a: - mol_b = ligand_b.to_openff() - system_generator.create_system(mol_b.to_topology().to_openmm(), - molecules=[mol_b]) + # TODO: Refactor if/when gufe provides the functionality https://github.com/OpenFreeEnergy/gufe/issues/251 + if isinstance(comp, SmallMoleculeComponent) and comp not in all_alchemical_mols: + common_small_mols[comp] = comp.to_openff() + + # Assign charges to ALL small mols, if unassigned -- more info: Openfe issue #576 + for off_mol in chain(all_alchemical_mols.values(), common_small_mols.values()): + # skip if we already have user charges + if not (off_mol.partial_charges is not None and np.any(off_mol.partial_charges)): + # due to issues with partial charge generation in ambertools + # we default to using the input conformer for charge generation + off_mol.assign_partial_charges( + 'am1bcc', use_conformers=off_mol.conformers + ) + system_generator.create_system(off_mol.to_topology().to_openmm(), + molecules=[off_mol]) # c. get OpenMM Modeller + a dictionary of resids for each component solvation_settings = settings.solvation_settings state_a_modeller, comp_resids = system_creation.get_omm_modeller( protein_comp=receptor_a, solvent_comp=solvent_a, - small_mols=small_mols_a, + small_mols=alchemical_small_mols_a | common_small_mols, omm_forcefield=system_generator.forcefield, solvent_settings=solvation_settings, ) @@ -197,7 +214,8 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): # e. create the stateA System state_a_system = system_generator.create_system( state_a_modeller.topology, - molecules=[s.to_openff() for s in small_mols_a], + molecules=list(chain(alchemical_small_mols_a.values(), + common_small_mols.values())), ) # 2. Get stateB system @@ -208,15 +226,10 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): exclude_resids=comp_resids[ligand_a], ) - # b. get a list of small molecules for stateB - off_mols_state_b = [ligand_b.to_openff(), ] - for comp in small_mols_a: - if comp != ligand_a: - off_mols_state_b.append(comp.to_openff()) - state_b_system = system_generator.create_system( state_b_topology, - molecules=off_mols_state_b, + molecules=list(chain(alchemical_small_mols_b.values(), + common_small_mols.values())), ) # c. Define correspondence mappings between the two systems diff --git a/feflow/tests/conftest.py b/feflow/tests/conftest.py index d9473c2..e412e94 100644 --- a/feflow/tests/conftest.py +++ b/feflow/tests/conftest.py @@ -1,17 +1,16 @@ # fixtures for chemicalcomponents and chemicalsystems to test protocols with import gufe import pytest -import importlib.resources +from importlib.resources import files, as_file from rdkit import Chem from gufe.mapping import LigandAtomMapping @pytest.fixture def benzene_modifications(): - with importlib.resources.path('gufe.tests.data', - 'benzene_modifications.sdf') as f: + source = files("gufe.tests.data").joinpath("benzene_modifications.sdf") + with as_file(source) as f: supp = Chem.SDMolSupplier(str(f), removeHs=False) - mols = list(supp) return {m.GetProp('_Name'): m for m in mols}