Skip to content

Commit

Permalink
attributes -> properties + orth centre check method
Browse files Browse the repository at this point in the history
  • Loading branch information
meandmytram committed Oct 26, 2023
1 parent 8a0f9c4 commit d8fbed7
Showing 1 changed file with 74 additions and 51 deletions.
125 changes: 74 additions & 51 deletions mdopt/mps/canonical.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This module contains the :class:`CanonicalMPS` class.
Hereafter, saying the MPS is in a canonical form will mean one of the following.
Hereafter, saying a MPS is in a canonical form will mean one of the following.
1) Right-canonical: all tensors are right isometries, i.e.::
Expand All @@ -21,12 +21,22 @@
3) Mixed-canonical: all but one tensors are left or right isometries.
This exceptional tensor will be hereafter called the **orthogonality centre**.
Note that, in the diagrams, a tensor with a star inside means that it is complex-conjugated.
Note, that in the diagrams, a tensor with a star inside means that it is complex-conjugated.
The Matrix Product State is stored as a list of three-dimensional tensors.
Essentially, it corresponds to storing each ``A[i]`` or ``B[i]`` as shown in
A Matrix Product State is thus stored as a list of three-dimensional tensors
as shown in the following diagram::
i i+1
...---( )---( )---...
| |
| |
Essentially, this corresponds to storing each ``A[i]`` or ``B[i]`` as shown in
fig.4c in reference `[1]`_.
Note, that we enumerate the bonds from the right side of the tensors. For example,
the bond ``0`` is a bond to the right of tensor ``0``.
.. _[1]: https://arxiv.org/abs/1805.00055
"""

Expand All @@ -37,7 +47,7 @@
from opt_einsum import contract

import mdopt
from mdopt.utils.utils import kron_tensors, split_two_site_tensor
from mdopt.utils.utils import svd, kron_tensors, split_two_site_tensor


class CanonicalMPS:
Expand All @@ -64,7 +74,7 @@ class CanonicalMPS:
num_sites : int
Number of sites.
num_bonds : int
Number of bonds.
Number of bonds. Note, that the "ghost" bonds are not included.
Raises
------
Expand All @@ -84,10 +94,6 @@ def __init__(
self.tensors = tensors
self.num_sites = len(tensors)
self.num_bonds = self.num_sites - 1
self.bond_dimensions = [
self.tensors[i].shape[-1] for i in range(self.num_bonds)
]
self.phys_dimensions = [self.tensors[i].shape[1] for i in range(self.num_sites)]
self.orth_centre = orth_centre
self.dtype = tensors[0].dtype
self.tolerance = tolerance
Expand All @@ -106,6 +112,20 @@ def __init__(
f"while the one at site {i} has {len(tensor.shape)}."
)

@property
def bond_dimensions(self) -> List[int]:
"""
Returns the list of all bond dimensions of the MPS.
"""
return [self.tensors[i].shape[-1] for i in range(self.num_bonds)]

@property
def phys_dimensions(self) -> List[int]:
"""
Returns the list of all physical dimensions of the MPS.
"""
return [self.tensors[i].shape[1] for i in range(self.num_sites)]

def __len__(self) -> int:
"""
Returns the number of sites in the MPS.
Expand All @@ -122,7 +142,7 @@ def copy(self) -> "CanonicalMPS":

def reverse(self) -> "CanonicalMPS":
"""
Returns a reversed version of a given MPS.
Returns a reversed version of the current MPS.
"""

reversed_tensors = [np.transpose(tensor) for tensor in reversed(self.tensors)]
Expand Down Expand Up @@ -301,6 +321,36 @@ def entanglement_entropy(
"""
return self.explicit(tolerance=tolerance).entanglement_entropy()

def check_orth_centre(self) -> Optional[int]:
"""
Checks the current position of the orthogonality centre by checking each tensor
for the isometry conditions.
Note, this method does not update the current instance's ``orth_centre`` attribute.
"""

_, flags_left, flags_right = mdopt.mps.utils.find_orth_centre( # type: ignore
self, return_orth_flags=True
)

if flags_left in (
[True] + [False] * (self.num_sites - 1),
[False] * self.num_sites,
):
if flags_right == [not flag for flag in flags_left]:
return 0

if flags_left in (
[True] * (self.num_sites - 1) + [False],
[True] * self.num_sites,
):
if flags_right == [not flag for flag in flags_left]:
return self.num_sites - 1

if all(flags_left) and all(flags_right):
return 0

return None

def move_orth_centre(
self,
final_pos: int,
Expand All @@ -327,6 +377,8 @@ def move_orth_centre(
------
ValueError
If ``final_pos`` does not match the MPS length.
ValueError
If ``self.orth_centre`` is still ``None`` after the search.
"""

if final_pos not in range(self.num_sites):
Expand All @@ -338,26 +390,7 @@ def move_orth_centre(
singular_values = []

if self.orth_centre is None:
_, flags_left, flags_right = mdopt.mps.utils.find_orth_centre( # type: ignore
self, return_orth_flags=True
)

if flags_left in (
[True] + [False] * (self.num_sites - 1),
[False] * self.num_sites,
):
if flags_right == [not flag for flag in flags_left]:
self.orth_centre = 0

if flags_left in (
[True] * (self.num_sites - 1) + [False],
[True] * self.num_sites,
):
if flags_right == [not flag for flag in flags_left]:
self.orth_centre = self.num_sites - 1

if all(flags_left) and all(flags_right):
self.orth_centre = 0
self.orth_centre = self.check_orth_centre()

if self.orth_centre is None:
raise ValueError("The orthogonality centre value is set to None.")
Expand All @@ -380,6 +413,7 @@ def move_orth_centre(
two_site_tensor,
chi_max=self.chi_max,
renormalise=renormalise,
strategy="svd",
)
singular_values.append(singular_values_bond)
mps.tensors[i] = u_l
Expand Down Expand Up @@ -411,31 +445,19 @@ def move_orth_centre_to_border(
Returns a new version of the current :class:`CanonicalMPS` instance with
the orthogonality centre moved to the closest (from the current position) border.
"""

if self.orth_centre is None:
_, flags_left, flags_right = mdopt.mps.utils.find_orth_centre( # type: ignore
self, return_orth_flags=True
)

if flags_left in (
[True] + [False] * (self.num_sites - 1),
[False] * self.num_sites,
):
if flags_right == [not flag for flag in flags_left]:
return self.copy(), "first"
self.orth_centre = self.check_orth_centre()

if flags_left in (
[True] * (self.num_sites - 1) + [False],
[True] * self.num_sites,
):
if flags_right == [not flag for flag in flags_left]:
return self.copy(), "last"
if self.orth_centre == 0:
return self.copy(), "first"

if all(flags_left) and all(flags_right):
return self.copy(), "first"
elif self.orth_centre == self.num_sites - 1:
return self.copy(), "last"

else:
if self.orth_centre <= int(self.num_bonds / 2):
if (self.orth_centre is not None) and (
self.orth_centre <= int(self.num_bonds / 2)
):
mps = self.move_orth_centre(
final_pos=0,
return_singular_values=False,
Expand All @@ -448,6 +470,7 @@ def move_orth_centre_to_border(
return_singular_values=False,
renormalise=renormalise,
)

return cast("CanonicalMPS", mps), "last"

def explicit(
Expand Down

0 comments on commit d8fbed7

Please sign in to comment.