Giter Site home page Giter Site logo

r-nn's Introduction

r-nn

r-nn running inference for GPT-2


r-nn is a Tensor library with fully supported Automatic Differentiation with an API modelled after PyTorch using only the standard library (rand and rand-distr crates are used only for weight initalization). The core backbone of r-nn is modelled after micrograd and works on scalar values (modelled as Value objects). A Tensor is a collection of Value objects woth support for the common PyTorch operations. r-nn is best used for educational purposes - to understand how Tensor operations works and the simplicity of backpropagation. It is fully tested against reference PyTorch implementations.

Often, it is hard to visualise how gradients work for large matrices. It turns out that by simply thinking of a matrix's gradient as being another matrix where each element in the gradient matrix is localised to an element in the original matrix simplifies the intuition. By only implementing backward() for +, *, pow(), exp() and ln(), it turns out that this is sufficient for basically all operations (including complex ones like softmax and tanh).

For a matrix multiplication of A @ B, the local gradient of A is B transposed and vice versa. This is not particularly obvious at first and r-nn does not need to implement that at all but is still able to generalise for matrices of any arbitrary size.

Examples

Examples can be found in the /examples folder, with examples for training a Multi Layer Perceptron, Recurrent Neural Network and running inference on GPT2. However, it is highly recommended to use r-nn for specifically educational purposes.

Optimisations

For educational purposes, r-nn is written to be explicit, for example using the naive matrix multiplication algorithm (for most cases). It also saves all intermediate values. In the matrix multiplication example, this means it stores $~3N^3$ intermediate values in addition to the matrix product. Each Value object is wrapped in a Rc<RefCell> which adds ~40 bytes of overhead to a single 32-bit floating point.

Optimisations would thus involve moving away from thinking in terms of Scalar operations to Tensor values with a single pointer to each Tensor instead. Gradient manipulations would then happen at the Tensor level instead.

References

Acknowledgements go to the following codebases which were referenced:

r-nn's People

Contributors

nrehiew avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.