Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept either arrays or maps as input where appropriate #135

Merged
merged 12 commits into from
Aug 9, 2023
10 changes: 10 additions & 0 deletions changelog/135.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Several functions have been updated to accept either numpy array or sunpy map inputs.
The following functions now accept either a numpy array or sunpy map, and return the same data type:

- `sunkit_image.enhance.mgn`
- `sunkit_image.trace.bandpass_filter`
- `sunkit_image.trace.smooth`

The following functions now accept either a numpy array or sunpy map, and their return type is unchanged:

- `sunkit_image.trace.occult2`
10 changes: 7 additions & 3 deletions sunkit_image/enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import numpy as np
import scipy.ndimage as ndimage

from sunkit_image.utils.decorators import accept_array_or_map

__all__ = ["mgn"]


@accept_array_or_map(arg_name="data")
def mgn(
data,
sigma=[1.25, 2.5, 5, 10, 20, 40],
Expand Down Expand Up @@ -41,7 +44,7 @@ def mgn(

Parameters
----------
data : `numpy.ndarray`
data : `numpy.ndarray`, `sunpy.map.GenericMap`
Image to be transformed.
sigma : `list`, optional
Range of Gaussian widths (i.e. the standard deviation of the Gaussian kernel) to transform over.
Expand Down Expand Up @@ -72,8 +75,9 @@ def mgn(

Returns
-------
`numpy.ndarray`
Normalized image.
`numpy.ndarray`, `sunpy.map.GenericMap`
Normalized image. If a map is input, a map is returned with new data
and the same metadata.

References
----------
Expand Down
11 changes: 11 additions & 0 deletions sunkit_image/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import astropy
import astropy.config.paths
import sunpy.data.sample
import sunpy.map
from astropy.utils.data import get_pkg_data_filename
from sunpy.coordinates import Helioprojective, get_earth
Expand Down Expand Up @@ -155,3 +156,13 @@ def granule_minimap3():
)
map = sunpy.map.GenericMap(arr, header)
return map


@pytest.fixture(params=["array", "map"])
@pytest.mark.remote_data
def aia_171(request):
smap = sunpy.map.Map(sunpy.data.sample.AIA_171_IMAGE)
if request.param == "map":
return smap
elif request.param == "array":
return smap.data
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{
"sunkit_image.tests.test_enhance.test_mgn": "4ef8a14c6b3290280b6fc9bc0748c8aa170760b2d43a5bbd67bb4d46ffd5408e",
"sunkit_image.tests.test_radial.test_fig_nrgf": "52b448fa9d845841066d71a1fc6e7bfc471a658d726af1cf16403ddb858fcfc0",
"sunkit_image.tests.test_radial.test_fig_fnrgf": "62621426ff8c3ee03cb1de004f9e08f63c4869dbd8b55d440adc73c74bfc5bcc",
"sunkit_image.tests.test_trace.test_occult2_fig": "c5ff1416cb51dc51ad1705174e5b13f8cb02486dc3c58fd79dc1b80534f6caf5"
}
"sunkit_image.tests.test_enhance.test_mgn[array]": "67b17ac7c343a1d568d9dfa2bb12e0041e2fc53e8ca1c0e1c2c7693eda2bd0d8",
"sunkit_image.tests.test_enhance.test_mgn[map]": "4ef8a14c6b3290280b6fc9bc0748c8aa170760b2d43a5bbd67bb4d46ffd5408e",
"sunkit_image.tests.test_radial.test_fig_nrgf": "52b448fa9d845841066d71a1fc6e7bfc471a658d726af1cf16403ddb858fcfc0",
"sunkit_image.tests.test_radial.test_fig_fnrgf": "62621426ff8c3ee03cb1de004f9e08f63c4869dbd8b55d440adc73c74bfc5bcc",
"sunkit_image.tests.test_trace.test_occult2_fig[array]": "c5ff1416cb51dc51ad1705174e5b13f8cb02486dc3c58fd79dc1b80534f6caf5",
"sunkit_image.tests.test_trace.test_occult2_fig[map]": "c5ff1416cb51dc51ad1705174e5b13f8cb02486dc3c58fd79dc1b80534f6caf5"
}
16 changes: 5 additions & 11 deletions sunkit_image/tests/test_enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@
from sunkit_image.tests.helpers import figure_test


@pytest.fixture
@pytest.mark.remote_data
def smap():
return sunpy.map.Map(sunpy.data.sample.AIA_171_IMAGE)


@figure_test
@pytest.mark.remote_data
def test_mgn(smap):
out = enhance.mgn(smap.data)
out = sunpy.map.Map(out, smap.meta)

out.plot()
def test_mgn(aia_171):
out = enhance.mgn(aia_171)
assert type(out) == type(aia_171)
if isinstance(out, sunpy.map.GenericMap):
out.plot()


@pytest.fixture
Expand Down
61 changes: 43 additions & 18 deletions sunkit_image/tests/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import sunpy.map
from astropy.io import fits

import sunkit_image.data.test as data
Expand All @@ -19,13 +20,20 @@
)


