From d8fbed7657b9ddc8af701e37124718d574490630 Mon Sep 17 00:00:00 2001 From: meandmytram Date: Thu, 26 Oct 2023 16:49:18 -0400 Subject: [PATCH] attributes -> properties + orth centre check method --- mdopt/mps/canonical.py | 125 ++++++++++++++++++++++++----------------- 1 file changed, 74 insertions(+), 51 deletions(-) diff --git a/mdopt/mps/canonical.py b/mdopt/mps/canonical.py index 93599dbb..eb227366 100644 --- a/mdopt/mps/canonical.py +++ b/mdopt/mps/canonical.py @@ -1,6 +1,6 @@ """ This module contains the :class:`CanonicalMPS` class. -Hereafter, saying the MPS is in a canonical form will mean one of the following. +Hereafter, saying a MPS is in a canonical form will mean one of the following. 1) Right-canonical: all tensors are right isometries, i.e.:: @@ -21,12 +21,22 @@ 3) Mixed-canonical: all but one tensors are left or right isometries. This exceptional tensor will be hereafter called the **orthogonality centre**. - Note that, in the diagrams, a tensor with a star inside means that it is complex-conjugated. + Note, that in the diagrams, a tensor with a star inside means that it is complex-conjugated. - The Matrix Product State is stored as a list of three-dimensional tensors. - Essentially, it corresponds to storing each ``A[i]`` or ``B[i]`` as shown in + A Matrix Product State is thus stored as a list of three-dimensional tensors + as shown in the following diagram:: + + i i+1 + ...---( )---( )---... + | | + | | + + Essentially, this corresponds to storing each ``A[i]`` or ``B[i]`` as shown in fig.4c in reference `[1]`_. + Note, that we enumerate the bonds from the right side of the tensors. For example, + the bond ``0`` is a bond to the right of tensor ``0``. + .. _[1]: https://arxiv.org/abs/1805.00055 """ @@ -37,7 +47,7 @@ from opt_einsum import contract import mdopt -from mdopt.utils.utils import kron_tensors, split_two_site_tensor +from mdopt.utils.utils import svd, kron_tensors, split_two_site_tensor class CanonicalMPS: @@ -64,7 +74,7 @@ class CanonicalMPS: num_sites : int Number of sites. num_bonds : int - Number of bonds. + Number of bonds. Note, that the "ghost" bonds are not included. Raises ------ @@ -84,10 +94,6 @@ def __init__( self.tensors = tensors self.num_sites = len(tensors) self.num_bonds = self.num_sites - 1 - self.bond_dimensions = [ - self.tensors[i].shape[-1] for i in range(self.num_bonds) - ] - self.phys_dimensions = [self.tensors[i].shape[1] for i in range(self.num_sites)] self.orth_centre = orth_centre self.dtype = tensors[0].dtype self.tolerance = tolerance @@ -106,6 +112,20 @@ def __init__( f"while the one at site {i} has {len(tensor.shape)}." ) + @property + def bond_dimensions(self) -> List[int]: + """ + Returns the list of all bond dimensions of the MPS. + """ + return [self.tensors[i].shape[-1] for i in range(self.num_bonds)] + + @property + def phys_dimensions(self) -> List[int]: + """ + Returns the list of all physical dimensions of the MPS. + """ + return [self.tensors[i].shape[1] for i in range(self.num_sites)] + def __len__(self) -> int: """ Returns the number of sites in the MPS. @@ -122,7 +142,7 @@ def copy(self) -> "CanonicalMPS": def reverse(self) -> "CanonicalMPS": """ - Returns a reversed version of a given MPS. + Returns a reversed version of the current MPS. """ reversed_tensors = [np.transpose(tensor) for tensor in reversed(self.tensors)] @@ -301,6 +321,36 @@ def entanglement_entropy( """ return self.explicit(tolerance=tolerance).entanglement_entropy() + def check_orth_centre(self) -> Optional[int]: + """ + Checks the current position of the orthogonality centre by checking each tensor + for the isometry conditions. + Note, this method does not update the current instance's ``orth_centre`` attribute. + """ + + _, flags_left, flags_right = mdopt.mps.utils.find_orth_centre( # type: ignore + self, return_orth_flags=True + ) + + if flags_left in ( + [True] + [False] * (self.num_sites - 1), + [False] * self.num_sites, + ): + if flags_right == [not flag for flag in flags_left]: + return 0 + + if flags_left in ( + [True] * (self.num_sites - 1) + [False], + [True] * self.num_sites, + ): + if flags_right == [not flag for flag in flags_left]: + return self.num_sites - 1 + + if all(flags_left) and all(flags_right): + return 0 + + return None + def move_orth_centre( self, final_pos: int, @@ -327,6 +377,8 @@ def move_orth_centre( ------ ValueError If ``final_pos`` does not match the MPS length. + ValueError + If ``self.orth_centre`` is still ``None`` after the search. """ if final_pos not in range(self.num_sites): @@ -338,26 +390,7 @@ def move_orth_centre( singular_values = [] if self.orth_centre is None: - _, flags_left, flags_right = mdopt.mps.utils.find_orth_centre( # type: ignore - self, return_orth_flags=True - ) - - if flags_left in ( - [True] + [False] * (self.num_sites - 1), - [False] * self.num_sites, - ): - if flags_right == [not flag for flag in flags_left]: - self.orth_centre = 0 - - if flags_left in ( - [True] * (self.num_sites - 1) + [False], - [True] * self.num_sites, - ): - if flags_right == [not flag for flag in flags_left]: - self.orth_centre = self.num_sites - 1 - - if all(flags_left) and all(flags_right): - self.orth_centre = 0 + self.orth_centre = self.check_orth_centre() if self.orth_centre is None: raise ValueError("The orthogonality centre value is set to None.") @@ -380,6 +413,7 @@ def move_orth_centre( two_site_tensor, chi_max=self.chi_max, renormalise=renormalise, + strategy="svd", ) singular_values.append(singular_values_bond) mps.tensors[i] = u_l @@ -411,31 +445,19 @@ def move_orth_centre_to_border( Returns a new version of the current :class:`CanonicalMPS` instance with the orthogonality centre moved to the closest (from the current position) border. """ - if self.orth_centre is None: - _, flags_left, flags_right = mdopt.mps.utils.find_orth_centre( # type: ignore - self, return_orth_flags=True - ) - - if flags_left in ( - [True] + [False] * (self.num_sites - 1), - [False] * self.num_sites, - ): - if flags_right == [not flag for flag in flags_left]: - return self.copy(), "first" + self.orth_centre = self.check_orth_centre() - if flags_left in ( - [True] * (self.num_sites - 1) + [False], - [True] * self.num_sites, - ): - if flags_right == [not flag for flag in flags_left]: - return self.copy(), "last" + if self.orth_centre == 0: + return self.copy(), "first" - if all(flags_left) and all(flags_right): - return self.copy(), "first" + elif self.orth_centre == self.num_sites - 1: + return self.copy(), "last" else: - if self.orth_centre <= int(self.num_bonds / 2): + if (self.orth_centre is not None) and ( + self.orth_centre <= int(self.num_bonds / 2) + ): mps = self.move_orth_centre( final_pos=0, return_singular_values=False, @@ -448,6 +470,7 @@ def move_orth_centre_to_border( return_singular_values=False, renormalise=renormalise, ) + return cast("CanonicalMPS", mps), "last" def explicit(