Skip to content

Commit

Permalink
Additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Sep 17, 2024
1 parent 3c0f682 commit cb8fa6e
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 18 deletions.
31 changes: 16 additions & 15 deletions src/dxtb/_src/calculators/types/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def get_dipole_deriv(
**kwargs: Any,
) -> Tensor:
prop = self.get_property(
"dipole_derivatives", positions, chrg=chrg, spin=spin, **kwargs
"dipole_deriv", positions, chrg=chrg, spin=spin, **kwargs
)
assert isinstance(prop, Tensor)
return prop
Expand All @@ -264,11 +264,7 @@ def get_dipole_derivatives(
spin: Tensor | float | int | None = defaults.SPIN,
**kwargs: Any,
) -> Tensor:
prop = self.get_property(
"dipole_derivatives", positions, chrg=chrg, spin=spin, **kwargs
)
assert isinstance(prop, Tensor)
return prop
return self.get_dipole_deriv(positions, chrg=chrg, spin=spin, **kwargs)

def get_polarizability(
self,
Expand All @@ -291,11 +287,7 @@ def get_pol_deriv(
**kwargs: Any,
) -> Tensor:
prop = self.get_property(
"polarizability_derivatives",
positions,
chrg=chrg,
spin=spin,
**kwargs,
"pol_deriv", positions, chrg=chrg, spin=spin, **kwargs
)
assert isinstance(prop, Tensor)
return prop
Expand Down Expand Up @@ -440,11 +432,15 @@ def get_charges(
spin: Tensor | float | int | None = defaults.SPIN,
**kwargs: Any,
) -> Tensor:
# pylint: disable=import-outside-toplevel
from dxtb._src.scf.base import Charges

prop = self.get_property(
"charges", positions, chrg=chrg, spin=spin, **kwargs
)
assert isinstance(prop, Tensor)
return prop
assert isinstance(prop, Charges)

return prop.mono

def get_mulliken_charges(
self,
Expand Down Expand Up @@ -501,8 +497,13 @@ def get_potential(
spin: Tensor | float | int | None = defaults.SPIN,
**kwargs: Any,
) -> Tensor:
# pylint: disable=import-outside-toplevel
from dxtb._src.scf.base import Potential

prop = self.get_property(
"potential", positions, chrg=chrg, spin=spin, **kwargs
)
assert isinstance(prop, Tensor)
return prop
assert isinstance(prop, Potential)
assert isinstance(prop.mono, Tensor)

return prop.mono
1 change: 0 additions & 1 deletion src/dxtb/_src/calculators/types/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,6 @@ def calculate(
props.remove("bond_orders")

if set(props) & set(properties):
print("Calculating energy")
self.energy(positions, chrg, spin, **kwargs)

if "forces" in properties:
Expand Down
4 changes: 2 additions & 2 deletions src/dxtb/_src/calculators/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def get_cache_key(self, key: str) -> str | None:

# printing

def __str__(self) -> str:
def __str__(self) -> str: # pragma: no cover
"""Return a string representation of the Cache object."""
counter = 0
l = []
Expand All @@ -396,7 +396,7 @@ def __str__(self) -> str:
f"{', '.join(l)})"
)

def __repr__(self) -> str:
def __repr__(self) -> str: # pragma: no cover
"""Return a representation of the Cache object."""
return str(self)

Expand Down
222 changes: 222 additions & 0 deletions test/test_calculator/test_cache/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,68 @@ def test_energy(dtype: torch.dtype) -> None:
assert len(calc.cache.list_cached_properties()) == 0


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_scf_props(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd)

options = dict(
opts,
**{
"cache_charges": True,
"cache_coefficients": True,
"cache_density": True,
"cache_iterations": True,
"cache_mo_energies": True,
"cache_occupation": True,
"cache_potential": True,
},
)

calc = GFN1Calculator(numbers, opts=options, **dd)
assert calc._ncalcs == 0

energy = calc.get_energy(positions)
assert calc._ncalcs == 1
assert isinstance(energy, Tensor)

# get other properties

prop = calc.get_charges(positions)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

prop = calc.get_mulliken_charges(positions)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

prop = calc.get_coefficients(positions)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

prop = calc.get_density(positions)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

prop = calc.get_iterations(positions)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

prop = calc.get_occupation(positions)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

prop = calc.get_potential(positions)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# check reset
calc.cache.reset_all()
assert len(calc.cache.list_cached_properties()) == 0


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("grad_mode", ["functorch", "row"])
def test_forces(
Expand Down Expand Up @@ -259,6 +321,166 @@ def test_dipole(dtype: torch.dtype) -> None:
assert len(calc.cache.list_cached_properties()) == 0


@pytest.mark.skipif(not has_libcint, reason="libcint not available")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_dipole_deriv(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd)
pos = positions.clone().requires_grad_(True)

options = dict(opts, **{"scf_mode": "full", "mixer": "anderson"})

field = torch.tensor([0, 0, 0], **dd, requires_grad=True)
efield = new_efield(field, **dd)

calc = AutogradCalculator(
numbers, GFN1_XTB, opts=options, interaction=efield, **dd
)
assert calc._ncalcs == 0

kwargs = {"use_analytical_dipmom": False, "use_functorch": True}

prop = calc.get_dipole_deriv(pos, **kwargs)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# cache is used for same calc
prop = calc.get_dipole_derivatives(pos, **kwargs)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# cache is used for energy (kwargs mess up the cache key!)
prop = calc.get_energy(pos)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# check reset
calc.cache.reset_all()
assert len(calc.cache.list_cached_properties()) == 0


@pytest.mark.skipif(not has_libcint, reason="libcint not available")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_polarizability(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd)
pos = positions.clone().requires_grad_(True)

options = dict(opts, **{"scf_mode": "full", "mixer": "anderson"})

field = torch.tensor([0, 0, 0], **dd, requires_grad=True)
efield = new_efield(field, **dd)

calc = AutogradCalculator(
numbers, GFN1_XTB, opts=options, interaction=efield, **dd
)
assert calc._ncalcs == 0

kwargs = {"use_functorch": True}

prop = calc.get_polarizability(pos, **kwargs)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# cache is used for same calc
prop = calc.get_polarizability(pos, **kwargs)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# cache is used for energy (kwargs mess up the cache key!)
prop = calc.get_energy(pos)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# check reset
calc.cache.reset_all()
assert len(calc.cache.list_cached_properties()) == 0


@pytest.mark.skipif(not has_libcint, reason="libcint not available")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_pol_deriv(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd)
pos = positions.clone().requires_grad_(True)

options = dict(opts, **{"scf_mode": "full", "mixer": "anderson"})

field = torch.tensor([0, 0, 0], **dd, requires_grad=True)
efield = new_efield(field, **dd)

calc = AutogradCalculator(
numbers, GFN1_XTB, opts=options, interaction=efield, **dd
)
assert calc._ncalcs == 0

kwargs = {"use_functorch": True}

prop = calc.get_pol_deriv(pos, **kwargs)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# cache is used for same calc
prop = calc.get_polarizability_derivatives(pos, **kwargs)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# cache is used for energy (kwargs mess up the cache key!)
prop = calc.get_energy(pos)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# check reset
calc.cache.reset_all()
assert len(calc.cache.list_cached_properties()) == 0


@pytest.mark.skipif(not has_libcint, reason="libcint not available")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_hyperpolarizability(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd)
pos = positions.clone().requires_grad_(True)

options = dict(opts, **{"scf_mode": "full", "mixer": "anderson"})

field = torch.tensor([0, 0, 0], **dd, requires_grad=True)
efield = new_efield(field, **dd)

calc = AutogradCalculator(
numbers, GFN1_XTB, opts=options, interaction=efield, **dd
)
assert calc._ncalcs == 0

kwargs = {"use_functorch": True}

prop = calc.get_hyperpolarizability(pos, **kwargs)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# cache is used for same calc
prop = calc.get_hyperpolarizability(pos, **kwargs)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# cache is used for energy (kwargs mess up the cache key!)
prop = calc.get_energy(pos)
assert calc._ncalcs == 1
assert isinstance(prop, Tensor)

# check reset
calc.cache.reset_all()
assert len(calc.cache.list_cached_properties()) == 0


@pytest.mark.skipif(not has_libcint, reason="libcint not available")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_ir(dtype: torch.dtype) -> None:
Expand Down
Loading

0 comments on commit cb8fa6e

Please sign in to comment.