-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TST fix tests for JAX-Galsim #1252
base: main
Are you sure you want to change the base?
Conversation
The commits here need to be rebased onto main I think. |
ok @rmjarvis This giant PR is ready for review. I can break it up into smaller PRs if that would help. My hope is that with this PR, we can run jax-galsim tests against the main branch going forward. |
np.testing.assert_almost_equal(neg_image.array/prof.flux, -image.array/prof.flux, 7, | ||
'%s negative flux drawReal is not negative of +flux image'%name) | ||
np.testing.assert_array_almost_equal(neg_image.array/prof.flux, -image.array/prof.flux, 7, | ||
'%s negative flux drawReal is not negative of +flux image'%name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? I thought they were equivalent. I'd rather not use the more verbose one if we can avoid it.
@@ -298,6 +371,9 @@ def do_shoot(prof, img, name): | |||
print('nphot = ',nphot) | |||
img2 = img.copy() | |||
|
|||
if is_jax_galsim(): | |||
rtol *= 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a sign of a problem with the jax-galsim implementation?
assert_raises(TypeError,galsim.AngleUnit, 1, 3) | ||
assert_raises(TypeError,galsim.Angle, 3.4) | ||
assert_raises(TypeError,galsim.Angle, theta1, galsim.degrees) | ||
assert_raises(ValueError,galsim.Angle, 'spam', galsim.degrees) | ||
assert_raises((ValueError, TypeError), galsim.Angle, 'spam') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should remove the last arg, galsim.degrees
.
(0.0, 0.0), | ||
rtol=0, | ||
atol=1e-16, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What computation are you doing in Jax-Galsim that makes this not be exactly zero. This should have been trivially true I would think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the line bloat here seems rather gratuitous. Can we put each of these in a single line?
else: | ||
check_dep(galsim.GSParams, allowed_flux_variation=0.90) | ||
check_dep(galsim.GSParams, range_division_for_extrema=50) | ||
check_dep(galsim.GSParams, small_fraction_of_flux=1.e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, I think there is no reason that JAX-GalSim needs to comport with our now-deprecated behavior. I'd be happy to just have:
if is_jax_galsim(): return
at the start of every test in this file.
assert_raises(ValueError, galsim.BoundsD, 11, 23, 17, "blue") | ||
if is_jax_galsim(): | ||
# jax doesn't raise for this | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I can guess what it does with floats. But I'm quite curious what it does with "blue".
|
||
if is_jax_galsim(): | ||
# jax doesn't raise for these things | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it do for an undefined Bounds if you ask for the center?
im_list, offsets = deInterleaveImage(img,8) | ||
img1 = interleaveImages(im_list,8,offsets) | ||
im_list, offsets = galsim.utilities.deInterleaveImage(img,8) | ||
img1 = galsim.utilities.interleaveImages(im_list,8,offsets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine I guess, but I'm curious why this was required?
@@ -499,7 +500,7 @@ def do_local_wcs(wcs, ufunc, vfunc, name): | |||
wcs2 = wcs.local() | |||
assert wcs == wcs2, name+' local() is not == the original' | |||
new_origin = galsim.PositionI(123,321) | |||
wcs3 = wcs.withOrigin(new_origin) | |||
wcs3 = wcs.shiftOrigin(new_origin) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a LocalWCS, these are equivalent. And IMO withOrigin is preferred, since the WCS doesn't have an origin yet.
@@ -41,10 +40,10 @@ def test_Zernike_orthonormality(): | |||
y = y[w].ravel() | |||
area = np.pi*R_outer**2 | |||
for j1 in range(1, jmax+1): | |||
Z1 = Zernike([0]*(j1+1)+[1], R_outer=R_outer) | |||
Z1 = galsim.zernike.Zernike([0]*(j1+1)+[1], R_outer=R_outer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, why add the extra boilerplate?
This PR has two main features:
It has a generalization of the concept of a "galsim backend" with associated functions for the test suite only. In the test suite only, the backend is used to adjust the tests as needed for the jax-based galsim implementation.
It has adjustments of the tests for jax-galsim. Most of them are related to cases where jax cannot raise the same errors (e.g., for checking argument types) or when the fact that jax arrays cannot be changed in place causes the APIs to differ.