Skip to content

Commit

Permalink
Improve interpolation from PaulaGarciaMolina/main
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll authored Apr 25, 2024
2 parents 4ab498a + 61c02bd commit 97b2fa4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
52 changes: 38 additions & 14 deletions src/seemps/analysis/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import copy
import numpy as np
from math import sqrt
from ..qft import qft_mpo
from ..state import MPS, MPSSum, Strategy, DEFAULT_STRATEGY
from typing import List

import numpy as np

from ..operators import MPO
from ..qft import qft_mpo
from ..state import DEFAULT_STRATEGY, MPS, CanonicalMPS, MPSSum, Strategy
from ..truncate import simplify
from .finite_differences import mpo_combined
from .space import Space, mpo_flip
Expand All @@ -24,7 +28,14 @@ def twoscomplement(L, **kwdargs):
return MPO([A0] + [A] * (L - 2) + [Aend], **kwdargs)


def fourier_interpolation_1D(ψ0mps, space, M0, Mf, dim, strategy=DEFAULT_STRATEGY):
def fourier_interpolation_1D(
ψ0mps: MPS,
space: Space,
M0: int,
Mf: int,
dim: int,
strategy: Strategy = DEFAULT_STRATEGY,
):
"""Obtain the Fourier interpolated MPS over the chosen dimension
with a new number of sites Mf.
Expand Down Expand Up @@ -73,12 +84,19 @@ def fourier_interpolation_1D(ψ0mps, space, M0, Mf, dim, strategy=DEFAULT_STRATE
)
U2c = new_space.extend(mpo_flip(twoscomplement(Mf, strategy=strategy)), dim)
ψfmps = iQFT_op @ (U2c @ Fψfmps)
ψfmps = ψfmps * (1 / sqrt(ψfmps.norm_squared()))

ψfmps = (Mf / M0) * ψfmps
if strategy.get_normalize_flag():
ψfmps = ψfmps.normalize_inplace()
return ψfmps, new_space


def fourier_interpolation(ψmps, space, old_sites, new_sites, **kwargs):
def fourier_interpolation(
ψmps: MPS,
space: Space,
old_sites: List,
new_sites: List,
strategy: Strategy = DEFAULT_STRATEGY,
):
"""Fourier interpolation on an MPS.
Parameters
Expand All @@ -91,8 +109,8 @@ def fourier_interpolation(ψmps, space, old_sites, new_sites, **kwargs):
List of integers with the original number of sites for each dimension.
new_sites : list[int]
List of integers with the new number of sites for each dimension.
**kwargs :
Arguments accepted by :class:`MPO`
strategy : Strategy, optional
Truncation strategy, defaults to DEFAULT_STRATEGY.
Returns
-------
Expand All @@ -101,9 +119,11 @@ def fourier_interpolation(ψmps, space, old_sites, new_sites, **kwargs):
"""
space = copy.copy(space)
if not isinstance(ψmps, CanonicalMPS):
ψmps = CanonicalMPS(ψmps, strategy=strategy)
for i, sites in enumerate(new_sites):
ψmps, space = fourier_interpolation_1D(
ψmps, space, old_sites[i], sites, dim=i, **kwargs
ψmps, space, old_sites[i], sites, dim=i, strategy=strategy
)
return ψmps

Expand Down Expand Up @@ -232,7 +252,9 @@ def finite_differences_interpolation_1D(
return simplify(odd + even, strategy=strategy), new_space


def finite_differences_interpolation(ψmps, space, **kwargs):
def finite_differences_interpolation(
ψmps: MPS, space: Space, strategy: Strategy = DEFAULT_STRATEGY
):
"""Finite differences interpolation of an MPS representing
a multidimensional function.
Expand All @@ -242,8 +264,8 @@ def finite_differences_interpolation(ψmps, space, **kwargs):
MPS representing a multidimensional function.
space : Space
Space on which the function is defined.
**kwargs :
Other arguments accepted by :class:`MPO`
strategy : Strategy, optional
Truncation strategy, defaults to DEFAULT_STRATEGY.
Returns
-------
Expand All @@ -252,5 +274,7 @@ def finite_differences_interpolation(ψmps, space, **kwargs):
"""
space = copy.deepcopy(space)
for i, q in enumerate(space.qubits_per_dimension):
ψmps, space = finite_differences_interpolation_1D(ψmps, space, dim=i, **kwargs)
ψmps, space = finite_differences_interpolation_1D(
ψmps, space, dim=i, strategy=strategy
)
return ψmps
5 changes: 4 additions & 1 deletion src/seemps/truncate/simplify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations
from typing import Optional, Union

from math import sqrt
from typing import Optional, Union

import numpy as np

from .. import tools
from ..state import (
DEFAULT_TOLERANCE,
Expand Down

0 comments on commit 97b2fa4

Please sign in to comment.