Skip to content

Commit

Permalink
orthonormalization now with separate max ranks
Browse files Browse the repository at this point in the history
  • Loading branch information
PGelss authored Jul 17, 2024
1 parent e368156 commit 49cf618
Showing 1 changed file with 46 additions and 19 deletions.
65 changes: 46 additions & 19 deletions scikit_tt/tensor_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ def matricize(self) -> np.ndarray:
def ortho_left(self, start_index: int=0,
end_index: Optional[int]=None,
threshold: float=0.0,
max_rank: int=np.infty,
max_rank: Union[int, List[int]]=np.infty,
progress: bool=False,
string: str='Left-orthonormalization') -> 'TT':
"""
Expand All @@ -1106,8 +1106,8 @@ def ortho_left(self, start_index: int=0,
end index for orthonormalization, default is the index of the penultimate core
threshold : float, optional
threshold for reduced SVD decompositions, default is 0
max_rank : int, optional
maximum rank of the left-orthonormalized tensor train, default is np.infty
max_rank : int or list[int], optional
maximum rank(s) of the left-orthonormalized tensor train, default is np.infty
progress : bool, optional
whether to show progress bar, default is False
string : string, optional
Expand Down Expand Up @@ -1139,7 +1139,19 @@ def ortho_left(self, start_index: int=0,

if isinstance(threshold, (int, float)) and threshold >= 0:

if (isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty:
# check for correct max_rank argument and set max_ranks
max_rank_tf = True
if not isinstance(max_rank, list) and ((isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty):
max_ranks = [1] + [max_rank for _ in range(self.order-1)] + [1]
else:
if len(max_rank) == self.order+1:
for i in range(self.order+1):
if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty):
max_rank_tf = False
if max_rank_tf:
max_ranks = max_rank

if max_rank_tf:

for i in range(start_index, end_index + 1):

Expand All @@ -1161,10 +1173,10 @@ def ortho_left(self, start_index: int=0,
u = u[:, indices]
s = s[indices]
v = v[indices, :]
if max_rank != np.infty:
u = u[:, :np.minimum(u.shape[1], max_rank)]
s = s[:np.minimum(s.shape[0], max_rank)]
v = v[:np.minimum(v.shape[0], max_rank), :]
if max_ranks[i+1] != np.infty:
u = u[:, :np.minimum(u.shape[1], max_ranks[i+1])]
s = s[:np.minimum(s.shape[0], max_ranks[i+1])]
v = v[:np.minimum(v.shape[0], max_ranks[i+1]), :]

# define updated rank and core
self.ranks[i + 1] = u.shape[1]
Expand Down Expand Up @@ -1192,7 +1204,7 @@ def ortho_left(self, start_index: int=0,
def ortho_right(self, start_index: Optional[int]=None,
end_index: int=1,
threshold: float=0,
max_rank: int=np.infty) -> 'TT':
max_rank: Union[int, List[int]]=np.infty) -> 'TT':
"""
Right-orthonormalization of tensor trains.
Expand All @@ -1204,8 +1216,8 @@ def ortho_right(self, start_index: Optional[int]=None,
end index for orthonormalization, default is 1
threshold : float, optional
threshold for reduced SVD decompositions, default is 0
max_rank : int, optional
maximum rank of the left-orthonormalized tensor train, default is np.infty
max_rank : int or list[int], optional
maximum rank(s) of the left-orthonormalized tensor train, default is np.infty
Returns
-------
Expand All @@ -1230,7 +1242,19 @@ def ortho_right(self, start_index: Optional[int]=None,

if isinstance(threshold, (int, np.int32, np.int64, float, np.float32, np.float64)) and threshold >= 0:

if (isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty:
# check for correct max_rank argument and set max_ranks
max_rank_tf = True
if not isinstance(max_rank, list) and ((isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty):
max_ranks = [1] + [max_rank for _ in range(self.order-1)] + [1]
else:
if len(max_rank) == self.order+1:
for i in range(self.order+1):
if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty):
max_rank_tf = False
if max_rank_tf:
max_ranks = max_rank

if max_rank_tf:

for i in range(start_index, end_index - 1, -1):

Expand All @@ -1252,10 +1276,13 @@ def ortho_right(self, start_index: Optional[int]=None,
u = u[:, indices]
s = s[indices]
v = v[indices, :]
if max_rank != np.infty:
u = u[:, :np.minimum(u.shape[1], max_rank)]
s = s[:np.minimum(s.shape[0], max_rank)]
v = v[:np.minimum(v.shape[0], max_rank), :]
if max_ranks[i] != np.infty:
print(u.shape, v.shape)
print(max_ranks[i])

u = u[:, :np.minimum(u.shape[1], max_ranks[i])]
s = s[:np.minimum(s.shape[0], max_ranks[i])]
v = v[:np.minimum(v.shape[0], max_ranks[i]), :]

# define updated rank and core
self.ranks[i] = v.shape[0]
Expand All @@ -1279,16 +1306,16 @@ def ortho_right(self, start_index: Optional[int]=None,
else:
raise TypeError('Start and end indices must be integers.')

def ortho(self, threshold: float=0, max_rank: int=np.infty) -> 'TT':
def ortho(self, threshold: float=0, max_rank: Union[int, List[int]]=np.infty) -> 'TT':
"""
Left- and right-orthonormalization of tensor trains.
Parameters
----------
threshold : float, optional
threshold for reduced SVD decompositions, default is 0
max_rank : int
maximum rank of the right-orthonormalized tensor train
max_rank : int or list[int], optional
maximum rank(s) of the left-orthonormalized tensor train, default is np.infty
Returns
-------
Expand Down

0 comments on commit 49cf618

Please sign in to comment.