Skip to content

Commit

Permalink
Replace combine() with a function specialized on MPSSum objects
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll committed Mar 31, 2024
1 parent df8614d commit 25cb048
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 154 deletions.
4 changes: 2 additions & 2 deletions src/seemps/truncate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .antilinear import AntilinearForm
from .simplify import simplify, combine, SIMPLIFICATION_STRATEGY
from .simplify import simplify, SIMPLIFICATION_STRATEGY

__all__ = ["simplify", "combine", "AntilinearForm", "SIMPLIFICATION_STRATEGY"]
__all__ = ["simplify", "AntilinearForm", "SIMPLIFICATION_STRATEGY"]
167 changes: 50 additions & 117 deletions src/seemps/truncate/simplify.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from __future__ import annotations

from typing import Optional, Union

import numpy as np
import scipy.linalg # type: ignore # scipy does not provide type declarations

from .. import tools
from ..state import (
DEFAULT_TOLERANCE,
Expand All @@ -17,7 +13,7 @@
Truncation,
)
from ..state.environments import scprod
from ..typing import Tensor3, Weight
from ..typing import Weight
from .antilinear import AntilinearForm

# TODO: We have to rationalize all this about directions. The user should
Expand All @@ -38,6 +34,7 @@ def simplify(
state: Union[MPS, MPSSum],
strategy: Strategy = SIMPLIFICATION_STRATEGY,
direction: int = +1,
guess: Optional[MPS] = None,
) -> CanonicalMPS:
"""Simplify an MPS state transforming it into another one with a smaller bond
dimension, sweeping until convergence is achieved.
Expand All @@ -50,31 +47,33 @@ def simplify(
Truncation strategy. Defaults to `SIMPLIFICATION_STRATEGY`.
direction : { +1, -1 }
Initial direction for the sweeping algorithm.
guess : MPS
A guess for the new state, to ease the optimization.
Returns
-------
CanonicalMPS
Approximation :math:`\\xi` to the state.
"""
if isinstance(state, MPSSum):
return combine(
state.weights,
state.states,
strategy=strategy,
direction=direction,
)
#
return simplify_mps_sum(state, strategy, direction, guess)

# Prepare initial guess
normalize = strategy.get_normalize_flag()
size = state.size
start = 0 if direction > 0 else -1
mps = CanonicalMPS(state, center=start, strategy=strategy)

# If we only do canonical forms, not variational optimization, a second
# pass on that initial guess suffices
if strategy.get_simplification_method() == Simplification.CANONICAL_FORM:
mps = CanonicalMPS(state, center=start, strategy=strategy)
return CanonicalMPS(mps, center=-1 - start, strategy=strategy)

if guess is None:
mps = CanonicalMPS(state, center=start, strategy=strategy)
else:
mps = CanonicalMPS(guess)

