diff --git a/changelog/135.feature.rst b/changelog/135.feature.rst new file mode 100644 index 00000000..52b57e55 --- /dev/null +++ b/changelog/135.feature.rst @@ -0,0 +1,10 @@ +Several functions have been updated to accept either numpy array or sunpy map inputs. +The following functions now accept either a numpy array or sunpy map, and return the same data type: + +- `sunkit_image.enhance.mgn` +- `sunkit_image.trace.bandpass_filter` +- `sunkit_image.trace.smooth` + +The following functions now accept either a numpy array or sunpy map, and their return type is unchanged: + +- `sunkit_image.trace.occult2` diff --git a/sunkit_image/enhance.py b/sunkit_image/enhance.py index ff243e48..e8fa9590 100644 --- a/sunkit_image/enhance.py +++ b/sunkit_image/enhance.py @@ -6,9 +6,12 @@ import numpy as np import scipy.ndimage as ndimage +from sunkit_image.utils.decorators import accept_array_or_map + __all__ = ["mgn"] +@accept_array_or_map(arg_name="data") def mgn( data, sigma=[1.25, 2.5, 5, 10, 20, 40], @@ -41,7 +44,7 @@ def mgn( Parameters ---------- - data : `numpy.ndarray` + data : `numpy.ndarray`, `sunpy.map.GenericMap` Image to be transformed. sigma : `list`, optional Range of Gaussian widths (i.e. the standard deviation of the Gaussian kernel) to transform over. @@ -72,8 +75,9 @@ def mgn( Returns ------- - `numpy.ndarray` - Normalized image. + `numpy.ndarray`, `sunpy.map.GenericMap` + Normalized image. If a map is input, a map is returned with new data + and the same metadata. References ---------- diff --git a/sunkit_image/tests/conftest.py b/sunkit_image/tests/conftest.py index 794dc265..75bc61e4 100644 --- a/sunkit_image/tests/conftest.py +++ b/sunkit_image/tests/conftest.py @@ -8,6 +8,7 @@ import astropy import astropy.config.paths +import sunpy.data.sample import sunpy.map from astropy.utils.data import get_pkg_data_filename from sunpy.coordinates import Helioprojective, get_earth @@ -155,3 +156,13 @@ def granule_minimap3(): ) map = sunpy.map.GenericMap(arr, header) return map + + +@pytest.fixture(params=["array", "map"]) +@pytest.mark.remote_data +def aia_171(request): + smap = sunpy.map.Map(sunpy.data.sample.AIA_171_IMAGE) + if request.param == "map": + return smap + elif request.param == "array": + return smap.data diff --git a/sunkit_image/tests/figure_hashes_mpl_372_ft_261_sunkit_image_dev_sunpy_500.json b/sunkit_image/tests/figure_hashes_mpl_372_ft_261_sunkit_image_dev_sunpy_500.json index 05e308a6..e69f4d65 100644 --- a/sunkit_image/tests/figure_hashes_mpl_372_ft_261_sunkit_image_dev_sunpy_500.json +++ b/sunkit_image/tests/figure_hashes_mpl_372_ft_261_sunkit_image_dev_sunpy_500.json @@ -1,6 +1,8 @@ { - "sunkit_image.tests.test_enhance.test_mgn": "4ef8a14c6b3290280b6fc9bc0748c8aa170760b2d43a5bbd67bb4d46ffd5408e", - "sunkit_image.tests.test_radial.test_fig_nrgf": "52b448fa9d845841066d71a1fc6e7bfc471a658d726af1cf16403ddb858fcfc0", - "sunkit_image.tests.test_radial.test_fig_fnrgf": "62621426ff8c3ee03cb1de004f9e08f63c4869dbd8b55d440adc73c74bfc5bcc", - "sunkit_image.tests.test_trace.test_occult2_fig": "c5ff1416cb51dc51ad1705174e5b13f8cb02486dc3c58fd79dc1b80534f6caf5" - } + "sunkit_image.tests.test_enhance.test_mgn[array]": "67b17ac7c343a1d568d9dfa2bb12e0041e2fc53e8ca1c0e1c2c7693eda2bd0d8", + "sunkit_image.tests.test_enhance.test_mgn[map]": "4ef8a14c6b3290280b6fc9bc0748c8aa170760b2d43a5bbd67bb4d46ffd5408e", + "sunkit_image.tests.test_radial.test_fig_nrgf": "52b448fa9d845841066d71a1fc6e7bfc471a658d726af1cf16403ddb858fcfc0", + "sunkit_image.tests.test_radial.test_fig_fnrgf": "62621426ff8c3ee03cb1de004f9e08f63c4869dbd8b55d440adc73c74bfc5bcc", + "sunkit_image.tests.test_trace.test_occult2_fig[array]": "c5ff1416cb51dc51ad1705174e5b13f8cb02486dc3c58fd79dc1b80534f6caf5", + "sunkit_image.tests.test_trace.test_occult2_fig[map]": "c5ff1416cb51dc51ad1705174e5b13f8cb02486dc3c58fd79dc1b80534f6caf5" +} diff --git a/sunkit_image/tests/test_enhance.py b/sunkit_image/tests/test_enhance.py index db58e1b7..a424b680 100644 --- a/sunkit_image/tests/test_enhance.py +++ b/sunkit_image/tests/test_enhance.py @@ -8,19 +8,13 @@ from sunkit_image.tests.helpers import figure_test -@pytest.fixture -@pytest.mark.remote_data -def smap(): - return sunpy.map.Map(sunpy.data.sample.AIA_171_IMAGE) - - @figure_test @pytest.mark.remote_data -def test_mgn(smap): - out = enhance.mgn(smap.data) - out = sunpy.map.Map(out, smap.meta) - - out.plot() +def test_mgn(aia_171): + out = enhance.mgn(aia_171) + assert type(out) == type(aia_171) + if isinstance(out, sunpy.map.GenericMap): + out.plot() @pytest.fixture diff --git a/sunkit_image/tests/test_trace.py b/sunkit_image/tests/test_trace.py index 34a02ad3..659c3a23 100644 --- a/sunkit_image/tests/test_trace.py +++ b/sunkit_image/tests/test_trace.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import sunpy.map from astropy.io import fits import sunkit_image.data.test as data @@ -19,13 +20,20 @@ ) -@pytest.fixture +@pytest.fixture(params=["array", "map"]) @pytest.mark.remote_data -def image_remote(): +def image_remote(request): with warnings.catch_warnings(): warnings.simplefilter("ignore", category=fits.verify.VerifyWarning) - im = fits.getdata("http://data.sunpy.org/sunkit-image/trace_1998-05-19T22:21:43.000_171_1024.fits") - return im + data, header = fits.getdata( + "http://data.sunpy.org/sunkit-image/trace_1998-05-19T22:21:43.000_171_1024.fits", header=True + ) + if request.param == "map": + return sunpy.map.Map((data, header)) + elif request.param == "array": + return data + else: + raise ValueError(f"Invalid request parameter {request.param}") @pytest.fixture @@ -161,16 +169,18 @@ def test_map(): @pytest.fixture -def image(): +def test_map_ones(): return np.ones((4, 4), dtype=np.float32) -def test_bandpass_filter(image, test_map): +def test_bandpass_filter_ones(test_map_ones): expect = np.zeros((4, 4)) - result = bandpass_filter(image) + result = bandpass_filter(test_map_ones) assert np.allclose(expect, result) + +def test_bandpass_filter(test_map): expect = np.array( [ [0.0, 0.0, 0.0, 0.0], @@ -181,22 +191,30 @@ def test_bandpass_filter(image, test_map): ) result = bandpass_filter(test_map) - assert np.allclose(expect, result) - with pytest.raises(ValueError) as record: - _ = bandpass_filter(image, 5, 1) - assert str(record.value) == "nsm1 should be less than nsm2" +@pytest.mark.remote_data +def test_bandpass_filter_output(aia_171): + # Check that bandpass filter works with both arrays and maps + result = bandpass_filter(aia_171) + assert type(result) == type(aia_171) + + +def test_bandpass_filter_error(test_map_ones): + with pytest.raises(ValueError, match="nsm1 should be less than nsm2"): + bandpass_filter(test_map_ones, 5, 1) + +def test_smooth_ones(test_map_ones): + filtered = smooth(test_map_ones, 1) + assert np.allclose(filtered, test_map_ones) -def test_smooth(image, test_map): - filtered = smooth(image, 1) - assert np.allclose(filtered, image) + filtered = smooth(test_map_ones, 4) + assert np.allclose(filtered, test_map_ones) - filtered = smooth(image, 4) - assert np.allclose(filtered, image) +def test_smooth(test_map): filtered = smooth(test_map, 1) assert np.allclose(filtered, test_map) @@ -213,7 +231,14 @@ def test_smooth(image, test_map): assert np.allclose(filtered, expect) -def test_erase_loop_in_image(image, test_map): +@pytest.mark.remote_data +def test_smooth_output(aia_171): + # Check that smooth works with both arrays and maps + result = smooth(aia_171, 1) + assert type(result) == type(aia_171) + + +def test_erase_loop_in_image(test_map_ones, test_map): # The starting point of a dummy loop istart = 0 jstart = 1 @@ -223,7 +248,7 @@ def test_erase_loop_in_image(image, test_map): xloop = [1, 2, 3] yloop = [1, 1, 1] - result = _erase_loop_in_image(image, istart, jstart, width, xloop, yloop) + result = _erase_loop_in_image(test_map_ones, istart, jstart, width, xloop, yloop) expect = np.array([[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]]) diff --git a/sunkit_image/trace.py b/sunkit_image/trace.py index 6db196c1..1e157a7a 100644 --- a/sunkit_image/trace.py +++ b/sunkit_image/trace.py @@ -6,6 +6,8 @@ import numpy as np from scipy import interpolate +from sunkit_image.utils.decorators import accept_array_or_map + __all__ = [ "occult2", "bandpass_filter", @@ -13,6 +15,7 @@ ] +@accept_array_or_map(arg_name="image", output_to_map=False) def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2): """ Implements the Oriented Coronal CUrved Loop Tracing (OCCULT-2) algorithm @@ -20,7 +23,7 @@ def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2): Parameters ---------- - image : `numpy.ndarray` + image : `numpy.ndarray`, `sunpy.map.GenericMap` Image in which loops are to be detected. nsm1 : `int` Low pass filter boxcar smoothing constant. @@ -44,7 +47,7 @@ def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2): ------- `list` A list of all loop where each element is itself a list of points containing - ``x`` and ``y`` coordinates for each point. + ``x`` and ``y`` pixel coordinates for each point. References ---------- @@ -217,13 +220,14 @@ def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2): # The functions below this are subroutines for the OCCULT 2. +@accept_array_or_map(arg_name="image") def bandpass_filter(image, nsm1=1, nsm2=3): """ Applies a band pass filter to the image. Parameters ---------- - image : `numpy.ndarray` + image : `numpy.ndarray`, `sunpy.map.GenericMap` Image to be filtered. nsm1 : `int` Low pass filter boxcar smoothing constant. @@ -236,7 +240,8 @@ def bandpass_filter(image, nsm1=1, nsm2=3): Returns ------- `numpy.ndarray` - Bandpass filtered image. + Bandpass filtered image. If a map is input, a map is returned with new data + and the same metadata. """ if nsm1 >= nsm2: raise ValueError("nsm1 should be less than nsm2") @@ -248,13 +253,14 @@ def bandpass_filter(image, nsm1=1, nsm2=3): return smooth(image, nsm1, "replace") - smooth(image, nsm2, "replace") +@accept_array_or_map(arg_name="image") def smooth(image, width, nanopt="replace"): """ Python implementation of the IDL's ``smooth``. Parameters ---------- - image : `numpy.ndarray` + image : `numpy.ndarray`, `sunpy.map.GenericMap` Image to be filtered. width : `int` Width of the boxcar window. The `width` should always be odd but if even value is given then @@ -264,8 +270,9 @@ def smooth(image, width, nanopt="replace"): Returns ------- - `numpy.ndarray` - Smoothed image. + `numpy.ndarray`, `sunpy.map.GenericMap` + Smoothed image. If a map is input, a map is returned with new data + and the same metadata. References ---------- diff --git a/sunkit_image/utils/decorators.py b/sunkit_image/utils/decorators.py new file mode 100644 index 00000000..48b8e910 --- /dev/null +++ b/sunkit_image/utils/decorators.py @@ -0,0 +1,56 @@ +import inspect +from typing import Union, Callable + +import numpy as np + +from sunpy.map import GenericMap, Map + + +def accept_array_or_map(*, arg_name: str, output_to_map=True) -> Callable[[Callable], Callable]: + """ + Decorator that allows a function to accept an array or a + `sunpy.map.GenericMap` as an argument. + + This can be applied to functions that: + + - Take a single array or map as input + - Return a single array that has the same pixel coordinates + as the input array. + + Parameters + ---------- + arg_name : `str` + Name of data/map argument in function signature. + output_to_map : `bool` + If `True` (the default), convert the function return to a map if a map + is given as input. For this to work the decorated function must return + an array where pixels have the same coordinates as the input map data. + """ + + def decorate(f: Callable) -> Callable: + sig = inspect.signature(f) + if arg_name not in sig.parameters: + raise RuntimeError(f"Could not find '{arg_name}' in function signature") + + def inner(*args, **kwargs) -> Union[np.ndarray, GenericMap]: + sig_bound = sig.bind(*args, **kwargs) + map_arg = sig_bound.arguments[arg_name] + if isinstance(map_arg, GenericMap): + map_in = True + sig_bound.arguments[arg_name] = map_arg.data + elif isinstance(map_arg, np.ndarray): + map_in = False + else: + raise ValueError(f"'{arg_name}' argument must be a sunpy map or numpy array (got type {type(map_arg)})") + + # Run decorated function + array_out = f(*sig_bound.args, **sig_bound.kwargs) + + if map_in and output_to_map: + return Map(array_out, map_arg.meta) + else: + return array_out + + return inner + + return decorate