Giter Site home page Giter Site logo

gloria's Introduction

GLoRIA: A Multimodal Global-Local Representation Learning Framework for Label-efficient Medical Image Recognition

GLoRIA (Global-Local Representations for Images using Attenion) is a multimodal representation learning framework for label-efficient medical image recognition. Our results demonstrate high-performance and label-efficiency for image-text retrieval, classification (finetuning and zeros-shot settings), and segmentation on different medical imaging datasets.

GLoRIA Manuscript
Shih-Cheng Huang (Mars), Liyue Shen, Matthew P. Lungren, Serena Yeung
Stanford University
Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2021

Approach

GLoRIA

Usage

Start by installing PyTorch 1.7.1 with the right CUDA version, then clone this repository and install the dependencies.

$ conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch
$ pip install [email protected]:marshuang80/gloria.git
$ conda env create -f environment.yml

Make sure to download the pretrained weights from here and place it in the ./pretrained folder.

Load GLoRIA pretrained models

import torch
import gloria

# get device
device = "cuda" if torch.cuda.is_available() else "cpu"

# load classifier
num_class = 5   # 5 class classification
freeze = True   # freeze encoder and only train linear classifier (less likely to overfit when training data is limited)
model = gloria.load_img_classification_model(num_cls=num_class, freeze_encoder=freeze, device=device)

# load segmentation model (UNet)
seg_model = gloria.load_img_segmentation_model(device=device)

Zeroshot classification for CheXpert5x200

import torch
import gloria
import pandas as pd 

df = pd.read_csv(gloria.constants.CHEXPERT_5x200)

# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
gloria_model = gloria.load_gloria(device=device)

# generate class prompt
# cls_promts = {
#    'Atelectasis': ['minimal residual atelectasis ', 'mild atelectasis' ...]
#    'Cardiomegaly': ['cardiomegaly unchanged', 'cardiac silhouette enlarged' ...] 
# ...
# } 
cls_prompts = gloria.generate_chexpert_class_prompts()

# process input images and class prompts 
processed_txt = gloria_model.process_class_prompts(cls_prompts, device)
processed_imgs = gloria_model.process_img(df['Path'].tolist(), device)

# zero-shot classification on 1000 images
similarities = gloria.zero_shot_classification(
    gloria_model, processed_imgs, processed_txt)

print(similarities)
#      Atelectasis  Cardiomegaly  Consolidation     Edema  Pleural Effusion
# 0       1.371477     -0.416303      -1.023546 -1.460464          0.145969
# 1       1.550474      0.277534       1.743613  0.187523          1.166638
# ..           ...           ...            ...       ...               ...

Training

This codebase has been developed with python version 3.7, PyTorch version 1.7.1, CUDA 10.2 and pytorch-lightning 1.1.4. Example configurations for pretraining and downstream classification can be found in the ./configs. All training and testing are done using the run.py script. For more documentation, please run:

python run.py --help

The preprocessing steps for each dataset can be found in ./gloria/datasets/preprocess_datasets.py

Representation Learning

Train the representation learning model with the following command:

python run.py -c ./configs/chexpert_pretrain_config.yaml --train
  • Please note that the CheXpert radiology reports are still under PHI review for HIPPA compliency, and not publicly availible yet.

Classification

Fine-tune the GLoRIA pretrained image model for classification with the following command:

# chexpert
python run.py -c ./configs/chexpert_classification_config.yaml --train --test --train_pct 0.01
# pneumonia
python run.py -c ./configs/pneumonia_classification_config.yaml --train --test --train_pct 0.01

The train_pct flag randomly selects a percentage of the dataset to fine-tune the model. This is use to determine the performance of the model under low data regime.

Segmentation

Fine-tune the GLoRIA pretrained image model for segmentation with the following command:

# chexpert
python run.py -c ./configs/pneumothorax_segmentation_config.yaml --train --test --train_pct 0.01

Citation

@inproceedings{huang2021gloria,
  title={GLoRIA: A Multimodal Global-Local Representation Learning Framework for Label-Efficient Medical Image Recognition},
  author={Huang, Shih-Cheng and Shen, Liyue and Lungren, Matthew P and Yeung, Serena},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={3942--3951},
  year={2021}
}

