Giter Site home page Giter Site logo

mlp's Introduction

mlp

This is a very exciting module because we tie a lot of things things together and train a Multi-layer Perceptron (MLP) to be an n-gram Language Model, following the paper A Neural Probabilistic Language Model from Bengio et al. 2003.

We have multiple parallel implementations that all get to the exact same results but in very different ways:

  • micrograd, following the previous module micrograd. This is a highly inefficient approach but it uses our own scalar-valued gradient engine.
  • numpy, where we use the array object of numpy but implement our own forward and backward pass using numpy operations.
  • C, which is the same as the numpy code but it fully spells out all the individual operations in C code.
  • PyTorch, where we use the pytorch library to implement the forward pass only. Just like micrograd, PyTorch will handle the backward pass for us.
  • mlx/JAX? (would be nice to look into)

In this module, two critical abstractions get explored and are tied together in depth:

  1. The idea of an Array (in numpy parlance) or Tensor (in PyTorch parlance): a multi-dimensional array that stores data and has operations defined on it.
  2. The idea of a Module: a class that has both a forward() and a backward() method. The forward pass computes the output given the input, and the backward pass computes the gradient of the loss with respect to the input. The "autograd engine" keeps track of the computational pass that is constructed in the forward pass, and then after the forward pass iterates in the reverse order and calls backward() on each module, implementing backpropagation.

The services offered by PyTorch then become clear: it gives both an efficient Array/Tensor object, and it has an Autograd engine (just like micrograd) that computes gradients for you. Only burshed on in this module is a third major offering of PyTorch, the fact that PyTorch Tensors can be moved to different devices (like GPUs) in a transparent way, greatly speeding up all the computations.

As a result of our efforts, we will get to enjoy a much lower validation loss than we saw in the ngram module, and with significantly fewer parameters. However, we're also doing this at a much higher computational cost at training time (we're essentially compressing the dataset into the model parameters), and also to some extent at inference time.

TODOs:

  • tune the hyperparameters so they are not terrible, I just winged it. (currently seeing val loss 2.06, recall count-based 4-gram was 2.11)
  • implement all the other versions that match pytorch reference

License

MIT

mlp's People

Contributors

karpathy avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.