Skip to content

Commit

Permalink
Add fit method to Transform trait
Browse files Browse the repository at this point in the history
And provide `fit_transform` methods on `Transform` and `Pipeline`.

This makes the API more similar to scikit-learn. It also makes the
fitting of a transformation more explicit, as it was previously
happening implicitly when the `transform` method was first called.
  • Loading branch information
sd2k committed Dec 21, 2024
1 parent 435e6c8 commit c467695
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 54 deletions.
2 changes: 1 addition & 1 deletion crates/augurs-forecaster/src/forecaster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ where
/// Fit the model to the given time series.
pub fn fit<D: Data + Clone>(&mut self, y: D) -> Result<()> {
let mut y = y.as_slice().to_vec();
self.pipeline.transform(&mut y)?;
self.pipeline.fit_transform(&mut y)?;
self.fitted = Some(self.model.fit(&y).map_err(|e| Error::Fit {
source: Box::new(e) as _,
})?);
Expand Down
90 changes: 73 additions & 17 deletions crates/augurs-forecaster/src/transforms.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
/*!
Data transformations.
This module contains the `Transform` enum, which contains various
predefined transformations. The enum contains various methods for
creating new instances of the various transformations, as well as
the `transform` and `inverse_transform` methods, which allow you to
apply a transformation to a time series and its inverse, respectively.
*/

// Note: implementations of the various transforms are in the
Expand All @@ -27,38 +21,90 @@ pub use interpolate::{InterpolateExt, LinearInterpolator};
pub use power::{BoxCox, YeoJohnson};
pub use scale::{MinMaxScaler, StandardScaleParams, StandardScaler};

/// Transforms and Transform implementations.
/// A transformation pipeline.
///
/// The `Transforms` struct is a collection of `Transform` instances that can be applied to a time series.
/// The `Transform` enum represents a single transformation that can be applied to a time series.
/// The `Pipeline` struct is a collection of heterogeneous `Transform` instances
/// that can be applied to a time series. Calling the `fit` or `fit_transform`
/// methods will fit each transformation to the output of the previous one in turn
/// starting by passing the input to the first transformation.
#[derive(Debug, Default)]
pub struct Pipeline(Vec<Box<dyn Transform>>);
pub struct Pipeline {
transforms: Vec<Box<dyn Transform>>,
is_fitted: bool,
}

impl Pipeline {
/// Create a new `Pipeline` with the given transforms.
pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
Self(transforms)
Self {
transforms,
is_fitted: false,
}
}

/// Fit the transformations to the given time series.
///
/// Prefer `fit_transform` if possible, as it avoids copying the input.
pub fn fit(&mut self, input: &[f64]) -> Result<(), Error> {
// Copy the input to avoid mutating the original.
// We need to do this so we can call `fit_transform` on each
// transformation in the pipeline without mutating the input.
// This is required because each transformation needs to be
// fit after previous transformations have been applied.
let mut input = input.to_vec();
// Reuse `fit_transform_inner`, and just discard the result.
self.fit_transform_inner(&mut input)?;
Ok(())
}

/// Apply the transformations to the given time series.
pub fn transform(&mut self, input: &mut [f64]) -> Result<(), Error> {
for t in self.0.iter_mut() {
fn fit_transform_inner(&mut self, input: &mut [f64]) -> Result<(), Error> {
for t in self.transforms.iter_mut() {
t.fit_transform(input)?;
}
self.is_fitted = true;
Ok(())
}

/// Fit and transform the given time series.
///
/// This is equivalent to calling `fit` and then `transform` on the pipeline,
/// but is more efficient because it avoids copying the input.
pub fn fit_transform(&mut self, input: &mut [f64]) -> Result<(), Error> {
self.fit_transform_inner(input)?;
Ok(())
}

/// Apply the fitted transformations to the given time series.
///
/// # Errors
///
/// This function will return an error if the pipeline has not been fitted.
pub fn transform(&self, input: &mut [f64]) -> Result<(), Error> {
for t in self.transforms.iter() {
t.transform(input)?;
}
Ok(())
}

/// Apply the inverse transformations to the given time series.
///
/// # Errors
///
/// This function will return an error if the pipeline has not been fitted.
pub fn inverse_transform(&self, input: &mut [f64]) -> Result<(), Error> {
for t in self.0.iter().rev() {
for t in self.transforms.iter().rev() {
t.inverse_transform(input)?;
}
Ok(())
}

/// Apply the inverse transformations to the given forecast.
///
/// # Errors
///
/// This function will return an error if the pipeline has not been fitted.
pub(crate) fn inverse_transform_forecast(&self, forecast: &mut Forecast) -> Result<(), Error> {
for t in self.0.iter().rev() {
for t in self.transforms.iter().rev() {
t.inverse_transform(&mut forecast.point)?;
if let Some(intervals) = forecast.intervals.as_mut() {
t.inverse_transform(&mut intervals.lower)?;
Expand All @@ -71,12 +117,22 @@ impl Pipeline {

/// A transformation that can be applied to a time series.
pub trait Transform: fmt::Debug + Sync + Send {
/// Fit the transformation to the given time series.
fn fit(&mut self, data: &[f64]) -> Result<(), Error>;

/// Apply the transformation to the given time series.
fn transform(&mut self, data: &mut [f64]) -> Result<(), Error>;
fn transform(&self, data: &mut [f64]) -> Result<(), Error>;

/// Apply the inverse transformation to the given time series.
fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error>;

/// Fit the transformation to the given time series and then apply it.
fn fit_transform(&mut self, data: &mut [f64]) -> Result<(), Error> {
self.fit(data)?;
self.transform(data)?;
Ok(())
}

/// Create a boxed version of the transformation.
///
/// This is useful for creating a `Transform` instance that can be used as
Expand Down
12 changes: 10 additions & 2 deletions crates/augurs-forecaster/src/transforms/exp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ impl fmt::Debug for Logit {
}

impl Transform for Logit {
fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> {
fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
Ok(())
}

fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
data.iter_mut().for_each(|x| *x = logit(*x));
Ok(())
}
Expand Down Expand Up @@ -67,7 +71,11 @@ impl fmt::Debug for Log {
}

impl Transform for Log {
fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> {
fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
Ok(())
}

fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
data.iter_mut().for_each(|x| *x = f64::ln(*x));
Ok(())
}
Expand Down
6 changes: 5 additions & 1 deletion crates/augurs-forecaster/src/transforms/interpolate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ impl Interpolater for LinearInterpolator {
}

impl Transform for LinearInterpolator {
fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> {
fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
Ok(())
}

fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
let interpolated: Vec<_> = data.iter().copied().interpolate(*self).collect();
data.copy_from_slice(&interpolated);
Ok(())
Expand Down
22 changes: 18 additions & 4 deletions crates/augurs-forecaster/src/transforms/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,17 @@ impl Default for BoxCox {
}

impl Transform for BoxCox {
fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> {
fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
if self.lambda.is_nan() {
self.lambda = optimize_box_cox_lambda(data)?;
}
Ok(())
}

fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
if self.lambda.is_nan() {
return Err(Error::NotFitted);
}
for x in data.iter_mut() {
*x = box_cox(*x, self.lambda)?;
}
Expand Down Expand Up @@ -338,10 +345,17 @@ impl Default for YeoJohnson {
}

impl Transform for YeoJohnson {
fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> {
fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
if self.lambda.is_nan() {
self.lambda = optimize_yeo_johnson_lambda(data)?;
}
Ok(())
}

fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
if self.lambda.is_nan() {
return Err(Error::NotFitted);
}
for x in data.iter_mut() {
*x = yeo_johnson(*x, self.lambda)?;
}
Expand Down Expand Up @@ -427,7 +441,7 @@ mod test {
let lambda = 0.5;
let mut box_cox = BoxCox::new().with_lambda(lambda).unwrap();
let expected = vec![0.0, 0.8284271247461903, 1.4641016151377544];
box_cox.transform(&mut data).unwrap();
box_cox.fit_transform(&mut data).unwrap();
assert_all_close(&expected, &data);
}

Expand All @@ -445,7 +459,7 @@ mod test {
fn yeo_johnson_test() {
let mut data = vec![-1.0, 0.0, 1.0];
let lambda = 0.5;
let mut yeo_johnson = YeoJohnson::new().with_lambda(lambda).unwrap();
let yeo_johnson = YeoJohnson::new().with_lambda(lambda).unwrap();
let expected = vec![-1.2189514164974602, 0.0, 0.8284271247461903];
yeo_johnson.transform(&mut data).unwrap();
assert_all_close(&expected, &data);
Expand Down
60 changes: 31 additions & 29 deletions crates/augurs-forecaster/src/transforms/scale.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,33 +97,37 @@ impl MinMaxScaler {
self.params = Some(FittedMinMaxScalerParams::new(data_range, self.output_scale));
self
}
}

fn fit(&self, data: &[f64]) -> Result<FittedMinMaxScalerParams, Error> {
match data
impl Transform for MinMaxScaler {
/// Fit the scaler to the given data.
///
/// This will compute the min and max values of the data and store them
/// in the `params` field of the scaler.
fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
let params = match data
.iter()
.copied()
.minmax_by(|a, b| a.partial_cmp(b).unwrap())
{
e @ MinMaxResult::NoElements | e @ MinMaxResult::OneElement(_) => Err(e.into()),
MinMaxResult::MinMax(min, max) => Ok(FittedMinMaxScalerParams::new(
MinMax { min, max },
self.output_scale,
)),
}
e @ MinMaxResult::NoElements | e @ MinMaxResult::OneElement(_) => return Err(e.into()),
MinMaxResult::MinMax(min, max) => {
FittedMinMaxScalerParams::new(MinMax { min, max }, self.output_scale)
}
};
self.params = Some(params);
Ok(())
}
}

impl Transform for MinMaxScaler {
fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> {
let params = match &mut self.params {
Some(p) => p,
None => self.params.get_or_insert(self.fit(data)?),
};
/// Apply the scaler to the given data.
fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
let params = self.params.as_ref().ok_or(Error::NotFitted)?;
data.iter_mut()
.for_each(|x| *x = *x * params.scale_factor + params.offset);
Ok(())
}

/// Apply the inverse of the scaler to the given data.
fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error> {
let params = self.params.as_ref().ok_or(Error::NotFitted)?;
data.iter_mut()
Expand Down Expand Up @@ -212,7 +216,7 @@ impl StandardScaleParams {
///
/// let mut data = vec![1.0, 2.0, 3.0];
/// let mut scaler = StandardScaler::new();
/// scaler.transform(&mut data);
/// scaler.fit_transform(&mut data);
///
/// assert_eq!(data, vec![-1.224744871391589, 0.0, 1.224744871391589]);
/// ```
Expand All @@ -237,18 +241,16 @@ impl StandardScaler {
self.params = Some(params);
self
}

fn fit(&self, data: &[f64]) -> StandardScaleParams {
StandardScaleParams::from_data(data.iter().copied())
}
}

impl Transform for StandardScaler {
fn transform(&mut self, data: &mut [f64]) -> Result<(), Error> {
let params = match &mut self.params {
Some(p) => p,
None => self.params.get_or_insert(self.fit(data)),
};
fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
self.params = Some(StandardScaleParams::from_data(data.iter().copied()));
Ok(())
}

fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
let params = self.params.as_ref().ok_or(Error::NotFitted)?;
data.iter_mut()
.for_each(|x| *x = (*x - params.mean) / params.std_dev);
Ok(())
Expand All @@ -273,7 +275,7 @@ mod test {
let mut data = vec![1.0, 2.0, 3.0];
let expected = vec![0.0, 0.5, 1.0];
let mut scaler = MinMaxScaler::new();
scaler.transform(&mut data).unwrap();
scaler.fit_transform(&mut data).unwrap();
assert_all_close(&expected, &data);
}

Expand All @@ -282,7 +284,7 @@ mod test {
let mut data = vec![1.0, 2.0, 3.0];
let expected = vec![0.0, 5.0, 10.0];
let mut scaler = MinMaxScaler::new().with_scaled_range(0.0, 10.0);
scaler.transform(&mut data).unwrap();
scaler.fit_transform(&mut data).unwrap();
assert_all_close(&expected, &data);
}

Expand Down Expand Up @@ -313,7 +315,7 @@ mod test {
// not necessarily obvious.
let expected = vec![-1.224744871391589, 0.0, 1.224744871391589];
let mut scaler = StandardScaler::new(); // 2.0, 1.0); // mean=2, std=1
scaler.transform(&mut data).unwrap();
scaler.fit_transform(&mut data).unwrap();
assert_all_close(&expected, &data);
}

Expand All @@ -322,7 +324,7 @@ mod test {
let mut data = vec![1.0, 2.0, 3.0];
let expected = vec![-1.0, 0.0, 1.0];
let params = StandardScaleParams::new(2.0, 1.0); // mean=2, std=1
let mut scaler = StandardScaler::new().with_parameters(params);
let scaler = StandardScaler::new().with_parameters(params);
scaler.transform(&mut data).unwrap();
assert_all_close(&expected, &data);
}
Expand Down

0 comments on commit c467695

Please sign in to comment.