Giter Site home page Giter Site logo

sramshetty / mixture-of-depths Goto Github PK

View Code? Open in Web Editor NEW
31.0 3.0 3.0 46 KB

An unofficial implementation of "Mixture-of-Depths: Dynamically allocating compute in transformer-based language models"

License: Other

Jupyter Notebook 46.03% Python 53.97%

mixture-of-depths's Introduction

Mixture of Depths

An unofficial implementation of "Mixture-of-Depths: Dynamically allocating compute in transformer-based language models"

Setup

  • First follow instructions for setting up your environment for Llama 2 here.
  • Then:
pip install einops

Details

  • Implementing MoD in Llama 2

  • Follow paper's configuration with some assumptions.

    • Route every other layer
    • Training configurations for both causal inference methods proposed
  • Notes on auxiliary router for causal inference:

    • Currently, we train it separately after MoD Llama is trained.
    • Simple task as we achieve high token prediction accuracy quickly, which is further simplified by using a simple dataset.
  • MoD_training.ipynb demonstrates training and was used for the results below.

  • MoD_sampling.ipynb demonstrates generation with each method.

Results

  • 50 million parameter model
    • C4
      • Baseline after 1 epoch:
        • Loss: 3.73
        • Samples/sec: 6.79
      • MoD w/ Auxiliary Loss after 1 epoch:
        • Loss: 3.81
        • Samples/sec: 8.15
      • MoD w/ Auxiliary Router after 1 epoch:
        • Loss: 4.19
        • Samples/sec: 7.64
    • Tiny Stories
      • Baseline after 5 epochs:
        • Loss: 2.46
        • Samples/sec: 11.22
      • MoD w/ Auxiliary Loss after 5 epochs:
        • Loss: 2.55
        • Samples/sec: 11.33
      • MoD w/ Auxiliary Router after 5 epochs:
        • Loss: 2.48
        • Auxilairy Router Causal Loss: 0.15
        • Samples/sec: 11.54

TODO

  • Validate
  • Sampling methods
    • Auxiliary loss
    • "Second" router

Citations

@misc{raposo2024mixtureofdepths,
    title={Mixture-of-Depths: Dynamically allocating compute in transformer-based language models}, 
    author={David Raposo and Sam Ritter and Blake Richards and Timothy Lillicrap and Peter Conway Humphreys and Adam Santoro},
    year={2024},
    eprint={2404.02258},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.