Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new features to the ndcube._add_ method #794

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/794.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allows addition of an ``NDCube`` and ``NDData`` (with the WCS of ``NDData`` being set to None), and combines their uncertainties and masks.
43 changes: 40 additions & 3 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import astropy.nddata
import astropy.units as u
from astropy.nddata import NDData
from astropy.units import UnitsError
from astropy.wcs.utils import _split_matrix

Expand Down Expand Up @@ -1034,23 +1035,59 @@ def __neg__(self):
return self._new_instance(data=-self.data)

def __add__(self, value):
kwargs = {}
if isinstance(value, NDData) and value.wcs is None:
if self.unit is not None and value.unit is not None:
value_data = (value.data * value.unit).to_value(self.unit)
elif self.unit is None:
value_data = value.data
else:
raise TypeError("Cannot add unitless NDData to a unitful NDCube.")

# use the format of the output of np.ma.MaskedArray, for combining mask
self_ma = np.ma.MaskedArray(self.data, mask=self.mask)
value_ma = np.ma.MaskedArray(value_data, mask=value.mask)

# addition, (and combining mask)
result_ma = self_ma + value_ma

# extract new mask and new data
kwargs["mask"] = result_ma.mask
kwargs["data"] = result_ma

# combine the uncertainty
if self.uncertainty is not None and value.uncertainty is not None:
new_uncertainty = self.uncertainty.propagate(
np.add, value, result_data = kwargs["data"], correlation=0
)
kwargs["uncertainty"] = new_uncertainty
elif self.uncertainty is not None:
new_uncertainty = self.uncertainty
kwargs["uncertainty"] = new_uncertainty
elif value.uncertainty is not None:
new_uncertainty = value.uncertainty
else:
new_uncertainty = None

if hasattr(value, 'unit'):
if isinstance(value, u.Quantity):
# NOTE: if the cube does not have units, we cannot
# perform arithmetic between a unitful quantity.
# This forces a conversion to a dimensionless quantity
# so that an error is thrown if value is not dimensionless
cube_unit = u.Unit('') if self.unit is None else self.unit
new_data = self.data + value.to_value(cube_unit)
kwargs["data"] = self.data + value.to_value(cube_unit)
else:
# NOTE: This explicitly excludes other NDCube objects and NDData objects
# which could carry a different WCS than the NDCube
return NotImplemented
elif self.unit not in (None, u.Unit("")):
raise TypeError("Cannot add a unitless object to an NDCube with a unit.")
else:
new_data = self.data + value
return self._new_instance(data=new_data)
kwargs["data"] = self.data + value

# return the new NDCube instance
return self._new_instance(**kwargs)

def __radd__(self, value):
return self.__add__(value)
Expand Down
27 changes: 26 additions & 1 deletion ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import astropy.wcs
from astropy.coordinates import SkyCoord, SpectralCoord
from astropy.io import fits
from astropy.nddata import UnknownUncertainty
from astropy.nddata import NDData, StdDevUncertainty, UnknownUncertainty
from astropy.tests.helper import assert_quantity_allclose
from astropy.time import Time
from astropy.units import UnitsError
Expand Down Expand Up @@ -1124,6 +1124,31 @@
check_arithmetic_value_and_units(new_cube, cube_quantity + value)


@pytest.mark.parametrize('value', [
NDData(np.random.rand(10, 12),
unit=u.ct,
wcs=None,
uncertainty=StdDevUncertainty(np.random.rand(10, 12)),
mask=np.random.choice([True, False], size=(10, 12))),
])
def test_cube_add_uncertainty_and_mask(ndcube_2d_ln_lt_units, value):
new_cube = ndcube_2d_ln_lt_units + value
# Check uncertainty propagation
expected_uncertainty = ndcube_2d_ln_lt_units.uncertainty.propagate(
operation=np.add,
other_nddata=value,
result_data=new_cube.data,
correlation=0,
)
assert np.allclose(new_cube.uncertainty.array, expected_uncertainty), \
f"Expected uncertainty: {expected_uncertainty}, but got: {new_cube.uncertainty.array}"
# Check mask combination
expected_mask = (np.ma.MaskedArray(ndcube_2d_ln_lt_units.data, mask=ndcube_2d_ln_lt_units.mask) + \
np.ma.MaskedArray(ndcube_2d_ln_lt_units.data, mask=ndcube_2d_ln_lt_units.mask)).mask
assert np.array_equal(new_cube.mask, expected_mask), \
f"Expected mask: {expected_mask}, but got: {new_cube.mask}"


Check warning on line 1151 in ndcube/tests/test_ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/tests/test_ndcube.py#L1151

Added line #L1151 was not covered by tests
@pytest.mark.parametrize('value', [
10 * u.ct,
u.Quantity([10], u.ct),
Expand Down
Loading