diff --git a/src/caustics/tests.py b/src/caustics/tests.py index 0e26dc13..7456e61d 100644 --- a/src/caustics/tests.py +++ b/src/caustics/tests.py @@ -10,8 +10,15 @@ __all__ = ["test"] +# Pseudo device as mentioned in https://github.com/pytorch/pytorch/issues/61654 +# using this device, no actual computation is done, but will check +# for where the tensors are located +META_DEVICE = torch.device("meta") -def _test_simulator_runs(): +DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + +def _test_simulator_runs(device=DEVICE): # Model cosmology = FlatLambdaCDM(name="cosmo") lensmass = SIE( @@ -32,7 +39,7 @@ def _test_simulator_runs(): name="lenslight", x0=0.0, y0=0.01, q=0.7, phi=pi / 4, n=3.0, Re=0.7, Ie=1.0 ) - psf = gaussian(0.05, 11, 11, 0.2, upsample=2) + psf = gaussian(0.05, 11, 11, 0.2, upsample=2, device=device) sim = Lens_Source( lens=lensmass, @@ -44,6 +51,9 @@ def _test_simulator_runs(): z_s=2.0, ) + # Send to device + sim = sim.to(device=device) + assert torch.all(torch.isfinite(sim())) assert torch.all( torch.isfinite( @@ -91,15 +101,18 @@ def _test_simulator_runs(): ) -def _test_jacobian_autograd_vs_finitediff(): +def _test_jacobian_autograd_vs_finitediff(device=DEVICE): # Models cosmology = FlatLambdaCDM(name="cosmo") lens = SIE(name="sie", cosmology=cosmology) - thx, thy = get_meshgrid(0.01, 20, 20) + thx, thy = get_meshgrid(0.01, 20, 20, device=device) # Parameters - z_s = torch.tensor(1.2) - x = torch.tensor([0.5, 0.912, -0.442, 0.7, pi / 3, 1.4]) + z_s = torch.tensor(1.2, device=device) + x = torch.tensor([0.5, 0.912, -0.442, 0.7, pi / 3, 1.4], device=device) + + # Send to device + lens = lens.to(device=device) # Evaluate Jacobian J_autograd = lens.jacobian_lens_equation(thx, thy, z_s, lens.pack(x)) @@ -113,11 +126,11 @@ def _test_jacobian_autograd_vs_finitediff(): ) -def _test_multiplane_jacobian(): +def _test_multiplane_jacobian(device=DEVICE): # Setup - z_s = torch.tensor(1.5, dtype=torch.float32) + z_s = torch.tensor(1.5, dtype=torch.float32, device=device) cosmology = FlatLambdaCDM(name="cosmo") - cosmology.to(dtype=torch.float32) + cosmology.to(dtype=torch.float32, device=device) # Parameters xs = [ @@ -125,27 +138,31 @@ def _test_multiplane_jacobian(): [0.7, 0.0, 0.5, 0.9999, -pi / 6, 0.7], [1.1, 0.4, 0.3, 0.9999, pi / 4, 0.9], ] - x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32) + x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32, device=device) lens = Multiplane( name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie_{i}", cosmology=cosmology) for i in range(len(xs))], ) - thx, thy = get_meshgrid(0.1, 10, 10) + + # Send to device + lens = lens.to(device=device) + + thx, thy = get_meshgrid(0.1, 10, 10, device=device) # Parameters - z_s = torch.tensor(1.2) - x = torch.tensor(xs).flatten() + z_s = torch.tensor(1.2, device=device) + x = torch.tensor(xs, device=device).flatten() A = lens.jacobian_lens_equation(thx, thy, z_s, lens.pack(x)) assert A.shape == (10, 10, 2, 2) -def _test_multiplane_jacobian_autograd_vs_finitediff(): +def _test_multiplane_jacobian_autograd_vs_finitediff(device=DEVICE): # Setup - z_s = torch.tensor(1.5, dtype=torch.float32) + z_s = torch.tensor(1.5, dtype=torch.float32, device=device) cosmology = FlatLambdaCDM(name="cosmo") - cosmology.to(dtype=torch.float32) + cosmology.to(dtype=torch.float32, device=device) # Parameters xs = [ @@ -153,18 +170,22 @@ def _test_multiplane_jacobian_autograd_vs_finitediff(): [0.7, 0.0, 0.5, 0.9999, -pi / 6, 0.7], [1.1, 0.4, 0.3, 0.9999, pi / 4, 0.9], ] - x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32) + x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32, device=device) lens = Multiplane( name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie_{i}", cosmology=cosmology) for i in range(len(xs))], ) - thx, thy = get_meshgrid(0.01, 10, 10) + + # Send to device + lens = lens.to(device=device) + + thx, thy = get_meshgrid(0.01, 10, 10, device=device) # Parameters - z_s = torch.tensor(1.2) - x = torch.tensor(xs).flatten() + z_s = torch.tensor(1.2, device=device) + x = torch.tensor(xs, device=device).flatten() # Evaluate Jacobian J_autograd = lens.jacobian_lens_equation(thx, thy, z_s, lens.pack(x)) @@ -178,11 +199,11 @@ def _test_multiplane_jacobian_autograd_vs_finitediff(): ) -def _test_multiplane_effective_convergence(): +def _test_multiplane_effective_convergence(device=DEVICE): # Setup - z_s = torch.tensor(1.5, dtype=torch.float32) + z_s = torch.tensor(1.5, dtype=torch.float32, device=device) cosmology = FlatLambdaCDM(name="cosmo") - cosmology.to(dtype=torch.float32) + cosmology.to(dtype=torch.float32, device=device) # Parameters xs = [ @@ -190,25 +211,29 @@ def _test_multiplane_effective_convergence(): [0.7, 0.0, 0.5, 0.9999, -pi / 6, 0.7], [1.1, 0.4, 0.3, 0.9999, pi / 4, 0.9], ] - x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32) + x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32, device=device) lens = Multiplane( name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie_{i}", cosmology=cosmology) for i in range(len(xs))], ) - thx, thy = get_meshgrid(0.1, 10, 10) + + # Send to device + lens = lens.to(device=device) + + thx, thy = get_meshgrid(0.1, 10, 10, device=device) # Parameters - z_s = torch.tensor(1.2) - x = torch.tensor(xs).flatten() + z_s = torch.tensor(1.2, device=device) + x = torch.tensor(xs, device=device).flatten() C = lens.effective_convergence_div(thx, thy, z_s, lens.pack(x)) assert C.shape == (10, 10) curl = lens.effective_convergence_curl(thx, thy, z_s, lens.pack(x)) assert curl.shape == (10, 10) -def test(): +def test(device=DEVICE): """ Run tests for caustics basic functionality. Run this function to ensure that caustics is working properly. @@ -222,9 +247,9 @@ def test(): To run the checks. """ - _test_simulator_runs() - _test_jacobian_autograd_vs_finitediff() - _test_multiplane_jacobian() - _test_multiplane_jacobian_autograd_vs_finitediff() - _test_multiplane_effective_convergence() + _test_simulator_runs(device=device) + _test_jacobian_autograd_vs_finitediff(device=device) + _test_multiplane_jacobian(device=device) + _test_multiplane_jacobian_autograd_vs_finitediff(device=device) + _test_multiplane_effective_convergence(device=device) print("all tests passed!")