Skip to content

Commit

Permalink
Update MPS Chebyshev approximation from jjrodriguezaldavero/pr_chebyshev
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll authored May 13, 2024
2 parents bfd8238 + 0e6264a commit 6ecac4e
Show file tree
Hide file tree
Showing 9 changed files with 555 additions and 219 deletions.
425 changes: 290 additions & 135 deletions src/seemps/analysis/chebyshev.py

Large diffs are not rendered by default.

47 changes: 29 additions & 18 deletions src/seemps/analysis/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,23 @@
MPS,
MPSSum,
Strategy,
DEFAULT_STRATEGY,
Truncation,
DEFAULT_TOLERANCE,
Simplification,
)
from ..truncate import simplify
from ..truncate import simplify, SIMPLIFICATION_STRATEGY
from .mesh import (
Interval,
RegularClosedInterval,
RegularHalfOpenInterval,
ChebyshevZerosInterval,
ChebyshevExtremaInterval,
)

DEFAULT_FACTORY_STRATEGY = Strategy(
COMPUTER_PRECISION = SIMPLIFICATION_STRATEGY.replace(
tolerance=float(np.finfo(np.double).eps),
simplification_tolerance=float(np.finfo(np.double).eps),
simplify=Simplification.DO_NOT_SIMPLIFY,
method=Truncation.RELATIVE_SINGULAR_VALUE,
tolerance=DEFAULT_TOLERANCE,
simplify=Simplification.VARIATIONAL,
simplification_tolerance=DEFAULT_TOLERANCE,
normalize=False,
)

Expand Down Expand Up @@ -96,7 +95,7 @@ def mps_exponential(start: float, stop: float, sites: int, c: complex = 1) -> MP


def mps_sin(
start: float, stop: float, sites: int, strategy: Strategy = DEFAULT_FACTORY_STRATEGY
start: float, stop: float, sites: int, strategy: Strategy = COMPUTER_PRECISION
) -> MPS:
"""
Returns an MPS representing a sine function discretized over an interval.
Expand Down Expand Up @@ -124,7 +123,7 @@ def mps_sin(


def mps_cos(
start: float, stop: float, sites: int, strategy: Strategy = DEFAULT_FACTORY_STRATEGY
start: float, stop: float, sites: int, strategy: Strategy = COMPUTER_PRECISION
) -> MPS:
"""
Returns an MPS representing a cosine function discretized over an interval.
Expand Down Expand Up @@ -154,7 +153,7 @@ def mps_cos(
_State = TypeVar("_State", bound=Union[MPS, MPSSum])


def mps_affine_transformation(mps: _State, orig: tuple, dest: tuple) -> _State:
def mps_affine(mps: _State, orig: tuple, dest: tuple) -> _State:
"""
Applies an affine transformation to an MPS, mapping it from one interval [x0, x1] to another [u0, u1].
This is a transformation u = a * x + b, with u0 = a * x0 + b and and u1 = a * x1 + b.
Expand Down Expand Up @@ -188,7 +187,7 @@ def mps_affine_transformation(mps: _State, orig: tuple, dest: tuple) -> _State:
return mps_affine


def mps_interval(interval: Interval, strategy: Strategy = DEFAULT_FACTORY_STRATEGY):
def mps_interval(interval: Interval, strategy: Strategy = COMPUTER_PRECISION):
"""
Returns an MPS corresponding to a specific type of interval (open, closed, or Chebyshev zeros).
Expand Down Expand Up @@ -216,7 +215,9 @@ def mps_interval(interval: Interval, strategy: Strategy = DEFAULT_FACTORY_STRATE
start_zeros = np.pi / (2 ** (sites + 1))
stop_zeros = np.pi + start_zeros
mps_zeros = -1.0 * mps_cos(start_zeros, stop_zeros, sites, strategy=strategy)
return mps_affine_transformation(mps_zeros, (-1, 1), (start, stop))
return mps_affine(mps_zeros, (-1, 1), (start, stop))
elif isinstance(interval, ChebyshevExtremaInterval):
raise NotImplementedError()
else:
raise ValueError(f"Unsupported interval type {type(interval)}")

Expand Down Expand Up @@ -285,7 +286,7 @@ def extend_mps(mps_id: int, mps_map: list[tuple[int, Tensor3]]) -> MPS:
def mps_tensor_product(
mps_list: list[MPS],
mps_order: str = "A",
strategy: Strategy = DEFAULT_FACTORY_STRATEGY,
strategy: Strategy = COMPUTER_PRECISION,
) -> MPS:
"""
Returns the tensor product of a list of MPS, with the sites arranged
Expand All @@ -309,16 +310,20 @@ def mps_tensor_product(
nested_sites = [mps._data for mps in mps_list]
flattened_sites = [site for sites in nested_sites for site in sites]
result = MPS(flattened_sites)
else:
elif mps_order == "B":
terms = mps_tensor_terms(mps_list, mps_order)
result = terms[0]
for idx, mps in enumerate(terms[1:]):
for _, mps in enumerate(terms[1:]):
result = result * mps
else:
raise ValueError(f"Invalid mps order {mps_order}")
return simplify(result, strategy=strategy)


def mps_tensor_sum(
mps_list: list[MPS], mps_order: str = "A", strategy: Strategy = DEFAULT_STRATEGY
mps_list: list[MPS],
mps_order: str = "A",
strategy: Strategy = COMPUTER_PRECISION,
) -> MPS:
"""
Returns the tensor sum of a list of MPS, with the sites arranged
Expand All @@ -340,16 +345,22 @@ def mps_tensor_sum(
"""
if mps_order == "A":
result = _mps_tensor_sum_serial_order(mps_list)
else:
elif mps_order == "B":
result = MPSSum(
[1.0] * len(mps_list), mps_tensor_terms(mps_list, mps_order)
).join()
else:
raise ValueError(f"Invalid mps order {mps_order}")
if strategy.get_simplify_flag():
return simplify(result, strategy=strategy)
return result


def _mps_tensor_sum_serial_order(mps_list: list[MPS]) -> MPS:
"""
Computes the MPS tensor sum in serial order in an optimized manner.
"""

def extend_tensor(A: Tensor3, first: bool, last: bool) -> Tensor3:
a, d, b = A.shape
output = np.zeros((a + 2, d, b + 2), dtype=A.dtype)
Expand All @@ -368,7 +379,7 @@ def extend_tensor(A: Tensor3, first: bool, last: bool) -> Tensor3:

output = [
extend_tensor(Ai, i == 0, i == len(A) - 1)
for n, A in enumerate(mps_list)
for _, A in enumerate(mps_list)
for i, Ai in enumerate(A)
]
output[0] = output[0][[0], :, :]
Expand Down
4 changes: 2 additions & 2 deletions src/seemps/analysis/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..truncate import simplify
from ..qft import iqft, qft_flip
from .mesh import RegularHalfOpenInterval, Mesh
from .factories import mps_tensor_product, mps_affine_transformation
from .factories import mps_tensor_product, mps_affine
from .cross import cross_interpolation, CrossStrategy


Expand Down Expand Up @@ -284,7 +284,7 @@ def func(k):
mps_v = mps_k2 * mps_phase
mps = (1 / sqrt(2) ** sites) * qft_flip(iqft(mps_v, strategy=strategy))

return mps_affine_transformation(mps, (-1, 1), (start, stop)).as_mps()
return mps_affine(mps, (-1, 1), (start, stop)).as_mps()


# TODO: Consider if this helper function is necessary
Expand Down
16 changes: 9 additions & 7 deletions src/seemps/analysis/lagrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

from ..state import MPS, Strategy
from ..state.schmidt import _destructive_svd
from ..state._contractions import _contract_last_and_first
from ..state.core import destructively_truncate_vector
from ..truncate import simplify, SIMPLIFICATION_STRATEGY
from .mesh import affine_transformation
from .mesh import array_affine


# TODO: Implement multivariate Lagrange interpolation and multirresolution constructions

DEFAULT_LAGRANGE_STRATEGY = SIMPLIFICATION_STRATEGY.replace(normalize=False)

Expand Down Expand Up @@ -99,7 +103,7 @@ def lagrange_rank_revealing(
U_L, R = np.linalg.qr(Al.reshape((2, order + 1)))
tensors = [U_L.reshape(1, 2, 2)]
for _ in range(sites - 2):
B = np.tensordot(R, Ac, axes=1)
B = _contract_last_and_first(R, Ac)
r1, s, r2 = B.shape
## SVD
U, S, V = _destructive_svd(B.reshape(r1 * s, r2))
Expand All @@ -109,7 +113,7 @@ def lagrange_rank_revealing(
R = S.reshape(D, 1) * V[:D, :]
##
tensors.append(U.reshape(r1, s, -1))
U_R = np.tensordot(R, Ar, axes=1)
U_R = _contract_last_and_first(R, Ar)
tensors.append(U_R)
return MPS(tensors)

Expand Down Expand Up @@ -246,9 +250,7 @@ def local_chebyshev_cardinal(self, x: float, j: int) -> float:
gamma_res = (
-gamma
if gamma < 0
else self.d - (gamma - self.d)
if gamma > self.d
else gamma
else self.d - (gamma - self.d) if gamma > self.d else gamma
)
if j == gamma_res:
P += self.local_angular_cardinal(theta, gamma)
Expand All @@ -275,7 +277,7 @@ def A_L(self, func: Callable, start: float, stop: float) -> np.ndarray:
"""

def affine_func(u):
return func(affine_transformation(u, orig=(0, 1), dest=(start, stop)))
return func(array_affine(u, orig=(0, 1), dest=(start, stop)))

A = np.zeros((1, 2, self.D))
for s in range(2):
Expand Down
20 changes: 13 additions & 7 deletions src/seemps/analysis/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __getitem__(self, idx: int) -> float: ...
def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
super()._validate_index(idx)
zero = np.cos(np.pi * (2 * idx + 1) / (2 * self.size))
return affine_transformation(zero, orig=(-1, 1), dest=(self.stop, self.start))
return array_affine(zero, orig=(-1, 1), dest=(self.stop, self.start))


class ChebyshevExtremaInterval(Interval):
Expand All @@ -132,7 +132,7 @@ def __getitem__(self, idx: int) -> float: ...
def __getitem__(self, idx: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
super()._validate_index(idx)
maxima = np.cos(np.pi * idx / (self.size - 1))
return affine_transformation(maxima, orig=(-1, 1), dest=(self.stop, self.start))
return array_affine(maxima, orig=(-1, 1), dest=(self.stop, self.start))


class Mesh:
Expand Down Expand Up @@ -210,7 +210,11 @@ def to_tensor(self):
)


def affine_transformation(x: np.ndarray, orig: tuple, dest: tuple) -> np.ndarray:
def array_affine(
x: np.ndarray,
orig: tuple,
dest: tuple,
) -> np.ndarray:
"""
Performs an affine transformation of x as u = a*x + b from orig=(x0, x1) to dest=(u0, u1).
"""
Expand All @@ -220,13 +224,13 @@ def affine_transformation(x: np.ndarray, orig: tuple, dest: tuple) -> np.ndarray
a = (u1 - u0) / (x1 - x0)
b = 0.5 * ((u1 + u0) - a * (x0 + x1))
x_affine = a * x
if np.abs(b) > np.finfo(np.float64).eps:
if abs(b) > np.finfo(np.float64).eps:
x_affine = x_affine + b
return x_affine


def mps_to_mesh_matrix(
sites_per_dimension: list[int], mps_order: str = "A"
sites_per_dimension: list[int], mps_order: str = "A", base: int = 2
) -> np.ndarray:
"""
Returns a matrix that transforms an array of MPS indices
Expand All @@ -236,13 +240,15 @@ def mps_to_mesh_matrix(
T = np.zeros((sum(sites_per_dimension), len(sites_per_dimension)), dtype=int)
start = 0
for m, n in enumerate(sites_per_dimension):
T[start : start + n, m] = 2 ** np.arange(n)[::-1]
T[start : start + n, m] = base ** np.arange(n)[::-1]
start += n
return T
elif mps_order == "B":
T = np.vstack(
[
np.diag([2 ** (n - i - 1) if n > i else 0 for n in sites_per_dimension])
np.diag(
[base ** (n - i - 1) if n > i else 0 for n in sites_per_dimension]
)
for i in range(max(sites_per_dimension))
]
)
Expand Down
18 changes: 17 additions & 1 deletion src/seemps/analysis/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import numpy as np
from ..operators import MPO, MPOList
from ..operators import MPO, MPOList, MPOSum
from ..state import Strategy, DEFAULT_STRATEGY
from typing import Union

Expand Down Expand Up @@ -292,3 +292,19 @@ def sin_mpo(n: int, a: float, dx: float, strategy=DEFAULT_STRATEGY):
exp2 = exponential_mpo(n, a, dx, c=-1j, strategy=strategy)
sin_mpo = (-1j) * 0.5 * (exp1 - exp2)
return sin_mpo.join(strategy=strategy)


def mpo_affine(
mpo: MPO,
orig: tuple,
dest: tuple,
):
x0, x1 = orig
u0, u1 = dest
a = (u1 - u0) / (x1 - x0)
b = 0.5 * ((u1 + u0) - a * (x0 + x1))
mpo_affine = a * mpo
if abs(b) > np.finfo(np.float64).eps:
I = MPO([np.ones((1, 2, 2, 1))] * len(mpo_affine))
mpo_affine = MPOSum(mpos=[mpo_affine, I], weights=[1, b]).join()
return mpo_affine
47 changes: 31 additions & 16 deletions src/seemps/truncate/simplify_mpo.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
from typing import Optional
from typing import Optional, Union
from math import isqrt

from ..operators import MPO
from ..operators import MPO, MPOList, MPOSum
from ..state import DEFAULT_STRATEGY, MPS, Strategy
from ..truncate import SIMPLIFICATION_STRATEGY, simplify


def mpo_as_mps(mpo):
def mpo_as_mps(mpo: MPO) -> MPS:
"""Recast MPO as MPS."""
_, i, j, _ = mpo[0].shape
return MPS([t.reshape(t.shape[0], i * j, t.shape[-1]) for t in mpo._data])


def mps_as_mpo(
mps: MPS,
mpo_strategy: Strategy = DEFAULT_STRATEGY,
) -> MPO:
"""Recast MPS as MPO."""
_, S, _ = mps[0].shape
s = isqrt(S)
if s**2 != S:
raise ValueError("The physical dimensions of the MPS must be a perfect square")
return MPO(
[t.reshape(t.shape[0], s, s, t.shape[-1]) for t in mps._data],
strategy=mpo_strategy,
)


# TODO: As opposed to MPS, the MPO class does not have an error attribute to keep track
# of the simplification errors
def simplify_mpo(
operator: MPO,
operator: Union[MPO, MPOList, MPOSum],
strategy: Strategy = SIMPLIFICATION_STRATEGY,
direction: int = +1,
guess: Optional[MPS] = None,
Expand All @@ -23,26 +41,23 @@ def simplify_mpo(
Parameters
----------
operator : MPO
MPO to approximate.
operator : Union[MPO, MPOList, MPOSum]
Operator to approximate.
strategy : Strategy
Truncation strategy. Defaults to `SIMPLIFICATION_STRATEGY`.
Truncation strategy. Defaults to `SIMPLIFICATION_STRATEGY`.
direction : { +1, -1 }
Initial direction for the sweeping algorithm. Defaults to +1.
Initial direction for the sweeping algorithm. Defaults to +1.
guess : MPS
A guess for the new state, to ease the optimization. Defaults to None.
A guess for the new state, to ease the optimization. Defaults to None.
mpo_strategy : Strategy
Strategy of the resulting MPO. Defaults to `DEFAULT_STRATEGY`.
Strategy of the resulting MPO. Defaults to `DEFAULT_STRATEGY`.
Returns
-------
MPO
Approximation O to the operator.
"""
_, i, j, _ = operator[0].shape
if isinstance(operator, MPOList) or isinstance(operator, MPOSum):
operator = operator.join()
mps = simplify(mpo_as_mps(operator), strategy, direction, guess)
[t.reshape(t.shape[0], i, j, t.shape[-1]) for t in mps._data]
return MPO(
[t.reshape(t.shape[0], i, j, t.shape[-1]) for t in mps._data],
strategy=mpo_strategy,
)
return mps_as_mpo(mps, mpo_strategy=mpo_strategy)
Loading

0 comments on commit 6ecac4e

Please sign in to comment.