Skip to content

Commit

Permalink
More linting (#163)
Browse files Browse the repository at this point in the history
* notbooks:

* linting

* update configs

* linting

* linting

* linting

* linting

* cover

* lint

* lint

* lint

* lint

* lint

* lint

* lint

* lint

* lint

* lint

* lint
  • Loading branch information
jmmshn authored Dec 14, 2023
1 parent 564210e commit a28b25a
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 44 deletions.
1 change: 0 additions & 1 deletion pymatgen/analysis/defects/ccd.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def _get_ediff(self, output_order="skb") -> npt.NDArray:
rearrangement here so that we have a single point of failure.
Args:
band_structure: The band structure of the relaxed defect calculation.
output_order: The order of the output. Defaults to "skb" (spin, kpoint, band]).
You can also use "bks" (band, kpoint, spin).
Expand Down
30 changes: 21 additions & 9 deletions pymatgen/analysis/defects/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .utils import get_plane_spacing

if TYPE_CHECKING:
from numpy.typing import ArrayLike
from pymatgen.core import Structure
from pymatgen.symmetry.structure import SymmetrizedStructure

Expand Down Expand Up @@ -312,11 +313,7 @@ def _has_oxi(struct):

@property
def symmetrized_structure(self) -> SymmetrizedStructure:
"""Returns the multiplicity of a defect site within the structure.
This is required for concentration analysis and confirms that defect_site is a
site in bulk_structure.
"""
"""Get the symmetrized version of the bulk structure."""
sga = SpacegroupAnalyzer(
self.structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance
)
Expand Down Expand Up @@ -895,7 +892,19 @@ def get_vacancy(structure: Structure, isite: int, **kwargs) -> Vacancy:
return Vacancy(structure=structure, site=site, **kwargs)


def _set_selective_dynamics(structure, site_pos, relax_radius):
def _set_selective_dynamics(
structure: Structure, site_pos: ArrayLike, relax_radius: float | str | None
):
"""Set the selective dynamics behavior.
Allow atoms to move for sites within a given radius of a given site,
all other atoms are fixed. Modify the structure in place.
Args:
structure: The structure to set the selective dynamics.
site_pos: The center of the relaxation sphere.
relax_radius: The radius of the relaxation sphere.
"""
if relax_radius is None:
return
if relax_radius == "auto":
Expand Down Expand Up @@ -974,11 +983,15 @@ def _get_mapped_sites(uc_structure: Structure, sc_structure: Structure, r=0.001)
return mapped_site_indices


def center_structure(structure, ref_fpos) -> Structure:
def center_structure(structure: Structure, ref_fpos: ArrayLike) -> Structure:
"""Shift the sites around a center.
Move all the sites in the structure so that they
are in the periodic image closest to the reference fractional position.
Args:
structure: The structure to be centered.
ref_fpos: The reference fractional position that will be set to the center.
"""
struct = structure.copy()
for idx, d_site in enumerate(struct):
Expand All @@ -997,8 +1010,7 @@ def _get_el_changes_from_structures(defect_sc: Structure, bulk_sc: Structure) ->
bulk_sc: The bulk structure.
Returns:
str: The name of the defect, if the defect is a complex, the names of the
individual defects are separated by "+".
dict: A dictionary representing the species changes in creating the defect.
"""

def _check_int(n):
Expand Down
43 changes: 33 additions & 10 deletions pymatgen/analysis/defects/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def get_site_groups(struct, symprec=0.01, angle_tolerance=5.0) -> List[SiteGroup
return site_groups


def get_soap_vec(struct: "Structure") -> "NDArray":
def get_soap_vec(struct: "Structure") -> NDArray:
"""Get the SOAP vector for each site in the structure.
Args:
Expand All @@ -237,17 +237,32 @@ def get_soap_vec(struct: "Structure") -> "NDArray":
return vecs


def get_site_vecs(struct: "Structure"):
"""Get the SiteVec representation of each site in the structure."""
def get_site_vecs(struct: Structure) -> List[SiteVec]:
"""Get the SiteVec representation of each site in the structure.
Args:
struct: Structure object to compute the site vectors (SOAP).
Returns:
List[SiteVec]: List of SiteVec representing each site in the structure.
"""
vecs = get_soap_vec(struct)
site_vecs = []
for i, site in enumerate(struct):
site_vecs.append(SiteVec(species=site.species_string, site=site, vec=vecs[i]))
return site_vecs
return [
SiteVec(species=site.species_string, site=site, vec=vecs[i])
for i, site in enumerate(struct)
]


def cosine_similarity(vec1, vec2) -> float:
"""Cosine similarity between two vectors."""
"""Cosine similarity between two vectors.
Args:
vec1: First vector
vec2: Second vector
Returns:
float: Cosine similarity between the two vectors
"""
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))


Expand Down Expand Up @@ -278,11 +293,19 @@ def best_match(sv: SiteVec, sgs: List[SiteGroup]) -> Tuple[SiteGroup, float]:
return best_match, best_similarity


