Giter Site home page Giter Site logo

fhaghighi / dira Goto Github PK

View Code? Open in Web Editor NEW
97.0 3.0 9.0 7.05 MB

Official PyTorch Implementation for DiRA: Discriminative, Restorative, and Adversarial Learning for Self-supervised Medical Image Analysis - CVPR 2022

License: Other

Python 100.00%
collaborative-learning contrastive-learning instance-discrimination medical-imaging self-supervised-learning transfer-learning

dira's Introduction

[CVPR'22] DiRA: Discriminative, Restorative, and Adversarial Learning for Self-supervised Medical Image Analysis

This repository provides a PyTorch implementation of the DiRA: Discriminative, Restorative, and Adversarial Learning for Self-supervised Medical Image Analysis which is published in CVPR 2022 (main conference).

Discriminative learning, restorative learning, and adversarial learning have proven beneficial for self-supervised learning schemes in computer vision and medical imaging. Existing efforts, however, omit their synergistic effects on each other in a ternary setup, which, we envision, can significantly benefit deep semantic representation learning. To realize this vision, we have developed DiRA, the first framework that unites discriminative, restorative, and adversarial learning in a unified manner to collaboratively glean complementary visual information from unlabeled medical images for fine-grained semantic representation learning.



Publication

DiRA: Discriminative, Restorative, and Adversarial Learning for Self-supervised Medical Image Analysis
Fatemeh Haghighi1*, Mohammad Reza Hosseinzadeh Taher1*, Michael B. Gotway2, Jianming Liang1
1 Arizona State University, 2 Mayo Clinic
* Equal contributors ordered alphabetically.
Published in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2022.

Paper | Code | Poster | Presentation (#)

Major results from our work

  1. DiRA enriches discriminative learning.


  1. DiRA improves robustness to small data regimes.


  1. DiRA improves weakly-supervised localization.

  1. DiRA outperforms fully-supervised baselines.


Credit to superbar by Scott Lowe for Matlab code of superbar.

Requirements

Installation

Clone the repository and install dependencies using the following command:

$ git clone https://github.com/fhaghighi/DiRA.git
$ cd DiRA/
$ pip install -r requirements.txt

Self-supervised pre-training

1. Preparing data

We used traing set of ChestX-ray14 dataset for pre-training 2D DiRA models, which can be downloaded from this link.

  • The downloaded ChestX-ray14 should have a directory structure as follows:
ChestX-ray14/
    |--  images/ 
         |-- 00000012_000.png
         |-- 00000017_002.png
         ... 

We use 10% of training data for validation. We also provide the list of training and validation images in dataset/Xray14_train_official.txt and dataset/Xray14_val_official.txt, respectively. The training set is based on the officiall split provided by ChestX-ray14 dataset. Training labels are not used during pre-training stage. The path to images folder is required for pre-training stage.

2. Pre-training DiRA

This implementation only supports multi-gpu, DistributedDataParallel training, which is faster and simpler; single-gpu or DataParallel training is not supported. The instance discrimination setup follows MoCo. The checkpoints with the lowest validation loss are used for fine-tuning. We do unsupervised pre-training of a U-Net model with ResNet-50 backbone on ChestX-ray14 using 4 NVIDIA V100 GPUs.

To stabilize the adversarial training process, we first warm up the encoder and decoder by training the discriminative and restorative components. To do so, run the following command:

python main_DiRA_moco.py /path/to/images/folder --dist-url 'tcp://localhost:10001' --multiprocessing-distributed \
--world-size 1 --rank 0 --mlp --moco-t 0.2  --cos --mode dir 

Next, we add the adversarial learning to jointly train the whole framework. To do so, run the following command:

python main_DiRA_moco.py /path/to/images/folder --dist-url 'tcp://localhost:10001' --multiprocessing-distributed \
--world-size 1 --rank 0 --mlp --moco-t 0.2  --cos --mode dira --batch-size 16   --epochs 400 --generator_pre_trained_weights checkpoint/DiRA_moco/dir/checkpoint.pth 

Fine-tuning on downstream tasks

For downstream tasks, we use the code provided by recent transfer learning benchmark in medical imaging.

DiRA provides a pre-trained U-Net model, which the encoder can be utilized for the classification and encoder-decoder for the segmentation downstream tasks.

For classification tasks, a ResNet-50 encoder can be initialized with the pre-trained encoder of DiRA as follows:

import torchvision.models as models

num_classes = #number of target task classes
weight = #path to DiRA pre-trained model
model = models.__dict__['resnet50'](num_classes=num_classes)
state_dict = torch.load(weight, map_location="cpu")
if "state_dict" in state_dict:
   state_dict = state_dict["state_dict"]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items()}
for k in list(state_dict.keys()):
   if k.startswith('fc') or k.startswith('segmentation_head') or k.startswith('decoder') :
      del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print("=> loaded pre-trained model '{}'".format(weight))
print("missing keys:", msg.missing_keys)

For segmentation tasks, a U-Net can be initialized with the pre-trained encoder and decoder of DiRA as follows:

import segmentation_models_pytorch as smp

backbone = 'resnet50'
weight = #path to DiRA pre-trained model
model=smp.Unet(backbone)
state_dict = torch.load(weight, map_location="cpu")
if "state_dict" in state_dict:
   state_dict = state_dict["state_dict"]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
for k in list(state_dict.keys()):
   if k.startswith('fc') or k.startswith('segmentation_head'):
      del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print("=> loaded pre-trained model '{}'".format(weight))
print("missing keys:", msg.missing_keys)

Citation

If you use this code or use our pre-trained weights for your research, please cite our paper:

@misc{haghighi2022dira,
      title={DiRA: Discriminative, Restorative, and Adversarial Learning for Self-supervised Medical Image Analysis}, 
      author={Fatemeh Haghighi and Mohammad Reza Hosseinzadeh Taher and Michael B. Gotway and Jianming Liang},
      year={2022},
      eprint={2204.10437},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgement

With the help of Zongwei Zhou, Zuwei Guo started implementing the earlier ideas behind ``United & Unified'', which has branched out into DiRA. We thank them for their feasibility exploration, especially their initial evaluation on TransVW and various training strategies. This research has been supported in part by ASU and Mayo Clinic through a Seed Grant and an Innovation Grant and in part by the NIH under Award Number R01HL128785. The content is solely the responsibility of the authors and does not necessarily represent the official views of the NIH. This work utilized the GPUs provided in part by the ASU Research Computing and in part by the Extreme Science and Engineering Discovery Environment (XSEDE) funded by the National Science Foundation (NSF) under grant number ACI-1548562. Paper content is covered by patents pending. We build U-Net architecture for segmentation tasks by referring to the released code at segmentation_models.pytorch. The instance discrimination is based on MoCo.

License

Released under the ASU GitHub Project License.

dira's People

Contributors

fhaghighi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

dira's Issues

pre-trained weights

hi, thansk for open-source.
can you share the pre-trained weights? such as resnet50 backbone and others.
thanks.

DIRA code for 3D volume pretraining

Thank you very much for sharing the codes. Could you please share the code for training the 3D models by applying DIRA to TransVW? If already available, could you please provide the link to the shared code?

Question about training details of ChestX-ray14

I think it is a great job for medical image analysis! However, I have some questions about the details of training 2D images.
①You did a warming up and then jointly training, which totally took up to 800 epoches as shown in your paper. Could you please tell me how many epoches it takes for warming up and how many epoches for jointly training?
②In the jointly training stage, you set the batch size to 16 instead of 256 in the repository, but it seems not to be mentioned in your paper. But I found it take little gpu memory when I set batch size to 16 and work well for batch size of 256. Is there any reason to set the batch size smaller? What happened if larger batch size like 256 is used during jointly training?
Thanks in advance for your time!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.