Giter Site home page Giter Site logo

factorizationmachines.jl's People

Contributors

btwardow avatar hydrotoast avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

factorizationmachines.jl's Issues

Implement ALS and MCMC

Implement ALS/MCMC similar to libfm and fastFM. This training method does not require a learning rate alpha. Both libfm and fastFM reuses a single class for both training methods due to common code.

For reference:

I have implemented the ALS algorithm in the past; however, I have not implemented an MCMC implementation before. I would proceed by first implementing the ALS algorithm and then attempt to implement the MCMC algorithm.

SGD Improvements: Minibatch and Scheduling

SGD can use two simple improvements:

  1. Minibatch sampling
  2. Scheduling

Minibatch sampling

Minibatching can help control the convergence/training time tradeoff (may be wrong about the former claim).

Not entirely certain of an optimized implementation for this; however, here's a sketch of a naive implementation given a matrix of feature vectors X and an vector labels y.

number_of_samples = 10000
minibatch_size = 100 # or minibatch_fraction instead?
sample_indices = sample_without_replacement(number_of_samples, minibatch_size)
X_minibatch = X[:, sample_indices]
y_minibatch = y[sample_indices]

yhat = fmPredict(fm, minibatch, fSum, fSumSqr)
mult = fmLoss(yhat, y)
fmSGDUpdate!(fm, alpha, minibatch_indices, X_minibatch, mult, fSum)

Some changes to consider:

  1. Fixed minibatch size vs. minibatch fraction of the dataset?
  2. Minibatch sampling algorithm sample_without_replacement; could this be implemented in an efficient manner? Perhaps it would be more efficient to implement a sliding window over the dataset e.g.
X_minibatch = X[:, i:i + minibatch_size - 1]
y_minibatch = y[i:i + minibatch_size - 1]

However, the minibatches will not vary per epoch in this scenario.

Scheduling

Set the learning rate alpha to decrease proportional to the square root of the current iteration sqrt(iteration). To prevent division by zero: note that iteration > 0.

alpha = alpha0 / sqrt(iteration)

Benchmarks against other implementations

Some other implementations to compare to:

Tasks:

  • Select 2-3 datasets for comparison
  • Setup experiments

Experiment approach

  1. Select a dataset and split it into training X_train, y_train and test X_test, y_test
  2. Download both libraries
  3. Train both libraries on X_train, y_train (and measure the training time)
  4. Verify that the test set evaluations are close enough on X_test, y_test
  5. Repeat the test 10 times

Implementing the experiment (up for discussion/alternatives)

  1. Write Benchmark script in Bash and use simple wall clock time for measurements
  2. Save script in a new benchmarks/ folder

Add Classification via Logistic Loss

For comparisons with other implementations:

For prediction (during training):

For training loss function:

Both implementations are structurally the same and use enums to distinguish between classification and regression tasks.

  • task = 0 implies regression
  • task = 1 implies classification

A brief roadmap:

  1. Decide on how the API should change for training; the prediction API should remain the same. Simple proposal: add an enum parameter to decide whether regression or classification should be used.
  2. Implement the corresponding branches for the training.

Alternative API ideas:

  1. Separate training tasks: fm_train_sgd_regression(...) and fm_train_sgd_classification(...)
  2. OOP: make the training tasks objects: FMRegression and FMClassification that are constructed with their own training parameters and support the interface fm_train_sgd to produce a single FMModel. The FMModel will reference the training task that produces it.

Personally, I am in favor of the (2) approach since it leads to fewer branches in the code.

Hyperparameter optimization utilities

There are several hyperparameters for factorization machines that are crucial to convergence (at least in the tests I've been running) especially in SGD.

  • alpha the learning rate
  • initMean and initStd for the Gaussian distribution used in latent vector initialization

A simple candidate utility that is easy to implement would be Random Search proposed by Bengio. scikit-learn has a reference implementation that we may use.

Implementation Plan

Still a work in progress

Hyperparameter optimization utilities may live in src/hyperparams.jl.

There are two important use cases that we may implement as two separate functions

  1. Analyzing the choice of hyperparameters and how it affects the evaluation
  2. Building an optimal model with the best hyperparameters

We use the Distributions package to define parameter distributions.

using Distributions

param_distributions = Dict(:alpha => Uniform(0.01, 1.0), :initStd => Gamma(1.0, 1.0))
num_samples = 100
result = fm_randomsearch(X, y, param_distributions = param_distributions, num_samples = num_samples)

fm = result.model
info(result.param_scores)

And a corresponding implementation sketch:

function fm_randomsearch(X, y; param_distributions = Dict(), num_samples = 10)
    param_scores = fill((0,0), 10)
    best_score = typemax(Float)
    best_model = null
    for i in 1:num_samples
        sample = [rand(dist) for dist in values(param_distributions)]
        model = fmTrain(X, y; sample...)
        score = evaluate(fm, X, y)
        if score > best_score
            best_model, best_score = model, score
        end
        push!(param_scores, (sample, score))
    end
    RandomSearchResult(best_model, param_scores)
end

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.