Skip to content

Commit

Permalink
More type checking/promotion in polynomials
Browse files Browse the repository at this point in the history
  • Loading branch information
jmeyers314 committed Jul 24, 2024
1 parent 0ad82fd commit 150010a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
11 changes: 11 additions & 0 deletions galsim/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,11 @@ def horner(x, coef, dtype=None):
Returns:
a numpy array of the evaluated polynomial. Will be the same shape as x.
"""
if dtype is None:
dtype = np.result_type(
np.min_scalar_type(x),
np.min_scalar_type(coef)
)
result = np.empty_like(x, dtype=dtype)
# Make sure everything is an array
if result.dtype == float:
Expand Down Expand Up @@ -616,6 +621,12 @@ def horner2d(x, y, coefs, dtype=None, triangle=False):
Returns:
a numpy array of the evaluated polynomial. Will be the same shape as x and y.
"""
if dtype is None:
dtype = np.result_type(
np.min_scalar_type(x),
np.min_scalar_type(y),
np.min_scalar_type(coefs)
)
result = np.empty_like(x, dtype=dtype)
temp = np.empty_like(x, dtype=dtype)
# Make sure everything is an array
Expand Down
4 changes: 2 additions & 2 deletions galsim/zernike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ def _from_uvxy(
coefficients.
"""
ret = DoubleZernike.__new__(DoubleZernike)
ret._coef_array_uvxy = uvxy
ret._coef_array_uvxy = np.asarray(uvxy, dtype=float)
ret.uv_outer = uv_outer
ret.uv_inner = uv_inner
ret.xy_outer = xy_outer
Expand Down Expand Up @@ -1192,7 +1192,7 @@ def __call__(self, u, v, x=None, y=None):
a_ij = np.zeros(self._coef_array_uvxy.shape[2:4])
for i, j in np.ndindex(a_ij.shape):
a_ij[i, j] = horner2d(
u, v, self._coef_array_uvxy[..., i, j]
u, v, self._coef_array_uvxy[..., i, j], dtype=float
)
return Zernike._from_coef_array_xy(
a_ij,
Expand Down
27 changes: 27 additions & 0 deletions tests/test_zernike.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,33 @@ def test_Zernike_rotate():
)


@timer
def test_zernike_eval():
for coef in [
np.ones(4),
np.ones(4, dtype=float),
np.ones(4, dtype=np.float32)
]:
Z = Zernike(coef)
assert Z.coef.dtype == np.float64
assert Z(0.0, 0.0) == 1.0
assert Z(0, 0) == 1.0

for coefs in [
np.ones((4, 4)),
np.ones((4, 4), dtype=float),
np.ones((4, 4), dtype=np.float32)
]:
dz = DoubleZernike(coefs)
assert dz.coef.dtype == np.float64
assert dz(0.0, 0.0) == dz(0, 0)

# Make sure we cast to float in _from_uvxy
uvxy = dz._coef_array_uvxy
dz2 = DoubleZernike._from_uvxy(uvxy.astype(int))
np.testing.assert_array_equal(dz2._coef_array_uvxy, dz._coef_array_uvxy)


@timer
def test_ne():
objs = [
Expand Down

0 comments on commit 150010a

Please sign in to comment.