@pytest.fixture
@pytest.fixture(params=["array", "map"])
@pytest.mark.remote_data
def image_remote():
def image_remote(request):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=fits.verify.VerifyWarning)
im = fits.getdata("http://data.sunpy.org/sunkit-image/trace_1998-05-19T22:21:43.000_171_1024.fits")
return im
data, header = fits.getdata(
"http://data.sunpy.org/sunkit-image/trace_1998-05-19T22:21:43.000_171_1024.fits", header=True
)
if request.param == "map":
return sunpy.map.Map((data, header))
elif request.param == "array":
return data
else:
raise ValueError(f"Invalid request parameter {request.param}")

Check warning on line 36 in sunkit_image/tests/test_trace.py

View check run for this annotation

Codecov / codecov/patch

sunkit_image/tests/test_trace.py#L36

Added line #L36 was not covered by tests


@pytest.fixture
Expand Down Expand Up @@ -161,16 +169,18 @@


@pytest.fixture
def image():
def test_map_ones():
return np.ones((4, 4), dtype=np.float32)


def test_bandpass_filter(image, test_map):
def test_bandpass_filter_ones(test_map_ones):
expect = np.zeros((4, 4))
result = bandpass_filter(image)
result = bandpass_filter(test_map_ones)

assert np.allclose(expect, result)


def test_bandpass_filter(test_map):
expect = np.array(
[
[0.0, 0.0, 0.0, 0.0],
Expand All @@ -181,22 +191,30 @@
)

result = bandpass_filter(test_map)

assert np.allclose(expect, result)

with pytest.raises(ValueError) as record:
_ = bandpass_filter(image, 5, 1)

assert str(record.value) == "nsm1 should be less than nsm2"
@pytest.mark.remote_data
def test_bandpass_filter_output(aia_171):
# Check that bandpass filter works with both arrays and maps
result = bandpass_filter(aia_171)
assert type(result) == type(aia_171)


def test_bandpass_filter_error(test_map_ones):
with pytest.raises(ValueError, match="nsm1 should be less than nsm2"):
bandpass_filter(test_map_ones, 5, 1)


def test_smooth_ones(test_map_ones):
filtered = smooth(test_map_ones, 1)
assert np.allclose(filtered, test_map_ones)

def test_smooth(image, test_map):
filtered = smooth(image, 1)
assert np.allclose(filtered, image)
filtered = smooth(test_map_ones, 4)
assert np.allclose(filtered, test_map_ones)

filtered = smooth(image, 4)
assert np.allclose(filtered, image)

def test_smooth(test_map):
filtered = smooth(test_map, 1)
assert np.allclose(filtered, test_map)

Expand All @@ -213,7 +231,14 @@
assert np.allclose(filtered, expect)


def test_erase_loop_in_image(image, test_map):
@pytest.mark.remote_data
def test_smooth_output(aia_171):
# Check that smooth works with both arrays and maps
result = smooth(aia_171, 1)
assert type(result) == type(aia_171)


def test_erase_loop_in_image(test_map_ones, test_map):
# The starting point of a dummy loop
istart = 0
jstart = 1
Expand All @@ -223,7 +248,7 @@
xloop = [1, 2, 3]
yloop = [1, 1, 1]

result = _erase_loop_in_image(image, istart, jstart, width, xloop, yloop)
result = _erase_loop_in_image(test_map_ones, istart, jstart, width, xloop, yloop)

expect = np.array([[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]])

Expand Down
21 changes: 14 additions & 7 deletions sunkit_image/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@
import numpy as np
from scipy import interpolate

from sunkit_image.utils.decorators import accept_array_or_map

__all__ = [
"occult2",
"bandpass_filter",
"smooth",
]


@accept_array_or_map(arg_name="image", output_to_map=False)
def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2):
"""
Implements the Oriented Coronal CUrved Loop Tracing (OCCULT-2) algorithm
for loop tracing in images.

Parameters
----------
image : `numpy.ndarray`
image : `numpy.ndarray`, `sunpy.map.GenericMap`
Image in which loops are to be detected.
nsm1 : `int`
Low pass filter boxcar smoothing constant.
Expand All @@ -44,7 +47,7 @@ def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2):
-------
`list`
A list of all loop where each element is itself a list of points containing
``x`` and ``y`` coordinates for each point.
``x`` and ``y`` pixel coordinates for each point.

References
----------
Expand Down Expand Up @@ -217,13 +220,14 @@ def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2):


# The functions below this are subroutines for the OCCULT 2.
@accept_array_or_map(arg_name="image")
def bandpass_filter(image, nsm1=1, nsm2=3):
"""
Applies a band pass filter to the image.