Acknowledgements

This codebase is adapted from ControlGAN

gloria's People

Contributors

marshuang80 avatar jbdel avatar

Stargazers

Andy avatar Soyeon BAK avatar sea_comet avatar Adil Dahlan avatar  avatar  avatar xufei avatar  avatar  avatar Chi Phan avatar Yuanpin Zhou avatar Muhammad Uzair Khattak avatar  avatar 394481125 avatar  avatar Karl Wu avatar  avatar Suoni Liu avatar  avatar Seonghoon-Yu avatar  Wang Helin avatar  avatar  avatar lthphu avatar Nicolay Rusnachenko avatar SoHappy avatar Linsen Mu avatar Ye Du avatar YANG AN avatar Jun Wang avatar Charl1e avatar LINGTONG ZHANG avatar He Li avatar  avatar Qchub avatar Youngtaek Oh avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar ZHANG CHU avatar Kang Wu avatar  avatar  avatar  avatar  avatar Xuchen Li (李旭宸) avatar Xiao Wang(王逍) avatar Liyuan Liu avatar  avatar  avatar  avatar hexiao avatar Cristata avatar Yingshu Li avatar Rain avatar  avatar Won Dong Kyu avatar AFK avatar Saika avatar José Domingues avatar FANG Xiao avatar Maple avatar Yuexi Du avatar YOLO avatar zhaozh10 avatar shen chen avatar Mohammad Reza Taesiri avatar Ronny Polle avatar  avatar jinqiwen avatar Mia Lim avatar  avatar Lxx avatar  avatar Prateek Upadhya avatar GoldWater avatar Dreamsome avatar Yujing Zou avatar Yann Ma avatar  avatar Aakash Tripathi avatar Aria F avatar arezki kacioui avatar Zirong.Liu avatar Pikachu avatar Faith Tan LH avatar YueJK avatar Yu Huang avatar Du Shenghui avatar  avatar Sinuo Wang avatar  avatar  avatar  avatar  avatar

Watchers

James Cloos avatar Jeovane H. Alves avatar Edward Chen avatar  avatar  avatar Cara Van Uden avatar

gloria's Issues

Can not download the pretrained model

Hello!
Thanks for your excellent work and model. I want to follow your model and transfer it to another task. However, I can not download the pre-trained ResNet-50 or ResNet-18 weights due to the website proxy or VPN. Could you please upload the model weights to Google Drive or Aliyundirve? My email address is [email protected].
Best wishes,
Li

Some questions about the similarities?

In the gloria_model.py/def get_local_similarities(130), row_sim is computed by "max"

 row_sim, max_row_idx = torch.max(row_sim, dim=1, keepdim=True)  # [48, 1]

and, in the gloria_loss.py / def local_loss(120), row_sim is computed by "sum"/mean

 if agg == "sum":
     row_sim = row_sim.sum(dim=1, keepdim=True)  # [48, 1]
 else:
     row_sim = row_sim.mean(dim=1, keepdim=True)  # [48, 1]
  • I would like to know what this [48,1]-dimensional vector represents ?
  • Why does it have different operations?(max, sum,mean)
  • and what does the subsequent concat out of the [48,48]-dimensional vector represent?

Hello, I can't find chexpert_8x200.csv

in gloria/gloria/constants.py line 18 is
CHEXPERT_5x200 = CHEXPERT_DATA_DIR / "chexpert_8x200.csv"
but pretrained weights from here
there isn't chexpert_8x200.csv
I look forward to and thank you for your reply!

Does the CheXpert dataset include reports now?

Hi,

