Giter Site home page Giter Site logo

rmitsuboshi / miniboosts Goto Github PK

View Code? Open in Web Editor NEW
25.0 1.0 3.0 4.17 MB

A collection of boosting algorithms written in Rust 🦀

License: MIT License

Rust 97.21% Jupyter Notebook 2.79%
boosting boosting-algorithm machine-learning adaboost rust adaboost-star corrective-erlpboost erlpboost lpboost smoothboost

miniboosts's Introduction

Documentation

MiniBoosts is a library for boosting algorithm developers.
Boosting is a repeated game between a Booster and a Weak Learner.

For each round of the game,

  1. The Booster chooses a distribution over training examples,
  2. Then the Weak Learner chooses a hypothesis (function) whose accuracy w.r.t. the distribution is slightly better than random guessing.

After sufficient rounds, the Booster outputs a hypothesis that performs significantly better on training examples.

Some Booster need to enable extended flag in Cargo.toml like this:

minibosts = { version = "0.3.3", features = ["extended"] }

These boosting algorithms use Gurobi to compute a distribution over training examples. Thanks to Gurobi, you can use the extended feature for free if you are a student.

BOOSTER FEATURE FLAG
AdaBoost
by Freund and Schapire, 1997
GBM (Gradient Boosting Machine),
by Jerome H. Friedman, 2001
LPBoost
by Demiriz, Bennett, and Shawe-Taylor, 2002
extended
SmoothBoost
by Servedio, 2003
AdaBoostV
by Rätsch and Warmuth, 2005
TotalBoost
by Warmuth, Liao, and Rätsch, 2006
extended
SoftBoost
by Warmuth, Glocer, and Rätsch, 2007
extended
ERLPBoost
by Warmuth and Glocer, and Vishwanathan, 2008
extended
CERLPBoost (Corrective ERLPBoost)
by Shalev-Shwartz and Singer, 2010
extended
MLPBoost
by Mitsuboshi, Hatano, and Takimoto, 2022
extended
GraphSepBoost (Graph Separation Boosting)
by Alon, Gonen, Hazan, and Moran, 2023

If you invent a new boosting algorithm, you can introduce it by implementing Booster trait. See cargo doc -F extended --open for details.

Currently, no weak learners use Gurobi. So, you can use all weak learners without enabling extended flag.

WEAK LEARNER
Decision Tree
Regression Tree
A worst-case weak learner for LPBoost
Gaussian Naive Bayes
Neural Network (Experimental)

Why MiniBoosts?

If you write a paper about boosting algorithms, you need to compare your algorithm against others. At this point, some issues arise.

  • Some boosting algorithms, such as LightGBM or XGBoost, are implemented and available for free. These are very easy to use in Python3 but hard to compare to other algorithms since they are implemented in C++ internally. Implementing your algorithm in Python3 makes the running time comparison unfair (Python3 is significantly slow compared to C++). However, implementing it in C++ is extremely hard (based on my experience).
  • Most boosting algorithms are designed for a decision-tree weak learner even though the boosting protocol does not demand.
  • There is no implementation for margin optimization boosting algorithms. Margin optimization is a better goal than empirical risk minimization in binary classification.

MiniBoosts is a crate to address the above issues.
This crate provides the followings.

  • Two main traits, named Booster and WeakLearner.
    • If you invent a new Boosting algorithm, all you need is to implement Booster.
    • If you invent a new Weak Learning algorithm, all you need is to implement WeakLearner.
  • Some famous boosting algorithms, including AdaBoost, LPBoost, ERLPBoost, etc.
  • Some weak learners, including Decision-Tree, Regression-Tree, etc.

MiniBoosts for reasearch

Sometimes, one wants to log each step of boosting procedure. You can use Logger struct to output log to .csv file, while printing the status like this:

Research feature example

See Research feature section for detail.

How to use

Documentation

Write the following to Cargo.toml.

miniboosts = { version = "0.3.3" }

If you want to use extended features, enable the flag:

miniboosts = { version = "0.3.3", features = ["extended"] }

Here is a sample code:

use miniboosts::prelude::*;


