From 15aa106ca8a8366845e6532cc709153727485297 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Sun, 22 Dec 2024 08:18:40 +0000 Subject: [PATCH] Rename `Transform` to `Transformer`; impl for `Pipeline`; improve docs `Transformer` matches the scikit-learn API. --- book/src/getting-started/quick-start.md | 8 +- crates/augurs-forecaster/README.md | 8 +- crates/augurs-forecaster/src/forecaster.rs | 18 ++-- crates/augurs-forecaster/src/lib.rs | 2 +- crates/augurs-forecaster/src/transforms.rs | 102 +++++++++++------- .../augurs-forecaster/src/transforms/exp.rs | 6 +- .../src/transforms/interpolate.rs | 4 +- .../augurs-forecaster/src/transforms/power.rs | 6 +- .../augurs-forecaster/src/transforms/scale.rs | 8 +- crates/augurs/tests/integration.rs | 4 +- examples/forecasting/examples/forecaster.rs | 12 +-- .../examples/prophet_forecaster.rs | 12 +-- js/augurs-mstl-js/src/lib.rs | 16 +-- js/augurs-transforms-js/src/lib.rs | 8 +- 14 files changed, 118 insertions(+), 96 deletions(-) diff --git a/book/src/getting-started/quick-start.md b/book/src/getting-started/quick-start.md index e0b751f..c1e059c 100644 --- a/book/src/getting-started/quick-start.md +++ b/book/src/getting-started/quick-start.md @@ -51,7 +51,7 @@ use augurs::{ ets::AutoETS, forecaster::{ transforms::{LinearInterpolator, Log, MinMaxScaler}, - Forecaster, Transform, + Forecaster, Transformer, }, mstl::MSTLModel, }; @@ -59,18 +59,18 @@ use augurs::{ fn main() { let data = &[1.0, 1.2, 1.4, 1.5, f64::NAN, 1.4, 1.2, 1.5, 1.6, 2.0, 1.9, 1.8]; - // Set up model and transforms + // Set up model and transformers let ets = AutoETS::non_seasonal().into_trend_model(); let mstl = MSTLModel::new(vec![2], ets); - let transforms = vec![ + let transformers = vec![ LinearInterpolator::new().boxed(), MinMaxScaler::new().boxed(), Log::new().boxed(), ]; // Create and fit forecaster - let mut forecaster = Forecaster::new(mstl).with_transforms(transforms); + let mut forecaster = Forecaster::new(mstl).with_transformers(transformers); forecaster.fit(data).expect("model should fit"); // Generate forecasts diff --git a/crates/augurs-forecaster/README.md b/crates/augurs-forecaster/README.md index 320531a..d5b6b86 100644 --- a/crates/augurs-forecaster/README.md +++ b/crates/augurs-forecaster/README.md @@ -17,7 +17,7 @@ augurs-mstl = "*" use augurs::{ ets::{AutoETS, trend::AutoETSTrendModel}, forecaster::{ - Forecaster, Transform, + Forecaster, Transformer, transforms::{LinearInterpolator, Logit, MinMaxScaler}, }, mstl::MSTLModel @@ -34,15 +34,15 @@ let data = &[ let ets = AutoETS::non_seasonal().into_trend_model(); let mstl = MSTLModel::new(vec![2], ets); -// Set up the transforms. -let transforms = vec![ +// Set up the transformers. +let transformers = vec![ LinearInterpolator::new().boxed(), MinMaxScaler::new().boxed(), Logit::new().boxed(), ]; // Create a forecaster using the transforms. -let mut forecaster = Forecaster::new(mstl).with_transforms(transforms); +let mut forecaster = Forecaster::new(mstl).with_transformers(transformers); // Fit the forecaster. This will transform the training data by // running the transforms in order, then fit the MSTL model. diff --git a/crates/augurs-forecaster/src/forecaster.rs b/crates/augurs-forecaster/src/forecaster.rs index a3cf8ba..fd40439 100644 --- a/crates/augurs-forecaster/src/forecaster.rs +++ b/crates/augurs-forecaster/src/forecaster.rs @@ -1,6 +1,6 @@ use augurs_core::{Fit, Forecast, Predict}; -use crate::{Data, Error, Pipeline, Result, Transform}; +use crate::{Data, Error, Pipeline, Result, Transformer}; /// A high-level API to fit and predict time series forecasting models. /// @@ -31,8 +31,8 @@ where } /// Set the transformations to be applied to the input data. - pub fn with_transforms(mut self, transforms: Vec>) -> Self { - self.pipeline = Pipeline::new(transforms); + pub fn with_transformers(mut self, transformers: Vec>) -> Self { + self.pipeline = Pipeline::new(transformers); self } @@ -92,13 +92,13 @@ mod test { #[test] fn test_forecaster() { let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0]; - let transforms = vec![ + let transformers = vec![ LinearInterpolator::new().boxed(), MinMaxScaler::new().boxed(), Logit::new().boxed(), ]; let model = MSTLModel::new(vec![2], NaiveTrend::new()); - let mut forecaster = Forecaster::new(model).with_transforms(transforms); + let mut forecaster = Forecaster::new(model).with_transformers(transformers); forecaster.fit(data).unwrap(); let forecasts = forecaster.predict(4, None).unwrap(); assert_all_close(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]); @@ -107,9 +107,9 @@ mod test { #[test] fn test_forecaster_power_positive() { let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0]; - let transforms = vec![BoxCox::new().boxed()]; + let transformers = vec![BoxCox::new().boxed()]; let model = MSTLModel::new(vec![2], NaiveTrend::new()); - let mut forecaster = Forecaster::new(model).with_transforms(transforms); + let mut forecaster = Forecaster::new(model).with_transformers(transformers); forecaster.fit(data).unwrap(); let forecasts = forecaster.predict(4, None).unwrap(); assert_all_close( @@ -126,9 +126,9 @@ mod test { #[test] fn test_forecaster_power_non_positive() { let data = &[0.0, 2.0, 3.0, 4.0, 5.0]; - let transforms = vec![YeoJohnson::new().boxed()]; + let transformers = vec![YeoJohnson::new().boxed()]; let model = MSTLModel::new(vec![2], NaiveTrend::new()); - let mut forecaster = Forecaster::new(model).with_transforms(transforms); + let mut forecaster = Forecaster::new(model).with_transformers(transformers); forecaster.fit(data).unwrap(); let forecasts = forecaster.predict(4, None).unwrap(); assert_all_close( diff --git a/crates/augurs-forecaster/src/lib.rs b/crates/augurs-forecaster/src/lib.rs index 6c9f6a8..f9b129d 100644 --- a/crates/augurs-forecaster/src/lib.rs +++ b/crates/augurs-forecaster/src/lib.rs @@ -8,6 +8,6 @@ pub mod transforms; pub use data::Data; pub use error::Error; pub use forecaster::Forecaster; -pub use transforms::{Pipeline, Transform}; +pub use transforms::{Pipeline, Transformer}; type Result = std::result::Result; diff --git a/crates/augurs-forecaster/src/transforms.rs b/crates/augurs-forecaster/src/transforms.rs index 70db57c..0de4a78 100644 --- a/crates/augurs-forecaster/src/transforms.rs +++ b/crates/augurs-forecaster/src/transforms.rs @@ -23,29 +23,58 @@ pub use scale::{MinMaxScaler, StandardScaleParams, StandardScaler}; /// A transformation pipeline. /// -/// The `Pipeline` struct is a collection of heterogeneous `Transform` instances -/// that can be applied to a time series. Calling the `fit` or `fit_transform` -/// methods will fit each transformation to the output of the previous one in turn -/// starting by passing the input to the first transformation. +/// A `Pipeline` is a collection of heterogeneous [`Transformer`] instances +/// that can be applied to a time series. Calling [`Pipeline::fit`] or [`Pipeline::fit_transform`] +/// will fit each transformation to the output of the previous one in turn +/// starting by passing the input to the first transformation. The +/// [`Pipeline::inverse_transform`] can then be used to back-transform data +/// to the original scale. #[derive(Debug, Default)] pub struct Pipeline { - transforms: Vec>, + transformers: Vec>, is_fitted: bool, } impl Pipeline { - /// Create a new `Pipeline` with the given transforms. - pub fn new(transforms: Vec>) -> Self { + /// Create a new `Pipeline` with the given transformers. + pub fn new(transformers: Vec>) -> Self { Self { - transforms, + transformers, is_fitted: false, } } + // Helper function for actually doing the fit then transform steps. + fn fit_transform_inner(&mut self, input: &mut [f64]) -> Result<(), Error> { + for t in self.transformers.iter_mut() { + t.fit_transform(input)?; + } + self.is_fitted = true; + Ok(()) + } + + /// Apply the inverse transformations to the given forecast. + /// + /// # Errors + /// + /// This function will return an error if the pipeline has not been fitted. + pub(crate) fn inverse_transform_forecast(&self, forecast: &mut Forecast) -> Result<(), Error> { + for t in self.transformers.iter().rev() { + t.inverse_transform(&mut forecast.point)?; + if let Some(intervals) = forecast.intervals.as_mut() { + t.inverse_transform(&mut intervals.lower)?; + t.inverse_transform(&mut intervals.upper)?; + } + } + Ok(()) + } +} + +impl Transformer for Pipeline { /// Fit the transformations to the given time series. /// /// Prefer `fit_transform` if possible, as it avoids copying the input. - pub fn fit(&mut self, input: &[f64]) -> Result<(), Error> { + fn fit(&mut self, input: &[f64]) -> Result<(), Error> { // Copy the input to avoid mutating the original. // We need to do this so we can call `fit_transform` on each // transformation in the pipeline without mutating the input. @@ -57,19 +86,11 @@ impl Pipeline { Ok(()) } - fn fit_transform_inner(&mut self, input: &mut [f64]) -> Result<(), Error> { - for t in self.transforms.iter_mut() { - t.fit_transform(input)?; - } - self.is_fitted = true; - Ok(()) - } - /// Fit and transform the given time series. /// /// This is equivalent to calling `fit` and then `transform` on the pipeline, /// but is more efficient because it avoids copying the input. - pub fn fit_transform(&mut self, input: &mut [f64]) -> Result<(), Error> { + fn fit_transform(&mut self, input: &mut [f64]) -> Result<(), Error> { self.fit_transform_inner(input)?; Ok(()) } @@ -79,8 +100,8 @@ impl Pipeline { /// # Errors /// /// This function will return an error if the pipeline has not been fitted. - pub fn transform(&self, input: &mut [f64]) -> Result<(), Error> { - for t in self.transforms.iter() { + fn transform(&self, input: &mut [f64]) -> Result<(), Error> { + for t in self.transformers.iter() { t.transform(input)?; } Ok(()) @@ -91,42 +112,43 @@ impl Pipeline { /// # Errors /// /// This function will return an error if the pipeline has not been fitted. - pub fn inverse_transform(&self, input: &mut [f64]) -> Result<(), Error> { - for t in self.transforms.iter().rev() { + fn inverse_transform(&self, input: &mut [f64]) -> Result<(), Error> { + for t in self.transformers.iter().rev() { t.inverse_transform(input)?; } Ok(()) } - - /// Apply the inverse transformations to the given forecast. - /// - /// # Errors - /// - /// This function will return an error if the pipeline has not been fitted. - pub(crate) fn inverse_transform_forecast(&self, forecast: &mut Forecast) -> Result<(), Error> { - for t in self.transforms.iter().rev() { - t.inverse_transform(&mut forecast.point)?; - if let Some(intervals) = forecast.intervals.as_mut() { - t.inverse_transform(&mut intervals.lower)?; - t.inverse_transform(&mut intervals.upper)?; - } - } - Ok(()) - } } /// A transformation that can be applied to a time series. -pub trait Transform: fmt::Debug + Sync + Send { +pub trait Transformer: fmt::Debug + Sync + Send { /// Fit the transformation to the given time series. + /// + /// For example, for a min-max scaler, this would find + /// the min and max of the provided data and store it on the + /// scaler ready for use in transforming and back-transforming. fn fit(&mut self, data: &[f64]) -> Result<(), Error>; /// Apply the transformation to the given time series. + /// + /// # Errors + /// + /// This function should return an error if the transform has not been fitted, + /// and may return other errors specific to the implementation. fn transform(&self, data: &mut [f64]) -> Result<(), Error>; /// Apply the inverse transformation to the given time series. + /// + /// # Errors + /// + /// This function should return an error if the transform has not been fitted, + /// and may return other errors specific to the implementation. fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error>; /// Fit the transformation to the given time series and then apply it. + /// + /// The default implementation just calls [`Self::fit`] then [`Self::transform`] + /// but it can be overridden to be more efficient if desired. fn fit_transform(&mut self, data: &mut [f64]) -> Result<(), Error> { self.fit(data)?; self.transform(data)?; @@ -137,7 +159,7 @@ pub trait Transform: fmt::Debug + Sync + Send { /// /// This is useful for creating a `Transform` instance that can be used as /// part of a [`Pipeline`]. - fn boxed(self) -> Box + fn boxed(self) -> Box where Self: Sized + 'static, { diff --git a/crates/augurs-forecaster/src/transforms/exp.rs b/crates/augurs-forecaster/src/transforms/exp.rs index b48e40a..1ede4c5 100644 --- a/crates/augurs-forecaster/src/transforms/exp.rs +++ b/crates/augurs-forecaster/src/transforms/exp.rs @@ -2,7 +2,7 @@ use std::fmt; -use super::{Error, Transform}; +use super::{Error, Transformer}; // Logit and logistic functions. @@ -35,7 +35,7 @@ impl fmt::Debug for Logit { } } -impl Transform for Logit { +impl Transformer for Logit { fn fit(&mut self, _data: &[f64]) -> Result<(), Error> { Ok(()) } @@ -70,7 +70,7 @@ impl fmt::Debug for Log { } } -impl Transform for Log { +impl Transformer for Log { fn fit(&mut self, _data: &[f64]) -> Result<(), Error> { Ok(()) } diff --git a/crates/augurs-forecaster/src/transforms/interpolate.rs b/crates/augurs-forecaster/src/transforms/interpolate.rs index 501e82f..0dede7b 100644 --- a/crates/augurs-forecaster/src/transforms/interpolate.rs +++ b/crates/augurs-forecaster/src/transforms/interpolate.rs @@ -12,7 +12,7 @@ use std::{ ops::{Add, Div, Mul, Sub}, }; -use super::{Error, Transform}; +use super::{Error, Transformer}; /// A type that can be used to interpolate between values. pub trait Interpolater { @@ -59,7 +59,7 @@ impl Interpolater for LinearInterpolator { } } -impl Transform for LinearInterpolator { +impl Transformer for LinearInterpolator { fn fit(&mut self, _data: &[f64]) -> Result<(), Error> { Ok(()) } diff --git a/crates/augurs-forecaster/src/transforms/power.rs b/crates/augurs-forecaster/src/transforms/power.rs index 59fade8..9c99b2a 100644 --- a/crates/augurs-forecaster/src/transforms/power.rs +++ b/crates/augurs-forecaster/src/transforms/power.rs @@ -3,7 +3,7 @@ use argmin::core::{CostFunction, Executor}; use argmin::solver::brent::BrentOpt; -use super::{Error, Transform}; +use super::{Error, Transformer}; /// Returns the Box-Cox transformation of the given value. /// Assumes x > 0. @@ -175,7 +175,7 @@ impl Default for BoxCox { } } -impl Transform for BoxCox { +impl Transformer for BoxCox { fn fit(&mut self, data: &[f64]) -> Result<(), Error> { if self.lambda.is_nan() { self.lambda = optimize_box_cox_lambda(data)?; @@ -344,7 +344,7 @@ impl Default for YeoJohnson { } } -impl Transform for YeoJohnson { +impl Transformer for YeoJohnson { fn fit(&mut self, data: &[f64]) -> Result<(), Error> { if self.lambda.is_nan() { self.lambda = optimize_yeo_johnson_lambda(data)?; diff --git a/crates/augurs-forecaster/src/transforms/scale.rs b/crates/augurs-forecaster/src/transforms/scale.rs index 915b57a..9bf1801 100644 --- a/crates/augurs-forecaster/src/transforms/scale.rs +++ b/crates/augurs-forecaster/src/transforms/scale.rs @@ -4,7 +4,7 @@ use core::f64; use itertools::{Itertools, MinMaxResult}; -use super::{Error, Transform}; +use super::{Error, Transformer}; /// Helper struct holding the min and max for use in a `MinMaxScaler`. #[derive(Debug, Clone, Copy)] @@ -99,7 +99,7 @@ impl MinMaxScaler { } } -impl Transform for MinMaxScaler { +impl Transformer for MinMaxScaler { /// Fit the scaler to the given data. /// /// This will compute the min and max values of the data and store them @@ -212,7 +212,7 @@ impl StandardScaleParams { /// ## Using the default constructor /// /// ``` -/// use augurs_forecaster::transforms::{StandardScaler, Transform}; +/// use augurs_forecaster::transforms::{StandardScaler, Transformer}; /// /// let mut data = vec![1.0, 2.0, 3.0]; /// let mut scaler = StandardScaler::new(); @@ -243,7 +243,7 @@ impl StandardScaler { } } -impl Transform for StandardScaler { +impl Transformer for StandardScaler { fn fit(&mut self, data: &[f64]) -> Result<(), Error> { self.params = Some(StandardScaleParams::from_data(data.iter().copied())); Ok(()) diff --git a/crates/augurs/tests/integration.rs b/crates/augurs/tests/integration.rs index e04f533..068132a 100644 --- a/crates/augurs/tests/integration.rs +++ b/crates/augurs/tests/integration.rs @@ -102,7 +102,7 @@ fn test_ets() { #[test] fn test_forecaster() { use augurs::{ - forecaster::{transforms::MinMaxScaler, Forecaster, Transform}, + forecaster::{transforms::MinMaxScaler, Forecaster, Transformer}, mstl::{MSTLModel, NaiveTrend}, }; use augurs_forecaster::transforms::{LinearInterpolator, Logit}; @@ -114,7 +114,7 @@ fn test_forecaster() { Logit::new().boxed(), ]; let model = MSTLModel::new(vec![2], NaiveTrend::new()); - let mut forecaster = Forecaster::new(model).with_transforms(transforms); + let mut forecaster = Forecaster::new(model).with_transformers(transforms); forecaster.fit(AIR_PASSENGERS).unwrap(); let forecasts = forecaster.predict(4, None).unwrap(); dbg!(&forecasts.point); diff --git a/examples/forecasting/examples/forecaster.rs b/examples/forecasting/examples/forecaster.rs index 33891b4..961e32f 100644 --- a/examples/forecasting/examples/forecaster.rs +++ b/examples/forecasting/examples/forecaster.rs @@ -9,7 +9,7 @@ use augurs::{ ets::AutoETS, forecaster::{ transforms::{LinearInterpolator, Log, MinMaxScaler}, - Forecaster, Transform, + Forecaster, Transformer, }, mstl::MSTLModel, }; @@ -39,20 +39,20 @@ fn main() { let ets = AutoETS::non_seasonal().into_trend_model(); let mstl = MSTLModel::new(vec![2], ets); - // Set up the transforms. - // These are just illustrative examples; you can use whatever transforms + // Set up the transformers. + // These are just illustrative examples; you can use whatever transformers // you want. - let transforms = vec![ + let transformers = vec![ LinearInterpolator::new().boxed(), MinMaxScaler::new().boxed(), Log::new().boxed(), ]; // Create a forecaster using the transforms. - let mut forecaster = Forecaster::new(mstl).with_transforms(transforms); + let mut forecaster = Forecaster::new(mstl).with_transformers(transformers); // Fit the forecaster. This will transform the training data by - // running the transforms in order, then fit the MSTL model. + // running the transformers in order, then fit the MSTL model. forecaster.fit(DATA).expect("model should fit"); // Generate some in-sample predictions with 95% prediction intervals. diff --git a/examples/forecasting/examples/prophet_forecaster.rs b/examples/forecasting/examples/prophet_forecaster.rs index 5fca5bd..91e39ca 100644 --- a/examples/forecasting/examples/prophet_forecaster.rs +++ b/examples/forecasting/examples/prophet_forecaster.rs @@ -1,7 +1,7 @@ //! Example of using the Prophet model with the wasmstan optimizer. use augurs::{ - forecaster::{transforms::MinMaxScaler, Forecaster, Transform}, + forecaster::{transforms::MinMaxScaler, Forecaster, Transformer}, prophet::{wasmstan::WasmstanOptimizer, Prophet, TrainingData}, }; @@ -18,10 +18,10 @@ fn main() -> Result<(), Box> { ]; let data = TrainingData::new(ds, y.clone())?; - // Set up the transforms. - // These are just illustrative examples; you can use whatever transforms + // Set up the transformers. + // These are just illustrative examples; you can use whatever transformers // you want. - let transforms = vec![MinMaxScaler::new().boxed()]; + let transformers = vec![MinMaxScaler::new().boxed()]; // Set up the model. Create the Prophet model as normal, then convert it to a // `ProphetForecaster`. @@ -29,10 +29,10 @@ fn main() -> Result<(), Box> { let prophet_forecaster = prophet.into_forecaster(data.clone(), Default::default()); // Finally create a Forecaster using those transforms. - let mut forecaster = Forecaster::new(prophet_forecaster).with_transforms(transforms); + let mut forecaster = Forecaster::new(prophet_forecaster).with_transformers(transformers); // Fit the forecaster. This will transform the training data by - // running the transforms in order, then fit the Prophet model. + // running the transformers in order, then fit the Prophet model. forecaster.fit(&y).expect("model should fit"); // Generate some in-sample predictions with 95% prediction intervals. diff --git a/js/augurs-mstl-js/src/lib.rs b/js/augurs-mstl-js/src/lib.rs index b865e7f..bbe7933 100644 --- a/js/augurs-mstl-js/src/lib.rs +++ b/js/augurs-mstl-js/src/lib.rs @@ -7,7 +7,7 @@ use wasm_bindgen::prelude::*; use augurs_ets::{trend::AutoETSTrendModel, AutoETS}; use augurs_forecaster::{ transforms::{LinearInterpolator, Logit}, - Forecaster, Transform, + Forecaster, Transformer, }; use augurs_mstl::{MSTLModel, TrendModel}; @@ -50,8 +50,8 @@ impl MSTL { let ets: Box = Box::new(AutoETSTrendModel::from(AutoETS::non_seasonal())); let model = MSTLModel::new(periods.convert()?, ets); - let forecaster = - Forecaster::new(model).with_transforms(options.unwrap_or_default().into_transforms()); + let forecaster = Forecaster::new(model) + .with_transformers(options.unwrap_or_default().into_transformers()); Ok(MSTL { forecaster }) } @@ -104,15 +104,15 @@ pub struct ETSOptions { } impl ETSOptions { - fn into_transforms(self) -> Vec> { - let mut transforms = vec![]; + fn into_transformers(self) -> Vec> { + let mut transformers = vec![]; if self.impute.unwrap_or_default() { - transforms.push(LinearInterpolator::new().boxed()) + transformers.push(LinearInterpolator::new().boxed()) } if self.logit_transform.unwrap_or_default() { - transforms.push(Logit::new().boxed()); + transformers.push(Logit::new().boxed()); } - transforms + transformers } } diff --git a/js/augurs-transforms-js/src/lib.rs b/js/augurs-transforms-js/src/lib.rs index a9cbb5a..03865b2 100644 --- a/js/augurs-transforms-js/src/lib.rs +++ b/js/augurs-transforms-js/src/lib.rs @@ -5,7 +5,7 @@ use tsify_next::Tsify; use wasm_bindgen::prelude::*; use augurs_core_js::VecF64; -use augurs_forecaster::transforms; +use augurs_forecaster::transforms::{StandardScaler, Transformer, YeoJohnson}; /// A transformation to be applied to the data. /// @@ -21,10 +21,10 @@ pub enum Transform { } impl Transform { - fn into_transform(self) -> Box { + fn into_transform(self) -> Box { match self { - Transform::StandardScaler => Box::new(transforms::StandardScaler::new()), - Transform::YeoJohnson => Box::new(transforms::YeoJohnson::new()), + Transform::StandardScaler => Box::new(StandardScaler::new()), + Transform::YeoJohnson => Box::new(YeoJohnson::new()), } } }