Skip to content

Commit

Permalink
Rename Transform to Transformer; impl for Pipeline; improve docs
Browse files Browse the repository at this point in the history
`Transformer` matches the scikit-learn API.
  • Loading branch information
sd2k committed Dec 22, 2024
1 parent c467695 commit 15aa106
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 96 deletions.
8 changes: 4 additions & 4 deletions book/src/getting-started/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,26 @@ use augurs::{
ets::AutoETS,
forecaster::{
transforms::{LinearInterpolator, Log, MinMaxScaler},
Forecaster, Transform,
Forecaster, Transformer,
},
mstl::MSTLModel,
};

fn main() {
let data = &[1.0, 1.2, 1.4, 1.5, f64::NAN, 1.4, 1.2, 1.5, 1.6, 2.0, 1.9, 1.8];

// Set up model and transforms
// Set up model and transformers
let ets = AutoETS::non_seasonal().into_trend_model();
let mstl = MSTLModel::new(vec![2], ets);

let transforms = vec![
let transformers = vec![
LinearInterpolator::new().boxed(),
MinMaxScaler::new().boxed(),
Log::new().boxed(),
];

// Create and fit forecaster
let mut forecaster = Forecaster::new(mstl).with_transforms(transforms);
let mut forecaster = Forecaster::new(mstl).with_transformers(transformers);
forecaster.fit(data).expect("model should fit");

// Generate forecasts
Expand Down
8 changes: 4 additions & 4 deletions crates/augurs-forecaster/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ augurs-mstl = "*"
use augurs::{
ets::{AutoETS, trend::AutoETSTrendModel},
forecaster::{
Forecaster, Transform,
Forecaster, Transformer,
transforms::{LinearInterpolator, Logit, MinMaxScaler},
},
mstl::MSTLModel
Expand All @@ -34,15 +34,15 @@ let data = &[
let ets = AutoETS::non_seasonal().into_trend_model();
let mstl = MSTLModel::new(vec![2], ets);

// Set up the transforms.
let transforms = vec![
// Set up the transformers.
let transformers = vec![
LinearInterpolator::new().boxed(),
MinMaxScaler::new().boxed(),
Logit::new().boxed(),
];

// Create a forecaster using the transforms.
let mut forecaster = Forecaster::new(mstl).with_transforms(transforms);
let mut forecaster = Forecaster::new(mstl).with_transformers(transformers);

// Fit the forecaster. This will transform the training data by
// running the transforms in order, then fit the MSTL model.
Expand Down
18 changes: 9 additions & 9 deletions crates/augurs-forecaster/src/forecaster.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use augurs_core::{Fit, Forecast, Predict};

use crate::{Data, Error, Pipeline, Result, Transform};
use crate::{Data, Error, Pipeline, Result, Transformer};

/// A high-level API to fit and predict time series forecasting models.
///
Expand Down Expand Up @@ -31,8 +31,8 @@ where
}

/// Set the transformations to be applied to the input data.
pub fn with_transforms(mut self, transforms: Vec<Box<dyn Transform>>) -> Self {
self.pipeline = Pipeline::new(transforms);
pub fn with_transformers(mut self, transformers: Vec<Box<dyn Transformer>>) -> Self {
self.pipeline = Pipeline::new(transformers);
self
}

Expand Down Expand Up @@ -92,13 +92,13 @@ mod test {
#[test]
fn test_forecaster() {
let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
let transforms = vec![
let transformers = vec![
LinearInterpolator::new().boxed(),
MinMaxScaler::new().boxed(),
Logit::new().boxed(),
];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
let mut forecaster = Forecaster::new(model).with_transformers(transformers);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_close(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]);
Expand All @@ -107,9 +107,9 @@ mod test {
#[test]
fn test_forecaster_power_positive() {
let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
let transforms = vec![BoxCox::new().boxed()];
let transformers = vec![BoxCox::new().boxed()];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
let mut forecaster = Forecaster::new(model).with_transformers(transformers);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_close(
Expand All @@ -126,9 +126,9 @@ mod test {
#[test]
fn test_forecaster_power_non_positive() {
let data = &[0.0, 2.0, 3.0, 4.0, 5.0];
let transforms = vec![YeoJohnson::new().boxed()];
let transformers = vec![YeoJohnson::new().boxed()];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
let mut forecaster = Forecaster::new(model).with_transformers(transformers);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_close(
Expand Down
2 changes: 1 addition & 1 deletion crates/augurs-forecaster/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pub mod transforms;
pub use data::Data;
pub use error::Error;
pub use forecaster::Forecaster;
pub use transforms::{Pipeline, Transform};
pub use transforms::{Pipeline, Transformer};

type Result<T> = std::result::Result<T, Error>;
102 changes: 62 additions & 40 deletions crates/augurs-forecaster/src/transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,58 @@ pub use scale::{MinMaxScaler, StandardScaleParams, StandardScaler};

/// A transformation pipeline.
///
/// The `Pipeline` struct is a collection of heterogeneous `Transform` instances
/// that can be applied to a time series. Calling the `fit` or `fit_transform`
/// methods will fit each transformation to the output of the previous one in turn
/// starting by passing the input to the first transformation.
/// A `Pipeline` is a collection of heterogeneous [`Transformer`] instances
/// that can be applied to a time series. Calling [`Pipeline::fit`] or [`Pipeline::fit_transform`]
/// will fit each transformation to the output of the previous one in turn
/// starting by passing the input to the first transformation. The
/// [`Pipeline::inverse_transform`] can then be used to back-transform data
/// to the original scale.
#[derive(Debug, Default)]
pub struct Pipeline {
transforms: Vec<Box<dyn Transform>>,
transformers: Vec<Box<dyn Transformer>>,
is_fitted: bool,
}

impl Pipeline {
/// Create a new `Pipeline` with the given transforms.
pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
/// Create a new `Pipeline` with the given transformers.
pub fn new(transformers: Vec<Box<dyn Transformer>>) -> Self {
Self {
transforms,
transformers,
is_fitted: false,
}
}

// Helper function for actually doing the fit then transform steps.
fn fit_transform_inner(&mut self, input: &mut [f64]) -> Result<(), Error> {
for t in self.transformers.iter_mut() {
t.fit_transform(input)?;
}
self.is_fitted = true;
Ok(())
}

/// Apply the inverse transformations to the given forecast.
///
/// # Errors
///
/// This function will return an error if the pipeline has not been fitted.
pub(crate) fn inverse_transform_forecast(&self, forecast: &mut Forecast) -> Result<(), Error> {
for t in self.transformers.iter().rev() {
t.inverse_transform(&mut forecast.point)?;
if let Some(intervals) = forecast.intervals.as_mut() {
t.inverse_transform(&mut intervals.lower)?;
t.inverse_transform(&mut intervals.upper)?;
}
}
Ok(())
}
}

impl Transformer for Pipeline {
/// Fit the transformations to the given time series.
///
/// Prefer `fit_transform` if possible, as it avoids copying the input.
pub fn fit(&mut self, input: &[f64]) -> Result<(), Error> {
fn fit(&mut self, input: &[f64]) -> Result<(), Error> {
// Copy the input to avoid mutating the original.
// We need to do this so we can call `fit_transform` on each
// transformation in the pipeline without mutating the input.
Expand All @@ -57,19 +86,11 @@ impl Pipeline {
Ok(())
}

fn fit_transform_inner(&mut self, input: &mut [f64]) -> Result<(), Error> {
for t in self.transforms.iter_mut() {
t.fit_transform(input)?;
}
self.is_fitted = true;
Ok(())
}

/// Fit and transform the given time series.
///
/// This is equivalent to calling `fit` and then `transform` on the pipeline,
/// but is more efficient because it avoids copying the input.
pub fn fit_transform(&mut self, input: &mut [f64]) -> Result<(), Error> {
fn fit_transform(&mut self, input: &mut [f64]) -> Result<(), Error> {
self.fit_transform_inner(input)?;
Ok(())
}
Expand All @@ -79,8 +100,8 @@ impl Pipeline {
/// # Errors
///
/// This function will return an error if the pipeline has not been fitted.
pub fn transform(&self, input: &mut [f64]) -> Result<(), Error> {
for t in self.transforms.iter() {
fn transform(&self, input: &mut [f64]) -> Result<(), Error> {
for t in self.transformers.iter() {
t.transform(input)?;
}
Ok(())
Expand All @@ -91,42 +112,43 @@ impl Pipeline {
/// # Errors
///
/// This function will return an error if the pipeline has not been fitted.
pub fn inverse_transform(&self, input: &mut [f64]) -> Result<(), Error> {
for t in self.transforms.iter().rev() {
fn inverse_transform(&self, input: &mut [f64]) -> Result<(), Error> {
for t in self.transformers.iter().rev() {
t.inverse_transform(input)?;
}
Ok(())
}

/// Apply the inverse transformations to the given forecast.
///
/// # Errors
///
/// This function will return an error if the pipeline has not been fitted.
pub(crate) fn inverse_transform_forecast(&self, forecast: &mut Forecast) -> Result<(), Error> {
for t in self.transforms.iter().rev() {
t.inverse_transform(&mut forecast.point)?;
if let Some(intervals) = forecast.intervals.as_mut() {
t.inverse_transform(&mut intervals.lower)?;
t.inverse_transform(&mut intervals.upper)?;
}
}
Ok(())
}
}

/// A transformation that can be applied to a time series.
pub trait Transform: fmt::Debug + Sync + Send {
pub trait Transformer: fmt::Debug + Sync + Send {
/// Fit the transformation to the given time series.
///
/// For example, for a min-max scaler, this would find
/// the min and max of the provided data and store it on the
/// scaler ready for use in transforming and back-transforming.
fn fit(&mut self, data: &[f64]) -> Result<(), Error>;

/// Apply the transformation to the given time series.
///
/// # Errors
///
/// This function should return an error if the transform has not been fitted,
/// and may return other errors specific to the implementation.
fn transform(&self, data: &mut [f64]) -> Result<(), Error>;

/// Apply the inverse transformation to the given time series.
///
/// # Errors
///
/// This function should return an error if the transform has not been fitted,
/// and may return other errors specific to the implementation.
fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error>;

/// Fit the transformation to the given time series and then apply it.
///
/// The default implementation just calls [`Self::fit`] then [`Self::transform`]
/// but it can be overridden to be more efficient if desired.
fn fit_transform(&mut self, data: &mut [f64]) -> Result<(), Error> {
self.fit(data)?;
self.transform(data)?;
Expand All @@ -137,7 +159,7 @@ pub trait Transform: fmt::Debug + Sync + Send {
///
/// This is useful for creating a `Transform` instance that can be used as
/// part of a [`Pipeline`].
fn boxed(self) -> Box<dyn Transform>
fn boxed(self) -> Box<dyn Transformer>
where
Self: Sized + 'static,
{
Expand Down
6 changes: 3 additions & 3 deletions crates/augurs-forecaster/src/transforms/exp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::fmt;

use super::{Error, Transform};
use super::{Error, Transformer};

// Logit and logistic functions.

Expand Down Expand Up @@ -35,7 +35,7 @@ impl fmt::Debug for Logit {
}
}

impl Transform for Logit {
impl Transformer for Logit {
fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
Ok(())
}
Expand Down Expand Up @@ -70,7 +70,7 @@ impl fmt::Debug for Log {
}
}

impl Transform for Log {
impl Transformer for Log {
fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions crates/augurs-forecaster/src/transforms/interpolate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
ops::{Add, Div, Mul, Sub},
};

use super::{Error, Transform};
use super::{Error, Transformer};

/// A type that can be used to interpolate between values.
pub trait Interpolater {
Expand Down Expand Up @@ -59,7 +59,7 @@ impl Interpolater for LinearInterpolator {
}
}

impl Transform for LinearInterpolator {
impl Transformer for LinearInterpolator {
fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
Ok(())
}
Expand Down
6 changes: 3 additions & 3 deletions crates/augurs-forecaster/src/transforms/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use argmin::core::{CostFunction, Executor};
use argmin::solver::brent::BrentOpt;

use super::{Error, Transform};
use super::{Error, Transformer};

/// Returns the Box-Cox transformation of the given value.
/// Assumes x > 0.
Expand Down Expand Up @@ -175,7 +175,7 @@ impl Default for BoxCox {
}
}

impl Transform for BoxCox {
impl Transformer for BoxCox {
fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
if self.lambda.is_nan() {
self.lambda = optimize_box_cox_lambda(data)?;
Expand Down Expand Up @@ -344,7 +344,7 @@ impl Default for YeoJohnson {
}
}

impl Transform for YeoJohnson {
impl Transformer for YeoJohnson {
fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
if self.lambda.is_nan() {
self.lambda = optimize_yeo_johnson_lambda(data)?;
Expand Down
Loading

0 comments on commit 15aa106

Please sign in to comment.