Skip to content

Commit

Permalink
Accept either arrays or maps as input where appropriate (#135)
Browse files Browse the repository at this point in the history
Co-authored-by: Nabil Freij <[email protected]>
  • Loading branch information
dstansby and nabobalis authored Aug 9, 2023
1 parent b053b07 commit 26c7ef7
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 44 deletions.
10 changes: 10 additions & 0 deletions changelog/135.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Several functions have been updated to accept either numpy array or sunpy map inputs.
The following functions now accept either a numpy array or sunpy map, and return the same data type:

- `sunkit_image.enhance.mgn`
- `sunkit_image.trace.bandpass_filter`
- `sunkit_image.trace.smooth`

The following functions now accept either a numpy array or sunpy map, and their return type is unchanged:

- `sunkit_image.trace.occult2`
10 changes: 7 additions & 3 deletions sunkit_image/enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import numpy as np
import scipy.ndimage as ndimage

from sunkit_image.utils.decorators import accept_array_or_map

__all__ = ["mgn"]


@accept_array_or_map(arg_name="data")
def mgn(
data,
sigma=[1.25, 2.5, 5, 10, 20, 40],
Expand Down Expand Up @@ -41,7 +44,7 @@ def mgn(
Parameters
----------
data : `numpy.ndarray`
data : `numpy.ndarray`, `sunpy.map.GenericMap`
Image to be transformed.
sigma : `list`, optional
Range of Gaussian widths (i.e. the standard deviation of the Gaussian kernel) to transform over.
Expand Down Expand Up @@ -72,8 +75,9 @@ def mgn(
Returns
-------
`numpy.ndarray`
Normalized image.
`numpy.ndarray`, `sunpy.map.GenericMap`
Normalized image. If a map is input, a map is returned with new data
and the same metadata.
References
----------
Expand Down
11 changes: 11 additions & 0 deletions sunkit_image/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import astropy
import astropy.config.paths
import sunpy.data.sample
import sunpy.map
from astropy.utils.data import get_pkg_data_filename
from sunpy.coordinates import Helioprojective, get_earth
Expand Down Expand Up @@ -155,3 +156,13 @@ def granule_minimap3():
)
map = sunpy.map.GenericMap(arr, header)
return map


@pytest.fixture(params=["array", "map"])
@pytest.mark.remote_data
def aia_171(request):
smap = sunpy.map.Map(sunpy.data.sample.AIA_171_IMAGE)
if request.param == "map":
return smap
elif request.param == "array":
return smap.data
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{
"sunkit_image.tests.test_enhance.test_mgn": "4ef8a14c6b3290280b6fc9bc0748c8aa170760b2d43a5bbd67bb4d46ffd5408e",
"sunkit_image.tests.test_radial.test_fig_nrgf": "52b448fa9d845841066d71a1fc6e7bfc471a658d726af1cf16403ddb858fcfc0",
"sunkit_image.tests.test_radial.test_fig_fnrgf": "62621426ff8c3ee03cb1de004f9e08f63c4869dbd8b55d440adc73c74bfc5bcc",
"sunkit_image.tests.test_trace.test_occult2_fig": "c5ff1416cb51dc51ad1705174e5b13f8cb02486dc3c58fd79dc1b80534f6caf5"
}
"sunkit_image.tests.test_enhance.test_mgn[array]": "67b17ac7c343a1d568d9dfa2bb12e0041e2fc53e8ca1c0e1c2c7693eda2bd0d8",
"sunkit_image.tests.test_enhance.test_mgn[map]": "4ef8a14c6b3290280b6fc9bc0748c8aa170760b2d43a5bbd67bb4d46ffd5408e",
"sunkit_image.tests.test_radial.test_fig_nrgf": "52b448fa9d845841066d71a1fc6e7bfc471a658d726af1cf16403ddb858fcfc0",
"sunkit_image.tests.test_radial.test_fig_fnrgf": "62621426ff8c3ee03cb1de004f9e08f63c4869dbd8b55d440adc73c74bfc5bcc",
"sunkit_image.tests.test_trace.test_occult2_fig[array]": "c5ff1416cb51dc51ad1705174e5b13f8cb02486dc3c58fd79dc1b80534f6caf5",
"sunkit_image.tests.test_trace.test_occult2_fig[map]": "c5ff1416cb51dc51ad1705174e5b13f8cb02486dc3c58fd79dc1b80534f6caf5"
}
16 changes: 5 additions & 11 deletions sunkit_image/tests/test_enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@
from sunkit_image.tests.helpers import figure_test


@pytest.fixture
@pytest.mark.remote_data
def smap():
return sunpy.map.Map(sunpy.data.sample.AIA_171_IMAGE)


@figure_test
@pytest.mark.remote_data
def test_mgn(smap):
out = enhance.mgn(smap.data)
out = sunpy.map.Map(out, smap.meta)

out.plot()
def test_mgn(aia_171):
out = enhance.mgn(aia_171)
assert type(out) == type(aia_171)
if isinstance(out, sunpy.map.GenericMap):
out.plot()


@pytest.fixture
Expand Down
61 changes: 43 additions & 18 deletions sunkit_image/tests/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import sunpy.map
from astropy.io import fits

import sunkit_image.data.test as data
Expand All @@ -19,13 +20,20 @@
)


@pytest.fixture
@pytest.fixture(params=["array", "map"])
@pytest.mark.remote_data
def image_remote():
def image_remote(request):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=fits.verify.VerifyWarning)
im = fits.getdata("http://data.sunpy.org/sunkit-image/trace_1998-05-19T22:21:43.000_171_1024.fits")
return im
data, header = fits.getdata(
"http://data.sunpy.org/sunkit-image/trace_1998-05-19T22:21:43.000_171_1024.fits", header=True
)
if request.param == "map":
return sunpy.map.Map((data, header))
elif request.param == "array":
return data
else:
raise ValueError(f"Invalid request parameter {request.param}")