fn main() {
    // Set file name
    let file = "/path/to/input/data.csv";

    // Read the CSV file
    // The column named `class` corresponds to the labels (targets).
    let sample = SampleReader::new()
        .file(file)
        .has_header(true)
        .target_feature("class")
        .read()
        .unwrap();


    // Set tolerance parameter as `0.01`.
    let tol: f64 = 0.01;


    // Initialize Booster
    let mut booster = AdaBoost::init(&sample)
        .tolerance(tol); // Set the tolerance parameter.


    // Construct `DecisionTree` Weak Learner from `DecisionTreeBuilder`.
    let weak_learner = DecisionTreeBuilder::new(&sample)
        .max_depth(3) // Specify the max depth (default is 2)
        .criterion(Criterion::Twoing) // Choose the split rule.
        .build(); // Build `DecisionTree`.


    // Run the boosting algorithm
    // Each booster returns a combined hypothesis.
    let f = booster.run(&weak_learner);


    // Get the batch prediction for all examples in `data`.
    let predictions = f.predict_all(&sample);


    // You can predict the `i`th instance.
    let i = 0_usize;
    let prediction = f.predict(&sample, i);

    // You can convert the hypothesis `f` to `String`.
    let s = serde_json::to_string(&f);
}

If you use boosting for soft margin optimization, initialize booster like this:

let n_sample = sample.shape().0; // Get the number of training examples
let nu = n_sample as f64 * 0.2; // Set the upper-bound of the number of outliers.
let lpboost = LPBoost::init(&sample)
    .tolerance(tol)
    .nu(nu); // Set a capping parameter.

Note that the capping parameter must satisfies 1 <= nu && nu <= n_sample.

Research feature

This crate can output a CSV file for such values in each step.

Here is an example:

use miniboosts::prelude::*;
use miniboosts::{
    Logger,
    LoggerBuilder,
    SoftMarginObjective,
};


// Define a loss function
fn zero_one_loss<H>(sample: &Sample, f: &H) -> f64
    where H: Classifier
{
    let n_sample = sample.shape().0 as f64;

    let target = sample.target();

    f.predict_all(sample)
        .into_iter()
        .zip(target.into_iter())
        .map(|(fx, &y)| if fx != y as i64 { 1.0 } else { 0.0 })
        .sum::<f64>()
        / n_sample
}


fn main() {
    // Read the training data
    let path = "/path/to/train/data.csv";
    let train = SampleReader::new()
        .file(path)
        .has_header(true)
        .target_feature("class")
        .read()
        .unwrap();

    // Set some parameters used later.
    let n_sample = train.shape().0 as f64;
    let nu = 0.01 * n_sample;


    // Read the test data
    let path = "/path/to/test/data.csv";
    let test = SampleReader::new()
        .file(path)
        .has_header(true)
        .target_feature("class")
        .read()
        .unwrap();


    let booster = LPBoost::init(&train)
        .tolerance(0.01)
        .nu(nu);

    let weak_learner = DecisionTreeBuilder::new(&train)
        .max_depth(2)
        .criterion(Criterion::Entropy)
        .build();

    // Set the objective function.
    // One can use your own function by implementing ObjectiveFunction trait.
    let objective = SoftMarginObjective::new(nu);

    let mut logger = LoggerBuilder::new()
        .booster(booster)
        .weak_learner(tree)
        .train_sample(&train)
        .test_sample(&test)
        .objective_function(objective)
        .loss_function(zero_one_loss)
        .time_limit_as_secs(120) // Terminate after 120 seconds
        .print_every(10)         // Print log every 10 rounds.
        .build();

    // Each line of `lpboost.csv` contains the following four information:
    // Objective value, Train loss, Test loss, Time per iteration
    // The returned value `f` is the combined hypothesis.
    let f = logger.run("logfile.csv")
        .expect("Failed to logging");
}

Others

  • Currently, this crate mainly supports boosting algorithms for binary classification.
  • Some boosting algorithms use Gurobi optimizer, so you must acquire a license to use this library. If you have the license, you can use these boosting algorithms (boosters) by specifying features = ["extended"] in Cargo.toml. The compilation fails if you try to use the extended feature without a Gurobi license.
  • One can log your algorithm by implementing Research trait.
  • Run cargo doc -F extended --open to see more information.
  • GraphSepBoost only supports the aggregation rule shown in Lemma 4.2 of their paper.

Future work

miniboosts's People

Contributors

rmitsuboshi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.