def _get_broundary(arr, n_max=16, n_skip=3):
def _get_broundary(arr, n_max=16, n_skip=3) -> int:
"""Get the boundary index for the high-distortion indices.
Assuming arr is sorted in reverse order,
find the biggest value drop in arr[n_skip:n_max].
Args:
arr: List of numbers
n_max: Maximum index to consider
n_skip: Number of indices to skip
Returns:
int: The boundary index
"""
sub_arr = np.array(arr[n_skip:n_max])
diffs = sub_arr[1:] - sub_arr[:-1]
Expand All @@ -291,7 +314,7 @@ def _get_broundary(arr, n_max=16, n_skip=3):

def get_weighted_average_position(
lattice: Lattice, frac_positions: ArrayLike, weights: ArrayLike | None = None
) -> "NDArray":
) -> NDArray:
"""Get the weighted average position of a set of positions in frac coordinates.
The algorithm starts at position with the highest weight, and gradually moves
Expand Down
18 changes: 16 additions & 2 deletions pymatgen/analysis/defects/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,20 @@ def _space_group_analyzer(self, structure: Structure) -> SpacegroupAnalyzer:
"This generator is using the `SpaceGroupAnalyzer` and requires `symprec` and `angle_tolerance` to be set."
)

def generate(self, *args, **kwargs) -> Generator[Defect, None, None]:
"""Generate a defect.
Args:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Generator[Defect, None, None]: Generator that yields a list of ``Defect`` objects.
"""
raise NotImplementedError

def get_defects(self, *args, **kwargs) -> list[Defect]:
"""Call the generator and convert the results into a list."""
"""Alias for self.generate."""
return list(self.generate(*args, **kwargs))


Expand Down Expand Up @@ -254,7 +266,7 @@ def generate(
insertions: The insertions to be made given as a dictionary {"Mg": [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]}.
multiplicities: The multiplicities of the insertions to be made given as a dictionary {"Mg": [1, 2]}.
equivalent_positions: The equivalent positions of the each inserted species given as a dictionary.
Note that they should typically be the same but we allow for more flexibility.
Note that they should typically be the same but we allow for more flexibility here.
**kwargs: Additional keyword arguments for the ``Interstitial`` constructor.
Returns:
Expand Down Expand Up @@ -418,6 +430,8 @@ class ChargeInterstitialGenerator(InterstitialGenerator):
min_dist: Minimum to atoms in the host structure
avg_radius: The radius around each local minima used to evaluate the average charge.
max_avg_charge: The maximum average charge to accept.
max_insertions: The maximum number of insertion sites to consider.
Will choose the sites with the lowest average charge.
"""

