Skip to content

Commit

Permalink
feat!: switch transform to a trait (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k authored Dec 23, 2024
1 parent e3cc55e commit ad31a8d
Show file tree
Hide file tree
Showing 21 changed files with 1,028 additions and 1,003 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ jobs:

- name: Run cargo nextest
run: just test-all
# Run book tests before doc tests because otherwise there are multiple
# 'augurs' rlibs that the book could use as its augurs dependency.
- name: Run doc tests
run: just doctest

Expand Down
3 changes: 3 additions & 0 deletions book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@

# Contributing
- [Contributing](./contributing.md)

# Appendix
- [Migration Guide](./migrating.md)
17 changes: 10 additions & 7 deletions book/src/getting-started/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,28 @@ For more complex scenarios, you can use the `Forecaster` API which supports data
# extern crate augurs;
use augurs::{
ets::AutoETS,
forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform},
forecaster::{
transforms::{LinearInterpolator, Log, MinMaxScaler},
Forecaster, Transformer,
},
mstl::MSTLModel,
};

fn main() {
let data = &[1.0, 1.2, 1.4, 1.5, f64::NAN, 1.4, 1.2, 1.5, 1.6, 2.0, 1.9, 1.8];

// Set up model and transforms
// Set up model and transformers
let ets = AutoETS::non_seasonal().into_trend_model();
let mstl = MSTLModel::new(vec![2], ets);

let transforms = vec![
Transform::linear_interpolator(),
Transform::min_max_scaler(MinMaxScaleParams::from_data(data.iter().copied())),
Transform::log(),
let transformers = vec![
LinearInterpolator::new().boxed(),
MinMaxScaler::new().boxed(),
Log::new().boxed(),
];

// Create and fit forecaster
let mut forecaster = Forecaster::new(mstl).with_transforms(transforms);
let mut forecaster = Forecaster::new(mstl).with_transformers(transformers);
forecaster.fit(data).expect("model should fit");

// Generate forecasts
Expand Down
137 changes: 137 additions & 0 deletions book/src/migrating.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Migration guide

This guide will help you migrate from the previous version of `augurs` to the latest version.

## From 0.7 to 0.8

### Transformations

In version 0.8 the `augurs::forecaster::Transform` enum was removed and replaced with the
`augurs::forecaster::Transformer` trait, which closely mirrors the scikit-learn `Transformer`
API. The various `Transform` enum variants were replaced with the following `Transformer`
implementations, such as `augurs::forecaster::transforms::MinMaxScaler`. The new `Pipeline`
struct is a `Transformer` implementation that can be used to chain multiple transformations
together.

Whereas some transformations previously needed to be passed the data in the constructor, this
is now handled by the `fit` method of the `Transformer` trait, allowing the transformations
to be more lazy.

It also makes it possible to implement custom transformations by implementing the `Transformer`
trait.

Before:

```rust,ignore
# extern crate augurs;
use augurs::{
forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform},
mstl::{MSTLModel, NaiveTrend},
};
let transforms = vec![
Transform::linear_interpolator(),
Transform::min_max_scaler(MinMaxScaleParams::new(0.0, 1.0)),
Transform::log(),
];
// use the transforms in a forecaster:
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
```

After:

```rust
# extern crate augurs;
use augurs::{
forecaster::{
transforms::{LinearInterpolator, Log, MinMaxScaler},
Forecaster, Transformer,
},
mstl::{MSTLModel, NaiveTrend},
};

let transformers = vec![
LinearInterpolator::new().boxed(),
MinMaxScaler::new().with_scaled_range(0.0, 1.0).boxed(),
Log::new().boxed(),
];
// use the transformers in a forecaster:
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transformers(transformers);
```

## From 0.6 to 0.7

### Prophet

Version 0.7 made changes to the way that holidays are treated in the Prophet model ([PR #181](https://github.com/grafana/augurs/pull/181)).

In versions prior to 0.7, holidays were implicitly assumed to last 1 day each, starting and
ending at midnight UTC. This stemmed from how the Python API works: holidays are passed as
a column of dates in a pandas DataFrame.

In version 0.7, each holiday is instead specified using a list of `HolidayOccurrence`s, which
each have a start and end time represented as Unix timestamps. This allows you to specify
holidays more flexibly:

- holidays lasting 1 day from midnight to midnight UTC can be specified using `HolidayOccurrence::for_day`.
This is the equivalent of the previous behavior.
- holidays lasting 1 day in a non-UTC timezone can be specified using `HolidayOccurrence::for_day_in_tz`.
The second argument is the offset in seconds from UTC, which can be calculated manually or using
the `chrono::FixedOffset::local_minus_utc` method.
- holidays lasting for custom periods, such as sub-daily or multi-day periods, can be specified using
`HolidayOccurrence::new` with a start and end time in seconds since the Unix epoch.

In short, you can replace the following code:

```rust,ignore
# extern crate augurs;
# extern crate chrono;
use augurs::prophet::Holiday;
use chrono::{prelude::*, Utc};
let holiday_date = Utc.with_ymd_and_hms(2022, 6, 12, 0, 0, 0).unwrap().timestamp();
let holiday = Holiday::new(vec![holiday_date]);
```

with the following code:

```rust,ignore
# extern crate augurs;
# extern crate chrono;
use augurs::prophet::{Holiday, HolidayOccurrence};
use chrono::{prelude::*, Utc};
let holiday_date = Utc.with_ymd_and_hms(2022, 6, 12, 0, 0, 0).unwrap().timestamp();
let occurrence = HolidayOccurrence::for_day(holiday_date);
let holiday = Holiday::new(vec![occurrence]);
```

Or use `HolidayOccurrence::for_day_in_tz` to specify a holiday in a non-UTC timezone:

```rust,ignore
# extern crate augurs;
# extern crate chrono;
use augurs::prophet::{Holiday, HolidayOccurrence};
use chrono::{prelude::*, Utc};
let holiday_date = Utc.with_ymd_and_hms(2022, 6, 12, 0, 0, 0).unwrap().timestamp();
// This holiday lasts for 1 day in UTC+1.
let occurrence = HolidayOccurrence::for_day_in_tz(holiday_date, 3600);
let holiday = Holiday::new(vec![occurrence]);
```

Or use `HolidayOccurrence::new` to specify a holiday with a custom start and end time:

```rust,ignore
# extern crate augurs;
# extern crate chrono;
use augurs::prophet::{Holiday, HolidayOccurrence};
use chrono::{prelude::*, Utc};
let holiday_date = Utc.with_ymd_and_hms(2022, 6, 12, 0, 0, 0).unwrap().timestamp();
// This holiday lasts for 1 hour.
let occurrence = HolidayOccurrence::new(holiday_date, holiday_date + 3600);
let holiday = Holiday::new(vec![occurrence]);
```
1 change: 0 additions & 1 deletion crates/augurs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub mod prelude {

mod distance;
mod forecast;
pub mod interpolate;
mod traits;

use std::convert::Infallible;
Expand Down
17 changes: 10 additions & 7 deletions crates/augurs-forecaster/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ augurs-mstl = "*"
```rust
use augurs::{
ets::{AutoETS, trend::AutoETSTrendModel},
forecaster::{Forecaster, Transform, transforms::MinMaxScaleParams},
forecaster::{
Forecaster, Transformer,
transforms::{LinearInterpolator, Logit, MinMaxScaler},
},
mstl::MSTLModel
};

Expand All @@ -31,15 +34,15 @@ let data = &[
let ets = AutoETS::non_seasonal().into_trend_model();
let mstl = MSTLModel::new(vec![2], ets);

// Set up the transforms.
let transforms = vec![
Transform::linear_interpolator(),
Transform::min_max_scaler(MinMaxScaleParams::from_data(data.iter().copied())),
Transform::log(),
// Set up the transformers.
let transformers = vec![
LinearInterpolator::new().boxed(),
MinMaxScaler::new().boxed(),
Logit::new().boxed(),
];

// Create a forecaster using the transforms.
let mut forecaster = Forecaster::new(mstl).with_transforms(transforms);
let mut forecaster = Forecaster::new(mstl).with_transformers(transformers);

// Fit the forecaster. This will transform the training data by
// running the transforms in order, then fit the MSTL model.
Expand Down
10 changes: 10 additions & 0 deletions crates/augurs-forecaster/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use augurs_core::ModelError;

use crate::transforms;

/// Errors returned by this crate.
#[derive(Debug, thiserror::Error)]
pub enum Error {
Expand All @@ -18,4 +20,12 @@ pub enum Error {
/// The original error.
source: Box<dyn ModelError>,
},

/// An error occurred while running a transformation.
#[error("Transform error: {source}")]
Transform {
/// The original error.
#[from]
source: transforms::Error,
},
}
Loading

0 comments on commit ad31a8d

Please sign in to comment.