Skip to content

Commit

Permalink
feat!: switch transform to a trait
Browse files Browse the repository at this point in the history
This is a breaking change which replaces the `Transform` enum with
a trait, which allows it to be extended.

The API is also changed from transforming iterators to transforming
slices. On the surface this seems suboptimal, but many transforms
actually need to know some information about the data before they
can be applied (i.e. they can't be applied in a simple streaming
manner). Prior to this, this meant that either:

1. the user had to pass a slice into the transform constructor to
   allow some parameters to be calculated, or
2. the transform had to collect the iterator anyway, which could
   cause issues with lifetimes and meant it was hard to compose
   transforms.

Now, the transform is passed a mutable slice instead, which is
reused by all transforms.

This commit also reworks errors, since they're required by the
`Transform` trait anyway.

I'm still not sure if we want to have a `fit` method on the
transform; it does seem sensible if the user wants to be able
to have more control over the fitting process (currently it's
unspecified how that will work, but all current transforms
fit when `transform` is first called, unless they have a
way to override their parameters).
  • Loading branch information
sd2k committed Dec 20, 2024
1 parent e6cc399 commit 0d6e6a6
Show file tree
Hide file tree
Showing 16 changed files with 750 additions and 1,011 deletions.
11 changes: 7 additions & 4 deletions book/src/getting-started/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ For more complex scenarios, you can use the `Forecaster` API which supports data
# extern crate augurs;
use augurs::{
ets::AutoETS,
forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform},
forecaster::{
transforms::{LinearInterpolator, Log, MinMaxScaler},
Forecaster, Transform,
},
mstl::MSTLModel,
};