def __init__(
Expand Down
53 changes: 48 additions & 5 deletions pymatgen/analysis/defects/plotting/optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,25 @@ def _plot_eigs(
x_width: float = 0.3,
**kwargs,
) -> None:
"""Plot the eigenvalues."""
"""Plot the eigenvalues.
Args:
d_eigs:
The dictionary of eigenvalues for the defect state. In the format of
(iband, ikpt, ispin) -> eigenvalue
e_fermi:
The bands above and below the Fermi level will be colored differently.
If not provided, they will all be colored the same.
ax:
The matplotlib axis object to plot on.
x0:
The x coordinate of the center of the set of lines representing the eigenvalues.
x_width:
The width of the set of lines representing the eigenvalues.
**kwargs:
Keyword arguments to pass to `matplotlib.pyplot.hlines`.
For example, `linestyles`, `alpha`, etc.
"""
if ax is None: # pragma: no cover
ax = plt.gca()

Expand Down Expand Up @@ -215,7 +233,7 @@ def _plot_matrix_elements(
arrow_width=0.1,
cmap=None,
norm=None,
):
) -> tuple[list[tuple], plt.cm, plt.Normalize]:
"""Plot arrow for the transition from the defect state to all other states.
Args:
Expand All @@ -242,13 +260,21 @@ def _plot_matrix_elements(
The cartesian direction of the WAVDER tensor to sum over for the plot.
If not provided, all the absolute values of the matrix for all
three diagonal entries will be summed.
Returns:
plot_data:
A list of tuples in the format of (iband, ikpt, ispin, eigenvalue, matrix element)
cmap:
The matplotlib color map used.
norm:
The matplotlib normalization used.
"""
if ax is None: # pragma: no cover
ax = plt.gca()
ax.set_aspect("equal")
jb, jkpt, jspin = next(filter(lambda x: x[0] == defect_band_index, d_eig.keys()))
y0 = d_eig[jb, jkpt, jspin]
plot_data = []
plot_data: list[tuple] = []
for (ib, ik, ispin), eig in d_eig.items():
A = 0
for idir, jdir in ijdirs:
Expand Down Expand Up @@ -289,8 +315,25 @@ def _plot_matrix_elements(
return plot_data, cmap, norm


def _get_dataframe(d_eigs, me_plot_data) -> pd.DataFrame:
"""Convert the eigenvalue and matrix element data into a pandas dataframe."""
def _get_dataframe(d_eigs: dict, me_plot_data: list[tuple]) -> pd.DataFrame:
"""Convert the eigenvalue and matrix element data into a pandas dataframe.
Args:
d_eigs:
The dictionary of eigenvalues for the defect state. In the format of
(iband, ikpt, ispin) -> eigenvalue
me_plot_data:
A list of tuples in the format of (iband, ikpt, ispin, eigenvalue, matrix element)
Returns:
A pandas dataframe with the following columns:
ib: The band index of the state the arrow is pointing to.
jb: The band index of the defect state.
kpt: The kpoint index of the state the arrow is pointing to.
spin: The spin index of the state the arrow is pointing to.
eig: The eigenvalue of the state the arrow is pointing to.
M.E.: The matrix element of the transition.
"""
_, ikpt, ispin = next(iter(d_eigs.keys()))
df = pd.DataFrame(
me_plot_data,
Expand Down
1 change: 0 additions & 1 deletion pymatgen/analysis/defects/plotting/phases.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def _convex_hull_2d(
points: list[dict],
x_element: Element,
y_element: Element,
tol: float = 0.001,
competing_phases: list = None,
) -> list[dict]:
"""Compute the convex hull of a set of points in 2D.
Expand Down
19 changes: 17 additions & 2 deletions pymatgen/analysis/defects/recombination.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@

@njit(cache=True)
def fact(n: int) -> float: # pragma: no cover
"""Compute the factorial of n."""
"""Compute the factorial of n.
Args:
n: The number to compute the factorial of.
Returns:
The factorial of n.
"""
if n > 20:
return LOOKUP_TABLE[-1] * np.prod(
np.array(list(range(21, n + 1)), dtype=np.double)
Expand All @@ -40,7 +47,15 @@ def fact(n: int) -> float: # pragma: no cover

@njit(cache=True)
def herm(x: float, n: int) -> float: # pragma: no cover
"""Recursive definition of hermite polynomial."""
"""Recursive definition of hermite polynomial.
Args:
x: The value to evaluate the hermite polynomial at.
n: The order of the hermite polynomial.
Returns:
The value of the hermite polynomial at x.
"""
if n == 0:
return 1.0
if n == 1:
Expand Down
28 changes: 20 additions & 8 deletions pymatgen/analysis/defects/supercells.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# from pymatgen.io.ase import AseAtomsAdaptor

if TYPE_CHECKING:
import numpy as np
from numpy.typing import ArrayLike, NDArray
from pymatgen.core import Structure

__author__ = "Jimmy-Xuan Shen"
Expand All @@ -29,7 +29,7 @@ def get_sc_fromstruct(
max_atoms: int = 240,
min_length: float = 10.0,
force_diagonal: bool = False,
) -> np.ndarray | np.array | None:
) -> NDArray | ArrayLike | None:
"""Generate the best supercell from a unitcell.
The CubicSupercellTransformation from PMG is much faster but don't iterate over as
Expand Down Expand Up @@ -92,7 +92,7 @@ def _cubic_cell(
max_atoms: int = 240,
min_length: float = 10.0,
force_diagonal: bool = False,
) -> np.ndarray | None:
) -> NDArray | None:
"""Generate the best supercell from a unit cell.
This is done using the pymatgen CubicSupercellTransformation class.
Expand Down Expand Up @@ -125,23 +125,35 @@ def _cubic_cell(
return cst.transformation_matrix


def _ase_cubic(base_struture, min_atoms: int = 80, max_atoms: int = 240):
def _ase_cubic(base_structure, min_atoms: int = 80, max_atoms: int = 240):
"""Generate the best supercell from a unit cell.
Use ASE's find_optimal_cell_shape function to find the best supercell.
Args:
base_structure: structure of the unit cell
max_atoms: Maximum number of atoms allowed in the supercell.
min_atoms: Minimum number of atoms allowed in the supercell.
Returns:
3x3 matrix: supercell matrix
"""
from ase.build import find_optimal_cell_shape, get_deviation_from_optimal_cell_shape
from pymatgen.io.ase import AseAtomsAdaptor

_logger.warn("ASE cubic supercell generation.")

aaa = AseAtomsAdaptor()
ase_atoms = aaa.get_atoms(base_struture)
lower = math.ceil(min_atoms / base_struture.num_sites)
upper = math.floor(max_atoms / base_struture.num_sites)
ase_atoms = aaa.get_atoms(base_structure)
lower = math.ceil(min_atoms / base_structure.num_sites)
upper = math.floor(max_atoms / base_structure.num_sites)
min_dev = (float("inf"), None)
for size in range(lower, upper + 1):
_logger.warn(f"Trying size {size} out of {upper}.")
sc = find_optimal_cell_shape(
ase_atoms.cell, target_size=size, target_shape="sc"
)
sc_cell = aaa.get_atoms(base_struture * sc).cell
sc_cell = aaa.get_atoms(base_structure * sc).cell
deviation = get_deviation_from_optimal_cell_shape(sc_cell, target_shape="sc")
min_dev = min(min_dev, (deviation, sc))
if min_dev[1] is None:
Expand Down
Loading

0 comments on commit a28b25a

Please sign in to comment.