Skip to content

Commit

Permalink
Replaced generic __class__.__init__() call with Coordinates.__init__(…
Browse files Browse the repository at this point in the history
…) call to avoid inheritance issues when using "as_coords=True" flag
  • Loading branch information
timbernat committed Apr 26, 2024
1 parent ac6e9e2 commit 5140047
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions polymerist/maths/lattices/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def linear_transformation(self, matrix : np.ndarray[Shape[N,N], float], as_coord
transformed_points = self.points @ matrix.T # NOTE: need to right-multiply and transpose, since ROWS of self.points need to be tranformed

if as_coords:
return self.__class__(transformed_points)
# return self.__class__(transformed_points) # TOSELF: this form is more general, but doesn't play nicely with inheritance
return Coordinates(transformed_points)
return transformed_points

def affine_transformation(self, matrix : np.ndarray[Shape[N,N], float], as_coords : bool=False) -> Union[np.ndarray[Shape[M, N], float], 'Coordinates']: # TOSELF: typehint on input matrix should be of shape N+1, N+1
Expand All @@ -142,7 +143,8 @@ def affine_transformation(self, matrix : np.ndarray[Shape[N,N], float], as_coord
transformed_points = aug_transformed[: , :self.n_dims] / aug_transformed[:, self.n_dims, None] # downcast augmented transformed points from homogeneous coordinates, normalizing by projective part

if as_coords:
return self.__class__(transformed_points)
# return self.__class__(transformed_points) # TOSELF: this form is more general, but doesn't play nicely with inheritance
return Coordinates(transformed_points)
return transformed_points


Expand All @@ -151,7 +153,7 @@ class BoundingBox(Coordinates):
def __init__(self, coords : Union[Coordinates, np.ndarray[Shape[M, N], Num]]) -> None:
if isinstance(coords, np.ndarray):
coords = Coordinates(coords) # allow passing of Coordinates-like classes
points = np.array([vertex for vertex in cartesian_product(*self.extrema.T)])
points = np.array([vertex for vertex in cartesian_product(*coords.extrema.T)])

super().__init__(points)

Expand Down

0 comments on commit 5140047

Please sign in to comment.