Skip to content

Commit

Permalink
[Fix] Squeeze y_values in signal_interpolate
Browse files Browse the repository at this point in the history
  • Loading branch information
jankwodnicki committed Dec 3, 2024
1 parent 874260b commit 7c8c895
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
4 changes: 3 additions & 1 deletion neurokit2/signal/signal_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def signal_interpolate(
x_values = np.squeeze(x_values.values)
if isinstance(x_new, pd.Series):
x_new = np.squeeze(x_new.values)
if isinstance(y_values, pd.Series):
y_values = np.squeeze(y_values.values)

if len(x_values) != len(y_values):
raise ValueError("x_values and y_values must be of the same length.")
Expand Down Expand Up @@ -158,7 +160,7 @@ def signal_interpolate(
# scipy.interpolate.PchipInterpolator for constant extrapolation akin to the behavior of
# scipy.interpolate.interp1d with fill_value=([y_values[0]], [y_values[-1]].
fill_value = ([interpolated[first_index]], [interpolated[last_index]])
elif isinstance(fill_value, float) or isinstance(fill_value, int):
elif isinstance(fill_value, (float, int)):
# if only a single integer or float is provided as a fill value, format as a tuple
fill_value = ([fill_value], [fill_value])

Expand Down
13 changes: 12 additions & 1 deletion tests/tests_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,25 @@ def test_signal_filter_with_missing():

def test_signal_interpolate():

# Test with arrays
x_axis = np.linspace(start=10, stop=30, num=10)
signal = np.cos(x_axis)
x_new = np.arange(1000)

interpolated = nk.signal_interpolate(x_axis, signal, x_new=np.arange(1000))
interpolated = nk.signal_interpolate(x_axis, signal, x_new)
assert len(interpolated) == 1000
assert interpolated[0] == signal[0]
assert interpolated[-1] == signal[-1]

# Test with Series
x_axis = pd.Series(x_axis)
signal = pd.Series(signal)
x_new = pd.Series(x_new)

interpolated = nk.signal_interpolate(x_axis, signal, x_new)
assert len(interpolated) == 1000
assert interpolated[0] == signal.iloc[0]
assert interpolated[-1] == signal.iloc[-1]

def test_signal_findpeaks():

Expand Down

0 comments on commit 7c8c895

Please sign in to comment.