Skip to content

Commit

Permalink
Merge pull request #37 from Ciela-Institute/simulatorclass
Browse files Browse the repository at this point in the history
Added Simulator class object
  • Loading branch information
ConnorStoneAstro authored May 29, 2023
2 parents 1e386d7 + b32e4c0 commit 8b63f29
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 17 deletions.
7 changes: 3 additions & 4 deletions src/caustic/lenses/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod
from typing import Any, Optional
from functools import partial

import torch
from torch import Tensor
Expand All @@ -11,7 +12,6 @@

__all__ = ("ThinLens", "ThickLens")


class ThickLens(Parametrized):
"""
Base class for modeling gravitational lenses that cannot be treated using the thin lens approximation.
Expand Down Expand Up @@ -117,8 +117,7 @@ def magnification(self, thx: Tensor, thy: Tensor, z_s: Tensor, x) -> Tensor:
Returns:
Tensor: The gravitational lensing magnification at the given coordinates.
"""
return get_magnification(self.raytrace, thx, thy, z_s, x)

return get_magnification(partial(self.raytrace, x = x), thx, thy, z_s)

class ThinLens(Parametrized):
"""Base class for thin gravitational lenses.
Expand Down Expand Up @@ -328,4 +327,4 @@ def magnification(self, thx: Tensor, thy: Tensor, z_s: Tensor, x) -> Tensor:
Returns:
Tensor: Gravitational magnification at the given coordinates.
"""
return get_magnification(self.raytrace, thx, thy, z_s, x)
return get_magnification(partial(self.raytrace, x = x), thx, thy, z_s)
14 changes: 7 additions & 7 deletions src/caustic/lenses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def get_pix_jacobian(
raytrace, thx, thy, z_s, x
raytrace, thx, thy, z_s
) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
"""Computes the Jacobian matrix of the partial derivatives of the
image position with respect to the source position
Expand All @@ -27,11 +27,11 @@ def get_pix_jacobian(
The Jacobian matrix of the image position with respect to the source position at the given point.
"""
jac = torch.func.jacfwd(raytrace, (0, 1))(thx, thy, z_s, x) # type: ignore
jac = torch.func.jacfwd(raytrace, (0, 1))(thx, thy, z_s) # type: ignore
return jac


def get_pix_magnification(raytrace, thx, thy, z_s, x) -> Tensor:
def get_pix_magnification(raytrace, thx, thy, z_s) -> Tensor:
"""
Computes the magnification at a single point on the lensing plane. The magnification is derived from the determinant
of the Jacobian matrix of the image position with respect to the source position.
Expand All @@ -46,11 +46,11 @@ def get_pix_magnification(raytrace, thx, thy, z_s, x) -> Tensor:
Returns:
The magnification at the given point on the lensing plane.
"""
jac = get_pix_jacobian(raytrace, thx, thy, z_s, x)
jac = get_pix_jacobian(raytrace, thx, thy, z_s)
return 1 / (jac[0][0] * jac[1][1] - jac[0][1] * jac[1][0]).abs()


def get_magnification(raytrace, thx, thy, z_s, x) -> Tensor:
def get_magnification(raytrace, thx, thy, z_s) -> Tensor:
"""
Computes the magnification over a grid on the lensing plane. This is done by calling `get_pix_magnification`
for each point on the grid.
Expand All @@ -65,6 +65,6 @@ def get_magnification(raytrace, thx, thy, z_s, x) -> Tensor:
Returns:
A tensor representing the magnification at each point on the grid.
"""
return vmap_n(get_pix_magnification, 2, (None, 0, 0, None, None))(
raytrace, thx, thy, z_s, x
return vmap_n(get_pix_magnification, 2, (None, 0, 0, None))(
raytrace, thx, thy, z_s
)
1 change: 0 additions & 1 deletion src/caustic/packed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import OrderedDict


class Packed(OrderedDict):
"""
Dummy wrapper for `x` so other functions can check its type.
Expand Down
8 changes: 3 additions & 5 deletions src/caustic/parametrized.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

__all__ = ("Parametrized",)


class Parametrized:
"""
Represents a class with Param and Parametrized attributes, typically used to construct parts of a simulator
Expand Down Expand Up @@ -282,7 +281,7 @@ def pack(
ValueError: If the number of dynamic arguments does not match the expected number.
ValueError: If the input is a tensor and the shape does not match the expected shape.
"""
if isinstance(x, dict):
if isinstance(x, (dict, Packed)):
missing_names = [
name for name in chain([self.name], self._descendants) if name not in x
]
Expand All @@ -302,7 +301,7 @@ def pack(
# TODO: give component and arg names
raise ValueError(
f"{n_passed} dynamic args were passed, but {n_expected} are "
"required"
"required."
)

cur_offset = self.n_dynamic
Expand Down Expand Up @@ -559,8 +558,7 @@ def add_params(p: Parametrized, dot):
add_params(desc, dot)

return dot



# class ParametrizedList(Parametrized):
# """
# TODO
Expand Down
18 changes: 18 additions & 0 deletions src/caustic/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .parameterized import Parametrized

__all__ = ("Simulator", )

class Simulator(Parametrized):
"""A caustic simulator using Parametrized framework.
Defines a simulator class which is a callable function that
operates on the Parametrized framework. Users define the `forward`
method which takes as its first argument an object which can be
packed, all other args and kwargs are simply passed to the forward
method.
See `Parametrized` for details on how to add/access parameters.
"""
def __call__(self, *args, **kwargs):
return self.forward(self.pack(args[0]), *args[1:], **kwargs)

0 comments on commit 8b63f29

Please sign in to comment.