Giter Site home page Giter Site logo

cvmi-lab / slotcon Goto Github PK

View Code? Open in Web Editor NEW
92.0 3.0 9.0 1.67 MB

(NeurIPS 2022) Self-Supervised Visual Representation Learning with Semantic Grouping

Home Page: https://wen-xin.info/slotcon/

License: Apache License 2.0

Python 95.56% Shell 4.44%
contrastive-learning neurips-2022 object-discovery pre-training self-supervised-learning slotcon

slotcon's Introduction

Self-Supervised Visual Representation Learning with Semantic Grouping

Self-Supervised Visual Representation Learning with Semantic Grouping (NeurIPS 2022)
By Xin Wen, Bingchen Zhao, Anlin Zheng, Xiangyu Zhang, and Xiaojuan Qi.

Introduction

We propose contrastive learning from data-driven semantic slots, namely SlotCon, for joint semantic grouping and representation learning. The semantic grouping is performed by assigning pixels to a set of learnable prototypes, which can adapt to each sample by attentive pooling over the feature and form new slots. Based on the learned data-dependent slots, a contrastive objective is employed for representation learning, which enhances the discriminability of features, and conversely facilitates grouping semantically coherent pixels together.

framework

Compared with previous efforts, by simultaneously optimizing the two coupled objectives of semantic grouping and contrastive learning, our approach bypasses the disadvantages of hand-crafted priors and is able to learn object/group-level representations from scene-centric images. Experiments show our approach effectively decomposes complex scenes into semantic groups for feature learning and significantly benefits downstream tasks, including object detection, instance segmentation, and semantic segmentation.

Pretrained models

Method Dataset Epochs Arch APb APm Download
SlotCon COCO 800 ResNet-50 41.0 37.0 script | backbone only | full ckpt
SlotCon COCO+ 800 ResNet-50 41.8 37.8 script | backbone only | full ckpt
SlotCon ImageNet-1K 100 ResNet-50 41.4 37.2 script | backbone only | full ckpt
SlotCon ImageNet-1K 200 ResNet-50 41.8 37.8 script | backbone only | full ckpt

Folder containing all the checkpoints: [link].

Getting started

Requirements

This project is developed with python==3.9 and pytorch==1.10.0, please be aware of possible code compatibility issues if you are using another version.

The following is an example of setting up the experimental environment:

  • Create the environment
conda create -n slotcon python=3.9 -y
conda activate slotcon
  • Install pytorch & torchvision (you can also pick your favorite version)
conda install pytorch==1.10.0 torchvision==0.11.0 cudatoolkit=11.3 -c pytorch
  • Clone our repo
git clone https://github.com/CVMI-Lab/SlotCon && cd ./SlotCon
  • (Optional) Create a soft link for the datasets
mkdir datasets
ln -s ${PATH_TO_COCO} ./datasets/coco
ln -s ${PATH_TO_IMAGENET} ./datasets/imagenet
  • (Optional) Install other requirements
pip install -r requirements.txt

Run pre-training

By default, we train with DDP over 8 GPUs on a single machine. The following are some examples of re-implementing SlotCon pre-training on COCO and ImageNet:

  • Train SlotCon on COCO for 800 epochs
./scripts/slotcon_coco_r50_800ep.sh
  • Train SlotCon on COCO+ for 800 epochs
./scripts/slotcon_cocoplus_r50_800ep.sh
  • Train SlotCon on ImageNet-1K for 100 epochs
./scripts/slotcon_imagenet_r50_100ep.sh

Evaluation: Object Detection & Instance Segmentation

Please install detectron2 and prepare the dataset first following the official instructions: [installation] [data preparation]

The following is an example usage of evaluating a pre-trained model on COCO:

  • First, link COCO to the required path:
mkdir transfer/detection/datasets
ln -s ${PATH_TO_COCO} transfer/detection/datasets/
  • Then, convert the pre-trained model to detectron2's format:
python transfer/detection/convert_pretrain_to_d2.py output/${EXP_NAME}/ckpt_epoch_xxx.pth ${EXP_NAME}.pkl
  • Finally, train a detector with the converted checkpoint:
cd transfer/detection &&
python train_net.py --config-file configs/COCO_R_50_FPN_1x_SlotCon.yaml --num-gpus 8 --resume MODEL.WEIGHTS ../../${EXP_NAME}.pkl OUTPUT_DIR ../../output/COCO_R_50_FPN_1x_${EXP_NAME}

Evaluation: Semantic Segmentation

Please install mmsegmentation and prepare the datasets first following the official instructions: [installation] [data preparation]

  • First, link the datasets for evaluation to the required path:
mkdir transfer/segmentation/data
ln -s ${PATH_TO_DATA} transfer/segmentation/data/
  • Then, convert the pre-trained model to mmsegmentation's format:
python transfer/segmentation/convert_pretrain_to_mm.py output/${EXP_NAME}/ckpt_epoch_xxx.pth ${EXP_NAME}.pth
  • Finally, run semantic segmentation in the following datasets: PASCAL VOC, Cityscapes, and ADE20K.
    • By default, we run PASCAL VOC and Cityscapes with 2 GPUs, and run ADE20K with 4 GPUs, with the total batch size fixed as 16.
# run pascal voc
cd transfer/segmentation &&
bash mim_dist_train.sh configs/voc12aug/fcn_d6_r50-d16_513x513_30k_voc12aug_moco.py ../../${EXP_NAME}.pth 2
# run cityscapes
cd transfer/segmentation &&
bash mim_dist_train.sh configs/cityscapes/fcn_d6_r50-d16_769x769_90k_cityscapes_moco.py ../../${EXP_NAME}.pth 2
# run ade20k
cd transfer/segmentation &&
bash mim_dist_train.sh configs/ade20k/fcn_r50-d8_512x512_80k_ade20k.py ../../${EXP_NAME}.pth 4

Prototype Visualization

We also provide the code for visualizing the learned prototypes' nearest neighbors. To run the following command, please prepare a full checkpoint.

python viz_slots.py \
    --data_dir ${PATH_TO_COCO} \
    --model_path ${PATH_TO_MODEL} \
    --save_path ${PATH_TO_SAVE} \
    --topk 5 \ # retrieve 5 nearest-neighbors for each prototype
    --sampling 20 # randomly sample 20 prototypes to visualize

concepts

Citing this work

If you find this repo useful for your research, please consider citing our paper:

@inproceedings{wen2022slotcon,
  title={Self-Supervised Visual Representation Learning with Semantic Grouping},
  author={Wen, Xin and Zhao, Bingchen and Zheng, Anlin and Zhang, Xiangyu and Qi, Xiaojuan},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}

Acknowledgment

Our codebase builds upon several existing publicly available codes. Specifically, we have modified and integrated the following repos into this project:

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

slotcon's People

Contributors

dtennant avatar xwen99 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

slotcon's Issues

ViT training ?

Hello,

Thank you for your very interesting work ! I'm currently trying to replicate your results with your provided codebase and I was wondering whether you also tested a Vision Transformer architecture as encoder ? You compared in the paper with DINO, but I wanted to know if you where able to get some properties close to what they obtained (a kind of saliency map with the attention map around the object of interest).

Thank you again for your response !

/home/sri/SlotCon/detectron2

Screenshot from 2023-02-17 12-13-53
When I run "python train_net.py --config-file /home/sri/SlotCon/transfer/detection/configs/COCO_R_50_FPN_1x_SlotCon.yaml --num-gpus 4 --resume MODEL.WEIGHTS /home/sri/SlotCon/sekhar.pkl OUTPUT_DIR /home/sri/SlotCon/output/COCO_R_50_FPN_1x_sekhar

it can not find the coco json file.

Index out of Range

Respected Sir,

While I am executing the code, getting the following error. It will be a very helpful for me if you are getting me out of that error.

Traceback (most recent call last):
File "/transfer/detection/convert_pretrain_to_d2.py", line 19, in
input = sys.argv[1]
IndexError: list index out of range

How to get the positive and negative pairs of slots?

Hi Xin,

Thanks for the great and insightful work.

When I read the code, I am confused by the label generation for contrastive learning of slots.
As shown in https://github.com/CVMI-Lab/SlotCon/blob/main/models/slotcon.py#L186, the slots with the same indexes are viewed as positive indexes while I find that these slots are generated by masked pooling from features and indexes maybe not be related to the semantic classes. Maybe I have missed something.

Looking forward to your rely!

About the prototypes

Hi Xin Wen,

Thanks for your great work! Regarding SlotCon, I have two questions:
(1) I notice the prototypes are initialized with nn.Embedding. I am wondering how to ensure that the trainable prototypes are optimized to be meaningful semantic groups via backpropagation. Since the loss functions do not explicitly ensure this, I am a little bit confused about the optimization of prototypes.
(2) Have you tried how the resize operation in the data augmentation matters? I mean, if you only do crop along with other augmentation, without resize operation, will the performance drop heavily?

Thanks for your reply!

Question about 3D model

Hi, Xin wen!

I am trying to use SlotCon to process 3D data, but some problems hinder me. Such as the operation torchvision.ops.roi_align don't provide 3D version. Do you have some solutions for that?

i need help

I am currently following the script for "Evaluation: Semantic Segmentation." However, when I try to follow the last "# run cityscapes" part, I encounter an error similar to the image below. I would like to know the solution to this issue.

스크린샷 2023-07-25 23-09-25
스크린샷 2023-07-25 23-10-10
스크린샷 2023-07-25 23-10-37

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.