From c4676951be660fbcb2ea0ba4a66a39f57aa7567a Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Sat, 21 Dec 2024 15:53:29 +0000 Subject: [PATCH] Add `fit` method to `Transform` trait And provide `fit_transform` methods on `Transform` and `Pipeline`. This makes the API more similar to scikit-learn. It also makes the fitting of a transformation more explicit, as it was previously happening implicitly when the `transform` method was first called. --- crates/augurs-forecaster/src/forecaster.rs | 2 +- crates/augurs-forecaster/src/transforms.rs | 90 +++++++++++++++---- .../augurs-forecaster/src/transforms/exp.rs | 12 ++- .../src/transforms/interpolate.rs | 6 +- .../augurs-forecaster/src/transforms/power.rs | 22 ++++- .../augurs-forecaster/src/transforms/scale.rs | 60 +++++++------ 6 files changed, 138 insertions(+), 54 deletions(-) diff --git a/crates/augurs-forecaster/src/forecaster.rs b/crates/augurs-forecaster/src/forecaster.rs index 1cecfae..a3cf8ba 100644 --- a/crates/augurs-forecaster/src/forecaster.rs +++ b/crates/augurs-forecaster/src/forecaster.rs @@ -39,7 +39,7 @@ where /// Fit the model to the given time series. pub fn fit(&mut self, y: D) -> Result<()> { let mut y = y.as_slice().to_vec(); - self.pipeline.transform(&mut y)?; + self.pipeline.fit_transform(&mut y)?; self.fitted = Some(self.model.fit(&y).map_err(|e| Error::Fit { source: Box::new(e) as _, })?); diff --git a/crates/augurs-forecaster/src/transforms.rs b/crates/augurs-forecaster/src/transforms.rs index 6674740..70db57c 100644 --- a/crates/augurs-forecaster/src/transforms.rs +++ b/crates/augurs-forecaster/src/transforms.rs @@ -1,11 +1,5 @@ /*! Data transformations. - -This module contains the `Transform` enum, which contains various -predefined transformations. The enum contains various methods for -creating new instances of the various transformations, as well as -the `transform` and `inverse_transform` methods, which allow you to -apply a transformation to a time series and its inverse, respectively. */ // Note: implementations of the various transforms are in the @@ -27,38 +21,90 @@ pub use interpolate::{InterpolateExt, LinearInterpolator}; pub use power::{BoxCox, YeoJohnson}; pub use scale::{MinMaxScaler, StandardScaleParams, StandardScaler}; -/// Transforms and Transform implementations. +/// A transformation pipeline. /// -/// The `Transforms` struct is a collection of `Transform` instances that can be applied to a time series. -/// The `Transform` enum represents a single transformation that can be applied to a time series. +/// The `Pipeline` struct is a collection of heterogeneous `Transform` instances +/// that can be applied to a time series. Calling the `fit` or `fit_transform` +/// methods will fit each transformation to the output of the previous one in turn +/// starting by passing the input to the first transformation. #[derive(Debug, Default)] -pub struct Pipeline(Vec>); +pub struct Pipeline { + transforms: Vec>, + is_fitted: bool, +} impl Pipeline { /// Create a new `Pipeline` with the given transforms. pub fn new(transforms: Vec>) -> Self { - Self(transforms) + Self { + transforms, + is_fitted: false, + } + } + + /// Fit the transformations to the given time series. + /// + /// Prefer `fit_transform` if possible, as it avoids copying the input. + pub fn fit(&mut self, input: &[f64]) -> Result<(), Error> { + // Copy the input to avoid mutating the original. + // We need to do this so we can call `fit_transform` on each + // transformation in the pipeline without mutating the input. + // This is required because each transformation needs to be + // fit after previous transformations have been applied. + let mut input = input.to_vec(); + // Reuse `fit_transform_inner`, and just discard the result. + self.fit_transform_inner(&mut input)?; + Ok(()) } - /// Apply the transformations to the given time series. - pub fn transform(&mut self, input: &mut [f64]) -> Result<(), Error> { - for t in self.0.iter_mut() { + fn fit_transform_inner(&mut self, input: &mut [f64]) -> Result<(), Error> { + for t in self.transforms.iter_mut() { + t.fit_transform(input)?; + } + self.is_fitted = true; + Ok(()) + } + + /// Fit and transform the given time series. + /// + /// This is equivalent to calling `fit` and then `transform` on the pipeline, + /// but is more efficient because it avoids copying the input. + pub fn fit_transform(&mut self, input: &mut [f64]) -> Result<(), Error> { + self.fit_transform_inner(input)?; + Ok(()) + } + + /// Apply the fitted transformations to the given time series. + /// + /// # Errors + /// + /// This function will return an error if the pipeline has not been fitted. + pub fn transform(&self, input: &mut [f64]) -> Result<(), Error> { + for t in self.transforms.iter() { t.transform(input)?; } Ok(()) } /// Apply the inverse transformations to the given time series. + /// + /// # Errors + /// + /// This function will return an error if the pipeline has not been fitted. pub fn inverse_transform(&self, input: &mut [f64]) -> Result<(), Error> { - for t in self.0.iter().rev() { + for t in self.transforms.iter().rev() { t.inverse_transform(input)?; } Ok(()) } /// Apply the inverse transformations to the given forecast. + /// + /// # Errors + /// + /// This function will return an error if the pipeline has not been fitted. pub(crate) fn inverse_transform_forecast(&self, forecast: &mut Forecast) -> Result<(), Error> { - for t in self.0.iter().rev() { + for t in self.transforms.iter().rev() { t.inverse_transform(&mut forecast.point)?; if let Some(intervals) = forecast.intervals.as_mut() { t.inverse_transform(&mut intervals.lower)?; @@ -71,12 +117,22 @@ impl Pipeline { /// A transformation that can be applied to a time series. pub trait Transform: fmt::Debug + Sync + Send { + /// Fit the transformation to the given time series. + fn fit(&mut self, data: &[f64]) -> Result<(), Error>; + /// Apply the transformation to the given time series. - fn transform(&mut self, data: &mut [f64]) -> Result<(), Error>; + fn transform(&self, data: &mut [f64]) -> Result<(), Error>; /// Apply the inverse transformation to the given time series. fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error>; + /// Fit the transformation to the given time series and then apply it. + fn fit_transform(&mut self, data: &mut [f64]) -> Result<(), Error> { + self.fit(data)?; + self.transform(data)?; + Ok(()) + } + /// Create a boxed version of the transformation. /// /// This is useful for creating a `Transform` instance that can be used as diff --git a/crates/augurs-forecaster/src/transforms/exp.rs b/crates/augurs-forecaster/src/transforms/exp.rs index e1c280b..b48e40a 100644 --- a/crates/augurs-forecaster/src/transforms/exp.rs +++ b/crates/augurs-forecaster/src/transforms/exp.rs @@ -36,7 +36,11 @@ impl fmt::Debug for Logit { } impl Transform for Logit { - fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> { + fn fit(&mut self, _data: &[f64]) -> Result<(), Error> { + Ok(()) + } + + fn transform(&self, data: &mut [f64]) -> Result<(), Error> { data.iter_mut().for_each(|x| *x = logit(*x)); Ok(()) } @@ -67,7 +71,11 @@ impl fmt::Debug for Log { } impl Transform for Log { - fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> { + fn fit(&mut self, _data: &[f64]) -> Result<(), Error> { + Ok(()) + } + + fn transform(&self, data: &mut [f64]) -> Result<(), Error> { data.iter_mut().for_each(|x| *x = f64::ln(*x)); Ok(()) } diff --git a/crates/augurs-forecaster/src/transforms/interpolate.rs b/crates/augurs-forecaster/src/transforms/interpolate.rs index 4b9c329..501e82f 100644 --- a/crates/augurs-forecaster/src/transforms/interpolate.rs +++ b/crates/augurs-forecaster/src/transforms/interpolate.rs @@ -60,7 +60,11 @@ impl Interpolater for LinearInterpolator { } impl Transform for LinearInterpolator { - fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> { + fn fit(&mut self, _data: &[f64]) -> Result<(), Error> { + Ok(()) + } + + fn transform(&self, data: &mut [f64]) -> Result<(), Error> { let interpolated: Vec<_> = data.iter().copied().interpolate(*self).collect(); data.copy_from_slice(&interpolated); Ok(()) diff --git a/crates/augurs-forecaster/src/transforms/power.rs b/crates/augurs-forecaster/src/transforms/power.rs index 2fb53bb..59fade8 100644 --- a/crates/augurs-forecaster/src/transforms/power.rs +++ b/crates/augurs-forecaster/src/transforms/power.rs @@ -176,10 +176,17 @@ impl Default for BoxCox { } impl Transform for BoxCox { - fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> { + fn fit(&mut self, data: &[f64]) -> Result<(), Error> { if self.lambda.is_nan() { self.lambda = optimize_box_cox_lambda(data)?; } + Ok(()) + } + + fn transform(&self, data: &mut [f64]) -> Result<(), Error> { + if self.lambda.is_nan() { + return Err(Error::NotFitted); + } for x in data.iter_mut() { *x = box_cox(*x, self.lambda)?; } @@ -338,10 +345,17 @@ impl Default for YeoJohnson { } impl Transform for YeoJohnson { - fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> { + fn fit(&mut self, data: &[f64]) -> Result<(), Error> { if self.lambda.is_nan() { self.lambda = optimize_yeo_johnson_lambda(data)?; } + Ok(()) + } + + fn transform(&self, data: &mut [f64]) -> Result<(), Error> { + if self.lambda.is_nan() { + return Err(Error::NotFitted); + } for x in data.iter_mut() { *x = yeo_johnson(*x, self.lambda)?; } @@ -427,7 +441,7 @@ mod test { let lambda = 0.5; let mut box_cox = BoxCox::new().with_lambda(lambda).unwrap(); let expected = vec![0.0, 0.8284271247461903, 1.4641016151377544]; - box_cox.transform(&mut data).unwrap(); + box_cox.fit_transform(&mut data).unwrap(); assert_all_close(&expected, &data); } @@ -445,7 +459,7 @@ mod test { fn yeo_johnson_test() { let mut data = vec![-1.0, 0.0, 1.0]; let lambda = 0.5; - let mut yeo_johnson = YeoJohnson::new().with_lambda(lambda).unwrap(); + let yeo_johnson = YeoJohnson::new().with_lambda(lambda).unwrap(); let expected = vec![-1.2189514164974602, 0.0, 0.8284271247461903]; yeo_johnson.transform(&mut data).unwrap(); assert_all_close(&expected, &data); diff --git a/crates/augurs-forecaster/src/transforms/scale.rs b/crates/augurs-forecaster/src/transforms/scale.rs index 6c69773..915b57a 100644 --- a/crates/augurs-forecaster/src/transforms/scale.rs +++ b/crates/augurs-forecaster/src/transforms/scale.rs @@ -97,33 +97,37 @@ impl MinMaxScaler { self.params = Some(FittedMinMaxScalerParams::new(data_range, self.output_scale)); self } +} - fn fit(&self, data: &[f64]) -> Result { - match data +impl Transform for MinMaxScaler { + /// Fit the scaler to the given data. + /// + /// This will compute the min and max values of the data and store them + /// in the `params` field of the scaler. + fn fit(&mut self, data: &[f64]) -> Result<(), Error> { + let params = match data .iter() .copied() .minmax_by(|a, b| a.partial_cmp(b).unwrap()) { - e @ MinMaxResult::NoElements | e @ MinMaxResult::OneElement(_) => Err(e.into()), - MinMaxResult::MinMax(min, max) => Ok(FittedMinMaxScalerParams::new( - MinMax { min, max }, - self.output_scale, - )), - } + e @ MinMaxResult::NoElements | e @ MinMaxResult::OneElement(_) => return Err(e.into()), + MinMaxResult::MinMax(min, max) => { + FittedMinMaxScalerParams::new(MinMax { min, max }, self.output_scale) + } + }; + self.params = Some(params); + Ok(()) } -} -impl Transform for MinMaxScaler { - fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> { - let params = match &mut self.params { - Some(p) => p, - None => self.params.get_or_insert(self.fit(data)?), - }; + /// Apply the scaler to the given data. + fn transform(&self, data: &mut [f64]) -> Result<(), Error> { + let params = self.params.as_ref().ok_or(Error::NotFitted)?; data.iter_mut() .for_each(|x| *x = *x * params.scale_factor + params.offset); Ok(()) } + /// Apply the inverse of the scaler to the given data. fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error> { let params = self.params.as_ref().ok_or(Error::NotFitted)?; data.iter_mut() @@ -212,7 +216,7 @@ impl StandardScaleParams { /// /// let mut data = vec![1.0, 2.0, 3.0]; /// let mut scaler = StandardScaler::new(); -/// scaler.transform(&mut data); +/// scaler.fit_transform(&mut data); /// /// assert_eq!(data, vec![-1.224744871391589, 0.0, 1.224744871391589]); /// ``` @@ -237,18 +241,16 @@ impl StandardScaler { self.params = Some(params); self } - - fn fit(&self, data: &[f64]) -> StandardScaleParams { - StandardScaleParams::from_data(data.iter().copied()) - } } impl Transform for StandardScaler { - fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> { - let params = match &mut self.params { - Some(p) => p, - None => self.params.get_or_insert(self.fit(data)), - }; + fn fit(&mut self, data: &[f64]) -> Result<(), Error> { + self.params = Some(StandardScaleParams::from_data(data.iter().copied())); + Ok(()) + } + + fn transform(&self, data: &mut [f64]) -> Result<(), Error> { + let params = self.params.as_ref().ok_or(Error::NotFitted)?; data.iter_mut() .for_each(|x| *x = (*x - params.mean) / params.std_dev); Ok(()) @@ -273,7 +275,7 @@ mod test { let mut data = vec![1.0, 2.0, 3.0]; let expected = vec![0.0, 0.5, 1.0]; let mut scaler = MinMaxScaler::new(); - scaler.transform(&mut data).unwrap(); + scaler.fit_transform(&mut data).unwrap(); assert_all_close(&expected, &data); } @@ -282,7 +284,7 @@ mod test { let mut data = vec![1.0, 2.0, 3.0]; let expected = vec![0.0, 5.0, 10.0]; let mut scaler = MinMaxScaler::new().with_scaled_range(0.0, 10.0); - scaler.transform(&mut data).unwrap(); + scaler.fit_transform(&mut data).unwrap(); assert_all_close(&expected, &data); } @@ -313,7 +315,7 @@ mod test { // not necessarily obvious. let expected = vec![-1.224744871391589, 0.0, 1.224744871391589]; let mut scaler = StandardScaler::new(); // 2.0, 1.0); // mean=2, std=1 - scaler.transform(&mut data).unwrap(); + scaler.fit_transform(&mut data).unwrap(); assert_all_close(&expected, &data); } @@ -322,7 +324,7 @@ mod test { let mut data = vec![1.0, 2.0, 3.0]; let expected = vec![-1.0, 0.0, 1.0]; let params = StandardScaleParams::new(2.0, 1.0); // mean=2, std=1 - let mut scaler = StandardScaler::new().with_parameters(params); + let scaler = StandardScaler::new().with_parameters(params); scaler.transform(&mut data).unwrap(); assert_all_close(&expected, &data); }