diff --git a/Makefile b/Makefile index e6c0a97..6475b0d 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ LINT_TARGET_DIRS := PyEMD doc example init: python -m venv .venv .venv/bin/pip install -r requirements.txt + .venv/bin/pip install -e .[dev] @echo "Run 'source .venv/bin/activate' to activate the virtual environment" test: @@ -17,6 +18,7 @@ doc: format: python -m black $(LINT_TARGET_DIRS) + python -m isort PyEMD lint-check: python -m isort --check PyEMD diff --git a/PyEMD/EMD.py b/PyEMD/EMD.py index d52305d..1d7e58f 100644 --- a/PyEMD/EMD.py +++ b/PyEMD/EMD.py @@ -13,7 +13,7 @@ from scipy.interpolate import interp1d from PyEMD.splines import akima, cubic, cubic_hermite, cubic_spline_3pts, pchip -from PyEMD.utils import get_timeline, deduce_common_type +from PyEMD.utils import deduce_common_type, get_timeline FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] @@ -199,14 +199,14 @@ def prepare_points( Position (1st row) and values (2nd row) of maxima. """ if self.extrema_detection == "parabol": - return self._prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val) + return self.prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val) elif self.extrema_detection == "simple": - return self._prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val) + return self.prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val) else: msg = "Incorrect extrema detection type. Please try: 'simple' or 'parabol'." raise ValueError(msg) - def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]: + def prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]: """ Performs mirroring on signal which extrema do not necessarily belong on the position array. @@ -324,7 +324,7 @@ def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> T return max_extrema, min_extrema - def _prepare_points_simple( + def prepare_points_simple( self, T: np.ndarray, S: np.ndarray, diff --git a/PyEMD/tests/test_utils.py b/PyEMD/tests/test_utils.py index f90a56b..e766ff2 100644 --- a/PyEMD/tests/test_utils.py +++ b/PyEMD/tests/test_utils.py @@ -2,7 +2,7 @@ import numpy as np -from PyEMD.utils import get_timeline, deduce_common_type +from PyEMD.utils import deduce_common_type, get_timeline class MyTestCase(unittest.TestCase): @@ -37,5 +37,6 @@ def test_deduce_common_types(self): self.assertEqual(deduce_common_type(np.int32, np.int32), np.int32) self.assertEqual(deduce_common_type(np.float32, np.float64), np.float64) + if __name__ == "__main__": unittest.main() diff --git a/PyEMD/tests/test_visualization.py b/PyEMD/tests/test_visualization.py index 882e6b6..982f9b3 100644 --- a/PyEMD/tests/test_visualization.py +++ b/PyEMD/tests/test_visualization.py @@ -19,8 +19,8 @@ def test_instantiation2(self): emd.emd(S, t) imfs, res = emd.get_imfs_and_residue() vis = Visualisation(emd) - self.assertTrue(np.alltrue(vis.imfs == imfs)) - self.assertTrue(np.alltrue(vis.residue == res)) + self.assertTrue(np.all(vis.imfs == imfs)) + self.assertTrue(np.all(vis.residue == res)) def test_check_imfs(self): vis = Visualisation() @@ -40,7 +40,7 @@ def test_check_imfs3(self): out_imfs, out_res = vis._check_imfs(imfs, None, False) - self.assertTrue(np.alltrue(imfs == out_imfs)) + self.assertTrue(np.all(imfs == out_imfs)) self.assertIsNone(out_res) def test_check_imfs4(self): @@ -57,8 +57,8 @@ def test_check_imfs5(self): imfs, res = emd.get_imfs_and_residue() vis = Visualisation(emd) imfs2, res2 = vis._check_imfs(imfs, res, False) - self.assertTrue(np.alltrue(imfs == imfs2)) - self.assertTrue(np.alltrue(res == res2)) + self.assertTrue(np.all(imfs == imfs2)) + self.assertTrue(np.all(res == res2)) def test_plot_imfs(self): vis = Visualisation() diff --git a/PyEMD/utils.py b/PyEMD/utils.py index 5dc9657..9dd71be 100644 --- a/PyEMD/utils.py +++ b/PyEMD/utils.py @@ -1,5 +1,5 @@ -from typing import Optional from functools import cache +from typing import Optional import numpy as np @@ -52,13 +52,13 @@ def smallest_inclusive_dtype(ref_dtype: np.dtype, ref_value) -> np.dtype: raise ValueError("Unsupported dtype '{}'. Only intX and floatX are supported.".format(ref_dtype)) + @cache def deduce_common_type(xtype: np.dtype, ytype: np.dtype) -> np.dtype: if xtype == ytype: return xtype - if np.version.version[0] == '1': + if np.version.version[0] == "1": dtype = np.find_common_type([xtype, ytype], []) else: dtype = np.promote_types(xtype, ytype) return dtype - \ No newline at end of file