Skip to content

Commit

Permalink
Fix mypy type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jjrodriguezaldavero committed Jul 5, 2024
1 parent da9fd5e commit 1107b4c
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/seemps/analysis/cross/cross.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import scipy.linalg
import scipy.linalg # type: ignore
import dataclasses
import functools

Expand Down
2 changes: 1 addition & 1 deletion src/seemps/analysis/cross/cross_dmrg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import scipy.linalg
import scipy.linalg # type: ignore
from dataclasses import dataclass
from typing import Optional, Callable

Expand Down
5 changes: 3 additions & 2 deletions src/seemps/analysis/cross/cross_greedy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import scipy.linalg
import scipy.linalg # type: ignore
from typing import TypeVar, Union, Optional, Callable
from dataclasses import dataclass

Expand Down Expand Up @@ -155,7 +155,8 @@ def get_row_indices(rows, all_rows):
G_cores = [self.Q_to_G(Q, j_l) for Q, j_l in zip(self.Q_factors, self.J_l[1:])]
self.mps = MPS(G_cores + [self.fibers[-1]])

_Index = TypeVar("_Index", bound=Union[np.intp, np.ndarray, slice])
# _Index = TypeVar("_Index", bound=Union[np.intp, np.ndarray, slice])
_Index = Union[np.intp, np.ndarray, slice]

def sample_superblock(
self, k: int, j_l: _Index = slice(None), j_g: _Index = slice(None)
Expand Down
10 changes: 5 additions & 5 deletions src/seemps/analysis/cross/cross_maxvol.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _update_maxvol(
fiber = cross.sample_fiber(k)
r_l, s, r_g = fiber.shape
if forward:
C = fiber.reshape(r_l * s, r_g, order=order)
C = fiber.reshape(r_l * s, r_g, order=order) # type: ignore
Q, _ = scipy.linalg.qr(C, mode="economic", overwrite_a=True, check_finite=False) # type: ignore
I, _ = choose_maxvol(
Q, # type: ignore
Expand All @@ -168,7 +168,7 @@ def _update_maxvol(
cross.I_l[k + 1] = combine_indices(cross.I_l[k], cross.I_s[k])[I]
else:
if k > 0:
R = fiber.reshape(r_l, s * r_g, order=order)
R = fiber.reshape(r_l, s * r_g, order=order) # type: ignore
Q, _ = scipy.linalg.qr( # type: ignore
R.T, mode="economic", overwrite_a=True, check_finite=False
)
Expand All @@ -179,7 +179,7 @@ def _update_maxvol(
cross_strategy.tol_maxvol_square,
cross_strategy.tol_maxvol_rect,
)
cross.mps[k] = (G.T).reshape(-1, s, r_g, order=order)
cross.mps[k] = (G.T).reshape(-1, s, r_g, order=order) # type: ignore
cross.I_g[k - 1] = combine_indices(cross.I_s[k], cross.I_g[k])[I]
else:
cross.mps[0] = fiber
Expand Down Expand Up @@ -219,11 +219,11 @@ def maxvol_rectangular(
if r_min < r or r_min > r_max or r_max > n:
raise ValueError("Invalid minimum/maximum number of added rows")
I0, B = maxvol_square(A, maxiter, tol)
I = np.hstack([I0, np.zeros(r_max - r, dtype=I0.dtype)])
I = np.hstack([I0, np.zeros(r_max - r, dtype=I0.dtype)]) # type: ignore
S = np.ones(n, dtype=int)
S[I0] = 0
F = S * np.linalg.norm(B) ** 2
for k in range(r, r_max):
for k in range(r, int(r_max)):
i = np.argmax(F)
if k >= r_min and F[i] <= tol_rect**2:
break
Expand Down
10 changes: 8 additions & 2 deletions src/seemps/analysis/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,15 @@ def __init__(self, start: int, stop: int, step: int = 1):
size = (stop - start + step - 1) // step
super().__init__(start, stop, size)

def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[int, np.ndarray]:
@overload
def __getitem__(self, idx: np.ndarray) -> np.ndarray: ...

@overload
def __getitem__(self, idx: int) -> float: ...

def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
super()._validate_index(idx)
return self.start + idx * self.step # type: ignore
return self.start + idx * self.step


class RegularInterval(Interval):
Expand Down

0 comments on commit 1107b4c

Please sign in to comment.