Giter Site home page Giter Site logo

transinvnet's Introduction

TransInvNet: Combining Transformer and Involution Network for Polyp Segentation


1. Overview

1.1 Introduction

Prevention of colorectal cancer has become a world health issue. In clinic practice, doctors usually use colonoscopy to detect polyps, but accurately segmenting polyps from colonoscopy images is a challenging task. To address this challenge, many CNN-based methods have been proposed. However, pure CNN-based methods have limitations. To overcome such limitations, we propose a novel architecture namely TransInvNet for accurate polyp segmentation in colonoscopy images. To be more specific, we combine the recently proposed involution network with vision transformer in two parallel branches, and then combine their output features. Based on the combined feature, we then use a simple decoder architecture with skip-connections to increase the resolution while decrease the channels step by step. Finally, we propose an attention segmentation module to combine attention map with reverse attention map together, which is able to help us distinguish polyp from its surrounding tissues and improve segmentation accuracy. Our method achieves great result on Kvasir dataset (mDice 0.910), and it also holds a good generalization ability on those unseen dataset (ETIS, CVC-ColonDB, Endoscene).

1.2 Network Architecture

alt Figure1

Figure 1: Architecture of proposed TransInvNet, which consists two parallel branches of RedNet and ViT respectively with a simple decoder.

alt Figure2

Figure 2: Architecture of attention segmentation module.

1.3 Quantitative Results

Our train/test split policy follows PraNet: Parallel Reverse Attention Network for Polyp Segmentation. 900 images from Kvasir-SEG and 550 images from CVC-ClinicDB are used for training, while rest images of these 2 datasets and CVC-ColonDB, ETIS, test set of Endoscene are used for testing.

alt Figure3

Figure 3: Quantitative results on Kvasir-SEG and CVC-ClinicDB datasets.

alt Figure4

Figure 3: Quantitative results on ETIS, Endoscene and CVC-Colon datasets.

1.4 Qualitative Results

alt Figure5

Figure 3: Qualitative results of our proposed TransInvNet compared to PraNet and HarDNet-MSEG.

1.5 Directory Tree for TransInvNet

.
├── cal_params.py
├── eval.py
├── images
│ ├── framework.png
│ ├── qualitiveresult.png
│ ├── quantitativeresult1.png
│ ├── quantitativeresult2.png
│ └── segmentationhead.png
├── inference.py
├── README.md
├── requirements.txt
├── train.py
├── TransInvNet
│ ├── model
│ │ ├── backbone
│ │ │ ├── base_backbone.py
│ │ │ ├── builder.py
│ │ │ └── rednet.py
│ │ ├── basic_blocks.py
│ │ ├── config.py
│ │ ├── decoder
│ │ │   └── decoder.py
│ │ ├── model.py
│ │ └── vit
│ │     └── vit.py
│ └── utils
│     ├── dataloader.py
│     ├── involution_cuda.py
│     └── utils.py

2. Installation & Usage

In our experiments, all training/testing are conducted using Pytorch with a single RTX2080 Ti GPU.

2.1 Installation

  • Install required libraries:
    • pip install -r requirements.txt
  • Download necessary data:
    We use five datasets in our experiments: Kvasir-SEG, CVC-ClinicDB, CVC-ColonDB, ETIS and EndoScene. We use the same split policy as PraNet, and you can download these datasets from their repo. Thanks to their great work.
    • Download train dataset. This dataset can be downloaded from this link (Google Drive). Configure your train_path to the directory of train dataset.
    • Download test dataset. This dataset can be downloaded from this link (Google Drive). Configure your test_path to the directory of test dataset.
  • Download Pretrained weights:
    Download pretrained weights for ViT and RedNet. A large part of our code is from ViT-Pytorch and Involution. Thanks for their wonderful works.
    • Download pretrained weights for Vision Transformer at this link. We use ViT-B_16 for our experiments. Place pretrained weights into TransInvNet/lib.
    • Download pretrained weights for RedNet at this link. We use RedNet-50 for our experiments. Place pretrained weights into TransInvNet/lib.

2.2 Training

pytron train.py --epoch --lr --batch_size --accmulation --img_size --clip --cfg --train_path --test_path --output_path --seed

For detailed information about each argument, please use python train.py --help.

2.3 Testing/Inference

To inference images using our proposed TransInvNet, you can either download our pretrained weights from this link or train one by yourself. After downloading pretrained weights of TransInvNet or finishing training, configure your weigh_path to trained weights and test_path to images you would like to inference. Use this command to inference images.

python inference.py --img_size --weight_path --test_path --output_path --threshold

For detailed information about each argument, please use python inference.py --help

2.4 Evaluation

Our evaluation code is modified from link. To evaluation a model, you need to configure weight_path to the trained weights and test_path to the dataset you would like to evaluate. You can use this command to run the evaluation script.

python eval.py --img_size --weight_path --test_path

For detailed information about each argument, please use python eval.py --help

2.5 Acknowledge

  • The code of Vision Transformer part is borrowed from ViT-Pytorch.
  • The code of Involution part is borrowed from involution.
  • Datasets used for experiments are from PraNet.

3. Reference

transinvnet's People

Stargazers

 avatar abellab avatar

Watchers

Roger Liu avatar

Forkers

clementfyj

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.