@pytest.fixture
Expand Down Expand Up @@ -161,16 +169,18 @@ def test_map():


@pytest.fixture
def image():
def test_map_ones():
return np.ones((4, 4), dtype=np.float32)


def test_bandpass_filter(image, test_map):
def test_bandpass_filter_ones(test_map_ones):
expect = np.zeros((4, 4))
result = bandpass_filter(image)
result = bandpass_filter(test_map_ones)

assert np.allclose(expect, result)


def test_bandpass_filter(test_map):
expect = np.array(
[
[0.0, 0.0, 0.0, 0.0],
Expand All @@ -181,22 +191,30 @@ def test_bandpass_filter(image, test_map):
)

result = bandpass_filter(test_map)

assert np.allclose(expect, result)

with pytest.raises(ValueError) as record:
_ = bandpass_filter(image, 5, 1)

assert str(record.value) == "nsm1 should be less than nsm2"
@pytest.mark.remote_data
def test_bandpass_filter_output(aia_171):
# Check that bandpass filter works with both arrays and maps
result = bandpass_filter(aia_171)
assert type(result) == type(aia_171)


def test_bandpass_filter_error(test_map_ones):
with pytest.raises(ValueError, match="nsm1 should be less than nsm2"):
bandpass_filter(test_map_ones, 5, 1)


def test_smooth_ones(test_map_ones):
filtered = smooth(test_map_ones, 1)
assert np.allclose(filtered, test_map_ones)

def test_smooth(image, test_map):
filtered = smooth(image, 1)
assert np.allclose(filtered, image)
filtered = smooth(test_map_ones, 4)
assert np.allclose(filtered, test_map_ones)

filtered = smooth(image, 4)
assert np.allclose(filtered, image)

def test_smooth(test_map):
filtered = smooth(test_map, 1)
assert np.allclose(filtered, test_map)

Expand All @@ -213,7 +231,14 @@ def test_smooth(image, test_map):
assert np.allclose(filtered, expect)


def test_erase_loop_in_image(image, test_map):
@pytest.mark.remote_data
def test_smooth_output(aia_171):
# Check that smooth works with both arrays and maps
result = smooth(aia_171, 1)
assert type(result) == type(aia_171)


def test_erase_loop_in_image(test_map_ones, test_map):
# The starting point of a dummy loop
istart = 0
jstart = 1
Expand All @@ -223,7 +248,7 @@ def test_erase_loop_in_image(image, test_map):
xloop = [1, 2, 3]
yloop = [1, 1, 1]

result = _erase_loop_in_image(image, istart, jstart, width, xloop, yloop)
result = _erase_loop_in_image(test_map_ones, istart, jstart, width, xloop, yloop)

expect = np.array([[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]])

Expand Down
21 changes: 14 additions & 7 deletions sunkit_image/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@
import numpy as np
from scipy import interpolate

from sunkit_image.utils.decorators import accept_array_or_map

__all__ = [
"occult2",
"bandpass_filter",
"smooth",
]


@accept_array_or_map(arg_name="image", output_to_map=False)
def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2):
"""
Implements the Oriented Coronal CUrved Loop Tracing (OCCULT-2) algorithm
for loop tracing in images.
Parameters
----------
image : `numpy.ndarray`
image : `numpy.ndarray`, `sunpy.map.GenericMap`
Image in which loops are to be detected.
nsm1 : `int`
Low pass filter boxcar smoothing constant.
Expand All @@ -44,7 +47,7 @@ def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2):
-------
`list`
A list of all loop where each element is itself a list of points containing
``x`` and ``y`` coordinates for each point.
``x`` and ``y`` pixel coordinates for each point.
References
----------
Expand Down Expand Up @@ -217,13 +220,14 @@ def occult2(image, nsm1, rmin, lmin, nstruc, ngap, qthresh1, qthresh2):


# The functions below this are subroutines for the OCCULT 2.
@accept_array_or_map(arg_name="image")
def bandpass_filter(image, nsm1=1, nsm2=3):
"""
Applies a band pass filter to the image.
Parameters
----------
image : `numpy.ndarray`
image : `numpy.ndarray`, `sunpy.map.GenericMap`
Image to be filtered.
nsm1 : `int`
Low pass filter boxcar smoothing constant.
Expand All @@ -236,7 +240,8 @@ def bandpass_filter(image, nsm1=1, nsm2=3):
Returns
-------
`numpy.ndarray`
Bandpass filtered image.
Bandpass filtered image. If a map is input, a map is returned with new data
and the same metadata.
"""
if nsm1 >= nsm2:
raise ValueError("nsm1 should be less than nsm2")
Expand All @@ -248,13 +253,14 @@ def bandpass_filter(image, nsm1=1, nsm2=3):
return smooth(image, nsm1, "replace") - smooth(image, nsm2, "replace")


@accept_array_or_map(arg_name="image")
def smooth(image, width, nanopt="replace"):
"""
Python implementation of the IDL's ``smooth``.
Parameters
----------
image : `numpy.ndarray`
image : `numpy.ndarray`, `sunpy.map.GenericMap`
Image to be filtered.
width : `int`
Width of the boxcar window. The `width` should always be odd but if even value is given then
Expand All @@ -264,8 +270,9 @@ def smooth(image, width, nanopt="replace"):
Returns
-------
`numpy.ndarray`
Smoothed image.
`numpy.ndarray`, `sunpy.map.GenericMap`
Smoothed image. If a map is input, a map is returned with new data
and the same metadata.
References
----------
Expand Down
56 changes: 56 additions & 0 deletions sunkit_image/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import inspect
from typing import Union, Callable

