Skip to content

Commit

Permalink
Change variable truncation to strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulaGarciaMolina committed Oct 26, 2023
1 parent 4e9abdc commit 0305aca
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions seemps/truncate/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def simplify(
state: Union[MPS, MPSSum],
truncation: Strategy = SIMPLIFICATION_STRATEGY,
strategy: Strategy = SIMPLIFICATION_STRATEGY,
direction: int = +1

) -> MPS:
Expand All @@ -34,7 +34,7 @@ def simplify(
----------
state : MPS | MPSSum
State to approximate.
truncation : Strategy
strategy : Strategy
Truncation strategy. Defaults to `SIMPLIFICATION_STRATEGY`.
direction : { +1, -1 }
Initial direction for the sweeping algorithm.
Expand All @@ -48,20 +48,20 @@ def simplify(
return combine(
state.weights,
state.states,
truncation=truncation,
strategy=strategy,
direction=direction,
)

size = state.size
start = 0 if direction > 0 else size - 1
normalize= truncation.get_normalize_flag()
maxsweeps = truncation.get_max_sweeps()
simplification_tolerance = truncation.get_simplification_tolerance()
max_bond_dimension = truncation.get_max_bond_dimension()
mps = CanonicalMPS(state, center=start, strategy=truncation)
normalize= strategy.get_normalize_flag()
maxsweeps = strategy.get_max_sweeps()
simplification_tolerance = strategy.get_simplification_tolerance()
max_bond_dimension = strategy.get_max_bond_dimension()
mps = CanonicalMPS(state, center=start, strategy=strategy)
if normalize:
mps.normalize_inplace()
if not truncation.get_simplification_method():
if not strategy.get_simplification_method():
return mps
if max_bond_dimension == 0 and simplification_tolerance <= 0:
return mps
Expand All @@ -76,12 +76,12 @@ def simplify(
for sweep in range(maxsweeps):
if direction > 0:
for n in range(0, size - 1):
mps.update_2site_right(form.tensor2site(direction), n, truncation)
mps.update_2site_right(form.tensor2site(direction), n, strategy)
form.update(direction)
last = size - 1
else:
for n in reversed(range(0, size - 1)):
mps.update_2site_left(form.tensor2site(direction), n, truncation)
mps.update_2site_left(form.tensor2site(direction), n, strategy)
form.update(direction)
last = 0
#
Expand Down Expand Up @@ -157,7 +157,7 @@ def combine(
weights: list[Weight],
states: list[MPS],
guess: Optional[MPS] = None,
truncation: Strategy = SIMPLIFICATION_STRATEGY,
strategy: Strategy = SIMPLIFICATION_STRATEGY,
direction: int = +1
) -> MPS:
"""Approximate a linear combination of MPS :math:`\\sum_i w_i \\psi_i` by
Expand All @@ -171,7 +171,7 @@ def combine(
List of states :math:`\\psi_i`.
guess : MPS, optional
Initial guess for the iterative algorithm.
truncation : Strategy
strategy : Strategy
Truncation strategy. Defaults to `SIMPLIFICATION_STRATEGY`.
direction : {+1, -1}
Initial direction for the sweeping algorithm.
Expand All @@ -187,11 +187,11 @@ def combine(
np.sqrt(np.abs(weights)) * np.sqrt(state.error())
for weights, state in zip(weights, states)
)
normalize= truncation.get_normalize_flag()
maxsweeps = truncation.get_max_sweeps()
simplification_tolerance = truncation.get_simplification_tolerance()
normalize= strategy.get_normalize_flag()
maxsweeps = strategy.get_max_sweeps()
simplification_tolerance = strategy.get_simplification_tolerance()
start = 0 if direction > 0 else guess.size - 1
φ = CanonicalMPS(guess, center=start, strategy=truncation, normalize=normalize)
φ = CanonicalMPS(guess, center=start, strategy=strategy, normalize=normalize)
err = norm_ψsqr = multi_norm_squared(weights, states)
if norm_ψsqr < simplification_tolerance:
return MPS([np.zeros((1, P.shape[1], 1)) for P in φ])
Expand All @@ -209,7 +209,7 @@ def combine(
weights * f.tensor2site(direction)
for weights, f in zip(weights, forms)
) # type: ignore
φ.update_2site_right(tensor, n, truncation)
φ.update_2site_right(tensor, n, strategy)
for f in forms:
f.update(direction)
else:
Expand All @@ -218,7 +218,7 @@ def combine(
weights * f.tensor2site(direction)
for weights, f in zip(weights, forms)
) # type: ignore
φ.update_2site_left(tensor, n, truncation)
φ.update_2site_left(tensor, n, strategy)
for f in forms:
f.update(direction)
last = 0
Expand Down

0 comments on commit 0305aca

Please sign in to comment.