Skip to content

Commit

Permalink
Use Strategy objects in MPS simplify/combine
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll authored Oct 26, 2023
2 parents d27ccc8 + 0305aca commit 3bb0e19
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 60 deletions.
1 change: 1 addition & 0 deletions seemps/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .core import (
Strategy,
Truncation,
Simplification,
DEFAULT_STRATEGY,
DEFAULT_TOLERANCE,
NO_TRUNCATION,
Expand Down
10 changes: 10 additions & 0 deletions seemps/state/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ class Truncation:
RELATIVE_NORM_SQUARED_ERROR = 2
ABSOLUTE_SINGULAR_VALUE = 3

class Simplification:
CANONICAL_FORM = 0
VARIATIONAL = 1

class Strategy:
def __init__(
self: Strategy,
method: int = 1,
simplification_method: int = 1,
tolerance: float = 1e-8,
simplification_tolerance: float = 1e-8,
max_bond_dimension: int = 0x8FFFFFFF,
max_sweeps: int = 16,
normalize: bool = False,
Expand All @@ -21,18 +27,22 @@ class Strategy:
def replace(
self: Strategy,
method: Optional[int] = None,
simplification_method: Optional[int] = None,
tolerance: Optional[float] = None,
simplification_tolerance: Optional[float] = None,
max_bond_dimension: Optional[int] = None,
max_sweeps: Optional[int] = None,
normalize: Optional[bool] = None,
simplify: Optional[bool] = None,
) -> Strategy: ...
def set_normalization(self: Strategy, normalize: bool) -> Strategy: ...
def get_tolerance(self) -> float: ...
def get_simplification_tolerance(self) -> float: ...
def get_max_bond_dimension(self) -> int: ...
def get_max_sweeps(self) -> int: ...
def get_normalize_flag(self) -> bool: ...
def get_simplify_flag(self) -> bool: ...
def get_simplification_method(self) -> int: ...
def __str__(self) -> str: ...

DEFAULT_TOLERANCE: float
Expand Down
36 changes: 32 additions & 4 deletions seemps/state/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,25 @@ class Truncation:
RELATIVE_NORM_SQUARED_ERROR = 2
ABSOLUTE_SINGULAR_VALUE = 3

class Simplification:
CANONICAL_FORM = 0
VARIATIONAL = 1

cdef class Strategy:
cdef int method
cdef int simplification_method
cdef double tolerance
cdef double simplification_tolerance
cdef int max_bond_dimension
cdef int max_sweeps
cdef bool normalize
cdef bool simplify

def __init__(self,
method: int = Truncation.RELATIVE_SINGULAR_VALUE,
simplification_method: int = Simplification.VARIATIONAL,
tolerance: float = 1e-8,
simplification_tolerance: float = 1e-8,
max_bond_dimension: Optional[int] = INT_MAX,
normalize: bool = False,
simplify: bool = False,
Expand All @@ -35,9 +43,13 @@ cdef class Strategy:
if tolerance == 0 and method > 0:
method = 3
self.tolerance = tolerance
self.simplification_tolerance = simplification_tolerance
if method < 0 or method > 3:
raise AssertionError("Invalid method argument passed to Strategy")
self.method = method
if simplification_method < 0 or simplification_method > 1:
raise AssertionError("Invalid simplification_method argument passed to Strategy")
self.simplification_method = simplification_method
if max_bond_dimension is None:
self.max_bond_dimension = INT_MAX
elif max_bond_dimension <= 0:
Expand All @@ -52,13 +64,17 @@ cdef class Strategy:

def replace(self,
method: Optional[Truncation] = None,
simplification_method: Optional[Simplification] = None,
tolerance: Optional[float] = None,
simplification_tolerance: Optional[float] = None,
max_bond_dimension: Optional[int] = None,
normalize: Optional[bool] = None,
simplify: Optional[bool] = None,
max_sweeps: Optional[int] = None):
return Strategy(method = self.method if method is None else method,
simplification_method = self.simplification_method if simplification_method is None else simplification_method,
tolerance = self.tolerance if tolerance is None else tolerance,
simplification_tolerance = self.simplification_tolerance if simplification_tolerance is None else simplification_tolerance,
max_bond_dimension = self.max_bond_dimension if max_bond_dimension is None else max_bond_dimension,
normalize = self.normalize if normalize is None else normalize,
simplify = self.simplify if simplify is None else simplify,
Expand All @@ -67,9 +83,15 @@ cdef class Strategy:
def get_method(self) -> int:
return self.method

def get_simplification_method(self) -> int:
return self.simplification_method

def get_tolerance(self) -> float:
return self.tolerance

def get_simplification_tolerance(self) -> float:
return self.simplification_tolerance

def get_max_bond_dimension(self) -> int:
return self.max_bond_dimension

Expand All @@ -91,14 +113,20 @@ cdef class Strategy:
method="RelativeNorm"
else:
method="AbsoluteSVD"
return f"Strategy(method={method}, tolerance={self.tolerance}, " \
f"max_bond_dimension={self.max_bond_dimension}, normalize={self.normalize}, " \
f"simplify={self.simplify}, max_sweeps={self.max_sweeps})"
if self.simplification_method == 0:
simplification_method="CanonicalForm"
elif self.simplification_method == 1:
simplification_method="Variational"
return f"Strategy(method={method}, simplification_method={simplification_method}, " \
f"tolerance={self.tolerance}, max_bond_dimension={self.max_bond_dimension}, " \
f"normalize={self.normalize}, simplify={self.simplify}, max_sweeps={self.max_sweeps})"

DEFAULT_TOLERANCE = np.finfo(np.float64).eps

DEFAULT_STRATEGY = Strategy(method = Truncation.RELATIVE_NORM_SQUARED_ERROR,
tolerance = np.finfo(np.float64).eps,
simplification_method = Simplification.VARIATIONAL,
tolerance = DEFAULT_TOLERANCE,
simplification_tolerance = DEFAULT_TOLERANCE,
max_bond_dimension = INT_MAX,
normalize = False)

Expand Down
99 changes: 43 additions & 56 deletions seemps/truncate/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,23 @@
from ..typing import *
from .antilinear import AntilinearForm


# TODO: We have to rationalize all this about directions. The user should
# not really care about it and we can guess the direction from the canonical
# form of either the guess or the state.

SIMPLIFICATION_STRATEGY = Strategy(
method=Truncation.RELATIVE_NORM_SQUARED_ERROR,
tolerance=DEFAULT_TOLERANCE,
max_bond_dimension=MAX_BOND_DIMENSION,
normalize=True,
max_sweeps=4
)

def simplify(
state: Union[MPS, MPSSum],
maxsweeps: int = 4,
direction: int = +1,
tolerance: float = DEFAULT_TOLERANCE,
normalize: bool = True,
max_bond_dimension: int = MAX_BOND_DIMENSION,
state: Union[MPS, MPSSum],
strategy: Strategy = SIMPLIFICATION_STRATEGY,
direction: int = +1

) -> MPS:
"""Simplify an MPS state transforming it into another one with a smaller bond
dimension, sweeping until convergence is achieved.
Expand All @@ -28,15 +34,10 @@ def simplify(
----------
state : MPS | MPSSum
State to approximate.
strategy : Strategy
Truncation strategy. Defaults to `SIMPLIFICATION_STRATEGY`.
direction : { +1, -1 }
Direction of the first sweep
maxsweeps : int
Maximum number of sweeps to run
tolerance : float
Relative tolerance when splitting the tensors. Defaults to
`DEFAULT_TOLERANCE`
max_bond_dimension : int
Maximum bond dimension. Defaults to `MAX_BOND_DIMENSION`
Initial direction for the sweeping algorithm.
Returns
-------
Expand All @@ -47,44 +48,40 @@ def simplify(
return combine(
state.weights,
state.states,
maxsweeps=maxsweeps,
strategy=strategy,
direction=direction,
tolerance=tolerance,
max_bond_dimension=max_bond_dimension,
normalize=normalize,
)

size = state.size
start = 0 if direction > 0 else size - 1

truncation = Strategy(
method=Truncation.RELATIVE_NORM_SQUARED_ERROR,
tolerance=tolerance,
max_bond_dimension=max_bond_dimension,
normalize=normalize,
)
mps = CanonicalMPS(state, center=start, strategy=truncation)
normalize= strategy.get_normalize_flag()
maxsweeps = strategy.get_max_sweeps()
simplification_tolerance = strategy.get_simplification_tolerance()
max_bond_dimension = strategy.get_max_bond_dimension()
mps = CanonicalMPS(state, center=start, strategy=strategy)
if normalize:
mps.normalize_inplace()
if max_bond_dimension == 0 and tolerance <= 0:
if not strategy.get_simplification_method():
return mps
if max_bond_dimension == 0 and simplification_tolerance <= 0:
return mps

form = AntilinearForm(mps, state, center=start)
norm_state_sqr = scprod(state, state).real
base_error = state.error()
err = 1.0
log(
f"SIMPLIFY state with |state|={norm_state_sqr**0.5} for {maxsweeps} sweeps, with tolerance {tolerance}."
f"SIMPLIFY state with |state|={norm_state_sqr**0.5} for {maxsweeps} sweeps, with tolerance {simplification_tolerance}."
)
for sweep in range(maxsweeps):
if direction > 0:
for n in range(0, size - 1):
mps.update_2site_right(form.tensor2site(direction), n, truncation)
mps.update_2site_right(form.tensor2site(direction), n, strategy)
form.update(direction)
last = size - 1
else:
for n in reversed(range(0, size - 1)):
mps.update_2site_left(form.tensor2site(direction), n, truncation)
mps.update_2site_left(form.tensor2site(direction), n, strategy)
form.update(direction)
last = 0
#
Expand All @@ -104,7 +101,7 @@ def simplify(
log(
f"sweep={sweep}, rel.err.={err}, old err.={old_err}, |mps|={norm_mps_sqr**0.5}"
)
if err < tolerance or err > old_err:
if err < simplification_tolerance or err > old_err:
log("Stopping, as tolerance reached")
break
direction = -direction
Expand Down Expand Up @@ -159,12 +156,9 @@ def combine_tensors(A: Tensor3, sumA: Tensor3) -> Tensor3:
def combine(
weights: list[Weight],
states: list[MPS],
guess: Optional[MPS] = None,
maxsweeps: int = 4,
direction: int = +1,
tolerance: float = DEFAULT_TOLERANCE,
max_bond_dimension: int = MAX_BOND_DIMENSION,
normalize: bool = True,
guess: Optional[MPS] = None,
strategy: Strategy = SIMPLIFICATION_STRATEGY,
direction: int = +1
) -> MPS:
"""Approximate a linear combination of MPS :math:`\\sum_i w_i \\psi_i` by
another one with a smaller bond dimension, sweeping until convergence is achieved.
Expand All @@ -174,17 +168,13 @@ def combine(
weights : list[Weight]
Weights of the linear combination :math:`w_i` in list form.
states : list[MPS]
List of states :math:`\\psi_i`
List of states :math:`\\psi_i`.
guess : MPS, optional
Initial guess for the iterative algorithm
Initial guess for the iterative algorithm.
strategy : Strategy
Truncation strategy. Defaults to `SIMPLIFICATION_STRATEGY`.
direction : {+1, -1}
Initial direction for the sweeping algorithm
maxsweeps : int
Maximum number of iterations
tolerance :
Relative tolerance when splitting the tensors
max_bond_dimension :
Maximum bond dimension
Initial direction for the sweeping algorithm.
Returns
-------
Expand All @@ -197,19 +187,16 @@ def combine(
np.sqrt(np.abs(weights)) * np.sqrt(state.error())
for weights, state in zip(weights, states)
)
strategy = Strategy(
method=Truncation.RELATIVE_NORM_SQUARED_ERROR,
tolerance=tolerance,
max_bond_dimension=max_bond_dimension,
normalize=normalize,
)
normalize= strategy.get_normalize_flag()
maxsweeps = strategy.get_max_sweeps()
simplification_tolerance = strategy.get_simplification_tolerance()
start = 0 if direction > 0 else guess.size - 1
φ = CanonicalMPS(guess, center=start, strategy=strategy, normalize=normalize)
err = norm_ψsqr = multi_norm_squared(weights, states)
if norm_ψsqr < tolerance:
if norm_ψsqr < simplification_tolerance:
return MPS([np.zeros((1, P.shape[1], 1)) for P in φ])
log(
f"COMBINE state with |state|={norm_ψsqr**0.5} for {maxsweeps} sweeps with tolerance {strategy.get_tolerance()}.\nWeights: {weights}"
f"COMBINE state with |state|={norm_ψsqr**0.5} for {maxsweeps} sweeps with tolerance {simplification_tolerance}.\nWeights: {weights}"
)

size = φ.size
Expand Down Expand Up @@ -249,7 +236,7 @@ def combine(
old_err = err
err = 2 * abs(1.0 - scprod_φψ.real / np.sqrt(norm_φsqr * norm_ψsqr))
log(f"sweep={sweep}, rel.err.={err}, old err.={old_err}, |φ|={norm_φsqr**0.5}")
if err < tolerance or err > old_err:
if err < simplification_tolerance or err > old_err:
log("Stopping, as tolerance reached")
break
direction = -direction
Expand Down
53 changes: 53 additions & 0 deletions tests/test_combine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
from seemps.state import DEFAULT_STRATEGY, NO_TRUNCATION, random_uniform_mps
from seemps.truncate.simplify import combine

from .tools import *


class TestCombine(TestCase):

def test_no_truncation(self):
d = 2
for n in range(3,9):
ψ1 = random_uniform_mps(d, n, D=int(2**(n/2)))
ψ1 = ψ1 * (1/ψ1.norm())
ψ2 = random_uniform_mps(d, n, D=int(2**(n/2)))
ψ2 = ψ2 * (1/ψ2.norm())
a1 = np.random.randn()
a2 = np.random.randn()
ψ = a1*ψ1.to_vector() + a2*ψ2.to_vector()
φ = combine(weights=[a1,a2], states=[ψ1,ψ2], truncation=NO_TRUNCATION)
self.assertSimilar(ψ, φ.to_vector())

def test_tolerance(self):
d = 2
tolerance = 1e-10
strategy = DEFAULT_STRATEGY.replace(simplification_tolerance=tolerance)
for n in range(3,15):
ψ1 = random_uniform_mps(d, n, D=int(2**(n/2)))
ψ1 = ψ1 * (1/ψ1.norm())
ψ2 = random_uniform_mps(d, n, D=int(2**(n/2)))
ψ2 = ψ2 * (1/ψ2.norm())
a1 = np.random.randn()
a2 = np.random.randn()
ψ = a1*ψ1.to_vector() + a2*ψ2.to_vector()
φ = combine(weights=[a1,a2], states=[ψ1,ψ2], truncation=strategy)
err = 2 * abs(
1.0 - np.vdot(ψ, φ.to_vector()).real / (np.linalg.norm(ψ) * φ.norm()))
self.assertTrue(err < tolerance)

def test_max_bond_dimensions(self):
d = 2
n = 14
for D in range(2,15):
strategy = DEFAULT_STRATEGY.replace(max_bond_dimension=D)
ψ1 = random_uniform_mps(d, n, D=int(2**(n/2)))
ψ1 = ψ1 * (1/ψ1.norm())
ψ2 = random_uniform_mps(d, n, D=int(2**(n/2)))
ψ2 = ψ2 * (1/ψ2.norm())
a1 = np.random.randn()
a2 = np.random.randn()
φ = combine(weights=[a1,a2], states=[ψ1,ψ2], truncation=strategy)
max_D_φ = max([max(t.shape) for t in φ])
self.assertTrue(max_D_φ <= D)
Loading

0 comments on commit 3bb0e19

Please sign in to comment.