Skip to content

Commit

Permalink
Implementing LogisticRegression (#224)
Browse files Browse the repository at this point in the history
* Adding new module, classical

* Implementing LinearRegression

* Adding new example for LinearRegression classical model

* Updating stuff in classical module

* Setting gradient_descent as private

* Adding rust docs

* Adding some comments

* Adding LogisticRegression

* Updating stuff

* Adding lincese
  • Loading branch information
mjovanc authored Jan 23, 2025
1 parent a3d2f16 commit 7a4305b
Show file tree
Hide file tree
Showing 9 changed files with 325 additions and 9 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ members = [
"examples/image_classification/cifar10",
"examples/image_classification/mnist",
"examples/image_classification/imagenet_v2",
"examples/classical/linear_regression",
"examples/classical/linear_regression",
"examples/classical/logistic_regression",

]
resolver = "2"

Expand Down
66 changes: 63 additions & 3 deletions delta/src/classical/classification.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
//! BSD 3-Clause License
//!
//! Copyright (c) 2025, BlackPortal ○
//!
//! Redistribution and use in source and binary forms, with or without
//! modification, are permitted provided that the following conditions are met:
//!
//! 1. Redistributions of source code must retain the above copyright notice, this
//! list of conditions and the following disclaimer.
//!
//! 2. Redistributions in binary form must reproduce the above copyright notice,
//! this list of conditions and the following disclaimer in the documentation
//! and/or other materials provided with the distribution.
//!
//! 3. Neither the name of the copyright holder nor the names of its
//! contributors may be used to endorse or promote products derived from
//! this software without specific prior written permission.
//!
//! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
//! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
//! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
//! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
//! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
//! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
//! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
//! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
use ndarray::{Array1, Array2};

use crate::classical::{calculate_loss, gradient_descent};
use crate::classical::{
batch_gradient_descent, calculate_log_loss, calculate_mse_loss, logistic_gradient_descent,
};

use super::Classical;

Expand All @@ -9,6 +40,11 @@ pub struct LinearRegression {
bias: f64,
}

pub struct LogisticRegression {
weights: Array1<f64>,
bias: f64,
}

impl Classical for LinearRegression {
fn new() -> Self {
LinearRegression { weights: Array1::zeros(1), bias: 0.0 }
Expand All @@ -17,10 +53,10 @@ impl Classical for LinearRegression {
fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>, learning_rate: f64, epochs: usize) {
for _ in 0..epochs {
let predictions = self.predict(x);
let loss = calculate_loss(&predictions, y);
let loss = calculate_mse_loss(&predictions, y);
// Using Batch Gradient Descent here, we might want the user to have the option
// to change optimizer such as SGD, Adam etc
let gradients = gradient_descent(x, y, &self.weights, self.bias);
let gradients = batch_gradient_descent(x, y, &self.weights, self.bias);

self.weights = &self.weights - &(gradients.0 * learning_rate);
self.bias -= gradients.1 * learning_rate;
Expand All @@ -33,3 +69,27 @@ impl Classical for LinearRegression {
x.dot(&self.weights) + self.bias
}
}

impl Classical for LogisticRegression {
fn new() -> Self {
LogisticRegression { weights: Array1::zeros(1), bias: 0.0 }
}

fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>, learning_rate: f64, epochs: usize) {
for _ in 0..epochs {
let predictions = self.predict(x);
let loss = calculate_log_loss(&predictions, y);
let gradients = logistic_gradient_descent(x, y, &self.weights, self.bias);

self.weights = &self.weights - &(gradients.0 * learning_rate);
self.bias -= gradients.1 * learning_rate;

println!("Loss: {}", loss);
}
}

fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
let linear_output = x.dot(&self.weights) + self.bias;
linear_output.mapv(|x| 1.0 / (1.0 + (-x).exp()))
}
}
28 changes: 28 additions & 0 deletions delta/src/classical/clustering.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//! BSD 3-Clause License
//!
//! Copyright (c) 2025, BlackPortal ○
//!
//! Redistribution and use in source and binary forms, with or without
//! modification, are permitted provided that the following conditions are met:
//!
//! 1. Redistributions of source code must retain the above copyright notice, this
//! list of conditions and the following disclaimer.
//!
//! 2. Redistributions in binary form must reproduce the above copyright notice,
//! this list of conditions and the following disclaimer in the documentation
//! and/or other materials provided with the distribution.
//!
//! 3. Neither the name of the copyright holder nor the names of its
//! contributors may be used to endorse or promote products derived from
//! this software without specific prior written permission.
//!
//! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
//! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
//! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
//! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
//! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
//! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
//! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
//! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 changes: 28 additions & 0 deletions delta/src/classical/dimensionality_reduction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//! BSD 3-Clause License
//!
//! Copyright (c) 2025, BlackPortal ○
//!
//! Redistribution and use in source and binary forms, with or without
//! modification, are permitted provided that the following conditions are met:
//!
//! 1. Redistributions of source code must retain the above copyright notice, this
//! list of conditions and the following disclaimer.
//!
//! 2. Redistributions in binary form must reproduce the above copyright notice,
//! this list of conditions and the following disclaimer in the documentation
//! and/or other materials provided with the distribution.
//!
//! 3. Neither the name of the copyright holder nor the names of its
//! contributors may be used to endorse or promote products derived from
//! this software without specific prior written permission.
//!
//! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
//! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
//! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
//! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
//! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
//! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
//! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
//! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
134 changes: 131 additions & 3 deletions delta/src/classical/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
//! BSD 3-Clause License
//!
//! Copyright (c) 2025, BlackPortal ○
//!
//! Redistribution and use in source and binary forms, with or without
//! modification, are permitted provided that the following conditions are met:
//!
//! 1. Redistributions of source code must retain the above copyright notice, this
//! list of conditions and the following disclaimer.
//!
//! 2. Redistributions in binary form must reproduce the above copyright notice,
//! this list of conditions and the following disclaimer in the documentation
//! and/or other materials provided with the distribution.
//!
//! 3. Neither the name of the copyright holder nor the names of its
//! contributors may be used to endorse or promote products derived from
//! this software without specific prior written permission.
//!
//! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
//! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
//! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
//! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
//! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
//! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
//! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
//! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
pub mod classification;
pub mod clustering;
pub mod dimensionality_reduction;
pub mod regression;

pub use classification::LinearRegression;
pub use classification::LogisticRegression;

use ndarray::{Array1, Array2};

Expand Down Expand Up @@ -64,13 +94,68 @@ pub trait Classical {
/// # Returns
///
/// Returns a `f64` representing the Mean Squared Error loss.
pub fn calculate_loss(predictions: &Array1<f64>, actuals: &Array1<f64>) -> f64 {
pub fn calculate_mse_loss(predictions: &Array1<f64>, actuals: &Array1<f64>) -> f64 {
let m = predictions.len() as f64;
let diff = predictions - actuals;
(diff.mapv(|x| x.powi(2)).sum()) / m
}

/// Performs gradient descent to compute the gradients for weights and bias.
/// Calculates the Cross-Entropy Loss (Log Loss) for Logistic Regression.
///
/// This function computes the log loss (also known as binary cross-entropy),
/// which is a commonly used loss function for binary classification problems.
/// It measures how well the predicted probabilities match the true labels.
/// The log loss penalizes wrong predictions with higher confidence, and rewards
/// correct predictions with higher confidence.
///
/// # Parameters:
/// - `predictions`: A reference to an `Array1<f64>` representing the predicted
/// probabilities for the positive class (values between 0 and 1).
/// - `actuals`: A reference to an `Array1<f64>` representing the actual labels,
/// where each label is either 0 or 1.
///
/// # Returns:
/// A `f64` value representing the average log loss across all samples in the dataset.
pub fn calculate_log_loss(predictions: &Array1<f64>, actuals: &Array1<f64>) -> f64 {
let m = predictions.len() as f64;
predictions
.iter()
.zip(actuals.iter())
.map(|(p, y)| {
let p = p.clamp(1e-15, 1.0 - 1e-15);
-(y * p.ln() + (1.0 - y) * (1.0 - p).ln())
})
.sum::<f64>()
/ m
}

/// Calculates the accuracy of the predictions.
///
/// This function computes the accuracy of the model by comparing the predicted
/// class labels (0 or 1) with the actual class labels. The accuracy is calculated
/// as the proportion of correct predictions in the dataset.
///
/// The function converts the predicted probabilities into binary predictions
/// (using a threshold of 0.5), then compares them with the actual labels to compute
/// the accuracy.
///
/// # Parameters:
/// - `predictions`: A reference to an `Array1<f64>` representing the predicted
/// probabilities for the positive class (values between 0 and 1).
/// - `actuals`: A reference to an `Array1<f64>` representing the true class labels,
/// where each label is either 0 or 1.
///
/// # Returns:
/// A `f64` value representing the accuracy of the predictions as a proportion
/// of correct predictions (between 0 and 1).
pub fn calculate_accuracy(predictions: &Array1<f64>, actuals: &Array1<f64>) -> f64 {
let binary_predictions: Array1<f64> = predictions.mapv(|x| if x >= 0.5 { 1.0 } else { 0.0 });
(binary_predictions - actuals).mapv(|x| if x == 0.0 { 1.0 } else { 0.0 }).sum() as f64
/ actuals.len() as f64
}

/// Performs batch gradient descent to compute the gradients for weights and bias.
///
/// This function calculates the gradients for updating the model parameters in a linear regression
/// context. It computes the predictions based on the current weights and bias, then calculates
Expand All @@ -89,7 +174,7 @@ pub fn calculate_loss(predictions: &Array1<f64>, actuals: &Array1<f64>) -> f64 {
/// Returns a tuple where:
/// - The first element is an `Array1<f64>` representing the gradient for the weights.
/// - The second element is a `f64` representing the gradient for the bias.
fn gradient_descent(
fn batch_gradient_descent(
x: &Array2<f64>,
y: &Array1<f64>,
weights: &Array1<f64>,
Expand All @@ -103,3 +188,46 @@ fn gradient_descent(

(grad_weights, grad_bias)
}

/// Performs gradient descent for Logistic Regression.
///
/// This function computes the gradients for the weights and bias in the logistic
/// regression model using the sigmoid function applied to the predictions. The
/// gradients are calculated as the partial derivatives of the cost function with
/// respect to the model parameters.
///
/// The logistic regression model uses the sigmoid function to model the probability
/// of the positive class. The gradients are then used to update the model parameters
/// during training to minimize the cost (log loss).
///
/// # Parameters:
/// - `x`: A reference to an `Array2<f64>` representing the input data matrix,
/// where each row is a training example and each column is a feature.
/// - `y`: A reference to an `Array1<f64>` representing the actual labels (0 or 1)
/// for the training examples.
/// - `weights`: A reference to an `Array1<f64>` representing the model weights.
/// Each weight corresponds to a feature in the input data.
/// - `bias`: A `f64` value representing the model's bias term.
///
/// # Returns:
/// A tuple `(grad_weights, grad_bias)` where:
/// - `grad_weights`: An `Array1<f64>` representing the gradients of the weights.
/// - `grad_bias`: A `f64` value representing the gradient of the bias term.
fn logistic_gradient_descent(
x: &Array2<f64>,
y: &Array1<f64>,
weights: &Array1<f64>,
bias: f64,
) -> (Array1<f64>, f64) {
let predictions = x.dot(weights) + bias;
let m = x.shape()[0] as f64;

// Sigmoid function applied to predictions
let sigmoid_preds = predictions.mapv(|x| 1.0 / (1.0 + (-x).exp()));

// Gradients for weights and bias
let grad_weights = x.t().dot(&(sigmoid_preds.clone() - y)) / m;
let grad_bias = (sigmoid_preds - y).sum() / m;

(grad_weights, grad_bias)
}
28 changes: 28 additions & 0 deletions delta/src/classical/regression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//! BSD 3-Clause License
//!
//! Copyright (c) 2025, BlackPortal ○
//!
//! Redistribution and use in source and binary forms, with or without
//! modification, are permitted provided that the following conditions are met:
//!
//! 1. Redistributions of source code must retain the above copyright notice, this
//! list of conditions and the following disclaimer.
//!
//! 2. Redistributions in binary form must reproduce the above copyright notice,
//! this list of conditions and the following disclaimer in the documentation
//! and/or other materials provided with the distribution.
//!
//! 3. Neither the name of the copyright holder nor the names of its
//! contributors may be used to endorse or promote products derived from
//! this software without specific prior written permission.
//!
//! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
//! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
//! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
//! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
//! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
//! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
//! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
//! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4 changes: 2 additions & 2 deletions examples/classical/linear_regression/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use deltaml::{
classical::{Classical, LinearRegression, calculate_loss},
classical::{Classical, LinearRegression, calculate_mse_loss},
common::ndarray::{Array1, Array2},
};

Expand All @@ -24,6 +24,6 @@ async fn main() {
println!("Predictions for new data: {:?}", predictions);

// Calculate accuracy or loss for the test data for demonstration
let test_loss = calculate_loss(&model.predict(&x_data), &y_data);
let test_loss = calculate_mse_loss(&model.predict(&x_data), &y_data);
println!("Test Loss after training: {:.6}", test_loss);
}
9 changes: 9 additions & 0 deletions examples/classical/logistic_regression/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
name = "logistic_regression"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
deltaml = { path = "../../../delta" }
tokio = { workspace = true, features = ["full"] }
Loading

0 comments on commit 7a4305b

Please sign in to comment.