simplification_tolerance = strategy.get_simplification_tolerance()
if not (norm_state_sqr := state.norm_squared()):
return CanonicalMPS(state.zero_state(), is_canonical=True)
Expand Down Expand Up @@ -122,15 +121,13 @@ def simplify(
return mps


def select_nonzero_mps_components(
weights: list[Weight], states: list[MPS]
) -> tuple[float, list[Weight], list[MPS]]:
def select_nonzero_mps_components(state: MPSSum) -> tuple[float, MPSSum]:
"""Compute the norm-squared of the linear combination of weights and
states and eliminate states that are zero or have zero weight."""
c: float = 0.0
final_weights: list[Weight] = []
final_states: list[MPS] = []
for wi, si in zip(weights, states):
for wi, si in zip(state.weights, state.states):
wic = wi.conjugate()
ni = (wic * wi).real * si.norm_squared()
if ni:
Expand All @@ -139,97 +136,28 @@ def select_nonzero_mps_components(
final_states.append(si)
final_weights.append(wi)
c += ni
return abs(c), final_weights, final_states


def crappy_guess_combine_state(weights: list[Weight], states: list[MPS]) -> MPS:
"""Make an educated guess that ensures convergence of the :func:`combine`
algorithm."""

def combine_tensors(A: Tensor3, sumA: Tensor3) -> Tensor3:
DL, d, DR = sumA.shape
a, d, b = A.shape
if DL < a or DR < b:
# Extend with zeros to accommodate new contribution
newA = np.zeros((max(DL, a), d, max(DR, b)), dtype=sumA.dtype)
newA[:DL, :, :DR] = sumA
else:
newA = sumA.copy()
dt = type(A[0, 0, 0] + sumA[0, 0, 0])
if sumA.dtype != dt:
newA = newA.astype(dt)
else:
newA[:a, :, :b] += A
return newA

guess: MPS = weights[0] * states[0]
for n, state in enumerate(states[1:]):
for i, (A, sumA) in enumerate(zip(state, guess)):
guess[i] = combine_tensors(A if i > 0 else A * weights[n], sumA)
return guess


def guess_combine_state(weights: list, states: list[MPS]) -> MPS:
"""Make an educated guess that ensures convergence of the :func:`combine`
algorithm."""

def combine_tensors(A, sumA, idx):
DL, d, DR = sumA.shape
a, d, b = A.shape
if idx == 0:
new_A = np.zeros((d, DR + b), dtype=sumA.dtype)
for d_i in range(d):
new_A[d_i, :] = np.concatenate(
(sumA.reshape(d, DR)[d_i, :], A.reshape(d, b)[d_i, :])
)
new_A = new_A.reshape(1, d, DR + b)
elif idx == -1:
new_A = np.zeros((DL + a, d), dtype=sumA.dtype)
for d_i in range(d):
new_A[:, d_i] = np.concatenate(
(sumA.reshape(DL, d)[:, d_i], A.reshape(a, d)[:, d_i])
)
new_A = new_A.reshape(DL + a, d, 1)
else:
new_A = np.zeros((DL + a, d, DR + b), dtype=sumA.dtype)
for d_i in range(d):
new_A[:, d_i, :] = scipy.linalg.block_diag(
sumA[:, d_i, :], A[:, d_i, :]
)
return new_A

guess = []
size = states[0].size
for i in range(size):
sumA = states[0][i] * weights[0] if i == 0 else states[0][i]
if i == size - 1:
i = -1
for n, state in enumerate(states[1:]):
A = state[i]
sumA = combine_tensors(A * weights[n + 1] if i == 0 else A, sumA, i)
guess.append(sumA)
return MPS(guess)
if len(final_weights) < state.size:
return abs(c), MPSSum(final_weights, final_states, check_args=False)
else:
return abs(c), state


# TODO: We have to rationalize all this about directions. The user should
# not really care about it and we can guess the direction from the canonical
# form of either the guess or the state.
def combine(
weights: list[Weight],
states: list[MPS],
guess: Optional[MPS] = None,
def simplify_mps_sum(
sum_state: MPSSum,
strategy: Strategy = SIMPLIFICATION_STRATEGY,
direction: int = +1,
guess: Optional[MPS] = None,
) -> CanonicalMPS:
"""Approximate a linear combination of MPS :math:`\\sum_i w_i \\psi_i` by
another one with a smaller bond dimension, sweeping until convergence is achieved.
Parameters
----------
weights : list[Weight]
Weights of the linear combination :math:`w_i` in list form.
states : list[MPS]
List of states :math:`\\psi_i`.
state : MPSSum
State to approximate
guess : MPS, optional
Initial guess for the iterative algorithm.
strategy : Strategy
Expand All @@ -242,13 +170,21 @@ def combine(
CanonicalMPS
Approximation to the linear combination in canonical form
"""
# Compute norm of output and eliminate zero states
orig_sum_state = sum_state
norm_state_sqr, state = select_nonzero_mps_components(sum_state)
if not norm_state_sqr:
tools.log(
"COMBINE state with |state|=0. Returning zero state.",
debug_level=2,
)
return CanonicalMPS(orig_sum_state.states[0].zero_state(), is_canonical=True)

normalize = strategy.get_normalize_flag()
start = 0 if direction > 0 else -1
# CANONICAL_FORM implements a simplification based on two passes
if strategy.get_simplification_method() == Simplification.CANONICAL_FORM:
mps = CanonicalMPS(
guess_combine_state(weights, states), center=start, strategy=strategy
)
mps = CanonicalMPS(sum_state.join(), center=start, strategy=strategy)
mps = CanonicalMPS(mps, center=-1 - start, strategy=strategy)
if tools.DEBUG >= 2:
tools.log(
Expand All @@ -262,9 +198,7 @@ def combine(
# output is expected to be a CanonicalMPS, we must use the
# strategy to construct it.
if strategy.get_simplification_method() == Simplification.DO_NOT_SIMPLIFY:
mps = CanonicalMPS(
guess_combine_state(weights, states), center=-1 - start, strategy=strategy
)
mps = CanonicalMPS(sum_state.join(), center=-1 - start, strategy=strategy)
if tools.DEBUG >= 2:
tools.log(
f"SIMPLIFY state with |state|={mps.norm():5e}\nusing single-pass "
Expand All @@ -273,32 +207,20 @@ def combine(
debug_level=2,
)

# Compute norm of output and eliminate zero states
orig_state_0 = states[0]
norm_state_sqr, weights, states = select_nonzero_mps_components(weights, states)
if not norm_state_sqr:
tools.log(
"COMBINE state with |state|=0. Returning zero state.",
debug_level=2,
)
return CanonicalMPS(orig_state_0.zero_state(), is_canonical=True)

# Prepare initial guess
if guess is None:
if strategy.get_simplification_method() == Simplification.VARIATIONAL:
guess = crappy_guess_combine_state(weights, states)
else: # Simplification.VARIATIONAL_EXACT_GUESS:
guess = guess_combine_state(weights, states)
guess = sum_state.join()
mps = CanonicalMPS(guess, center=start, strategy=strategy)
simplification_tolerance = strategy.get_simplification_tolerance()

size = mps.size
weights, states = sum_state.weights, sum_state.states
forms = [AntilinearForm(mps, si, center=start) for si in states]
tools.log(
f"COMBINE state with |state|={norm_state_sqr**0.5:5e} for {strategy.get_max_sweeps():5e}"
f"sweeps with tolerance {simplification_tolerance:5e}.\nWeights: {weights}",
debug_level=2,
)
size = mps.size
forms = [AntilinearForm(mps, state, center=start) for state in states]
err = 2.0
for sweep in range(max(1, strategy.get_max_sweeps())):
if direction > 0:
Expand Down Expand Up @@ -352,3 +274,14 @@ def combine(
if normalize and norm_mps_sqr:
last_tensor /= norm_mps_sqr
return mps


def combine(
weights: list[Weight],
states: list[MPS],
strategy: Strategy = SIMPLIFICATION_STRATEGY,
direction: int = +1,
guess: Optional[MPS] = None,
) -> CanonicalMPS:
"""Deprecated, use `simplify` instead."""
return simplify_mps_sum(MPSSum(weights, states))
2 changes: 0 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
test_contractions,
test_strategy,
test_circuits,
test_combine,
test_contractions,
test_hdf5,
test_linear_form,
test_qft,
test_register,
test_simplify,
test_strategy,
test_tools,
test_truncate,
Expand Down
1 change: 1 addition & 0 deletions tests/test_truncate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import test_truncation, test_simplify, test_simplify_sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
scprod,
)
from seemps.truncate import simplify
from .. import tools

from .tools import *


class TestSimplify(TestCase):
class TestSimplify(tools.TestCase):
def test_no_truncation(self):
d = 2
strategy = DEFAULT_STRATEGY.replace(
Expand Down
Loading

0 comments on commit 25cb048

Please sign in to comment.