Giter Site home page Giter Site logo

facebookresearch / ijepa Goto Github PK

View Code? Open in Web Editor NEW
2.7K 2.7K 334.0 33 KB

Official codebase for I-JEPA, the Image-based Joint-Embedding Predictive Architecture. First outlined in the CVPR paper, "Self-supervised learning from images with a joint-embedding predictive architecture."

License: Other

Python 100.00%

ijepa's Introduction

I-JEPA

Official PyTorch codebase for I-JEPA (the Image-based Joint-Embedding Predictive Architecture) published @ CVPR-23. [arXiv] [JEPAs] [blogpost]

Method

I-JEPA is a method for self-supervised learning. At a high level, I-JEPA predicts the representations of part of an image from the representations of other parts of the same image. Notably, this approach learns semantic image features:

  1. without relying on pre-specified invariances to hand-crafted data transformations, which tend to be biased for particular downstream tasks,
  2. and without having the model fill in pixel-level details, which tend to result in learning less semantically meaningful representations.

ijepa

Visualizations

As opposed to generative methods that have a pixel decoder, I-JEPA has a predictor that makes predictions in latent space. The predictor in I-JEPA can be seen as a primitive (and restricted) world-model that is able to model spatial uncertainty in a static image from a partially observable context. This world model is semantic in the sense that it predicts high level information about unseen regions in the image, rather than pixel-level details.

We trained a stochastic decoder that maps the I-JEPA predicted representations back in pixel space as sketches. The model correctly captures positional uncertainty and produces high-level object parts with the correct pose (e.g., dog’s head, wolf’s front legs).

ijepa-predictor-sketch Caption: Illustrating how the predictor learns to model the semantics of the world. For each image, the portion outside of the blue box is encoded and given to the predictor as context. The predictor outputs a representation for what it expects to be in the region within the blue box. To visualize the prediction, we train a generative model that produces a sketch of the contents represented by the predictor output, and we show a sample output within the blue box. The predictor recognizes the semantics of what parts should be filled in (the top of the dog’s head, the bird’s leg, the wolf’s legs, the other side of the building).

Evaluations

I-JEPA pretraining is also computationally efficient. It does not involve any overhead associated with applying more computationally intensive data augmentations to produce multiple views. Only one view of the image needs to be processed by the target encoder, and only the context blocks need to be processed by the context encoder. Empirically, I-JEPA learns strong off-the-shelf semantic representations without the use of hand-crafted view augmentations.

1percenteval lineareval

Pretrained models

arch. patch size resolution epochs data download
ViT-H 14x14 224x224 300 ImageNet-1K full checkpoint logs configs
ViT-H 16x16 448x448 300 ImageNet-1K full checkpoint logs configs
ViT-H 14x14 224x224 66 ImageNet-22K full checkpoint logs configs
ViT-g 16x16 224x224 44 ImageNet-22K full checkpoint logs configs

Code Structure

.
├── configs                   # directory in which all experiment '.yaml' configs are stored
├── src                       # the package
│   ├── train.py              #   the I-JEPA training loop
│   ├── helper.py             #   helper functions for init of models & opt/loading checkpoint
│   ├── transforms.py         #   pre-train data transforms
│   ├── datasets              #   datasets, data loaders, ...
│   ├── models                #   model definitions
│   ├── masks                 #   mask collators, masking utilities, ...
│   └── utils                 #   shared utilities
├── main_distributed.py       # entrypoint for launch distributed I-JEPA pretraining on SLURM cluster
└── main.py                   # entrypoint for launch I-JEPA pretraining locally on your machine

Config files: Note that all experiment parameters are specified in config files (as opposed to command-line-arguments). See the configs/ directory for example config files.

Launching I-JEPA pretraining

Single-GPU training

This implementation starts from the main.py, which parses the experiment config file and runs the pre-training locally on a multi-GPU (or single-GPU) machine. For example, to run I-JEPA pretraining on GPUs "0","1", and "2" on a local machine using the config configs/in1k_vith14_ep300.yaml, type the command:

python main.py \
  --fname configs/in1k_vith14_ep300.yaml \
  --devices cuda:0 cuda:1 cuda:2