import numpy as np

from sunpy.map import GenericMap, Map


def accept_array_or_map(*, arg_name: str, output_to_map=True) -> Callable[[Callable], Callable]:
"""
Decorator that allows a function to accept an array or a
`sunpy.map.GenericMap` as an argument.
This can be applied to functions that:
- Take a single array or map as input
- Return a single array that has the same pixel coordinates
as the input array.
Parameters
----------
arg_name : `str`
Name of data/map argument in function signature.
output_to_map : `bool`
If `True` (the default), convert the function return to a map if a map
is given as input. For this to work the decorated function must return
an array where pixels have the same coordinates as the input map data.
"""

def decorate(f: Callable) -> Callable:
sig = inspect.signature(f)
if arg_name not in sig.parameters:
raise RuntimeError(f"Could not find '{arg_name}' in function signature")

def inner(*args, **kwargs) -> Union[np.ndarray, GenericMap]:
sig_bound = sig.bind(*args, **kwargs)
map_arg = sig_bound.arguments[arg_name]
if isinstance(map_arg, GenericMap):
map_in = True
sig_bound.arguments[arg_name] = map_arg.data
elif isinstance(map_arg, np.ndarray):
map_in = False
else:
raise ValueError(f"'{arg_name}' argument must be a sunpy map or numpy array (got type {type(map_arg)})")

# Run decorated function
array_out = f(*sig_bound.args, **sig_bound.kwargs)

if map_in and output_to_map:
return Map(array_out, map_arg.meta)
else:
return array_out

return inner

return decorate

0 comments on commit 26c7ef7

Please sign in to comment.