Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve behavior of SED class, especially with respect to pickling and spline interpolation. #1257

Merged
merged 17 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,18 @@ Bug Fixes
- Fixed a slight inaccuracy in the FFT phase shifts for single-precision images. (#1231, #1234)
- Fixed a bug that prevented a convolution of two PhaseScreenPSF objects from being drawn with
photon shooting. (#1242)


Changes from v2.5.0 to v2.5.1
=============================

- Fixed an incompatibility with Python 3.12 that we had missed.
- Fixed a bug in the SED class normalization when using astropy.units for flux_type. Thanks
to Sid Mau for finding and fixing this bug. (#1254, #1256)
- Fixed a bug in the `EmissionLine.atRedshift` method. (#1257)
- Added interpolant option to `SED` and `Bandpass` classes to use when reading from a file.
(#1257)
- Improved the behavior of SEDs when using spline interpolant. (#1187, #1257)
- No longer pickle the SED of chromatic objects when the SED is a derived value. (#1257)
- Added interpolant option to `utilities.trapz`. (#1257)
- Added clip_neg option to `DistDeviate` class. (#1257)
28 changes: 18 additions & 10 deletions galsim/bandpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ class Bandpass:
zeropoint: Set the zero-point for this Bandpass. Here, this can only be a float
value. See the method `withZeropoint` for other options for how to
set this using a particular spectrum (AB, Vega, etc.) [default: None]
interpolant: If reading from a file, what interpolant to use. [default: 'linear']
"""
def __init__(self, throughput, wave_type, blue_limit=None, red_limit=None,
zeropoint=None, _wave_list=None, _tp=None):
zeropoint=None, interpolant='linear', _wave_list=None, _tp=None):
# Note that `_wave_list` acts as a private construction variable that overrides the way that
# `wave_list` is normally constructed (see `Bandpass.__mul__` below)

Expand All @@ -122,6 +123,7 @@ def __init__(self, throughput, wave_type, blue_limit=None, red_limit=None,
self.blue_limit = blue_limit # These may change as we go through this.
self.red_limit = red_limit
self.zeropoint = zeropoint
self.interpolant = interpolant

# Parse the various options for wave_type
if isinstance(wave_type, str):
Expand Down Expand Up @@ -222,7 +224,7 @@ def _initialize_tp(self):
elif isinstance(self._orig_tp, basestring):
isfile, filename = utilities.check_share_file(self._orig_tp, 'bandpasses')
if isfile:
self._tp = LookupTable.from_file(filename, interpolant='linear')
self._tp = LookupTable.from_file(filename, interpolant=self.interpolant)
else:
if self.blue_limit is None or self.red_limit is None:
raise GalSimIncompatibleValuesError(
Expand Down Expand Up @@ -433,8 +435,9 @@ def withZeropoint(self, zeropoint):
raise TypeError(
"Don't know how to handle zeropoint of type: {0}".format(type(zeropoint)))

return Bandpass(self._orig_tp, self.wave_type, self.blue_limit, self.red_limit, zeropoint,
self.wave_list, self._tp)
return Bandpass(self._orig_tp, self.wave_type, self.blue_limit, self.red_limit,
zeropoint=zeropoint, interpolant=self.interpolant,
_wave_list=self.wave_list, _tp=self._tp)

def truncate(self, blue_limit=None, red_limit=None, relative_throughput=None,
preserve_zp='auto'):
Expand Down Expand Up @@ -528,9 +531,11 @@ def truncate(self, blue_limit=None, red_limit=None, relative_throughput=None,

if preserve_zp:
return Bandpass(self._orig_tp, self.wave_type, blue_limit, red_limit,
_wave_list=wave_list, _tp=self._tp, zeropoint=self.zeropoint)
zeropoint=self.zeropoint, interpolant=self.interpolant,
_wave_list=wave_list, _tp=self._tp)
else:
return Bandpass(self._orig_tp, self.wave_type, blue_limit, red_limit,
interpolant=self.interpolant,
_wave_list=wave_list, _tp=self._tp)

def thin(self, rel_err=1.e-4, trim_zeros=True, preserve_range=True, fast_search=True,
Expand Down Expand Up @@ -586,11 +591,14 @@ def thin(self, rel_err=1.e-4, trim_zeros=True, preserve_range=True, fast_search=
if len(self.wave_list) > 0:
x = self.wave_list
f = self(x)
interpolant = (self.interpolant if not isinstance(self._tp, LookupTable)
else self._tp.interpolant)
newx, newf = utilities.thin_tabulated_values(x, f, rel_err=rel_err,
trim_zeros=trim_zeros,
preserve_range=preserve_range,
fast_search=fast_search)
tp = _LookupTable(newx, newf, 'linear')
fast_search=fast_search,
interpolant=interpolant)
tp = _LookupTable(newx, newf, interpolant)
blue_limit = np.min(newx)
red_limit = np.max(newx)
wave_list = np.array(newx)
Expand Down Expand Up @@ -622,9 +630,9 @@ def __hash__(self):

def __repr__(self):
return ('galsim.Bandpass(%r, wave_type=%r, blue_limit=%r, red_limit=%r, zeropoint=%r, '
'_wave_list=array(%r))')%(
self._orig_tp, self.wave_type, self.blue_limit, self.red_limit, self.zeropoint,
self.wave_list.tolist())
'interpolant=%r, _wave_list=array(%r))')%(
self._orig_tp, self.wave_type, self.blue_limit, self.red_limit,
self.zeropoint, self.interpolant, self.wave_list.tolist())

def __str__(self):
orig_tp = repr(self._orig_tp)
Expand Down
48 changes: 48 additions & 0 deletions galsim/chromatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1895,6 +1895,14 @@ def __str__(self):
s += '.atRedshift(%s)'%(self._redshift)
return s

def __getstate__(self):
d = self.__dict__.copy()
d.pop('sed',None)
return d

def __setstate__(self, d):
self.__dict__ = d

def _getTransformations(self, wave):
if hasattr(self._jac, '__call__'):
jac = self._jac(wave)
Expand Down Expand Up @@ -2084,6 +2092,14 @@ def __repr__(self):
def __str__(self):
return str(self.original) + ' * ' + str(self.sed)

def __getstate__(self):
d = self.__dict__.copy()
d.pop('sed',None)
return d

def __setstate__(self, d):
self.__dict__ = d

def _getTransformations(self, wave):
flux_ratio = self._flux_ratio(wave)
return self._jac, self._offset, flux_ratio
Expand Down Expand Up @@ -2269,6 +2285,14 @@ def __str__(self):
str_list = [ str(obj) for obj in self.obj_list ]
return 'galsim.ChromaticSum([%s])'%', '.join(str_list)

def __getstate__(self):
d = self.__dict__.copy()
d.pop('sed',None)
return d

def __setstate__(self, d):
self.__dict__ = d

def evaluateAtWavelength(self, wave):
"""Evaluate this chromatic object at a particular wavelength ``wave``.

Expand Down Expand Up @@ -2580,6 +2604,14 @@ def __str__(self):
str_list = [ str(obj) for obj in self.obj_list ]
return 'galsim.ChromaticConvolution([%s])'%', '.join(str_list)

def __getstate__(self):
d = self.__dict__.copy()
d.pop('sed',None)
return d

def __setstate__(self, d):
self.__dict__ = d

def _approxWavelength(self, wave):
# If any of the components prefer a different wavelength, use that for all.
achrom_objs = []
Expand Down Expand Up @@ -2985,6 +3017,14 @@ def __repr__(self):
def __str__(self):
return 'galsim.ChromaticAutoConvolution(%s)'%self._obj

def __getstate__(self):
d = self.__dict__.copy()
d.pop('sed',None)
return d

def __setstate__(self, d):
self.__dict__ = d

def evaluateAtWavelength(self, wave):
"""Evaluate this chromatic object at a particular wavelength ``wave``.

Expand Down Expand Up @@ -3082,6 +3122,14 @@ def __repr__(self):
def __str__(self):
return 'galsim.ChromaticAutoCorrelation(%s)'%self._obj

def __getstate__(self):
d = self.__dict__.copy()
d.pop('sed',None)
return d

def __setstate__(self, d):
self.__dict__ = d

def evaluateAtWavelength(self, wave):
"""Evaluate this chromatic object at a particular wavelength ``wave``.

Expand Down
12 changes: 9 additions & 3 deletions galsim/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,9 +769,11 @@ class DistDeviate(BaseDeviate):
npoints: Number of points DistDeviate should create for its internal interpolation
tables. [default: 256, unless the function is a non-log `LookupTable`, in
which case it uses the table's x values]
clip_neg: Clip any negative input values to zero. [default: False; an error will
be raised if any negative probabilities are found.]
"""
def __init__(self, seed=None, function=None, x_min=None,
x_max=None, interpolant=None, npoints=None):
x_max=None, interpolant=None, npoints=None, clip_neg=False):

# Set up the PRNG
self._rng_type = _galsim.UniformDeviateImpl
Expand Down Expand Up @@ -823,7 +825,7 @@ def __init__(self, seed=None, function=None, x_min=None,
"Cannot provide an interpolant with a callable function argument",
interpolant=interpolant, function=function)
if isinstance(function, LookupTable):
if x_min or x_max:
if (x_min not in (None, function.x_min)) or (x_max not in (None, function.x_max)):
raise GalSimIncompatibleValuesError(
"Cannot provide x_min or x_max with a LookupTable function",
function=function, x_min=x_min, x_max=x_max)
Expand Down Expand Up @@ -859,7 +861,11 @@ def __init__(self, seed=None, function=None, x_min=None,
pdf = np.array(pdf)

# Check that the probability is nonnegative
if not np.all(pdf >= 0.):
if clip_neg:
# Write it this way so nan -> 0 as well as negative values.
w = np.where(~(pdf >= 0))
pdf[w] = 0.
elif not np.all(pdf >= 0.):
raise GalSimValueError('Negative probability found in DistDeviate.',function)

# Compute the cumulative distribution function = int(pdf(x),x)
Expand Down
4 changes: 2 additions & 2 deletions galsim/real.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .gsparams import GSParams
from .position import PositionD
from .bounds import BoundsI
from .utilities import lazy_property, doc_inherit, convert_interpolant
from .utilities import lazy_property, doc_inherit, convert_interpolant, merge_sorted
from .interpolant import Quintic
from .interpolatedimage import InterpolatedImage, _InterpolatedKImage
from .convolve import Convolve, Deconvolve
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def _poly_SEDs(bands):
# Use polynomial SEDs by default; up to the number of bands provided.
waves = []
for bp in bands:
waves = np.union1d(waves, bp.wave_list)
waves = merge_sorted([waves, bp.wave_list])
SEDs = []
for i in range(len(bands)):
SEDs.append(
Expand Down
Loading