Skip to content

Commit

Permalink
Add affine_transformation to mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
jjrodriguezaldavero committed Apr 17, 2024
1 parent 6320dfd commit a27f891
Showing 1 changed file with 51 additions and 23 deletions.
74 changes: 51 additions & 23 deletions src/seemps/analysis/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,13 @@ def _validate_index(self, idx):
raise TypeError("Index must be an integer or a NumPy array")

@overload
def __getitem__(self, idx: np.ndarray) -> np.ndarray:
...
def __getitem__(self, idx: np.ndarray) -> np.ndarray: ...

@overload
def __getitem__(self, idx: int) -> float:
...
def __getitem__(self, idx: int) -> float: ...

@abstractmethod
def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
...
def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[float, np.ndarray]: ...

def to_vector(self) -> np.ndarray:
return np.array([self[idx] for idx in range(self.size)])
Expand All @@ -72,12 +69,10 @@ def __init__(self, start: float, stop: float, size: int):
self.step = (stop - start) / (size - 1)

@overload
def __getitem__(self, idx: np.ndarray) -> np.ndarray:
...
def __getitem__(self, idx: np.ndarray) -> np.ndarray: ...

@overload
def __getitem__(self, idx: int) -> float:
...
def __getitem__(self, idx: int) -> float: ...

def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
super()._validate_index(idx)
Expand All @@ -92,12 +87,10 @@ def __init__(self, start: float, stop: float, size: int):
self.step = (stop - start) / size

@overload
def __getitem__(self, idx: np.ndarray) -> np.ndarray:
...
def __getitem__(self, idx: np.ndarray) -> np.ndarray: ...

@overload
def __getitem__(self, idx: int) -> float:
...
def __getitem__(self, idx: int) -> float: ...

def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
super()._validate_index(idx)
Expand All @@ -112,17 +105,34 @@ def __init__(self, start: float, stop: float, size: int):
super().__init__(start, stop, size)

@overload
def __getitem__(self, idx: np.ndarray) -> np.ndarray:
...
def __getitem__(self, idx: np.ndarray) -> np.ndarray: ...

@overload
def __getitem__(self, idx: int) -> float: ...

def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
super()._validate_index(idx)
zero = np.cos(np.pi * (2 * idx + 1) / (2 * self.size))
return affine_transformation(zero, orig=(-1, 1), dest=(self.stop, self.start))


class ChebyshevExtremaInterval(Interval):
"""Irregular discretization given by an affine map between the
N extrema of the (N-1)-th Chebyshev polynomial in [-1, 1] to (start, stop)."""

def __init__(self, start: float, stop: float, size: int):
super().__init__(start, stop, size)

@overload
def __getitem__(self, idx: np.ndarray) -> np.ndarray: ...

@overload
def __getitem__(self, idx: int) -> float:
...
def __getitem__(self, idx: int) -> float: ...

def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
super()._validate_index(idx)
zero = np.cos(np.pi * (2 * (self.size - idx) - 1) / (2 * self.size))
return (self.stop - self.start) * (zero + 1) / 2 + self.start
maxima = np.cos(np.pi * idx / (self.size - 1))
return affine_transformation(maxima, orig=(-1, 1), dest=(self.stop, self.start))


class Mesh:
Expand Down Expand Up @@ -155,6 +165,7 @@ class Mesh:
dimensions: tuple[int, ...]

def __init__(self, intervals: list[Interval]):
# TODO: Rename dimensions to shape = (dimensions, dimension)
self.intervals = intervals
self.dimension = len(intervals)
self.dimensions = tuple(interval.size for interval in self.intervals)
Expand Down Expand Up @@ -199,19 +210,36 @@ def to_tensor(self):
)


def mps_to_mesh_matrix(sites_per_dimension: list[int], order: str = "A") -> np.ndarray:
def affine_transformation(x: np.ndarray, orig: tuple, dest: tuple) -> np.ndarray:
"""
Performs an affine transformation of x as u = a*x + b from orig=(x0, x1) to dest=(u0, u1).
"""
# TODO: Combine the affine transformations for vectors, MPS and MPO.
x0, x1 = orig
u0, u1 = dest
a = (u1 - u0) / (x1 - x0)
b = 0.5 * ((u1 + u0) - a * (x0 + x1))
x_affine = a * x
if np.abs(b) > np.finfo(np.float64).eps:
x_affine = x_affine + b
return x_affine


def mps_to_mesh_matrix(
sites_per_dimension: list[int], mps_order: str = "A"
) -> np.ndarray:
"""
Returns a matrix that transforms an array of MPS indices
to an array of Mesh indices based on the specified order and base.
"""
if order == "A":
if mps_order == "A":
T = np.zeros((sum(sites_per_dimension), len(sites_per_dimension)), dtype=int)
start = 0
for m, n in enumerate(sites_per_dimension):
T[start : start + n, m] = 2 ** np.arange(n)[::-1]
start += n
return T
elif order == "B":
elif mps_order == "B":
T = np.vstack(
[
np.diag([2 ** (n - i - 1) if n > i else 0 for n in sites_per_dimension])
Expand Down

0 comments on commit a27f891

Please sign in to comment.