Skip to content

Commit

Permalink
Added more entropy functions
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll committed Jan 1, 2024
1 parent f4baa03 commit a701b4a
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 18 deletions.
26 changes: 25 additions & 1 deletion src/seemps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,31 @@
cgs,
)

from .state import *
from .state import (
Strategy,
Truncation,
Simplification,
DEFAULT_STRATEGY,
DEFAULT_TOLERANCE,
NO_TRUNCATION,
MAX_BOND_DIMENSION,
MPS,
MPSSum,
CanonicalMPS,
product_state,
GHZ,
W,
spin_wave,
graph,
AKLT,
random,
random_mps,
random_uniform_mps,
gaussian,
all_entanglement_entropies,
all_Renyi_entropies,
sample_mps,
)
from .mpo import MPO, MPOList
from .hamiltonians import *
from .tools import σx, σy, σz
Expand Down
20 changes: 10 additions & 10 deletions src/seemps/state/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from . import array
from .core import (
Strategy,
Truncation,
Simplification,
DEFAULT_STRATEGY,
DEFAULT_TOLERANCE,
NO_TRUNCATION,
MAX_BOND_DIMENSION,
)
from .mps import MPS, MPSSum, Weight
from .factories import (
product_state,
Expand All @@ -13,13 +21,5 @@
gaussian,
)
from .canonical_mps import CanonicalMPS
from .core import (
Strategy,
Truncation,
Simplification,
DEFAULT_STRATEGY,
DEFAULT_TOLERANCE,
NO_TRUNCATION,
MAX_BOND_DIMENSION,
)
from .entropies import all_entanglement_entropies, all_Renyi_entropies
from .sampling import sample_mps
62 changes: 55 additions & 7 deletions src/seemps/state/canonical_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,8 @@ def right_environment(self, site: int) -> Environment:
ρ = environments.update_right_environment(A, A, ρ)
return ρ

def entanglement_entropy(self, site: Optional[int] = None) -> float:
"""Compute the entanglement entropy of the MPS for a bipartition
around `site`.
def Schmidt_weights(self, site: Optional[int] = None) -> Vector:
"""Return the Schmidt weights for a bipartition around `site`.
Parameters
----------
Expand All @@ -182,13 +181,15 @@ def entanglement_entropy(self, site: Optional[int] = None) -> float:
Returns
-------
float
Von Neumann entropy of bipartition.
numbers: np.ndarray
Vector of non-negative Schmidt weights.
"""
if site is None:
site = self.center
else:
site = self._interpret_center(site)
if site != self.center:
return self.copy().recenter(site).entanglement_entropy()
return self.copy().recenter(site).Schmidt_weights()
# TODO: this is for [0, self.center] (self.center, self.size)
# bipartitions, but we can also optimizze [0, self.center) [self.center, self.size)
A = self._data[site]
Expand All @@ -200,7 +201,54 @@ def entanglement_entropy(self, site: Optional[int] = None) -> float:
check_finite=False,
lapack_driver=schmidt.SVD_LAPACK_DRIVER,
)
return -np.sum(2 * s * s * np.log2(s))
s *= s
s /= np.sum(s)
return s

def entanglement_entropy(self, site: Optional[int] = None) -> float:
"""Compute the entanglement entropy of the MPS for a bipartition
around `site`.
Parameters
----------
site : int, optional
Site in the range `[0, self.size)`, defaulting to `self.center`.
The system is diveded into `[0, self.site)` and `[self.site, self.size)`.
Returns
-------
float
Von Neumann entropy of bipartition.
"""
s = self.Schmidt_weights(site)
return -np.sum(s * np.log2(s))

def Renyi_entropy(self, site: Optional[int] = None, alpha: float = 2.0) -> float:
"""Compute the Renyi entropy of the MPS for a bipartition
around `site`.
Parameters
----------
site : int, optional
Site in the range `[0, self.size)`, defaulting to `self.center`.
The system is diveded into `[0, self.site)` and `[self.site, self.size)`.
alpha : float, default = 2
Power of the Renyi entropy.
Returns
-------
float
Von Neumann entropy of bipartition.
"""
s = self.Schmidt_weights(site)
if alpha < 0:
raise ValueError("Invalid Renyi entropy power")
if alpha == 0:
alpha = 1e-9
elif alpha == 1:
alpha = 1 - 1e-9
S = np.log(np.sum(s**alpha)) / (1 - alpha)
return S

def update_canonical(
self, A: Tensor3, direction: int, truncation: Strategy
Expand Down
24 changes: 24 additions & 0 deletions src/seemps/state/entropies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
from .mps import MPS
from .canonical_mps import CanonicalMPS
from ..typing import Vector


def all_entanglement_entropies(state: MPS) -> Vector:
cstate = CanonicalMPS(state, center=0)
L = len(cstate)
entropies = np.empty(L)
for i in range(L):
cstate = CanonicalMPS(cstate, center=i)
entropies[i] = cstate.entanglement_entropy(i)
return entropies


def all_Renyi_entropies(state: MPS, alpha: float) -> Vector:
cstate = CanonicalMPS(state, center=0)
L = len(cstate)
entropies = np.empty(L)
for i in range(L):
cstate = CanonicalMPS(cstate, center=i)
entropies[i] = cstate.Renyi_entropy(i, alpha)
return entropies
12 changes: 12 additions & 0 deletions tests/test_canonical.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,24 @@ def test_canonical_complains_if_center_out_of_bounds(self):
with self.assertRaises(Exception):
CanonicalMPS(mps, center=-11)

def test_canonical_Schmidt_weights(self):
mps = CanonicalMPS(product_state([1.0, 0.0], 10), center=0)
self.assertSimilar(mps.Schmidt_weights(), [1.0])
self.assertSimilar(mps.Schmidt_weights(0), [1.0])
self.assertSimilar(mps.Schmidt_weights(-1), [1.0])

def test_canonical_entanglement_entropy(self):
mps = CanonicalMPS(product_state([1.0, 0.0], 10), center=0)
self.assertAlmostEqual(mps.entanglement_entropy(), 0.0)
self.assertAlmostEqual(mps.entanglement_entropy(0), 0.0)
self.assertAlmostEqual(mps.entanglement_entropy(-1), 0.0)

def test_canonical_Renyi_entropy(self):
mps = CanonicalMPS(product_state([1.0, 0.0], 10), center=0)
self.assertAlmostEqual(mps.Renyi_entropy(alpha=2), 0.0)
self.assertAlmostEqual(mps.Renyi_entropy(0, alpha=2), 0.0)
self.assertAlmostEqual(mps.Renyi_entropy(-1, alpha=2), 0.0)

def test_canonical_from_vector(self):
state = self.rng.normal(size=2**8)
state /= np.linalg.norm(state)
Expand Down

0 comments on commit a701b4a

Please sign in to comment.