Skip to content

Commit

Permalink
Update chebyshev, factories and mesh modules from jjrodriguezaldavero…
Browse files Browse the repository at this point in the history
…/main
  • Loading branch information
juanjosegarciaripoll authored Jan 4, 2024
2 parents cf76365 + 9530cdd commit f399002
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 112 deletions.
16 changes: 8 additions & 8 deletions src/seemps/analysis/chebyshev.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations
from typing import Callable, Optional, Union
from typing import Callable, Optional

import numpy as np
from scipy.fft import dct # type: ignore
from typing import Callable, Optional

from .mesh import ChebyshevZerosInterval, Interval
from .factories import mps_interval
from .sampling import infinity_norm
from ..operators import MPO
from ..state import MPS, Strategy, DEFAULT_STRATEGY, Truncation, Simplification
from ..state import MPS, Strategy, Truncation, Simplification
from ..truncate import simplify


Expand All @@ -24,10 +23,11 @@ def chebyshev_coefficients(
Returns the Chebyshev coefficients for a given function on a specified
interval using the Discrete Cosine Transform (DCT II).
The accuracy of the Chebyshev approximation is correlated to the magnitude
of the last few coefficients in the series (depending on their periodicity),
with smaller absolute values typically indicating a better approximation of
the function.
The error of the Chebyshev approximation is related to the magnitude of the
last (or second-to-last) coefficient of the series. As a rule of thumb, the
error is given by the sum of all the neglected coefficients, which can be
close to the magnitude of the last coefficient if the series decays sufficiently
fast (for example, exponentially fast for analytical functions).
Parameters
----------
Expand Down Expand Up @@ -134,7 +134,7 @@ def cheb2mps(
L = len(x_mps)
I = MPS([np.ones((1, 2, 1))] * L)
if np.abs(b) > np.finfo(np.float64).eps:
x_mpo = (x_mps + b * I).join()
x_mps = (x_mps + b * I).join()
else:
raise Exception("In cheb2mps, either domain or an MPS must be provided.")

Expand Down
69 changes: 48 additions & 21 deletions src/seemps/analysis/factories.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from typing import List

from ..state import MPS, Strategy
from ..state import MPS, Strategy, DEFAULT_STRATEGY
from ..truncate import simplify
from .mesh import (
Interval,
Expand All @@ -10,10 +10,6 @@
ChebyshevZerosInterval,
)

# TODO: These descriptions are wrong. `sites` is not the number of
# discretization points, but rather the number of qubits used for this
# representation. The size is 2**sites


def mps_equispaced(start: float, stop: float, sites: int):
"""
Expand All @@ -26,12 +22,12 @@ def mps_equispaced(start: float, stop: float, sites: int):
stop : float
The end of the interval.
sites : int
The number of discretization points.
The number of sites or qubits for the MPS.
Returns
-------
MPS
An MPS representing the equispaced discretization of the interval.
An MPS representing an equispaced discretization within [start, stop].
"""
step = (stop - start) / 2**sites
tensor_1 = np.zeros((1, 2, 2))
Expand All @@ -47,7 +43,7 @@ def mps_equispaced(start: float, stop: float, sites: int):
return MPS(tensors)


def mps_exponential(start: float, stop: float, sites: int, c: complex) -> MPS:
def mps_exponential(start: float, stop: float, sites: int, c: complex = 1) -> MPS:
"""
Returns an MPS representing an exponential function discretized over an interval.
Expand All @@ -58,7 +54,7 @@ def mps_exponential(start: float, stop: float, sites: int, c: complex) -> MPS:
stop : float
The end of the interval.
sites : int
The number of discretization points.
The number of sites or qubits for the MPS.
c : complex
The coefficient in the exponent of the exponential function.
Expand All @@ -84,7 +80,37 @@ def mps_exponential(start: float, stop: float, sites: int, c: complex) -> MPS:
return MPS(tensors)


def mps_cosine(start: float, stop: float, sites: int) -> MPS:
def mps_sine(
start: float, stop: float, sites: int, strategy: Strategy = DEFAULT_STRATEGY
) -> MPS:
"""
Returns an MPS representing a sine function discretized over an interval.
Parameters
----------
start : float
The start of the interval.
stop : float
The end of the interval.
sites : int
The number of sites or qubits for the MPS.
strategy : Strategy, default = DEFAULT_STRATEGY
The MPS simplification strategy to apply.
Returns
-------
MPS
An MPS representing the discretized sine function over the interval.
"""
mps_1 = mps_exponential(start, stop, sites, c=1j)
mps_2 = mps_exponential(start, stop, sites, c=-1j)

return simplify(-0.5j * (mps_1 - mps_2), strategy=strategy)


def mps_cosine(
start: float, stop: float, sites: int, strategy: Strategy = DEFAULT_STRATEGY
) -> MPS:
"""
Returns an MPS representing a cosine function discretized over an interval.
Expand All @@ -95,7 +121,9 @@ def mps_cosine(start: float, stop: float, sites: int) -> MPS:
stop : float
The end of the interval.
sites : int
The number of discretization points.
The number of sites or qubits for the MPS.
strategy : Strategy, default = DEFAULT_STRATEGY
The MPS simplification strategy to apply.
Returns
-------
Expand All @@ -105,28 +133,27 @@ def mps_cosine(start: float, stop: float, sites: int) -> MPS:
mps_1 = mps_exponential(start, stop, sites, c=1j)
mps_2 = mps_exponential(start, stop, sites, c=-1j)

return simplify(0.5 * (mps_1 + mps_2))
return simplify(0.5 * (mps_1 + mps_2), strategy=strategy)


# TODO: Eliminate the `rescale` argument now that we have `Interval.map_to`
def mps_interval(interval: Interval, rescale: bool = False):
def mps_interval(interval: Interval, strategy: Strategy = DEFAULT_STRATEGY):
"""
Returns an MPS corresponding to a specific type of interval (open, closed, or Chebyshev zeros).
Parameters
----------
interval : Interval
The interval object containing start and stop points and the interval type.
rescale : bool, optional
Flag to rescale the interval to [-1, 1].
strategy : Strategy, default = DEFAULT_STRATEGY
The MPS simplification strategy to apply.
Returns
-------
MPS
An MPS representing the interval according to its type.
"""
start = interval.start if not rescale else -1
stop = interval.stop if not rescale else 1
start = interval.start
stop = interval.stop
sites = int(np.log2(interval.size))
if isinstance(interval, RegularHalfOpenInterval):
return mps_equispaced(start, stop, sites)
Expand All @@ -136,7 +163,7 @@ def mps_interval(interval: Interval, rescale: bool = False):
elif isinstance(interval, ChebyshevZerosInterval):
start_mapped = np.pi / (2 ** (sites + 1))
stop_mapped = np.pi + start_mapped
return -1.0 * mps_cosine(start_mapped, stop_mapped, sites)
return -1.0 * mps_cosine(start_mapped, stop_mapped, sites, strategy=strategy)
else:
raise ValueError(f"Unsupported interval type {type(interval)}")

Expand All @@ -160,15 +187,15 @@ def mps_tensor_product(mps_list: List[MPS]) -> MPS:
return MPS(flattened_sites)


def mps_tensor_sum(mps_list: List[MPS], strategy: Strategy = Strategy()) -> MPS:
def mps_tensor_sum(mps_list: List[MPS], strategy: Strategy = DEFAULT_STRATEGY) -> MPS:
"""
Returns the tensor sum of a list of MPS.
Parameters
----------
mps_list : List[MPS]
The list of MPS objects to sum.
strategy : Strategy, optional
strategy : Strategy, default = DEFAULT_STRATEGY
The MPS simplification strategy to apply.
Returns
Expand Down
3 changes: 3 additions & 0 deletions src/seemps/analysis/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def to_vector(self) -> np.ndarray:
def map_to(self, start: float, stop: float) -> Interval:
return type(self)(start, stop, self.size)

def update_size(self, size: int) -> Interval:
return type(self)(self.start, self.stop, size)

def __iter__(self) -> Iterator:
return (self[i] for i in range(self.size))

Expand Down
83 changes: 83 additions & 0 deletions tests/test_analysis_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import numpy as np
from seemps.analysis import (
mps_equispaced,
mps_exponential,
mps_sine,
mps_cosine,
RegularHalfOpenInterval,
RegularClosedInterval,
ChebyshevZerosInterval,
mps_interval,
mps_tensor_sum,
mps_tensor_product,
)

from .tools import TestCase


class TestMPSFactories(TestCase):
def test_mps_equispaced(self):
self.assertSimilar(
mps_equispaced(-1, 1, 5).to_vector(),
np.linspace(-1, 1, 2**5, endpoint=False),
)

def test_mps_exponential(self):
self.assertSimilar(
mps_exponential(-1, 1, 5, c=1).to_vector(),
np.exp(np.linspace(-1, 1, 2**5, endpoint=False)),
)
self.assertSimilar(
mps_exponential(-1, 1, 5, c=-1).to_vector(),
np.exp(-np.linspace(-1, 1, 2**5, endpoint=False)),
)

def test_mps_sine(self):
self.assertSimilar(
mps_sine(-1, 1, 5).to_vector(),
np.sin(np.linspace(-1, 1, 2**5, endpoint=False)),
)

def test_mps_cosine(self):
self.assertSimilar(
mps_cosine(-1, 1, 5).to_vector(),
np.cos(np.linspace(-1, 1, 2**5, endpoint=False)),
)

def test_mps_interval(self):
start = -1
stop = 1
sites = 5
mps_half_open = mps_interval(RegularHalfOpenInterval(start, stop, 2**sites))
mps_closed = mps_interval(RegularClosedInterval(start, stop, 2**sites))
mps_zeros = mps_interval(ChebyshevZerosInterval(start, stop, 2**sites))
zeros = lambda d: np.array(
[np.cos(np.pi * (2 * k - 1) / (2 * d)) for k in range(d, 0, -1)]
)
self.assertSimilar(
mps_half_open, np.linspace(start, stop, 2**sites, endpoint=False)
)
self.assertSimilar(
mps_closed, np.linspace(start, stop, 2**sites, endpoint=True)
)
self.assertSimilar(mps_zeros, zeros(2**sites))


class TestMPSOperations(TestCase):
def test_tensor_product(self):
sites = 5
interval = RegularHalfOpenInterval(-1, 2, 2**sites)
mps_x = mps_interval(interval)
mps_x_times_y = mps_tensor_product([mps_x, mps_x])
Z_mps = mps_x_times_y.to_vector().reshape((2**sites, 2**sites))
X, Y = np.meshgrid(interval.to_vector(), interval.to_vector())
self.assertSimilar(Z_mps, X * Y)

def test_tensor_sum(self):
sites = 5
interval = RegularHalfOpenInterval(-1, 2, 2**sites)
mps_x = mps_interval(interval)
mps_x_plus_y = mps_tensor_sum([mps_x, mps_x])
Z_mps = mps_x_plus_y.to_vector().reshape((2**sites, 2**sites))
X, Y = np.meshgrid(interval.to_vector(), interval.to_vector())
self.assertSimilar(Z_mps, X + Y)
Loading

0 comments on commit f399002

Please sign in to comment.