Note: This example is just used for illustrative purposes, as the ViT-H/14 config should be run on 16 A100 80G GPUs for an effective batch-size of 2048, in order to reproduce our results.

Multi-GPU training

In the multi-GPU setting, the implementation starts from main_distributed.py, which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source submitit tool and provide examples for a SLURM cluster.

For example, to pre-train on 16 A100 80G GPUs using the pre-training experiment configs specificed inside configs/in1k_vith14_ep300.yaml, type the command:

python main_distributed.py \
  --fname configs/in1k_vith14_ep300.yaml \
  --folder $path_to_save_submitit_logs \
  --partition $slurm_partition \
  --nodes 2 --tasks-per-node 8 \
  --time 1000

Requirements

  • Python 3.8 (or newer)
  • PyTorch 2.0
  • torchvision
  • Other dependencies: pyyaml, numpy, opencv, submitit

License

See the LICENSE file for details about the license under which this code is made available.

Citation

If you find this repository useful in your research, please consider giving a star ⭐ and a citation

@article{assran2023self,
  title={Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture},
  author={Assran, Mahmoud and Duval, Quentin and Misra, Ishan and Bojanowski, Piotr and Vincent, Pascal and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas},
  journal={arXiv preprint arXiv:2301.08243},
  year={2023}
}

ijepa's People

Contributors

midoassran avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ijepa's Issues

NC license is non-free

Non-commercial (NC) clauses are non-free as they are not free software according to the FSF, open source according to the OSI, or free culture according to Freedom Defined. I would recommend using CC-BY-SA-4.0, CC-BY-4.0, or CC-0-1.0 instead which are free culture Creative Commons licenses.

Logging Configuration and YAML Loading Issues in "main.py"

The current code snippet has two issues related to logging configuration and YAML loading. These issues should be addressed to ensure proper functionality and security. Here are the details of each issue:

Logging Configuration:
The logging configuration in the code is incomplete and needs improvement. It lacks the setup for the log handler and the desired log level. The code should be modified to include a suitable log handler and set the appropriate log level for different ranks. This will ensure consistent and effective logging throughout the application.

YAML Loading Security:
The code uses yaml.load to parse the YAML file, which can potentially introduce security vulnerabilities due to code injection. It is recommended to replace yaml.load with yaml.safe_load to safely load the YAML file. This will prevent potential risks associated with malicious YAML files.

To address these issues, the logging configuration should be updated to include the desired log handler and log level. Additionally, the usage of yaml.load should be replaced with yaml.safe_load for secure YAML loading.

Please consider reviewing and resolving these issues to enhance the code's functionality and ensure proper security practices

yuedajiong-question #03: how about the generalization ability of the type of algorihtm? especial for reconstruciton-like downstream task.

after different encoders(both input and target side) mapping, and minimize the loss in laten-space, can we say: the type of algorihtm is proper to those taskes make decision based highly-abstract image information but not raw pixel sapce, such as for: classification, high-level shape,color,shape, and motion-action?

if we have to use it in reconstruction-like task, we should include RAW_image:
O = f_reconstucion_based_on_LeCun_iJEPA(jkepa_high_level, Raw_Image)
that means, we just use IJEPA's high-level information.

All blocks have same size within a batch

Hi,
I'm currently exploring I-JEPA and I have a question about the multiblock mask collator.
I see that in the collator you sample p_size and e_size outside of the loop that iterates over the batch, so all blocks from a batch have the same size.

Is it something required for the model to work, or is it just because it was easier to implement? Or is it eventually to avoid padding issues?

Thanks a lot!

yuedajiong-question #05: how about the encoding space design? advanced or limited?

just from my limited understanding:

as we can see, from paper and in code, the mask is created in dataset creation, that means the mask in original 2d pixel space.

even encoded, that is, after encoder(input-side) and before predictor, the mask used is still raw mask, in Euclidean space/2D/pixel.

even if no hierarchical encoder so far, that still limited the tranfrorming-ability/mapping-ability/degree-of-freedom of encoder, we must keep somthing<not only form/shape, but also semantic> and we can recognize/use it after encoding, that is, can apply mask after encoding.

