Skip to content

Commit

Permalink
Merge pull request #1115 from openforcefield/use-charge-caching
Browse files Browse the repository at this point in the history
More consistently use existing charge caching
  • Loading branch information
mattwthompson authored Jan 15, 2025
2 parents 6a515c0 + 9ad8427 commit 9460db5
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 60 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ jobs:
run: micromamba install "foyer >=0.12.1" -c conda-forge -yq

- name: Run tests
if: always()
run: |
python -m pytest $COV openff/interchange/ -r fExs -n logical --durations=10
run: python -m pytest $COV openff/interchange/ -r fExs -n logical --durations=10

- name: Run small molecule regression tests
if: ${{ matrix.openeye == true && matrix.openmm == true }}
Expand Down
4 changes: 1 addition & 3 deletions examples/protein_ligand/protein_ligand.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,7 @@
"id": "47",
"metadata": {},
"source": [
"### GROMACS\n",
"\n",
"Interchange's GROMACS exporter is a little slow for biopolymers; this will be faster in a future release."
"### GROMACS"
]
},
{
Expand Down
3 changes: 1 addition & 2 deletions openff/interchange/interop/amber/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,7 @@ def to_prmtop(interchange: "Interchange", file_path: Path | str):

prmtop.write("%FLAG CHARGE\n%FORMAT(5E16.8)\n")
charges = [
charge.m_as(unit.e) * AMBER_COULOMBS_CONSTANT
for charge in interchange["Electrostatics"]._get_charges().values()
charge.m_as(unit.e) * AMBER_COULOMBS_CONSTANT for charge in interchange["Electrostatics"].charges.values()
]
text_blob = "".join([f"{val:16.8E}" for val in charges])
_write_text_blob(prmtop, text_blob)
Expand Down
2 changes: 1 addition & 1 deletion openff/interchange/interop/lammps/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _write_atoms(lmp_file: IO, interchange: Interchange, atom_type_map: dict):

vdw_handler = interchange["vdW"]

charges = interchange["Electrostatics"]._get_charges()
charges = interchange["Electrostatics"].charges
positions = interchange.positions.m_as(unit.angstrom) # type: ignore[union-attr]

"""
Expand Down
9 changes: 3 additions & 6 deletions openff/interchange/interop/openmm/_import/_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
from openff.toolkit import Quantity, Topology
from openff.utilities.utilities import has_package, requires_package

from openff.interchange.common._nonbonded import vdWCollection
from openff.interchange.common._nonbonded import ElectrostaticsCollection, vdWCollection
from openff.interchange.common._valence import (
AngleCollection,
BondCollection,
ConstraintCollection,
ProperTorsionCollection,
)
from openff.interchange.exceptions import UnsupportedImportError
from openff.interchange.interop.openmm._import._nonbonded import (
BasicElectrostaticsCollection,
)
from openff.interchange.interop.openmm._import.compat import _check_compatible_inputs
from openff.interchange.warnings import MissingPositionsWarning

Expand Down Expand Up @@ -182,7 +179,7 @@ def _convert_constraints(

def _convert_nonbonded_force(
force: "openmm.NonbondedForce",
) -> tuple[vdWCollection, BasicElectrostaticsCollection]:
) -> tuple[vdWCollection, ElectrostaticsCollection]:
from openff.units.openmm import from_openmm as from_openmm_quantity

from openff.interchange.components.potentials import Potential
Expand All @@ -194,7 +191,7 @@ def _convert_nonbonded_force(
)

vdw = vdWCollection()
electrostatics = BasicElectrostaticsCollection(version=0.4, scale_14=0.833333)
electrostatics = ElectrostaticsCollection(version=0.4, scale_14=0.833333)

n_parametrized_particles = force.getNumParticles()

Expand Down
6 changes: 4 additions & 2 deletions openff/interchange/interop/openmm/_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def _create_single_nonbonded_force(
)

if data.electrostatics_collection is not None:
partial_charges = data.electrostatics_collection._get_charges()
partial_charges = data.electrostatics_collection.charges

# mapping between (openmm) index of each atom and the (openmm) index of each virtual particle
# of that parent atom (if any)
Expand All @@ -394,7 +394,9 @@ def _create_single_nonbonded_force(
other_top_key = SingleAtomChargeTopologyKey(
this_atom_index=atom_index,
)

partial_charge = partial_charges[other_top_key].m_as(unit.e)

else:
partial_charge = 0.0

Expand Down Expand Up @@ -941,7 +943,7 @@ def _set_particle_parameters(
# handling for electrostatics_force = None
electrostatics: ElectrostaticsCollection = data.electrostatics_collection

partial_charges = electrostatics._get_charges()
partial_charges = electrostatics.charges

vdw: vdWCollection = data.vdw_collection

Expand Down
10 changes: 3 additions & 7 deletions openff/interchange/smirnoff/_gromacs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _convert(

try:
vdw_collection = interchange["vdW"]
electrostatics_collection = interchange["Electrostatics"]
interchange["Electrostatics"]
except KeyError:
raise UnsupportedExportError("Plugins not implemented.")

Expand Down Expand Up @@ -146,8 +146,6 @@ def _convert(

vdw_parameters = vdw_collection.potentials[vdw_collection.key_map[key]].parameters

charge = electrostatics_collection._get_charges()[key]

# Build atom types
system.atom_types[atom_type_name] = LennardJonesAtomType(
name=_atom_atom_type_map[atom],
Expand All @@ -168,8 +166,6 @@ def _convert(

vdw_parameters = vdw_collection.potentials[vdw_collection.key_map[virtual_site_key]].parameters

charge = electrostatics_collection._get_charges()[key]

# TODO: Separate class for "atom types" representing virtual sites?
system.atom_types[atom_type_name] = LennardJonesAtomType(
name=_atom_atom_type_map[virtual_site_key],
Expand All @@ -185,7 +181,7 @@ def _convert(
_partial_charges: dict[int | VirtualSiteKey, float] = dict()

# Indexed by particle (atom or virtual site) indices
for key, charge in interchange["Electrostatics"]._get_charges().items():
for key, charge in interchange["Electrostatics"].charges.items():
if type(key) is TopologyKey:
_partial_charges[key.atom_indices[0]] = charge
elif type(key) is VirtualSiteKey:
Expand Down Expand Up @@ -585,7 +581,7 @@ def _convert_virtual_sites(
residue_index=molecule.atoms[0].residue_index,
residue_name=molecule.atoms[0].residue_name,
charge_group_number=1,
charge=interchange["Electrostatics"]._get_charges()[virtual_site_key],
charge=interchange["Electrostatics"].charges[virtual_site_key],
mass=Quantity(0.0, unit.dalton),
),
)
Expand Down
44 changes: 8 additions & 36 deletions openff/interchange/smirnoff/_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,6 @@
LibraryChargeHandler,
]

_ZERO_CHARGE = Quantity(0.0, unit.elementary_charge)


@unit.wraps(
ret=unit.elementary_charge,
args=(unit.elementary_charge, unit.elementary_charge),
strict=True,
)
def _add_charges(
charge1: "Quantity",
charge2: "Quantity",
) -> "Quantity":
"""Add two charges together."""
return charge1 + charge2


def _upconvert_vdw_handler(vdw_handler: vdWHandler):
"""Given a vdW with version 0.3 or 0.4, up-convert to 0.4 or short-circuit if already 0.4."""
Expand Down Expand Up @@ -335,6 +320,7 @@ def _get_charges(
) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]:
"""Get the total partial charge on each atom or particle."""
# Keyed by index for atoms and by VirtualSiteKey for virtual sites.
# work in unitless (float, implicit elementary_charge) values until returning
charges: dict[VirtualSiteKey | int, _ElementaryChargeQuantity] = dict()

for topology_key, potential_key in self.key_map.items():
Expand All @@ -348,19 +334,14 @@ def _get_charges(
"virtual sites, not by a `ChargeIncrementModelHandler`.",
)

total_charge: Quantity = numpy.sum(parameter_value)
# assumes virtual sites can only have charges determined in one step

charges[topology_key] = -1.0 * total_charge
charges[topology_key] = -1.0 * numpy.sum(parameter_value)

# Apply increments to "orientation" atoms
for i, increment in enumerate(parameter_value):
orientation_atom_index = topology_key.orientation_atom_indices[i]

charges[orientation_atom_index] = _add_charges(
charges.get(orientation_atom_index, _ZERO_CHARGE),
increment,
)
charges[orientation_atom_index] = charges.get(orientation_atom_index, 0.0) + increment.m

elif parameter_key == "charge":
assert len(topology_key.atom_indices) == 1
Expand All @@ -373,10 +354,7 @@ def _get_charges(
"molecules_with_preset_charges",
"ExternalSource",
):
charges[atom_index] = _add_charges(
charges.get(atom_index, _ZERO_CHARGE),
parameter_value,
)
charges[atom_index] = charges.get(atom_index, 0.0) + parameter_value.m

elif potential_key.associated_handler in ( # type: ignore[operator]
"ChargeIncrementModelHandler"
Expand All @@ -385,10 +363,7 @@ def _get_charges(
# we "add" the charge whether or not the increment was already applied.
# There should be a better way to do this.

charges[atom_index] = _add_charges(
charges.get(atom_index, _ZERO_CHARGE),
parameter_value,
)
charges[atom_index] = charges.get(atom_index, 0.0) + parameter_value.m

else:
raise RuntimeError(
Expand All @@ -400,10 +375,7 @@ def _get_charges(

atom_index = topology_key.atom_indices[0]

charges[atom_index] = _add_charges(
charges.get(atom_index, _ZERO_CHARGE),
parameter_value,
)
charges[atom_index] = charges.get(atom_index, 0.0) + parameter_value.m

logger.info(
"Charge section ChargeIncrementModel, applying charge increment from atom " # type: ignore[union-attr]
Expand All @@ -428,7 +400,7 @@ def _get_charges(
if include_virtual_sites:
returned_charges[index] = charge

return returned_charges
return {key: Quantity(val, "elementary_charge") for key, val in returned_charges.items()}

@classmethod
def parameter_handler_precedence(cls) -> list[str]:
Expand Down Expand Up @@ -956,7 +928,7 @@ def store_matches(

topology_charges = [0.0] * topology.n_atoms

for key, val in self._get_charges().items():
for key, val in self.charges.items():
topology_charges[key.atom_indices[0]] = val.m

# TODO: Better data structures in Topology.identical_molecule_groups will make this
Expand Down

0 comments on commit 9460db5

Please sign in to comment.