Skip to content

Commit

Permalink
Merge pull request #94 from kthyng/more_filters
Browse files Browse the repository at this point in the history
fixes and improvements for tidal filter
  • Loading branch information
kthyng authored Sep 13, 2023
2 parents 1805869 + a142637 commit f56a02c
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions oceans/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,7 @@ def medfilt1(x, L=3):
>>> L = 103
>>> xout = medfilt1(x=x, L=L)
>>> ax = plt.subplot(212)
>>> (
... l1,
... l2,
... ) = ax.plot(
>>> (l1, l2,) = ax.plot(
... x
... ), ax.plot(xout)
>>> ax.grid(True)
Expand Down Expand Up @@ -570,7 +567,7 @@ def md_trenberth(x):
return y


def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
def pl33tn(x, dt=1.0, T=33.0, mode="valid", t=None):
"""
Computes low-passed series from `x` using pl33 filter, with optional
sample interval `dt` (hours) and filter half-amplitude period T (hours)
Expand Down Expand Up @@ -608,14 +605,25 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
"""

import cf_xarray # noqa: F401
import pandas as pd
import xarray as xr

if isinstance(x, xr.Dataset):
raise TypeError("Input a DataArray not a Dataset.")
if isinstance(x, (xr.Dataset, pd.DataFrame)):
raise TypeError("Input a DataArray not a Dataset, or a Series not a DataFrame.")

if isinstance(x, pd.Series) and not isinstance(
x.index,
pd.core.indexes.datetimes.DatetimeIndex,
):
raise TypeError("Input Series needs to have parsed datetime indices.")

# find dt in units of hours
if isinstance(x, xr.DataArray):
# find dt in units of hours
dt = (x.cf["T"][1] - x.cf["T"][0]) * 1e-9 / 3600
dt = (x.cf["T"][1] - x.cf["T"][0]) / np.timedelta64(
360_000_000_000,
)
elif isinstance(x, pd.Series):
dt = (x.index[1] - x.index[0]) / pd.Timedelta("1H")

pl33 = np.array(
[
Expand Down Expand Up @@ -694,18 +702,20 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
dt = float(dt) * (33.0 / T)

filter_time = np.arange(0.0, 33.0, dt, dtype="d")
# N = len(filter_time)
Nt = len(filter_time)
filter_time = np.hstack((-filter_time[-1:0:-1], filter_time))

pl33 = np.interp(filter_time, _dt, pl33)
pl33 /= pl33.sum()

if isinstance(x, xr.DataArray):
x = x.interpolate_na(dim=x.cf["T"].name)

weight = xr.DataArray(pl33, dims=["window"])
xf = (
x.rolling({x.cf["T"].name: len(pl33)}, center=True)
.construct({x.cf["T"].name: "window"})
.dot(weight)
.dot(weight, dims="window")
)
# update attrs
attrs = {
Expand All @@ -715,7 +725,26 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
}
xf.attrs = attrs

elif isinstance(x, pd.Series):
xf = x.to_frame().apply(np.convolve, v=pl33, mode=mode)

# nan out edges which are not good values anyway
if mode == "same":
xf[: Nt - 1] = np.nan
xf[-Nt + 2 :] = np.nan

else: # use numpy
xf = np.convolve(x, pl33, mode=mode)

# times to match xf
if t is not None:
# Nt = len(filter_time)
tf = t[Nt - 1 : -Nt + 1]
return xf, tf

# nan out edges which are not good values anyway
if mode == "same":
xf[: Nt - 1] = np.nan
xf[-Nt + 2 :] = np.nan

return xf

0 comments on commit f56a02c

Please sign in to comment.