Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve interpolation #67

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions src/seemps/analysis/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import copy
import numpy as np
from math import sqrt
from ..qft import qft_mpo
from ..state import MPS, MPSSum, Strategy, DEFAULT_STRATEGY
from typing import List

import numpy as np

from ..operators import MPO
from ..qft import qft_mpo
from ..state import DEFAULT_STRATEGY, MPS, CanonicalMPS, MPSSum, Strategy
from ..truncate import simplify
from .finite_differences import mpo_combined
from .space import Space, mpo_flip
Expand All @@ -24,7 +28,14 @@ def twoscomplement(L, **kwdargs):
return MPO([A0] + [A] * (L - 2) + [Aend], **kwdargs)


def fourier_interpolation_1D(ψ0mps, space, M0, Mf, dim, strategy=DEFAULT_STRATEGY):
def fourier_interpolation_1D(
ψ0mps: MPS,
space: Space,
M0: int,
Mf: int,
dim: int,
strategy: Strategy = DEFAULT_STRATEGY,
):
"""Obtain the Fourier interpolated MPS over the chosen dimension
with a new number of sites Mf.

Expand Down Expand Up @@ -73,12 +84,19 @@ def fourier_interpolation_1D(ψ0mps, space, M0, Mf, dim, strategy=DEFAULT_STRATE
)
U2c = new_space.extend(mpo_flip(twoscomplement(Mf, strategy=strategy)), dim)
ψfmps = iQFT_op @ (U2c @ Fψfmps)
ψfmps = ψfmps * (1 / sqrt(ψfmps.norm_squared()))

ψfmps = (Mf / M0) * ψfmps
if strategy.get_normalize_flag():
ψfmps = ψfmps.normalize_inplace()
return ψfmps, new_space


def fourier_interpolation(ψmps, space, old_sites, new_sites, **kwargs):
def fourier_interpolation(
ψmps: MPS,
space: Space,
old_sites: List,
new_sites: List,
strategy: Strategy = DEFAULT_STRATEGY,
):
"""Fourier interpolation on an MPS.

Parameters
Expand All @@ -91,8 +109,8 @@ def fourier_interpolation(ψmps, space, old_sites, new_sites, **kwargs):
List of integers with the original number of sites for each dimension.
new_sites : list[int]
List of integers with the new number of sites for each dimension.
**kwargs :
Arguments accepted by :class:`MPO`
strategy : Strategy, optional
Truncation strategy, defaults to DEFAULT_STRATEGY.

Returns
-------
Expand All @@ -101,9 +119,11 @@ def fourier_interpolation(ψmps, space, old_sites, new_sites, **kwargs):

"""
space = copy.copy(space)
if not isinstance(ψmps, CanonicalMPS):
ψmps = CanonicalMPS(ψmps, strategy=strategy)
for i, sites in enumerate(new_sites):
ψmps, space = fourier_interpolation_1D(
ψmps, space, old_sites[i], sites, dim=i, **kwargs
ψmps, space, old_sites[i], sites, dim=i, strategy=strategy
)
return ψmps

Expand Down Expand Up @@ -232,7 +252,9 @@ def finite_differences_interpolation_1D(
return simplify(odd + even, strategy=strategy), new_space


def finite_differences_interpolation(ψmps, space, **kwargs):
def finite_differences_interpolation(
ψmps: MPS, space: Space, strategy: Strategy = DEFAULT_STRATEGY
):
"""Finite differences interpolation of an MPS representing
a multidimensional function.

Expand All @@ -242,8 +264,8 @@ def finite_differences_interpolation(ψmps, space, **kwargs):
MPS representing a multidimensional function.
space : Space
Space on which the function is defined.
**kwargs :
Other arguments accepted by :class:`MPO`
strategy : Strategy, optional
Truncation strategy, defaults to DEFAULT_STRATEGY.

Returns
-------
Expand All @@ -252,5 +274,7 @@ def finite_differences_interpolation(ψmps, space, **kwargs):
"""
space = copy.deepcopy(space)
for i, q in enumerate(space.qubits_per_dimension):
ψmps, space = finite_differences_interpolation_1D(ψmps, space, dim=i, **kwargs)
ψmps, space = finite_differences_interpolation_1D(
ψmps, space, dim=i, strategy=strategy
)
return ψmps
5 changes: 4 additions & 1 deletion src/seemps/truncate/simplify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations
from typing import Optional, Union

from math import sqrt
from typing import Optional, Union

import numpy as np

from .. import tools
from ..state import (
DEFAULT_TOLERANCE,
Expand Down
Loading