Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
test: Update tests to start using GPU device when available and skip …
Browse files Browse the repository at this point in the history
…CPU (Ciela-Institute#164)

* test: Start to update testing to use cpu and cuda

* test: Add 'device' fixture with values based on cuda availability

* fix: Add self return for FlatLambdaCDM.to method

* test: Update test_base to use 'device'

* test: Update test_batching to use 'device'

* test: Update test_cosmology to use 'device'

* test: Update test_epl to use 'device'

* test: Update test_external_shear to use 'device'

* test: Update test_interpolate_image to use 'device'

* test: Update tests_jacobian_lens_equation to use 'device'

* test: Move tensors to CPU first before converting to np.array

Moves tensors '.to("cpu")' explicitely before '.numpy()'
to avoid 'TypeError' from PyTorch:
'TypeError: can't convert cuda:0 device type tensor to numpy.
Use Tensor.cpu() to copy the tensor to host memory first.'

* fix: Fix base forward_raytrace with guesses not in same device

* feat: Added 'core' module with 'sync_device' decorator

Added 'core.py' module and a 'sync_device' decorator
that can be used on any function to sync the device
of the input tensors to the first non-cpu device it finds.

* test: Update test_comoving_dist to cast to numpy

* test: Add device to the rest of applicable test

* test: Update to use missing device

* test: Set device on plane params

* revert: Remove 'sync_device' decorator

* fix device batch_lm

* test: Update .to('cpu') to .cpu() from code review

* fix: Update tests/test_multiplane.py

---------

Co-authored-by: Connor Stone <[email protected]>
  • Loading branch information
lsetiawan and ConnorStoneAstro authored Feb 21, 2024
1 parent 0fdb2c9 commit f2f02ce
Show file tree
Hide file tree
Showing 25 changed files with 321 additions and 190 deletions.
2 changes: 2 additions & 0 deletions src/caustics/cosmology/FlatLambdaCDM.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def to(
self._comoving_distance_helper_x_grid.to(device, dtype)
)

return self

def hubble_distance(self, h0):
"""
Calculate the Hubble distance.
Expand Down
4 changes: 2 additions & 2 deletions src/caustics/lenses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def forward_raytrace(
raise ValueError("fov must be given to generate initial guesses")

# Random starting points in image plane
guesses = torch.as_tensor(fov) * (
torch.rand(n_init, 2) - 0.5
guesses = (torch.as_tensor(fov) * (torch.rand(n_init, 2) - 0.5)).to(
device=bxy.device
) # Has shape (n_init, Din:2)

# Optimize guesses in image plane
Expand Down
29 changes: 19 additions & 10 deletions src/caustics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def get_cluster_means(xs: Tensor, k: int):
return torch.stack(means)


def _lm_step(f, X, Y, Cinv, L, Lup, Ldn, epsilon):
def _lm_step(f, X, Y, Cinv, L, Lup, Ldn, epsilon, L_min, L_max):
# Forward
fY = f(X)
dY = Y - fY
Expand All @@ -592,7 +592,9 @@ def _lm_step(f, X, Y, Cinv, L, Lup, Ldn, epsilon):

# Hessian
hess = J.T @ Cinv @ J
hess_perturb = L * (torch.diag(hess) + 0.1 * torch.eye(hess.shape[0]))
hess_perturb = L * (
torch.diag(hess) + 0.1 * torch.eye(hess.shape[0], device=hess.device)
)
hess = hess + hess_perturb

# Step
Expand All @@ -609,7 +611,7 @@ def _lm_step(f, X, Y, Cinv, L, Lup, Ldn, epsilon):
# Update
X = torch.where(rho >= epsilon, X + h, X)
chi2 = torch.where(rho > epsilon, chi2_new, chi2)
L = torch.clamp(torch.where(rho >= epsilon, L / Ldn, L * Lup), 1e-9, 1e9)
L = torch.clamp(torch.where(rho >= epsilon, L / Ldn, L * Lup), L_min, L_max)

return X, L, chi2

Expand Down Expand Up @@ -637,16 +639,23 @@ def batch_lm(
raise ValueError("x and y must having matching batch dimension")

if C is None:
C = torch.eye(Dout).repeat(B, 1, 1)
C = torch.eye(Dout, device=X.device).repeat(B, 1, 1)
Cinv = torch.linalg.inv(C)

v_lm_step = torch.vmap(partial(_lm_step, lambda x: f(x, *f_args, **f_kwargs)))
L = L * torch.ones(B)
Lup = L_up * torch.ones(B)
Ldn = L_dn * torch.ones(B)
e = epsilon * torch.ones(B)
v_lm_step = torch.vmap(
partial(
_lm_step,
lambda x: f(x, *f_args, **f_kwargs),
Lup=L_up,
Ldn=L_dn,
epsilon=epsilon,
L_min=L_min,
L_max=L_max,
)
)
L = L * torch.ones(B, device=X.device)
for _ in range(max_iter):
Xnew, L, C = v_lm_step(X, Y, Cinv, L, Lup, Ldn, e)
Xnew, L, C = v_lm_step(X, Y, Cinv, L)
if (
torch.all((Xnew - X).abs() < stopping)
and torch.sum(L < 1e-2).item() > B / 3
Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
import sys
import os
import torch
import pytest

# Add the helpers directory to the path so we can import the helpers
sys.path.append(os.path.join(os.path.dirname(__file__), "utils"))

CUDA_AVAILABLE = torch.cuda.is_available()


@pytest.fixture(
params=[
pytest.param(
"cpu", marks=pytest.mark.skipif(CUDA_AVAILABLE, reason="CUDA available")
),
pytest.param(
"cuda",
marks=pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available"),
),
]
)
def device(request):
return torch.device(request.param)
16 changes: 9 additions & 7 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from caustics import test as mini_test


def test():
z_l = torch.tensor(0.5, dtype=torch.float32)
z_s = torch.tensor(1.5, dtype=torch.float32)
def test(device):
z_l = torch.tensor(0.5, dtype=torch.float32, device=device)
z_s = torch.tensor(1.5, dtype=torch.float32, device=device)

# Model
cosmology = FlatLambdaCDM(name="cosmo")
Expand All @@ -22,10 +22,12 @@ def test():
phi=torch.tensor(np.pi / 5),
b=torch.tensor(1.0),
)
# Send to device
lens = lens.to(device)

# Point in the source plane
sp_x = torch.tensor(0.2)
sp_y = torch.tensor(0.2)
sp_x = torch.tensor(0.2, device=device)
sp_y = torch.tensor(0.2, device=device)

# Points in image plane
x, y = lens.forward_raytrace(sp_x, sp_y, z_s)
Expand All @@ -37,8 +39,8 @@ def test():
assert torch.all((sp_y - by).abs() < 1e-3)


def test_quicktest():
def test_quicktest(device):
"""
Quick test to check that the built-in `test` module is working
"""
mini_test()
mini_test(device=device)
9 changes: 5 additions & 4 deletions tests/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from utils import setup_image_simulator, setup_simulator


def test_vmapped_simulator():
def test_vmapped_simulator(device):
sim, (sim_params, cosmo_params, lens_params, source_params) = setup_simulator(
batched_params=True
batched_params=True,
device=device,
)
n_pix = sim.n_pix
print(sim.params)
Expand Down Expand Up @@ -35,9 +36,9 @@ def test_vmapped_simulator():
assert vmap(sim)(x_semantic).shape == torch.Size([2, n_pix, n_pix])


def test_vmapped_simulator_with_pixelated_modules():
def test_vmapped_simulator_with_pixelated_modules(device):
sim, (cosmo_params, lens_params, kappa, source) = setup_image_simulator(
batched_params=True
batched_params=True, device=device
)
n_pix = sim.n_pix
print(sim.params)
Expand Down
18 changes: 10 additions & 8 deletions tests/test_cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,29 @@ def get_cosmologies() -> List[Tuple[Cosmology, Cosmology_AP]]:
return cosmologies


def test_comoving_dist():
def test_comoving_dist(device):
rtol = 1e-3
atol = 0

zs = torch.linspace(0.05, 3, 10)
zs = torch.linspace(0.05, 3, 10, device=device)
for cosmology, cosmology_ap in get_cosmologies():
vals = cosmology.comoving_distance(zs).numpy()
vals_ref = cosmology_ap.comoving_distance(zs).value / 1e2 # type: ignore
assert np.allclose(vals, vals_ref, rtol, atol)
cosmology.to(device=device)

vals = cosmology.comoving_distance(zs)
vals_ref = cosmology_ap.comoving_distance(zs.cpu().numpy()).value / 1e2 # type: ignore
assert np.allclose(vals.cpu().numpy(), vals_ref, rtol, atol)

def test_to_method_flatlambdacdm():

def test_to_method_flatlambdacdm(device):
cosmo = CausticFlatLambdaCDM()
# Make sure private tensors are created on float32 by default
assert cosmo._comoving_distance_helper_x_grid.dtype == torch.float32
assert cosmo._comoving_distance_helper_y_grid.dtype == torch.float32
cosmo.to(dtype=torch.float64)
cosmo.to(dtype=torch.float64, device=device)
# Make sure distance helper get sent to proper dtype and device
assert cosmo._comoving_distance_helper_x_grid.dtype == torch.float64
assert cosmo._comoving_distance_helper_y_grid.dtype == torch.float64


if __name__ == "__main__":
test_comoving_dist()
test_comoving_dist(None)
42 changes: 28 additions & 14 deletions tests/test_epl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
from caustics.lenses import EPL


def test_lenstronomy():
def test_lenstronomy(device):
# Models
cosmology = FlatLambdaCDM(name="cosmo")
lens = EPL(name="epl", cosmology=cosmology)
lens = lens.to(device=device)
# There is also an EPL_NUMBA class lenstronomy, but it shouldn't matter much
lens_model_list = ["EPL"]
lens_ls = LensModel(lens_model_list=lens_model_list)

# Parameters
z_s = torch.tensor(1.0)
x = torch.tensor([0.7, 0.912, -0.442, 0.7, pi / 3, 1.4, 1.35])
z_s = torch.tensor(1.0, device=device)
x = torch.tensor([0.7, 0.912, -0.442, 0.7, pi / 3, 1.4, 1.35], device=device)

e1, e2 = param_util.phi_q2_ellipticity(phi=x[4].item(), q=x[3].item())
theta_E = (x[5] / x[3].sqrt()).item()
Expand All @@ -35,23 +36,30 @@ def test_lenstronomy():
]

# Different tolerances for difference quantities
alpha_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol=1e-100, atol=6e-5)
kappa_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol=3e-5, atol=1e-100)
Psi_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol=3e-5, atol=1e-100)
alpha_test_helper(
lens, lens_ls, z_s, x, kwargs_ls, rtol=1e-100, atol=6e-5, device=device
)
kappa_test_helper(
lens, lens_ls, z_s, x, kwargs_ls, rtol=3e-5, atol=1e-100, device=device
)
Psi_test_helper(
lens, lens_ls, z_s, x, kwargs_ls, rtol=3e-5, atol=1e-100, device=device
)


def test_special_case_sie():
def test_special_case_sie(device):
"""
Checks that the deflection field matches an SIE for `t=1`.
"""
cosmology = FlatLambdaCDM(name="cosmo")
lens = EPL(name="epl", cosmology=cosmology)
lens.to(device=device)
lens_model_list = ["SIE"]
lens_ls = LensModel(lens_model_list=lens_model_list)

# Parameters
z_s = torch.tensor(1.9)
x = torch.tensor([0.7, 0.912, -0.442, 0.7, pi / 3, 1.4, 1.0])
z_s = torch.tensor(1.9, device=device)
x = torch.tensor([0.7, 0.912, -0.442, 0.7, pi / 3, 1.4, 1.0], device=device)
e1, e2 = param_util.phi_q2_ellipticity(phi=x[4].item(), q=x[3].item())
theta_E = (x[5] / x[3].sqrt()).item()
kwargs_ls = [
Expand All @@ -65,11 +73,17 @@ def test_special_case_sie():
]

# Different tolerances for difference quantities
alpha_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol=1e-100, atol=6e-5)
kappa_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol=6e-5, atol=1e-100)
Psi_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol=3e-5, atol=1e-100)
alpha_test_helper(
lens, lens_ls, z_s, x, kwargs_ls, rtol=1e-100, atol=6e-5, device=device
)
kappa_test_helper(
lens, lens_ls, z_s, x, kwargs_ls, rtol=6e-5, atol=1e-100, device=device
)
Psi_test_helper(
lens, lens_ls, z_s, x, kwargs_ls, rtol=3e-5, atol=1e-100, device=device
)


if __name__ == "__main__":
test_lenstronomy()
test_special_case_sie()
test_lenstronomy(None)
test_special_case_sie(None)
13 changes: 8 additions & 5 deletions tests/test_external_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@
from caustics.lenses import ExternalShear


def test():
def test(device):
atol = 1e-5
rtol = 1e-5

# Models
cosmology = FlatLambdaCDM(name="cosmo")
lens = ExternalShear(name="shear", cosmology=cosmology)
lens.to(device=device)
lens_model_list = ["SHEAR"]
lens_ls = LensModel(lens_model_list=lens_model_list)
print(lens)

# Parameters
z_s = torch.tensor(2.0)
x = torch.tensor([0.7, 0.12, -0.52, -0.1, 0.1])
z_s = torch.tensor(2.0, device=device)
x = torch.tensor([0.7, 0.12, -0.52, -0.1, 0.1], device=device)
kwargs_ls = [
{
"ra_0": x[1].item(),
Expand All @@ -29,8 +30,10 @@ def test():
}
]

lens_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol, atol, test_kappa=False)
lens_test_helper(
lens, lens_ls, z_s, x, kwargs_ls, rtol, atol, test_kappa=False, device=device
)


if __name__ == "__main__":
test()
test(None)
Loading

0 comments on commit f2f02ce

Please sign in to comment.