Skip to content

shogun-toolbox/shogun-rust

Repository files navigation

shogun-rust

This is a Rust crate with bindings to the Shogun machine learning framework.

Note: this crate is in very early development and only supports a very limited part of the Shogun library.
Note: this is just a Rust wrapper for the shogun C++ library so the internals/API are not very Rust-like.

More information about the design can be found here.

Build

Assumes you have shogun-static installed locally, as well as spdlog. If not found CMake will throw an error.

To build simply:

cargo build

And then from another crate:

extern crate shogun;

Example

Basic API

use shogun::shogun::{Kernel, Version};

fn main() {
    let version = Version::new();
    println!("Shogun version {}", version.main_version().unwrap());

    // shogun-rust supports Shogun's factory functions
    let k = match Kernel::new("GaussianKernel") {
        Ok(obj) => obj,
        Err(msg) => {
            panic!("No can do: {}", msg);
        },
    };

    // also supports put
    match k.put("log_width", &1.0) {
        Err(msg) => println!("Failed to put value."),
        _ => (),
    }

    // and get
    match k.get("log_width") {
        Ok(value) => match value.downcast_ref::<f64>() {
            Some(fvalue) => println!("GaussianKernel::log_width: {}", fvalue),
            None => println!("GaussianKernel::log_width not of type f64"),
        },
        Err(msg) => panic!("{}", msg),
    }
}

Training a Random Forest

let f_feats_train = File::read_csv("classifier_4class_2d_linear_features_train.dat".to_string())?;
let f_feats_test = File::read_csv("classifier_4class_2d_linear_features_test.dat".to_string())?;
let f_labels_train = File::read_csv("classifier_4class_2d_linear_labels_train.dat".to_string())?;
let f_labels_test = File::read_csv("classifier_4class_2d_linear_labels_test.dat".to_string())?;

let features_train = Features::from_file(&f_feats_train)?;
let features_test = Features::from_file(&f_feats_test)?;
let labels_train = Labels::from_file(&f_labels_train)?;
let labels_test = Labels::from_file(&f_labels_test)?;

let mut rand_forest = Machine::new("RandomForest")?;
let m_vote = CombinationRule::new("MajorityVote")?;

rand_forest.put("labels", &labels_train)?;
rand_forest.put("num_bags", &100)?;
rand_forest.put("combination_rule", &m_vote)?;
rand_forest.put("seed", &1)?;

rand_forest.train(&features_train)?;

let predictions = rand_forest.apply(&features_test)?;

let acc = Evaluation::new("MulticlassAccuracy")?;
rand_forest.put("oob_evaluation_metric", &acc)?;
let accuracy = acc.evaluate(&predictions, &labels_test)?;

println!("Model accuracy: {}", accuracy);