Giter Site home page Giter Site logo

5ajaki / ml-aim Goto Github PK

View Code? Open in Web Editor NEW

This project forked from apple/ml-aim

0.0 0.0 0.0 72 KB

This repository provides the code and model checkpoints of the research paper: Scalable Pre-training of Large Autoregressive Image Models

License: Other

Python 100.00%

ml-aim's Introduction

AIM: Autoregressive Image Models

Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar, Joshua M Susskind, and Armand Joulin

[Paper] [BibTex]

This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models.

We introduce AIM a collection of vision models pre-trained with an autoregressive generative objective. We show that autoregressive pre-training of image features exhibits similar scaling properties to their textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:

  1. the model capacity can be trivially scaled to billions of parameters, and
  2. AIM effectively leverages large collections of uncurated image data.

Installation

Please install PyTorch using the official installation instructions. Afterward, install the package as:

pip install git+https://[email protected]/apple/ml-aim.git

We also offer MLX backend support for research and experimentation on Apple silicon. To enable MLX support, simply run:

pip install mlx

Usage

Below we provide an example of usage in PyTorch:

from PIL import Image

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="torch")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
logits, features = model(inp)
and in both MLX
from PIL import Image
import mlx.core as mx

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
inp = mx.array(inp.numpy())
logits, features = model(inp)
and JAX
from PIL import Image
import jax.numpy as jnp

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
inp = jnp.array(inp)
(logits, features), _ = model.apply(params, inp, mutable=['batch_stats'])

Pre-trained checkpoints

The pre-trained models can be accessed via PyTorch Hub as:

import torch

aim_600m = torch.hub.load("apple/ml-aim", "aim_600M")
aim_1b   = torch.hub.load("apple/ml-aim", "aim_1B")
aim_3b   = torch.hub.load("apple/ml-aim", "aim_3B")
aim_7b   = torch.hub.load("apple/ml-aim", "aim_7B")

or via HuggingFace Hub as:

from aim.torch.models import AIMForImageClassification

aim_600m = AIMForImageClassification.from_pretrained("apple/aim-600M")
aim_1b   = AIMForImageClassification.from_pretrained("apple/aim-1B")
aim_3b   = AIMForImageClassification.from_pretrained("apple/aim-3B")
aim_7b   = AIMForImageClassification.from_pretrained("apple/aim-7B")

Pre-trained backbones

The following table contains pre-trained backbones used in our paper.

model #params attn (best layer) backbone, SHA256
AIM-0.6B 0.6B 79.4% link, 0d6f6b8f
AIM-1B 1B 82.3% link, d254ecd3
AIM-3B 3B 83.3% link, 8475ce4e
AIM-7B 7B 84.0% link, 184ed94c

Pre-trained attention heads

The table below contains the classification results on ImageNet-1k validation set.

model top-1 IN-1k attention head, SHA256
last layer best layer last layer best layer
AIM-0.6B 78.5% 79.4% link, 5ce5a341 link, ebd45c05
AIM-1B 80.6% 82.3% link, db3be2ad link, f1ed7852
AIM-3B 82.2% 83.3% link, 5c057b30 link, ad380e16
AIM-7B 82.4% 84.0% link, 1e5c99ba link, 73ecd732

Reproducing the IN-1k classification results

The commands below reproduce the attention probe results on ImageNet-1k validation set. We run the evaluation using 1 node with 8 GPUs:

torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
  --model=aim-7B \
  --batch-size=64 \
  --data-path=/path/to/imagenet \
  --probe-layers=best \
  --backbone-ckpt-path=/path/to/backbone_ckpt.pth \
  --head-ckpt-path=/path/to/head_ckpt.pth

By default, we probe features from the intermediate 6 layers that provide the best performance. To change this, simply pass --probe-layers=last.

Citation

If you find our work useful, please consider citing us as:

@article{el2024scalable,
  title={Scalable Pre-training of Large Autoregressive Image Models},
  author={El-Nouby, Alaaeldin and Klein, Michal and Zhai, Shuangfei and Bautista, Miguel Angel and Toshev, Alexander and Shankar, Vaishaal and Susskind, Joshua M and Joulin, Armand},
  journal={arXiv preprint arXiv:2401.08541},
  year={2024}
}

ml-aim's People

Contributors

michalk8 avatar aelnouby avatar eltociear avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.