Skip to content

Commit

Permalink
Renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jan 15, 2025
1 parent 424dabb commit 96b96e6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
31 changes: 18 additions & 13 deletions src/dxtb/_src/components/interactions/coulomb/thirdorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,26 @@ class ES3(Interaction):
hubbard_derivs: Tensor
"""Hubbard derivatives of all atoms."""

shell_params: Tensor | None
"""Scaling factors for shell-resolved third-order electrostatics."""
shell_scale: Tensor | None
"""
Scaling factors for shell-resolved third-order electrostatics.
In GFN2-xTB, this is a tensor of shape ``(3,)`` containing the scaling
factors for the s, p, and d shells.
"""

__slots__ = ["hubbard_derivs", "shell_params"]
__slots__ = ["hubbard_derivs", "shell_scale"]

def __init__(
self,
hubbard_derivs: Tensor,
shell_params: Tensor | None = None,
shell_scale: Tensor | None = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(device, dtype)
self.hubbard_derivs = hubbard_derivs
self.shell_params = shell_params
self.shell_scale = shell_scale

# pylint: disable=unused-argument
@override
Expand Down Expand Up @@ -217,12 +222,12 @@ def get_cache(
# if the cache is built, store the cachevar for validation
self._cachevars = cachvars

if self.shell_params is None:
if self.shell_scale is None:
hd = ihelp.spread_uspecies_to_atom(self.hubbard_derivs)
else:
hd = (
ihelp.spread_uspecies_to_shell(self.hubbard_derivs)
* self.shell_params[ihelp.angular]
* self.shell_scale[ihelp.angular]
)

self.cache = ES3Cache(hd)
Expand Down Expand Up @@ -254,7 +259,7 @@ def get_atom_energy(self, charges: Tensor, cache: ES3Cache) -> Tensor:
"""
return (
cache.hd * torch.pow(charges, 3.0) / 3.0
if self.shell_params is None
if self.shell_scale is None
else torch.zeros_like(charges)
)

Expand All @@ -276,7 +281,7 @@ def get_shell_energy(self, charges: Tensor, cache: ES3Cache) -> Tensor:
"""
return (
torch.zeros_like(charges)
if self.shell_params is None
if self.shell_scale is None
else cache.hd * torch.pow(charges, 3.0) / 3.0
)

Expand All @@ -299,7 +304,7 @@ def get_atom_potential(self, charges: Tensor, cache: ES3Cache) -> Tensor:
"""
return (
cache.hd * torch.pow(charges, 2.0)
if self.shell_params is None
if self.shell_scale is None
else torch.zeros_like(charges)
)

Expand All @@ -322,7 +327,7 @@ def get_shell_potential(self, charges: Tensor, cache: ES3Cache) -> Tensor:
"""
return (
torch.zeros_like(charges)
if self.shell_params is None
if self.shell_scale is None
else cache.hd * torch.pow(charges, 2.0)
)

Expand Down Expand Up @@ -369,7 +374,7 @@ def new_es3(
torch.unique(numbers), par.element, "gam3", **dd
)

shell_params = (
shell_scale = (
None
if par.thirdorder.shell is False
else torch.tensor(
Expand All @@ -382,4 +387,4 @@ def new_es3(
)
)

return ES3(hubbard_derivs, shell_params=shell_params, **dd)
return ES3(hubbard_derivs, shell_scale=shell_scale, **dd)
8 changes: 4 additions & 4 deletions test/test_coulomb/test_es3_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_qat_single(dtype: torch.dtype, name: str) -> None:
assert es is not None

# atom-resolved charge test
es.shell_params = None
es.shell_scale = None

cache = es.get_cache(numbers=numbers, ihelp=ihelp)
e = es.get_atom_energy(qat, cache)
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_qat_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
assert es is not None

# atom-resolved charge test
es.shell_params = None
es.shell_scale = None

cache = es.get_cache(numbers=numbers, ihelp=ihelp)
e = es.get_atom_energy(qat, cache)
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_grad_param_shell(name: str) -> None:
assert GFN2_XTB.thirdorder is not None
assert GFN2_XTB.thirdorder.shell is not False

shell_params = torch.tensor(
shell_scale = torch.tensor(
[
GFN2_XTB.thirdorder.shell.s,
GFN2_XTB.thirdorder.shell.p,
Expand All @@ -221,7 +221,7 @@ def test_grad_param_shell(name: str) -> None:
hd.requires_grad_(True)

def func(hubbard_derivs: Tensor):
es = es3.ES3(hubbard_derivs, shell_params=shell_params, **dd)
es = es3.ES3(hubbard_derivs, shell_scale=shell_scale, **dd)
cache = es.get_cache(numbers=numbers, ihelp=ihelp)
return es.get_shell_energy(qsh, cache)

Expand Down

0 comments on commit 96b96e6

Please sign in to comment.