Giter Site home page Giter Site logo

fallingnight / pspg Goto Github PK

View Code? Open in Web Editor NEW
3.0 2.0 2.0 3.38 MB

PRCV'24 | Official PyTorch implementation of "Pseudo-Prompt Generating in Pre-trained Vision-Language Models for Multi-Label Medical Image Classification"

Makefile 1.41% Python 98.59%

pspg's Introduction

PsPG - Pseudo Prompt Generating


Awesome Screenshot Pretrained models GitHub Repo stars GitHub Repo forks

Here is the official implementation of Pseudo Prompt Generating (PsPG).

Yaoqin Ye, Junjie Zhang and Hongwei Shi. 2024. Pseudo-Prompt Generating in Pre-trained Vision-Language Models for Multi-Label Medical Image Classification. In Pattern Recognition and Computer Vision: 7th Chinese Conference, PRCV 2024

Intro

Pseudo-Prompt Generating (PsPG) is designed to address Multi-Label Medical Image Classification task, capitalizing on the priori knowledge of multi-modal features. Featuring a RNN-based decoder, PsPG autoregressively generates class-tailored embedding vectors, i.e., pseudo-prompts.

Via python

We recommend you use conda and provide the environment.yaml. You can import and activate the environment by:

conda env create -f environment.yaml
conda activate pspg

Get Started

You can use python train.py --help or python val.py --help to learn the usage of the python scripts for training and validating. Next section we'll introduce some details.

For reference, we provide two typical commands for training and validating, respectively:

# train
python train.py -nc configs/model/RN50.yaml -dc configs/dataset/example.yaml \
  --dataset_dir datasets/example --checkpoint weights/PsPG.pth --decoder_hidden 512 \
  --max_epochs 50 --output_dir output  --test_file_path test.jsonl \
  --val_file_path val.jsonl --train_file_path train.jsonl --start_afresh 
# validate
python val.py -nc configs/model/RN50.yaml -dc configs/dataset/example.yaml \
  --dataset_dir datasets/example --checkpoint weights/PsPG.pth --decoder_hidden 512 \
  --output_dir output  --test_file_path test.jsonl \
  --model_name example-RN50 --save_pred

Via make (recommended)

We provide makefile to simplify the process of training and validating the PsPG model, and we recommend you use make instead of directly executing the python script.

To use make, you must have linux, msys2 (if in Windows) or OSX (not tested for OSX) environment with make and conda installed. And make sure that conda is in your PATH, or CONDA_EXE is set to the path of conda executable (usually path/to/Conda/Scripts/conda). Otherwise, the make will report that it cannot find conda.

Get started

Upon the first usage, you can establish the conda environment of running the PsPG model by this command:

make init

Next, you can train the model with the default setting with this command:

make train

Then, you can validate the model with the weights you just trained with this command:

make val

The weights and predicting results will output to the output folder by default, use this command to clear the output folder (it will clear the __pycache__ as well):

make clean

Plug in Dataset

We provide an example dataset so that user can run for the first time quickly. For other datasets, please refer to Datasets.

To plug in a dataset, you should create a dataset folder named after the dataset in the datasets folder, create two (or three) datamap files, and place the datamap files (will be explained afterwards) and the pictures into the folder.

A datamap file consists of lines of jsons (jsonl), each decribing one data entry of a dataset. A data entry consists of:

  • id (optional)
  • path to the image (relative to the dataset root directory)
  • text (optional)
  • ground truth labels

The datamap files required by a dataset should be named as test.jsonl, val.jsonl (if need) and train.jsonl (or you could modify Makefile to use custom setting).

You can refer to datasets/example to further learn the format of a dataset.

Append DATASET=[dataset] to your make command to specify the [dataset], and the variable is set to example by default.

Select The Weights

By default, make will check if there is any file of checkpoint output during training (to be exact if there is last.pth in the output folder). If so, make will select it as the checkpoint for training and validating.

Otherwise, make will next check if there is weights/PsPG.pth. If so, this will be used, or, make will try to train the model from scratch in case of train target is specified or report an error in case of val is specified.

If you want to specify a target checkpoint, append CHECKPOINT=[yourCheckpointPath] to your make command.

Change the Image Encoder of CLIP

You can change the backbone type via appending MODEL=model to the make command. We use model=RN50 by default. The provided weights are also for resnet-50.

You can use other pre-trained clip provided by openai like RN50, RN101, ViT-B.

Advanced usage

There are other variables in makefile for more flexible usage. If you want to modify more detailedly, please refer to utils/cfg_builder.py.

EPOCH

You can control the max train epochs by specifying the EPOCH variable.

MODEL_NAME

The model name is used only when saving the validating result. Results will be output to output/$(MODEL_NAME) by default.

OUTPUT_DIR

You can change the name of the output folder by appending OUTPUT_DIR=output_dir. This variable is useful when you want to execute two or more processes simultaneously.

EXTRA_FLAGS

You can add extra flags to the python script by specifying the EXTRA_FLAGS variable in Makefile (See Makefilefor details).

Or you can append EXTRAFLAGS="--option1 arg1 --option2 arg2..." (do not add extra spaces!!) to the make command.

If you wonder what flags are provided by the python script, execute:

  • make val VAL_FLAGS=--help (for val.py)
  • make train TRAIN_FLAGS=--help (for train.py)

If you want to override the flag specified by make, use VAL_FLAG="..." or TRAIN_FLAGS="...".

We recommend you to refer to the source code.

Data and Processing

Datasets

Datasets for training or testing can be accessed via the following links:

In order to access to some of these datasets, you should be credentialed.

Notice that due to restriction, Private-CXR is not provided. For any research requests, please contact the corresponding author.

Processing Scripts

We provide some scripts can be run independently to preprocess datasets or calculate metrics. For more information, please refer to utils/independ_metrics and utils/preprocess.

Pre-trained Weights

We provide one pre-trained weight: PsPG. This weight was trained on CheXpert, use RN50 as backbone, with hidden size 512, seq length 16.

In addition, we provide another without parameters of the prompt learner: CLIP-finetuned. This weight was finetuned on MIMIC-CXR.

pspg's People

Contributors

bardisk avatar fallingnight avatar

Stargazers

Junjie Zhang avatar cgnerds avatar  avatar

Watchers

Kostas Georgiou avatar  avatar

Forkers

bardisk zchanceg

pspg's Issues

spatial attention module

嗨喽,你好呀
最近看了你们的论文,看到其中有 spatial attention module 模块。我有点困惑,这个模块里是只对local-level的特征进行了通道注意力的计算吧?然后把这个和global的特征进行了拼接。

Code Release?

Your research is so excellent, can you release the code?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.