Skip to content

Commit

Permalink
fix: typo in EMD_matlab
Browse files Browse the repository at this point in the history
  • Loading branch information
laszukdawid committed Sep 11, 2024
1 parent 9ca0bae commit 9372aa8
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 7 deletions.
30 changes: 24 additions & 6 deletions PyEMD/EMD_matlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def extractMaxMinSpline(self, T, S):
return [-1] * 4

# Extrapolation of signal (ober boundaries)
maxExtrema, minExtrema = self.preparePoints(S, T, maxPos, maxVal, minPos, minVal)
maxExtrema, minExtrema = self.preparePoints(
S, T, maxPos, maxVal, minPos, minVal
)

_, maxSpline = self.splinePoints(T, maxExtrema, self.splineKind)
_, minSpline = self.splinePoints(T, minExtrema, self.splineKind)
Expand Down Expand Up @@ -218,8 +220,12 @@ def preparePoints(self, S, T, maxPos, maxVal, minPos, minVal):
minExtrema = np.array([tmin, zmin], dtype=self.DTYPE)

# Make double sure, that each extremum is significant
maxExtrema = np.delete(maxExtrema, np.where(maxExtrema[0, 1:] == maxExtrema[0, :-1]), axis=1)
minExtrema = np.delete(minExtrema, np.where(minExtrema[0, 1:] == minExtrema[0, :-1]), axis=1)
maxExtrema = np.delete(
maxExtrema, np.where(maxExtrema[0, 1:] == maxExtrema[0, :-1]), axis=1
)
minExtrema = np.delete(
minExtrema, np.where(minExtrema[0, 1:] == minExtrema[0, :-1]), axis=1
)

return maxExtrema, minExtrema

Expand Down Expand Up @@ -251,7 +257,9 @@ def splinePoints(self, T, extrema, splineKind):

elif kind == "cubic":
if extrema.shape[1] > 3:
return t, interp1d(extrema[0], extrema[1], kind=kind)(t).astype(self.DTYPE)
return t, interp1d(extrema[0], extrema[1], kind=kind)(t).astype(
self.DTYPE
)
else:
return self.cubicSpline_3points(T, extrema)

Expand Down Expand Up @@ -435,8 +443,18 @@ def emd(self, S, T=None, maxImf=None):
The decomposition is limited to maxImf imf. No limitation as default.
Returns IMF functions in dic format. IMF = {0:imf0, 1:imf1...}.
*Note*: First argument `self` should be an instance of EMD class.
It should be resolved in future versions.
For example:
```
emd = EMD()
emd.emd(emd, S, T, maxImf)
```
Input:
---------
self: Instance of EMD class.
S: Signal.
T: Positions of signal. If none passed numpy arange is created.
maxImf: IMF number to which decomposition should be performed.
Expand All @@ -457,7 +475,7 @@ def emd(self, S, T=None, maxImf=None):
maxImf = -1

# Make sure same types are dealt
S, T = unify_type(S, T)
S, T = unify_types(S, T)
self.DTYPE = S.dtype

Res = S.astype(self.DTYPE)
Expand All @@ -479,7 +497,7 @@ def emd(self, S, T=None, maxImf=None):

if S.shape != T.shape:
info = "Time array should be the same size as signal."
raise Exception(info)
raise ValueError(info)

# Create arrays
IMF = {} # Dic for imfs signals
Expand Down
2 changes: 1 addition & 1 deletion PyEMD/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

__version__ = "1.6.3"
__version__ = "1.6.4"
logger = logging.getLogger("pyemd")

from PyEMD.CEEMDAN import CEEMDAN # noqa
Expand Down
49 changes: 49 additions & 0 deletions PyEMD/tests/test_emd_matlab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest

import numpy as np

from PyEMD.EMD_matlab import EMD


class EMDMatlabTest(unittest.TestCase):
@staticmethod
def test_default_call_EMD():
T = np.arange(0, 1, 0.01)
S = np.cos(2 * T * 2 * np.pi)
max_imf = 2

emd = EMD()
emd.emd(emd, S, T, max_imf)

def test_different_length_input(self):
T = np.arange(20)
S = np.random.random(len(T) + 7)

emd = EMD()
with self.assertRaises(ValueError):
emd.emd(emd, S, T)

def test_trend(self):
"""
Input is trend. Expeting no shifting process.
"""
emd = EMD()

T = np.arange(0, 1, 0.01)
S = np.cos(2 * T * 2 * np.pi)

# Input - linear function f(t) = 2*t
output = emd.emd(emd, S, T)
self.assertEqual(len(output), 4, "Expecting 4 outputs - IMF, EXT, ITER, imfNo")

IMF, EXT, ITER, imfNo = output
self.assertEqual(len(IMF), 2, "Expecting single IMF + residue")
self.assertEqual(len(IMF[0]), len(S), "Expecting single IMF")
self.assertTrue(np.allclose(S, IMF[0]))
self.assertLessEqual(ITER[0], 5, "Expecting 5 iterations at most")
self.assertEqual(imfNo, 2, "Expecting 1 IMF")
self.assertEqual(EXT[0], 3, "Expecting single EXT")


if __name__ == "__main__":
unittest.main()

0 comments on commit 9372aa8

Please sign in to comment.