Skip to content

Commit

Permalink
Entirely replaced ApplyJitter with the multivariate jitter class
Browse files Browse the repository at this point in the history
  • Loading branch information
maxecharles committed Sep 19, 2023
1 parent 8abed43 commit 0f0ad97
Showing 1 changed file with 19 additions and 72 deletions.
91 changes: 19 additions & 72 deletions dLux/detector_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,68 +102,16 @@ class ApplyJitter(DetectorLayer):
Attributes
----------
sigma : Array, pixels
The standard deviation of the Gaussian kernel, in units of pixels.
kernel_size : int
The size of the convolution kernel to use.
The size in pixels of the convolution kernel to use.
r : float, arcseconds
The magnitude of the jitter.
shear : float
The shear of the jitter. A radially symmetric Gaussian kernel would have a shear value of 0.
phi : float, degrees
The angle of the jitter.
"""

kernel_size: int
sigma: Array

def __init__(self: DetectorLayer, sigma: Array, kernel_size: int = 10):
"""
Constructor for the ApplyJitter class.
Parameters
----------
sigma : Array, pixels
The standard deviation of the Gaussian kernel, in units of pixels.
kernel_size : int = 10
The size of the convolution kernel to use.
"""
super().__init__()
self.kernel_size = int(kernel_size)
self.sigma = np.asarray(sigma, dtype=float)
if self.sigma.ndim != 0:
raise ValueError("sigma must be a scalar array.")

def generate_kernel(self: DetectorLayer, pixel_scale: Array) -> Array:
"""
Generates the normalised Gaussian kernel.
Returns
-------
kernel : Array
The Gaussian kernel.
"""

extent = self.kernel_size * pixel_scale
x = np.linspace(0, extent, self.kernel_size) - 0.5 * extent
kernel = norm.pdf(x, scale=self.sigma) * norm.pdf(
x[:, None], scale=self.sigma
)
return kernel / np.sum(kernel)

def __call__(self: DetectorLayer, image: Image()) -> Image():
"""
Applies the layer to the Image.
Parameters
----------
image : Image
The image to operate on.
Returns
-------
image : Image
The transformed image.
"""
kernel = self.generate_kernel(image.pixel_scale)
return image.convolve(kernel)


class ApplyAsymmetricJitter(DetectorLayer):
kernel_size: int
r: float = None
shear: float = None
Expand All @@ -184,7 +132,7 @@ def __init__(
r : float, arcseconds
The magnitude of the jitter.
shear : float
The shear of the jitter.
The shear of the jitter. A radially symmetric Gaussian kernel would have a shear value of 0.
phi : float, degrees
The angle of the jitter.
kernel_size : int = 10
Expand All @@ -198,6 +146,14 @@ def __init__(

@property
def covariance_matrix(self):
"""
Generates the covariance matrix for the multivariate normal distribution.
Returns
-------
covariance_matrix : Array
The covariance matrix.
"""
rot_angle = np.radians(self.phi) - np.pi / 4

# Construct the rotation matrix
Expand All @@ -218,20 +174,11 @@ def covariance_matrix(self):
np.dot(rotation_matrix, skew_matrix), rotation_matrix.T
)

# Ensure positive semi-definiteness
try:
# Attempt Cholesky decomposition
jax.scipy.linalg.cholesky(covariance_matrix)
return covariance_matrix
except:
# TODO don't think this works
raise ValueError(
"Covariance matrix is not positive semi-definite."
)
return covariance_matrix

def generate_kernel(self, pixel_scale: float) -> Array:
"""
Generates the normalised Gaussian kernel.
Generates the normalised multivariate Gaussian kernel.
Parameters
----------
Expand Down Expand Up @@ -270,7 +217,7 @@ def __call__(self: DetectorLayer, image: Image()) -> Image():
The transformed image.
"""
kernel = self.generate_kernel(
dl.utils.rad_to_arcsec(image.pixel_scale)
dLux.utils.rad_to_arcsec(image.pixel_scale)
)

return image.convolve(kernel)
Expand Down

0 comments on commit 0f0ad97

Please sign in to comment.