Expand All @@ -61,9 +64,9 @@ fn main() {
let mstl = MSTLModel::new(vec![2], ets);

let transforms = vec![
Transform::linear_interpolator(),
Transform::min_max_scaler(MinMaxScaleParams::from_data(data.iter().copied())),
Transform::log(),
LinearInterpolator::new().boxed(),
MinMaxScaler::new().boxed(),
Log::new().boxed(),
];

// Create and fit forecaster
Expand Down
1 change: 0 additions & 1 deletion crates/augurs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub mod prelude {

mod distance;
mod forecast;
pub mod interpolate;
mod traits;

use std::convert::Infallible;
Expand Down
10 changes: 10 additions & 0 deletions crates/augurs-forecaster/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use augurs_core::ModelError;

use crate::transforms;

/// Errors returned by this crate.
#[derive(Debug, thiserror::Error)]
pub enum Error {
Expand All @@ -18,4 +20,12 @@ pub enum Error {
/// The original error.
source: Box<dyn ModelError>,
},

/// An error occurred while running a transformation.
#[error("Transform error: {source}")]
Transform {
/// The original error.
#[from]
source: transforms::Error,
},
}
91 changes: 33 additions & 58 deletions crates/augurs-forecaster/src/forecaster.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use augurs_core::{Fit, Forecast, Predict};

use crate::{Data, Error, Result, Transform, Transforms};
use crate::{Data, Error, Pipeline, Result, Transform};

/// A high-level API to fit and predict time series forecasting models.
///
Expand All @@ -13,7 +13,7 @@ pub struct Forecaster<M: Fit> {
model: M,
fitted: Option<M::Fitted>,

transforms: Transforms,
pipeline: Pipeline,
}

impl<M> Forecaster<M>
Expand All @@ -26,23 +26,21 @@ where
Self {
model,
fitted: None,
transforms: Transforms::default(),
pipeline: Pipeline::default(),
}
}

/// Set the transformations to be applied to the input data.
pub fn with_transforms(mut self, transforms: Vec<Transform>) -> Self {
self.transforms = Transforms::new(transforms);
pub fn with_transforms(mut self, transforms: Vec<Box<dyn Transform>>) -> Self {
self.pipeline = Pipeline::new(transforms);
self
}

/// Fit the model to the given time series.
pub fn fit<D: Data + Clone>(&mut self, y: D) -> Result<()> {
let data: Vec<_> = self
.transforms
.transform(y.as_slice().iter().copied())
.collect();
self.fitted = Some(self.model.fit(&data).map_err(|e| Error::Fit {
let mut y = y.as_slice().to_vec();
self.pipeline.transform(&mut y)?;
self.fitted = Some(self.model.fit(&y).map_err(|e| Error::Fit {
source: Box::new(e) as _,
})?);
Ok(())
Expand All @@ -55,87 +53,66 @@ where
/// Predict the next `horizon` values, optionally including prediction
/// intervals at the given level.
pub fn predict(&self, horizon: usize, level: impl Into<Option<f64>>) -> Result<Forecast> {
self.fitted()?
.predict(horizon, level.into())
.map_err(|e| Error::Predict {
source: Box::new(e) as _,
})
.map(|f| self.transforms.inverse_transform(f))
let mut untransformed =
self.fitted()?
.predict(horizon, level.into())
.map_err(|e| Error::Predict {
source: Box::new(e) as _,
})?;
self.pipeline
.inverse_transform_forecast(&mut untransformed)?;
Ok(untransformed)
}

/// Produce in-sample forecasts, optionally including prediction intervals
/// at the given level.
pub fn predict_in_sample(&self, level: impl Into<Option<f64>>) -> Result<Forecast> {
self.fitted()?
let mut untransformed = self
.fitted()?
.predict_in_sample(level.into())
.map_err(|e| Error::Predict {
source: Box::new(e) as _,
})
.map(|f| self.transforms.inverse_transform(f))
})?;
self.pipeline
.inverse_transform_forecast(&mut untransformed)?;
Ok(untransformed)
}
}

#[cfg(test)]
mod test {
use itertools::{Itertools, MinMaxResult};

use augurs::mstl::{MSTLModel, NaiveTrend};
use augurs_testing::assert_all_close;

use crate::transforms::MinMaxScaleParams;
use crate::transforms::{BoxCox, LinearInterpolator, Logit, MinMaxScaler, YeoJohnson};

use super::*;

fn assert_approx_eq(a: f64, b: f64) -> bool {
if a.is_nan() && b.is_nan() {
return true;
}
(a - b).abs() < 0.001
}

fn assert_all_approx_eq(a: &[f64], b: &[f64]) {
if a.len() != b.len() {
assert_eq!(a, b);
}
for (ai, bi) in a.iter().zip(b) {
if !assert_approx_eq(*ai, *bi) {
assert_eq!(a, b);
}
}
}

#[test]
fn test_forecaster() {
let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
let MinMaxResult::MinMax(min, max) = data
.iter()
.copied()
.minmax_by(|a, b| a.partial_cmp(b).unwrap())
else {
unreachable!();
};
let transforms = vec![
Transform::linear_interpolator(),
Transform::min_max_scaler(MinMaxScaleParams::new(min - 1e-3, max + 1e-3)),
Transform::logit(),
LinearInterpolator::new().boxed(),
MinMaxScaler::new().boxed(),
Logit::new().boxed(),
];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_approx_eq(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]);
assert_all_close(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]);
}

#[test]
fn test_forecaster_power_positive() {
let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
let got = Transform::power_transform(data);
assert!(got.is_ok());
let transforms = vec![got.unwrap()];
let transforms = vec![BoxCox::new().boxed()];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_approx_eq(
assert_all_close(
&forecasts.point,
&[
5.084499064884572,
Expand All @@ -149,14 +126,12 @@ mod test {
#[test]
fn test_forecaster_power_non_positive() {
let data = &[0.0, 2.0, 3.0, 4.0, 5.0];
let got = Transform::power_transform(data);
assert!(got.is_ok());
let transforms = vec![got.unwrap()];
let transforms = vec![YeoJohnson::new().boxed()];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_approx_eq(
assert_all_close(
&forecasts.point,
&[
5.205557727170964,
Expand Down
3 changes: 1 addition & 2 deletions crates/augurs-forecaster/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub mod transforms;
pub use data::Data;
pub use error::Error;
pub use forecaster::Forecaster;
pub use transforms::Transform;
pub(crate) use transforms::Transforms;
pub use transforms::{Pipeline, Transform};

type Result<T> = std::result::Result<T, Error>;
Loading

0 comments on commit 0d6e6a6

Please sign in to comment.