is this the best design?
no any other design that mapping mask to another more abstract space? or if we can not design a correct mask-mapping, can we use other fulcrum for self-supervised.

just from my view, the raw,cann't-mapping mask is more limited, especial for later hierarchical encoders in LeCun's total-JEPA.

in a word, on the one hand, we want to constuct loss in laten-space after encoders, on the other hand, we still use original pixel space mask after encoders, for self-supervised. right? awkward?

yuedajiong-question #06: what is the better pathway to create unified vision for super-AI?

iJEPA?

I think: an unified 3d task is better? (conditional-generation/reconsuction for priori-remember and implicit-explicit representation)

Image(s)/Video -> f_cond_gen_as_recon(timestep, ...) -> [implicit-representation-by-object] #here123 + camera-information(origin+direction) ->f_diffentable_render_not_only_nerf(cam_info, scene-representation, timestep, implicit-representation-by-object ...) -> Image(s)/Video

scene-representation: for multi-objects interaction.
timestep: for dynamic, not only train/recon, but alse infer/gen.

#here123: if need explicit, we can add a mapping branch to transform object implicit-representation to explicit, like nerf to mesh.

this is E2E differentable.

--- we can train this unified vision, after trained:

  1. we can use the components in implicit-representation-by-object, and laten tensors in f_cond_gen_as_recon for upper LeCun's total-JEPA.
  2. we have a vision-oriented world-model(explicit) for upper thinking, the system can image that: A car go off a cliff.
    --- we can train this vision system independently.

just from my view:

  1. task definition is very important. most of research works will be valueless, because the algorithms are just for that limited and local task definition.
  2. what is the better unified vision task? (if we have a chance to rethink before super-AI arrival.)
    P0: As much as possible to satisfy physical reality: 3D in physical --> 2*2D in human --> reconstruced 3D
    P1: multitask for necessary information can be keep in network.

And this is not a joke. It is a tragedy:
most of researcher and entrepreneur, focus on LAAAAAAAAAAARGE model, but not an AAAAAAAAAAAdvaned small model first and scale-up when the small guy is smart enough.

under the large model benchmark leaderboard, no a single grass grow.

Difficulty continue self supervised pre training on custom dataset

Hello,

I'm experiencing challenges
training the model on a custom dataset consisting of medical images.

Environment & Setup:

GPUs: 8 x A100 40GB
Batch Size: 32 per GPU (Total: 256)
Learning Rate: 0.001
Model: VIT huge patch size 14
Epochs: 300
Warmup: Set to 0
Weights Initialization: Loaded from the provided checkpoint in this repository
Problem Description:
When training with the mentioned setup, I'm observing that it's difficult to get the desired learning rate and other hyperparameters to work effectively. The loss goes up and the rankme and F1-score does not go up. Please see the evaluation schema below and the attached figures:

Training Metrics & Evaluation:
Here's a brief outline of the evaluation metrics from the training:

