Skip to content

Commit

Permalink
Replace left/right_orth_2site with _update_canonical_2site_left/right
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll committed Apr 12, 2024
1 parent 0ab4cfc commit 74e318b
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 34 deletions.
10 changes: 4 additions & 6 deletions src/seemps/state/canonical_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
Strategy,
_update_in_canonical_form_right,
_update_in_canonical_form_left,
_update_canonical_2site_left,
_update_canonical_2site_right,
_canonicalize,
)
from .mps import MPS
Expand Down Expand Up @@ -270,9 +272,7 @@ def update_2site_right(self, AA: Tensor4, site: int, strategy: Strategy) -> None
Truncation strategy, including relative tolerances and maximum
bond dimensions
"""
self._data[site], self._data[site + 1], err = schmidt.left_orth_2site(
AA, strategy
)
err = _update_canonical_2site_left(self._data, AA, site, strategy)
self.center = site + 1
self.update_error(err)

Expand All @@ -293,9 +293,7 @@ def update_2site_left(self, AA: Tensor4, site: int, strategy: Strategy) -> None:
Truncation strategy, including relative tolerances and maximum
bond dimensions
"""
self._data[site], self._data[site + 1], err = schmidt.right_orth_2site(
AA, strategy
)
err = _update_canonical_2site_right(self._data, AA, site, strategy)
self.center = site
self.update_error(err)

Expand Down
2 changes: 2 additions & 0 deletions src/seemps/state/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ PYBIND11_MODULE(core, m) {
m.def("_canonicalize", &_canonicalize,
R"doc(Update a list of `Tensor3` objects to be in canonical form
with respect to `center`.)doc");
m.def("_update_canonical_2site_left", &_update_canonical_2site_left);
m.def("_update_canonical_2site_right", &_update_canonical_2site_right);
m.def("left_orth_2site", &left_orth_2site);
m.def("right_orth_2site", &right_orth_2site);
}
6 changes: 6 additions & 0 deletions src/seemps/state/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def _update_in_canonical_form_left(
def _update_in_canonical_form_right(
state: list[Tensor3], A: Tensor3, site: int, truncation: Strategy
) -> tuple[int, float]: ...
def _update_canonical_2site_left(
state: list[Tensor3], A: Tensor4, site: int, truncation: Strategy
) -> float: ...
def _update_canonical_2site_right(
state: list[Tensor3], A: Tensor4, site: int, truncation: Strategy
) -> float: ...
def _canonicalize(state: list[Tensor3], center: int, truncation: Strategy) -> float: ...

from .mps import MPS # noqa: E402
5 changes: 4 additions & 1 deletion src/seemps/state/mps.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,8 @@ std::tuple<py::object, py::object, double>
left_orth_2site(py::object AA, const Strategy &strategy);
std::tuple<py::object, py::object, double>
right_orth_2site(py::object AA, const Strategy &strategy);

double _update_canonical_2site_left(py::list state, py::object A, int site,
const Strategy &strategy);
double _update_canonical_2site_right(py::list state, py::object A, int site,
const Strategy &strategy);
} // namespace seemps
67 changes: 41 additions & 26 deletions src/seemps/state/schmidt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,36 +104,51 @@ right_orth_2site(py::object AA, const Strategy &strategy) {
as_3tensor(matrix_resize(V, D, -1), D, d2, b), err};
}

/*
double _update_canonical_2site_left(py::list state, py::object A, int site,
const Strategy &strategy) {
if (!is_array(A) || array_ndim(A) != 4) {
throw std::invalid_argument(
"Invalid tensor passed to _update_canonical_2site_left");
}
py::object tensor = array_getcontiguous(A);
auto a = array_dim(tensor, 0);
auto d1 = array_dim(tensor, 1);
auto d2 = array_dim(tensor, 2);
auto b = array_dim(tensor, 3);

// Split tensor
auto [U, s, V] =
destructive_svd(array_reshape(tensor, array_dims_t{a * d1, d2 * b}));
auto err = destructively_truncate_vector(s, strategy);
auto D = array_size(s);

def left_orth_2site(AA, strategy: Strategy):
"""Split a tensor AA[a,b,c,d] into B[a,b,r] and C[r,c,d] such
that 'B' is a left-isometry, truncating the size 'r' according
to the given 'strategy'. Tensor 'AA' may be overwritten."""
α, d1, d2, β = AA.shape
U, S, V = _destructive_svd(AA.reshape(α * d1, β * d2))
err = destructively_truncate_vector(S, strategy)
D = S.size
return (
U[:, :D].reshape(α, d1, D),
(S.reshape(D, 1) * V[:D, :]).reshape(D, d2, β),
err,
)
state[site] = as_3tensor(matrix_resize(U, -1, D), a, d1, D);
state[site + 1] =
as_3tensor(as_matrix(s, D, 1) * matrix_resize(V, D, -1), D, d2, b);
return err;
}

double _update_canonical_2site_right(py::list state, py::object A, int site,
const Strategy &strategy) {
if (!is_array(A) || array_ndim(A) != 4) {
throw std::invalid_argument(
"Invalid tensor passed to _update_canonical_2site_left");
}
py::object tensor = array_getcontiguous(A);
auto a = array_dim(tensor, 0);
auto d1 = array_dim(tensor, 1);
auto d2 = array_dim(tensor, 2);
auto b = array_dim(tensor, 3);

def right_orth_2site(AA, strategy: Strategy):
"""Split a tensor AA[a,b,c,d] into B[a,b,r] and C[r,c,d] such
that 'C' is a right-isometry, truncating the size 'r' according
to the given 'strategy'. Tensor 'AA' may be overwritten."""
α, d1, d2, β = AA.shape
U, S, V = _destructive_svd(AA.reshape(α * d1, β * d2))
err = destructively_truncate_vector(S, strategy)
D = S.size
return (U[:, :D] * S).reshape(α, d1, D), V[:D, :].reshape(D, d2, β), err
// Split tensor
auto [U, s, V] =
destructive_svd(array_reshape(tensor, array_dims_t{a * d1, d2 * b}));
auto err = destructively_truncate_vector(s, strategy);
auto D = array_size(s);

*/
state[site] = as_3tensor(matrix_resize(U, -1, D) * s, a, d1, D);
state[site + 1] = as_3tensor(matrix_resize(V, D, -1), D, d2, b);
return err;
}

} // namespace seemps
2 changes: 1 addition & 1 deletion src/seemps/state/schmidt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
right_orth_2site,
)

__all__ = ["schmidt_weights", "left_orth_2site", "right_orth_2site", "vector2mps"]
__all__ = ["schmidt_weights", "vector2mps"]


def vector2mps(
Expand Down

0 comments on commit 74e318b

Please sign in to comment.