Parameters
----------
image : `numpy.ndarray`
image : `numpy.ndarray`, `sunpy.map.GenericMap`
Image to be filtered.
nsm1 : `int`
Low pass filter boxcar smoothing constant.
Expand All @@ -236,7 +240,8 @@ def bandpass_filter(image, nsm1=1, nsm2=3):
Returns
-------
`numpy.ndarray`
Bandpass filtered image.
Bandpass filtered image. If a map is input, a map is returned with new data
and the same metadata.
"""
if nsm1 >= nsm2:
raise ValueError("nsm1 should be less than nsm2")
Expand All @@ -248,13 +253,14 @@ def bandpass_filter(image, nsm1=1, nsm2=3):
return smooth(image, nsm1, "replace") - smooth(image, nsm2, "replace")


@accept_array_or_map(arg_name="image")
def smooth(image, width, nanopt="replace"):
"""
Python implementation of the IDL's ``smooth``.

Parameters
----------
image : `numpy.ndarray`
image : `numpy.ndarray`, `sunpy.map.GenericMap`
Image to be filtered.
width : `int`
Width of the boxcar window. The `width` should always be odd but if even value is given then
Expand All @@ -264,8 +270,9 @@ def smooth(image, width, nanopt="replace"):

Returns
-------
`numpy.ndarray`
Smoothed image.
`numpy.ndarray`, `sunpy.map.GenericMap`
Smoothed image. If a map is input, a map is returned with new data
and the same metadata.

References
----------
Expand Down
56 changes: 56 additions & 0 deletions sunkit_image/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import inspect
from typing import Union, Callable

import numpy as np

from sunpy.map import GenericMap, Map


def accept_array_or_map(*, arg_name: str, output_to_map=True) -> Callable[[Callable], Callable]:
"""
Decorator that allows a function to accept an array or a
`sunpy.map.GenericMap` as an argument.

This can be applied to functions that:

- Take a single array or map as input
- Return a single array that has the same pixel coordinates
as the input array.

Parameters
----------
arg_name : `str`
Name of data/map argument in function signature.
output_to_map : `bool`
If `True` (the default), convert the function return to a map if a map
is given as input. For this to work the decorated function must return
an array where pixels have the same coordinates as the input map data.
"""

def decorate(f: Callable) -> Callable:
sig = inspect.signature(f)
if arg_name not in sig.parameters:
raise RuntimeError(f"Could not find '{arg_name}' in function signature")

Check warning on line 33 in sunkit_image/utils/decorators.py

View check run for this annotation

Codecov / codecov/patch

sunkit_image/utils/decorators.py#L33

Added line #L33 was not covered by tests

def inner(*args, **kwargs) -> Union[np.ndarray, GenericMap]:
sig_bound = sig.bind(*args, **kwargs)
map_arg = sig_bound.arguments[arg_name]
if isinstance(map_arg, GenericMap):
map_in = True
sig_bound.arguments[arg_name] = map_arg.data
elif isinstance(map_arg, np.ndarray):
map_in = False
else:
raise ValueError(f"'{arg_name}' argument must be a sunpy map or numpy array (got type {type(map_arg)})")

Check warning on line 44 in sunkit_image/utils/decorators.py

View check run for this annotation

Codecov / codecov/patch

sunkit_image/utils/decorators.py#L44

Added line #L44 was not covered by tests

# Run decorated function
array_out = f(*sig_bound.args, **sig_bound.kwargs)

if map_in and output_to_map:
return Map(array_out, map_arg.meta)
else:
return array_out

return inner

return decorate