diff --git a/PyEMD/EMD_matlab.py b/PyEMD/EMD_matlab.py index eaf3dac..21907f4 100644 --- a/PyEMD/EMD_matlab.py +++ b/PyEMD/EMD_matlab.py @@ -15,6 +15,7 @@ from scipy.interpolate import interp1d from PyEMD.splines import akima +from PyEMD.utils import deduce_common_type class EMD: @@ -429,7 +430,7 @@ def stop_sifting(self, imf, envMax, envMin, mean, extNo): @staticmethod def _common_dtype(x, y): - dtype = np.find_common_type([x.dtype, y.dtype], []) + dtype = deduce_common_type([x.dtype, y.dtype], []) if x.dtype != dtype: x = x.astype(dtype) if y.dtype != dtype: diff --git a/PyEMD/experimental/jitemd.py b/PyEMD/experimental/jitemd.py index c84d569..242145c 100644 --- a/PyEMD/experimental/jitemd.py +++ b/PyEMD/experimental/jitemd.py @@ -30,6 +30,8 @@ from numba.types import float64, int64, unicode_type from scipy.interpolate import Akima1DInterpolator, interp1d +from PyEMD.utils import deduce_common_type + FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] @@ -758,17 +760,6 @@ def check_imf( return False -# @nb.jit -def _common_dtype(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Casts inputs (x, y) into a common numpy DTYPE.""" - dtype = np.find_common_type([x.dtype, y.dtype], []) - if x.dtype != dtype: - x = x.astype(dtype) - if y.dtype != dtype: - y = y.astype(dtype) - return x, y - - @nb.jit def _normalize_time(t: np.ndarray) -> np.ndarray: """ @@ -811,7 +802,7 @@ def emd( # T = _normalize_time(T) # Make sure same types are dealt - # S, T = _common_dtype(S, T) + # S, T = deduce_common_Types(S, T) MAX_ITERATION = config["MAX_ITERATION"] FIXE = config["FIXE"] FIXE_H = config["FIXE_H"]