Skip to content

Commit

Permalink
feat: add standard scaler transform (#204)
Browse files Browse the repository at this point in the history
This PR adds a standard scaler transform to the augurs-forecaster
transforms. It's similar to the existing `MinMaxScaler`, but instead
scales the data such that it has a mean of 0 and a standard deviation
of 1. This is similar to the `StandardScaler` in scikit-learn.
  • Loading branch information
sd2k authored Dec 18, 2024
1 parent d4a3f80 commit 96739a5
Showing 1 changed file with 185 additions and 0 deletions.
185 changes: 185 additions & 0 deletions crates/augurs-forecaster/src/transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub enum Transform {
LinearInterpolator,
/// Min-max scaling.
MinMaxScaler(MinMaxScaleParams),
/// Standard scaling.
StandardScaler(StandardScaleParams),
/// Logit transform.
Logit,
/// Log transform.
Expand Down Expand Up @@ -92,6 +94,22 @@ impl Transform {
Self::MinMaxScaler(min_max_params)
}

/// Create a new standard scaler.
///
/// This scaler standardizes features by removing the mean and scaling to unit variance.
///
/// The standard score of a sample x is calculated as:
///
/// ```text
/// z = (x - u) / s
/// ```
///
/// where u is the mean and s is the standard deviation in the provided
/// `StandardScaleParams`.
pub fn standard_scaler(scale_params: StandardScaleParams) -> Self {
Self::StandardScaler(scale_params)
}

/// Create a new logit transform.
///
/// This transform applies the logit function to each item.
Expand Down Expand Up @@ -173,6 +191,7 @@ impl Transform {
match self {
Self::LinearInterpolator => Box::new(input.interpolate(LinearInterpolator::default())),
Self::MinMaxScaler(params) => Box::new(input.min_max_scale(params)),
Self::StandardScaler(params) => Box::new(input.standard_scale(params)),
Self::Logit => Box::new(input.logit()),
Self::Log => Box::new(input.log()),
Self::BoxCox { lambda } => Box::new(input.box_cox(*lambda)),
Expand Down Expand Up @@ -202,6 +221,7 @@ impl Transform {
match self {
Self::LinearInterpolator => Box::new(input),
Self::MinMaxScaler(params) => Box::new(input.inverse_min_max_scale(params)),
Self::StandardScaler(params) => Box::new(input.inverse_standard_scale(params)),
Self::Logit => Box::new(input.logistic()),
Self::Log => Box::new(input.exp()),
Self::BoxCox { lambda } => Box::new(input.inverse_box_cox(*lambda)),
Expand Down Expand Up @@ -354,6 +374,126 @@ trait InverseMinMaxScaleExt: Iterator<Item = f64> {

impl<T> InverseMinMaxScaleExt for T where T: Iterator<Item = f64> {}

/// Parameters for the standard scaler.
#[derive(Debug, Clone)]
pub struct StandardScaleParams {
/// The mean of the data.
pub mean: f64,
/// The standard deviation of the data.
pub std_dev: f64,
}

impl StandardScaleParams {
/// Create a new `StandardScaleParams` with the given mean and standard deviation.
pub fn new(mean: f64, std_dev: f64) -> Self {
Self { mean, std_dev }
}

/// Create a new `StandardScaleParams` from the given data.
///
/// Note: this uses Welford's online algorithm to compute mean and variance in a single pass,
/// since we only have an iterator. The standard deviation is calculated using the
/// biased estimator, for parity with the [scikit-learn implementation][sklearn].
///
/// [sklearn]: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html
pub fn from_data<T>(data: T) -> Self
where
T: Iterator<Item = f64>,
{
// Use Welford's online algorithm to compute mean and variance in a single pass,
// since we only have an iterator.
let mut count = 0_u64;
let mut mean = 0.0;
let mut m2 = 0.0;

for x in data {
count += 1;
let delta = x - mean;
mean += delta / count as f64;
let delta2 = x - mean;
m2 += delta * delta2;
}

// Handle empty iterator case
if count == 0 {
return Self::new(0.0, 1.0);
}

// Calculate standard deviation
let std_dev = (m2 / count as f64).sqrt();

Self { mean, std_dev }
}
}

/// Iterator adapter that scales each item using the given mean and standard deviation,
/// so that (assuming the adapter was created using the same data), the output items
/// have zero mean and unit standard deviation.
#[derive(Debug, Clone)]
struct StandardScale<T> {
inner: T,
mean: f64,
std_dev: f64,
}

impl<T> Iterator for StandardScale<T>
where
T: Iterator<Item = f64>,
{
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|x| (x - self.mean) / self.std_dev)
}
}

trait StandardScaleExt: Iterator<Item = f64> {
fn standard_scale(self, params: &StandardScaleParams) -> StandardScale<Self>
where
Self: Sized,
{
StandardScale {
inner: self,
mean: params.mean,
std_dev: params.std_dev,
}
}
}

impl<T> StandardScaleExt for T where T: Iterator<Item = f64> {}

/// Iterator adapter that applies the inverse standard scaling transformation.
#[derive(Debug, Clone)]
struct InverseStandardScale<T> {
inner: T,
mean: f64,
std_dev: f64,
}

impl<T> Iterator for InverseStandardScale<T>
where
T: Iterator<Item = f64>,
{
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|x| (x * self.std_dev) + self.mean)
}
}

trait InverseStandardScaleExt: Iterator<Item = f64> {
fn inverse_standard_scale(self, params: &StandardScaleParams) -> InverseStandardScale<Self>
where
Self: Sized,
{
InverseStandardScale {
inner: self,
mean: params.mean,
std_dev: params.std_dev,
}
}
}

impl<T> InverseStandardScaleExt for T where T: Iterator<Item = f64> {}

// Logit and logistic functions.

/// Returns the logistic function of the given value.
Expand Down Expand Up @@ -793,6 +933,51 @@ mod test {
assert_all_close(&expected, &actual);
}

#[test]
fn standard_scale() {
let data = vec![1.0, 2.0, 3.0];
let params = StandardScaleParams::new(2.0, 1.0); // mean=2, std=1
let expected = vec![-1.0, 0.0, 1.0];
let actual: Vec<_> = data.into_iter().standard_scale(&params).collect();
assert_all_close(&expected, &actual);
}

#[test]
fn inverse_standard_scale() {
let data = vec![-1.0, 0.0, 1.0];
let params = StandardScaleParams::new(2.0, 1.0); // mean=2, std=1
let expected = vec![1.0, 2.0, 3.0];
let actual: Vec<_> = data.into_iter().inverse_standard_scale(&params).collect();
assert_all_close(&expected, &actual);
}

#[test]
fn standard_scale_params_from_data() {
// Test case 1: Simple sequence
let data = vec![1.0, 2.0, 3.0];
let params = StandardScaleParams::from_data(data.into_iter());
assert_approx_eq!(params.mean, 2.0);
assert_approx_eq!(params.std_dev, 0.816496580927726);

// Test case 2: More complex data
let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
let params = StandardScaleParams::from_data(data.into_iter());
assert_approx_eq!(params.mean, 5.0);
assert_approx_eq!(params.std_dev, 2.0);

// Test case 3: Empty iterator should return default values
let data: Vec<f64> = vec![];
let params = StandardScaleParams::from_data(data.into_iter());
assert_approx_eq!(params.mean, 0.0);
assert_approx_eq!(params.std_dev, 1.0);

// Test case 4: Single value
let data = vec![42.0];
let params = StandardScaleParams::from_data(data.into_iter());
assert_approx_eq!(params.mean, 42.0);
assert_approx_eq!(params.std_dev, 0.0); // technically undefined, but we return 0
}

#[test]
fn min_max_scale_params_from_data() {
let data = [1.0, 2.0, f64::NAN, 3.0];
Expand Down

0 comments on commit 96739a5

Please sign in to comment.