Thank you very much for releasing the source code of your work. I noticed that you use CheXpert for multimodal pre-training of your model. However, as far as I'm aware, the CheXpert dataset does not include the actual reports, only labels they extracted from the reports. I understand that people doing research on images + reports usually use datasets like MIMIC-CXR and Open-I. In fact, a fellow researcher downloaded the CheXpert dataset a couple of years ago and confirmed that it only came with labels and images, but no reports. However, I'm checking the website of the dataset right now (here: https://stanfordaimi.azurewebsites.net/datasets/8cbd9ed4-2eb9-4565-affc-111cf4f7ebe2) and just noticed that they have included new labels, i.e., cheXbert and visualCheXbert generated labels, and they have also increased the size of CheXpert-v1.0.zip to 471.12 GB (we have a copy downloaded a couple of years ago that weighs 439 GB).

Question: does that mean that the CheXpert dataset includes reports now? If so, that would be spectacular, because it would mean we could experiment with CheXpert + MIMIC-CXR instead of just using MIMIC-CXR alone (or CheXpert alone for that matter).

In fact, I'm curious: is there a particular reason why you didn't include MIMIC-CXR in your experiments (https://physionet.org/content/mimic-cxr/2.0.0/)?

Kind regards,
Pablo

Considerations for cheXpert 5*200?

In the paper, it is mentioned that you picked 5 tasks "Atelectasis, Cardiomegaly,
Edema, Pleural, Effsion" out of 14, but when I look into the original cheXpert dataset, there is no "Pleural" or "Effusion" in the 14 categories, only "Pleural Effusion" and "Pleural Others"?

Finetuned model for segmentation?

Hi,
Thank you for your excellent work. Could I ask if you could offer the fine-tuned model for segmentation?
Also, how is the SIIM data from kaggle mapped to the code you used? What are the train.csv/val.csv/test.csv?
Thank you so much!
image

Pretrain with MIMIC-CXR Val Loss

Hi,

I am pretraining the model with the MIMIC-CXR JPG dataset, but the local alignment validation loss is kinda weird. Is this normal when you pretrain GLoRIA with MIMIC?

Thanks in advance!!
loss

how to load those saved models to continue training

I'm sorry to bother you again

I want to know how do I load those saved models
./gloria/data/ckpt/pneumothorax_segmentation_0.01/2022_08_03_12_23_46/epoch=35-step=143.ckpt

What is the difference between this and the pre-trained model
How can I continue my previous training after the training is completed

I have tried replacing the pre-training model directly with epoch=35-step= 143.CKpt
But I found a difference in the size of the two, and it was not successful,

image text Retriver

Great work. Could you please supply an example about how to use the class Retriver? Thanks.

The reported results of RSNA dataset

Hi~
Thank you very much for your impressive work!
The results of RSNA dataset are quite different from the original paper, I feel confused about this.
Thank you very much if you could explain the reason.
image
image

RuntimeError: grad can be implicitly created only for scalar outputs

When I fine-tune the pre-trained model weights in SIIM segmentation tasks, the following error is reported:

  | Name  | Type      | Params
------------------------------------
0 | model | Unet      | 32.5 M
1 | loss  | MixedLoss | 0     
------------------------------------
32.5 M    Trainable params
0         Non-trainable params
32.5 M    Total params
Epoch 0:   0%|          | 0/669 [00:01<?, ?it/s]                      
Traceback (most recent call last):
  File "run.py", line 167, in <module>
    main(cfg, args)
  File "run.py", line 106, in main
    trainer.fit(model, dm)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 473, in fit
    results = self.accelerator_backend.train()
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/accelerators/dp_accelerator.py", line 110, in train
    results = self.train_or_test()
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 69, in train_or_test
    results = self.trainer.train()
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 524, in train
    self.train_loop.run_training_epoch()
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 572, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 730, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 513, in optimizer_step
    using_lbfgs=is_lbfgs,
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py", line 1261, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 286, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 140, in __optimizer_step
    trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/plugins/native_amp.py", line 75, in optimizer_step
    closure()
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 725, in train_step_and_backward_closure
    self.trainer.hiddens
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 828, in training_step_and_backward
    self.backward(result, optimizer, opt_idx)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 851, in backward
    result.closure_loss, optimizer, opt_idx, *args, **kwargs
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 99, in backward
    closure_loss, optimizer, opt_idx, *args, **kwargs
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/plugins/native_amp.py", line 47, in backward
    model.backward(closure_loss, optimizer, opt_idx)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py", line 1158, in backward
    loss.backward(*args, **kwargs)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/torch/autograd/__init__.py", line 126, in backward
    grad_tensors_ = _make_grads(tensors, grad_tensors_)
  File "/home/wentaochen/anaconda3/envs/gloria/lib/python3.7/site-packages/torch/autograd/__init__.py", line 50, in _make_grads
    raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

Hope you can give me some advice! Thank you very much!

Large CPU usage rate

I try to train gloria in MIMIC-III. To fit the dataset function, I write a csv file for MIMIC-III. The code is working. However, the cpu usage rate achieve 99 %, which is abnormal.

Link to CheXpert5x200

Hey

Your work is really amazing. Can you please share the link to download CheXpert5x200 dataset used for Image-text retrieval? I would like to run some experiments.

Thanks in advance.
Shivangi

In the segmentation task, EncodedPixels seems to have an extra space, which I remove, but...

When I was working on the segmentation task, I ran into a problem:
Traceback (most recent call last): File "run.py", line 167, in <module> main(cfg, args) File "run.py", line 106, in main trainer.fit(model, dm) File "/GPUFS/nsccgz_ywang_zfd/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 457, in fit self.accelerator_backend.setup(model) File "/GPUFS/nsccgz_ywang_zfd/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/dp_accelerator.py", line 56, in setup self.setup_optimizers(model) File "/GPUFS/nsccgz_ywang_zfd/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 145, in setup_optimizers optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) File "/GPUFS/nsccgz_ywang_zfd/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/optimizers.py", line 31, in init_optimizers optim_conf = model.configure_optimizers() File "/GPUFS/nsccgz_ywang_zfd/zxp/gloria/gloria/lightning/segmentation_model.py", line 37, in configure_optimizers scheduler = builder.build_scheduler(self.cfg, optimizer, self.dm) File "/GPUFS/nsccgz_ywang_zfd/zxp/gloria/gloria/builder.py", line 109, in build_scheduler num_iter = len(dm.train_dataloader().dataset) File "/GPUFS/nsccgz_ywang_zfd/zxp/gloria/gloria/datasets/data_module.py", line 110, in train_dataloader dataset = self.dataset(self.cfg, split="train", transform=transform) File "/GPUFS/nsccgz_ywang_zfd/zxp/gloria/gloria/datasets/image_dataset.py", line 192, in __init__ neg_series_selected = np.random.choice( File "mtrand.pyx", line 908, in numpy.random.mtrand.RandomState.choice ValueError: 'a' cannot be empty unless no samples are taken
ValueError: 'a' cannot be empty unless no samples are taken
I want to ask what is this first parameter
How to deal with

pretrained code

Hi, i am running your pretrained code and I got this error and do you know how to deal with this?
image

zero shot classification results

I tried using your script for zero shot classification together with the pretrained weights (both resnet18 and resnet50). The calssification results I got are very random (accuracy 17-22% for each class). Maybe there is an aditional step needed or the weights I downloaded are not the trained weights?

PneumothoraxImageDataset for segmentation

Thank you for your impressive work.
Does PneumothoraxImageDataset includes negative samples? I found the comment
# only keep positive samples for segmentation
in

# only keep positive samples for segmentation

but I read that code and it looks like that the negative samples are included.
So does it includes negative samples for segmentation and how are the scale (the number of images) of train/val/test sets in your experiments.
Thank you very much if you could give me the details.

pip install failed

Howdy,

I am trying to install gloria via pip, and I did pip install [email protected]:marshuang80/gloria.git. I got this error:

ERROR: Invalid requirement: '[email protected]:marshuang80/gloria.git' Hint: It looks like a path. File '[email protected]:marshuang80/gloria.git' does not exist.

Could you help to take a look at it?

Many thanks!

Reproduced results on RSNA Pneumonia dataset

Hi,
Thanks for your impressive work!

Recently, i tried to reproduce the classification results on the RSNA Pneumonia dataset. I used the released code and the pretrained model on your project and run the experiments follow your instruction. And surprisingly, i got the results even higher than the paper. Actually, i thought it might because of the randomness, but i tried to run several times with different random seed and all the results were superior, which makes me confused.

Do you think it's normal? Is there any improvement contained on your released pretrained model?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.