Evaluation Methods: loss, Rankme.
Classification Type: 3-class classification downstream task on validation set
Metrics: F1 macro and accuracy (KNN

I'd appreciate any guidance or recommendations to help resolve this.
Thank you for your assistance.

1
2

Best regards,
Christian

Bias in Multiblock Mask Collator

I've been sampling the multi-block mask collator and plotting the masks to understand how they look, and believe I've found a bias that may have significant impact on the training of any models using this class.

The following patterns are consistently shown for batch sizes >= 128, and convey that many patches central in the image are never masked by enc_masks. Note this behaviour only occurs for allow_overlap=False.

Here are four examples I've sampled using the code below, with no cherry picking.
Each image is a 128 sized batch of enc masks, generated using the default arguments for the multi-block mask collator.
Each pixel represents a patch, and is white iff that patch is included in any of the masks in its batch.
Repro code below.

image
image
image
image

import torch
from src.mask import MaskCollator
import matplotlib.pyplot as plt

collator = MaskCollator()

batch = [torch.randn(3, 224, 224) for _ in range(1024)]
batch = collator(batch)
batch, enc_masks, pred_masks = batch

def display_mask(mask):
    # mask is a tensor of indices from 0 to 195
    # can be individual mask, or multiple.
    # display a 14x14 grid, where each cell is on if the corresponding index is in the mask
    grid = torch.zeros(14,14)
    for i in range(196):
        grid[i // 14, i % 14] = 1 if i in mask else 0
    plt.imshow(grid, cmap='gray')
    plt.show()

# change second index from ':' to integer to visualise individual masks
display_mask(enc_masks[0][:])

Image resolution & folder structure for unsupervised pre-training

Am exploring I-JEPA, wanted to make sure I understood what it's expecting in terms of the structure of image_folder - e.g., here's my config:
 

data:
  batch_size: 128
  color_jitter_strength: 0.0
  crop_scale:
  - 0.3
  - 1.0
  crop_size: 224
  image_folder: /data_home/datasets/custom_dataset/unlabeled
  num_workers: 10
  pin_mem: true
  root_path: /data_home/datasets/custom_dataset/unlabeled
  use_color_distortion: false
  use_gaussian_blur: false
  use_horizontal_flip: false

Imagine my image_folder is structured like this - where each batch is a folder containing several thousand unlabeled images:

/data_home/datasets/custom_dataset/unlabeled/batch_001
/data_home/datasets/custom_dataset/unlabeled/batch_002

Is structuring my dataset like that an incorrect way of pretraining? E.g., will I-JEPA be incorrectly influenced by the "grouping" of images in each batch folder (even though each folder contains randomly assembled unlabeled images)?

Additionally, for pre-training I-JEPA on a new dataset composed of unlabeled data, what resolution should those unlabeled images be?

Thank you!

Will The sketche decoder release?

The stochastic decoder that maps the I-JEPA predicted representations back in pixel space as sketche will release?Your decoder looks like it works well, and I'd like to try it。

Is the VIT predictor important?

Hello, I read your paper and enjoyed it very much. I understand that the predictor is necessary to prevent representational collapse, since it is used to predict missing information based on the information it has. I am not sure if I missed it but does the predictor have to be a VIT, or could the decoder/predictor be any reasonable architecture such as a group of linear layers. Would the decoder not being a ViT prevent the model from learning robust high-level semantic representations?

Some questions about context and target

Hi authors, it's an amazing job, the idea is new and results are impressive.
When I read the papers, i'm confused by how you get context and target.
(1) In the paper you mentioned that image first go throught a vit to get a sequence of patch-level features. and you randomly sample M patch features. Until now i'm following, but then you just mentioned how you sample the blocks with a random aspect ratio in the range (0.75, 1.5) and random scale in the range (0.15, 0.2). , in my understanding, it refers to you first use this ratio to pick a mask, then you use this mask to get features of patches in side this mask, is that true?
(2) in Context section, you refer block, in my understanding the block should be a rectangular, but in figure 4 it seems not. you also mention Since the target blocks are sampled independently from the context block, there may be significant overlap, why target block sampled from context block? aren't they sampled from original image patch-level representations?

Thanks for answer questions !

Insufficient Logging and Error Handling in "main_distributed.py"

The code lacks proper logging and error handling, which can make it difficult to troubleshoot issues and identify failures during execution. It is important to improve the logging mechanism and implement robust error handling to enhance the reliability and maintainability of the code. Here are the details of the issues:

Incomplete Logging:
The code only initializes a basic logging configuration without specifying the log output file or formatting. This can make it challenging to track and analyze the application's behavior and diagnose any potential errors or anomalies.

Lack of Error Handling:
The code does not include comprehensive error-handling mechanisms, which can lead to unhandled exceptions and unexpected termination of the program. Proper error handling should be implemented to catch and handle exceptions gracefully, providing useful error messages and appropriate actions to be taken.

Missing Input Validation:
The code does not perform sufficient validation and sanitization of the input arguments obtained from the command line. This leaves the application vulnerable to unexpected or malicious input, potentially leading to security vulnerabilities or incorrect behavior.

Deprecated YAML Loader:
The code uses the yaml.load function with the yaml.FullLoader, which is deprecated since PyYAML version 5.1. It is recommended to switch to the yaml.safe_load function for safer loading of YAML files.

Downstream task

After train the model can we use only target-encoder for down-stream task ?? like- image captioning etc.

Why doesn't the model collapse?

Hi, thanks for the great work. From the paper, the only objective function is the distance between predicted and encoded representations of target patches. Why does the model not converge to a trivial solution? e.g. predicting 0s all the time. I noticed there are some cited works on this issue, but do you have an intuitive explanation for this? Thanks!

How to load ijepa checkpoints?

I am trying to use this model for classification of cifar10 in Google Colab. I was trying to load the model to study its layers so I cloned this repo and I am using it as follows:

from vision_transformer import vit_huge
# Initialize the ViT-H model with the specified patch size and resolution
model = vit_huge(patch_size=14, num_classes=1000)  # Adjust num_classes if needed
import torch
# Load the state dictionary from the file
state_dict = torch.load('/content/drive/MyDrive/IN1K-vit.h.14-300e.pth.tar')

# Load the state dictionary into the model
model.load_state_dict(state_dict)

# Print the layers/modules of the model for inspection
def print_model_layers(model, prefix=""):
    for name, module in model.named_children():
        if isinstance(module, torch.nn.Module):
            module_name = prefix + "." + name if prefix else name
            print(module_name)
            print_model_layers(module, prefix=module_name)

print_model_layers(model)

but I get the following error:

`RuntimeError Traceback (most recent call last)
in <cell line: 6>()
4
5 # Load the state dictionary into the model
----> 6 model.load_state_dict(state_dict)
7
8 # Print the layers/modules of the model

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
2039
2040 if len(error_msgs) > 0:
-> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2042 self.class.name, "\n\t".join(error_msgs)))
2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for VisionTransformer:
Missing key(s) in state_dict: "pos_embed", "patch_embed.proj.weight", "patch_embed.proj.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.2.norm2.weight", "blocks.2.norm2.bias", "blocks.2.mlp.fc1.weight", "blocks.2.mlp.fc1.bias", "blocks.2.mlp.fc2.weight", "blocks.2.mlp.fc2.bias", "blocks.3.norm1.weight", "blocks.3.norm1.bias", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.3.norm2.weight", "blocks.3.norm2.bias", "blocks.3.mlp.fc1.weight", "blocks.3.mlp.fc1.bias", "blocks.3.mlp.fc2.weight", "blocks.3.mlp.fc2.bias", "blocks.4.norm1.weight", "blocks.4.norm1.bias", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "bl...
Unexpected key(s) in state_dict: "encoder", "predictor", "opt", "scaler", "target_encoder", "epoch", "loss", "batch_size", "world_size", "lr".`

I do not understand which vit from the vision_tranformer.py I am supposed to use for the checkpoint (IN1K-vit.h.14-300e.pth.tar) because using vit_huge gives the error above.

generative model decoder for target-encoder

Firstly, thank you for your great work.I freeze the target-encoder weights, and train a decoder following the RCDM framework to map the average-pool of the target-encoder outputs back to pixel space,but i can't obtain visual results similar to those in the paper.Specifically, I transformed the output of the I-JEPA target encoder, which is a tensor of [batchsize, 256,1280], into a tensor of [batch size, 1280] through torch. mean (x, dim=1), and then trained the decoder of the generated model using the RCDM framework. Due to limitations in graphics memory and graphics card, I set the batch size to 6 instead of the default 8, and used a single GPU for training. I used imagenet as the training set and trained approximately 1000000 steps, but I did not achieve good results. Should I need to make any special settings when training the decoder? By the way, will the pre trained model for generating model decoders be released? I would greatly appreciate it if you could reply to me

Torch version

Hi, thank you for sharing your great work.
I have a question.
According to the Readme file, you recommend the torch 2.0.
I'm using torch 1.12, and are there any probabilities that the torch version might make the problems?

imagenet1k Huggingface extraction

Hi. I downloaded the imagenet1k file from huggingface, which has the train_images_0.tar.gz file. However, whenever I extract it (by running tar -xf train_images_0.tar.gz), I only get images, and not subfolders, which torchvision expects. Therefore, whenever I run main.py, I get a FileNotFoundError(f"Couldn't find any class folder in {directory}.").

Does anyone know how to successfully curate the data so that I can get the ijepa code working?

Thanks!

Training loss increases

Hi.

I am trying to train ijepa vit-huge_16_448 on a dataset with medical images.
I use 24 A100 GPUS with 40Gb of memory.
I have adjusted the learning rate using the linear scaling rule
The original ijepa uses lr=0.001 with batch_size=16, gpu=16. This gives a total batch size of 256.
In my experiment I have batch_size=6 and gpus=24. This gives a total batch size of 144.
The fraction between these numbers is 144/256=0.56
So my learning rate should be 0.001*0.56 = 0.00056
The loss is decreasing in the beginning, but after 3 epochs is starts increasing:

1 epoch avg loss :0.028
2 epoch avg loss: 0.005
3 epoch avg loss: 0.005
4 epoch avg loss: 0.006
5 epoch avg loss: 0.008
6 epoch avg loss: 0.012
7 epoch avg loss: 0.015

Why does the loss increase at this point?

Difference between I-JEPA and data2vec2.0?

If I understand I-JEPA correctly, The key insight is to predict the more semantic representation instead of low-level pixels. Then how is I-JEPA different from data2vec2.0? Can you give me a high-level vision?

image

Training from scratch

Hey everyone!

First off, thanks for the great work. I implemented my own version of I-JEPA (https://github.com/Ugenteraan/I-JEPA) by referencing to this repository.

I used the Doges 77 Breeds (https://www.kaggle.com/datasets/madibokishev/doges-77-breeds) dataset for the training. The loss goes down in a convincing manner during the SSL training. However during the downstream, when I load the pre-trained weights from the encoder and use probing, the accuracy is no better than a randomly initialized encoder weights.

Does anyone have a clue on what might have been the cause of this?

Thanks in advance! Cheers.

Making the code a bit more installable

Hi,

  1. I made some small modifications to make the code more installable (using namespace and not src so it can be pip installed) and usable with import (like ijepa.utils) in other packages. So for example its easy to load the train model .
  2. Also the ability to run training not in data parallel on osx (or none gpu machine).

Will it be ok to open a pull request?

Thanks!

Typo in Paper

This should probably be target block. Great work!
image

yuedajiong-question #02: same or different encoding space?

what is the essence of predictor? can we use same encoders both input-side and target-side?

my understanding:

  1. Lecun and this MidoAssran's implementation:
    pixel space -> f_encode_input -> encoded-but-before-predict space #S_e -> f_predict -> predicted space #S_p
    pixel space ->f_ebncode_target -> encoded-for-loss-with-predicted space #S_t
    based on my limited understanding, the distance between S_t and S_p is nearer than the distance between S_t and S_e.
  2. Another:
    pixel space -> f_encode_input -> encoded-but-before-predict space #S_e -> f_predict -> predicted space #S_p
    'FORCE' to use same f_encode_input as encoder_target (of course, need code modification). (enen, we can use another reconsturction even multi-task based model to pre-train the encoder, we just need to train predicttor here.)
    so, this will be: the predictor prdicts masked information in SAME laten space. few parameters, especial hierarchical encoders.

is this still reasonable for min/max information content:
https://scontent-hkt1-1.xx.fbcdn.net/v/t39.2365-6/274349327_328573009069618_3830687132316417744_n.jpg?_nc_cat=110&ccb=1-7&_nc_sid=ad8a9d&_nc_ohc=laxyPF2DPaUAX-t_SPX&_nc_ht=scontent-hkt1-1.xx&oh=00_AfDrZb2-IiXgN7GjAyo0DCvL8qhV5-W0Jc6wdnzs-AyPJw&oe=648F1A69

Struggling to replicate evaluation results

Hi folks,

I'm trying to replicate your linear probe evaluation results. I can only get your pre-trained model to score 77% (with the last layer) or 80.8% (with the last 4 layers) on a linear probe of CIFAR-100. I'm using the ViT-H with a patch size of 14 that was trained for 300 epochs on ImageNet-1k.

I am using the following transforms, taken from the VISSL repo. Training:

# Taken from the VISSL repo.
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=224, interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
])

And testing:

test_transforms = transforms.Compose([
    transforms.Resize(size=256, interpolation=3),
    transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
    ])    

I'm training a linear model using SGD w/ parameters lr=0.01, momentum=0.9, weight_decay=5e-4, nesterov=True, and I'm decaying the learning rate by a factor of 10 on the 8th, 16th, and 24th epochs.

To get the "last four" embedding, I'm concatenating the output of the last four blocks to create a (B, 1024, 1280) tensor.

For the averaging, I'm averaging over the spatial dimension, i.e. turning the (B, S, 1280) tensor into a (B, 1280) tensor.

Would you be able to shed any light on what I could be missing in my reproduction?

Longer training schedule?

Hi Assran, thank you for your greatTTTTTTTTT work! I wonder if using a longer pre-training schedule (800/1200/1600 epochs), how much performance superiority can we get upon previous methods like MAE?

yuedajiong-question #04: can we use this iJEPA pre-train for that LeCun total-JEPA?

just from my understanding:

  1. the loss fulcrum is mask in iJEPA.
  2. the LeCun total-JEPA is still E2E differentiable, but maybe trained in a mini interactive env based vision, and feedback signal is from action but not direct vision.

can we directly use this pre-train parameters? or re-train from zero?

in other word: the representation of mask-fulcrum-self-supervised , and, the representation of sequencial-actions-fulcrum-self-supervised in later total-JEPA, are similar?

Capitalization correction in README.md

Description:

I noticed a minor capitalization inconsistency in the README.md file, specifically in the section describing the benefits of I-JEPA. The word "Without" is currently written in lowercase, and I propose capitalizing it for consistency and clarity.

Proposed Change:

Capitalize the word "Without" in the following sentence to maintain consistency with the capitalization of other sentences:

Original Sentence:

"...without relying on pre-specified invariances to hand-crafted data transformations..."

Proposed Sentence:

  1. Without relying on pre-specified invariances to hand-crafted data transformations, which tend to be biased for particular downstream tasks,
  2. And without having the model fill in pixel-level details, which tend to result in learning less semantically meaningful representations.

This change will help improve the readability and maintain a consistent writing style throughout the README.md file.

Linear probing

Hi, I have a question about linear probing, I haven't seen a CLS token. Is the classification performed directly on all of the outputs (which makes a lot of parameters for a single layer) or on an average of the outputs ? Thx !

FileNotFoundError: params-ijepa.yaml when run main.py

env

windows11
Python 3.9.13
torch 2.0

log

here is log when i run “python main.py --fname configs/in1k_vith14_ep300.yaml --devices cuda:0”

INFO:root:called-params configs/in1k_vith14_ep300.yaml
INFO:root:loaded params...
{   'data': {   'batch_size': 128,
                'color_jitter_strength': 0.0,
                'crop_scale': [0.3, 1.0],
                'crop_size': 224,
                'image_folder': 'imagenet_full_size/061417/',
                'num_workers': 10,
                'pin_mem': True,
                'root_path': '$replace_this_with_absolute_path_to_your_datasets_directory',
                'use_color_distortion': False,
                'use_gaussian_blur': False,
                'use_horizontal_flip': False},
    'logging': {   'folder': '$replace_this_with_path_for_experiment_logs/vith14.224-bs.2048-ep.300/',
                   'write_tag': 'jepa'},
    'mask': {   'allow_overlap': False,
                'aspect_ratio': [0.75, 1.5],
                'enc_mask_scale': [0.85, 1.0],
                'min_keep': 10,
                'num_enc_masks': 1,
                'num_pred_masks': 4,
                'patch_size': 14,
                'pred_mask_scale': [0.15, 0.2]},
    'meta': {   'copy_data': False,
                'load_checkpoint': False,
                'model_name': 'vit_huge',
                'pred_depth': 12,
                'pred_emb_dim': 384,
                'read_checkpoint': None,
                'use_bfloat16': True},
    'optimization': {   'ema': [0.996, 1.0],
                        'epochs': 300,
                        'final_lr': 1e-06,
                        'final_weight_decay': 0.4,
                        'ipe_scale': 1.0,
                        'lr': 0.001,
                        'start_lr': 0.0002,
                        'warmup': 40,
                        'weight_decay': 0.04}}
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
INFO:root:Running... (rank: 0/1)
Process Process-1:
Traceback (most recent call last):
  File "D:\Program Files\anaconda3\lib\multiprocessing\process.py", line 315, in _bootstrap
    self.run()
  File "D:\Program Files\anaconda3\lib\multiprocessing\process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "E:\deepLearn\ijepa\main.py", line 55, in process_main
    app_main(args=params)
  File "E:\deepLearn\ijepa\src\train.py", line 129, in main
    with open(dump, 'w') as f:
FileNotFoundError: [Errno 2] No such file or directory: '$replace_this_with_path_for_experiment_logs/vith14.224-bs.2048-ep.300/params-ijepa.yaml'

anyone meet is too?

Loading pre-trained model: state_dict key mismatch

Firstly, thanks for the amazing work!
I implemented my own code loading your pre-trained model, IN1K-vit.h.16-448px-300e.pth, and encountered this issue:

RuntimeError: Error(s) in loading state_dict for VisionTransformer:
	Missing key(s) in state_dict: "pos_embed", "patch_embed.proj.weight", "patch_embed.proj.bias", "blocks.0.norm1.weight", ......
        Unexpected key(s) in state_dict: "module.pos_embed", "module.patch_embed.proj.weight", "module.patch_embed.proj.bias", "module.blocks.0.norm1.weight", ......

I used the exact same model architecture in your vision_transformer.py file, and the problem was solved after I added this line before loading:

pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}

I wonder if there are any issues in your released model weights. Did you forget to update to the newest version?

Struggling to Train Downstream Classifier

Hi,

I'm working on training a downstream classification task from the ImageNet-22k checkpoint. When I use a TinyViT checkpoint, average over the first dimension of output and feed that into a linear classification head, the model trains appropriately. However, if I replace TinyViT with the target encoder of I-JEPA, once again averaging over the first dimension of the final layer and feeding into a linear classification head. However, the model fails to train at all in these conditions. Has anyone been able to successfully train on a downstream task?

Thank you!

yuedajiong-question #01: no hierarchical?

no hierarchical encoder(input-side) and predictor? (I can see the hierarchical design in LeCun's 10 years plan)

do we need explicit hierarchical encoder and predictor? or current one level encoder(i think it is a type of weakened/implicit hierarchical encoding, different level abstraction is based on residual of transformer).

if hierarchical encoder and predictor (encode to different abstraction levels, and predict), do we need hierarchical target_encoder?

Please Open Source ViT-B/16

Firstly,I would like to express my appreciation for your amazing work.
As ViT-B is too large for our project,I noticed that you conducted experiments with ViT-B in your paper. Could you please release the ViT-Base model that you used for these experiments? It would be greatly appreciated if you could open source ViT-B/16 for the benefit of the wider community. Thank you very much for your consideration.

RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

Fitrstly, thanks for your great work!
But when I want to run the in1k_vith14_ep300.yaml with the guidence, I have met a mistake:

File "/ijepa-main/main.py", line 52, in process_main
    app_main(args=params)
File "/ijepa-main/src/train.py", line 221, in main 
    encoder = DistributedDataParallel(encoder, static_graph=True)
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

I have a device with two GPUs, both of which are RTX 3090. How could I solve this problem? If you can provide assistance, I would greatly appreciate it.

Can we use pertained ijpea for transfer learning object detection task?

First of all thank you for sharing such an amazing project 🙂. Can we use the embeddings generated by context encoder of this pertained model for a task like object detection on custom dataset? If yes do you have any suggestions on which type of head architecture(for bbox and class label prediction) can be used with ijepa as backbone for object detection task?

Thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.