Giter Site home page Giter Site logo

emkd's Introduction

EMKD

The code repository of IEEE TMI paper Efficient Medical Image Segmentation Based on Knowledge Distillation

Structure of this repository

This repository is organized as:

  • datasets contains the dataloader for different datasets
  • networks contains a model zoo for network models
  • scripts coontains scripts for preparing data
  • utils contains api for training and processing data
  • train.py train a single model
  • train_kd.py train with KD

Usage Guide

Requirements

All the codes are tested in the following environment:

  • pytorch 1.8.0
  • pytorch-lightning >= 1.3.7
  • OpenCV
  • nibabel

Dataset Preparation

KiTS

Download data here

Please follow the instructions and the data/ directory should then be structured as follows

data
├── case_00000
|   ├── imaging.nii.gz
|   └── segmentation.nii.gz
├── case_00001
|   ├── imaging.nii.gz
|   └── segmentation.nii.gz
...
├── case_00209
|   ├── imaging.nii.gz
|   └── segmentation.nii.gz
└── kits.json

Cut 3D data into slices using scripts/SliceMaker.py

python scripts/SliceMaker.py --inpath /data/kits19/data --outpath /data/kits/train --dataset kits --task tumor

LiTS

Similar to KiTS but you may make some adjustments in running scripts/SliceMaker.py

lits
├── Training_Batch
└── Test-Data
python scripts/SliceMaker.py --inpath /data/lits/Training-Batch --outpath /data/lits/train --dataset lits --task tumor

Running

Training Teacher Model

Before knowledge distillation, a well-trained teacher model is required. /train.py is used to trained a single model without KD(either a teacher model or a student model).

RAUNet is recommended to be the teacher model.

python train.py --model raunet --checkpoint_path /data/checkpoints

After training, the checkpoints will be stored in /data/checkpoints as assigned.

If you want to try different models, use --model with following choices

'deeplabv3+', 'enet', 'erfnet', 'espnet', 'mobilenetv2', 'unet++', 'raunet', 'resnet18', 'unet', 'pspnet'

Training With Knowledge Distillation

For example, use enet as student model

python train_kd.py --tckpt /data/checkpoints/name_of_teacher_checkpoint.ckpt --smodel enet

--tckpt refers to the path of teacher model checkpoint. And you can change student model by revising --smodel

emkd's People

Contributors

eaglemit avatar fubuki901 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

emkd's Issues

How do you obtain the results of LiTS?

In your paper, you mentioned that: "we follow the official split of LiTS, using 131 cases for training and 70 cases for testing."

I'm wondering if you obtain your results by submitting the segmentation masks to the evaluating server of LiTS? Since the segmentation masks of LiTS seems not available.

TypeError: __init__() missing 1 required positional argument: 'params'

Thank you very much for your previous reply, I can run the train.py file now, but when I follow your instructions to run the train_kd.py file again, an error occurred. The error occurred here !!!

      class KDPL(BasePLModel):
          def __init__(self, params):
              super(KDPL, self).__init__()
              self.save_hyperparameters(params)
      
              
              # load and freeze teacher net
              !!!self.t_net = SegPL.load_from_checkpoint(checkpoint_path=self.hparams.tckpt)
              self.t_net.freeze()
      
              # student net
              self.net = get_model(self.hparams.smodel, channels=2)

Inference

Can you provide an inference script for visualizing results?

Implementation for 3D Data

Hello,
I was wondering what kind of changes would be necessary in order to apply the code for 3D data?
Thank you for your help :)

Some questions in loss calculation

Hello
How are you?
Thanks for contributing to this project.
I have some questions about your loss calculation.

image

I think that the sizes of tensor p and q should be the same before calling F.kl_div in the above method.

image

Also, I think that the sizes of tensor t and gt should be the same in the above method.

How do u think about these?
Thanks

About reload the trained student model params

hello,
I have trained the teacher net and kd the student net,but when I reload the student net ,I found the saved kd result contains t_net and net like the picture
1660908558093

how can I only reload the student network params?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.