Giter Site home page Giter Site logo

xiaojiean815 / synthetic-ct-generation-from-mri-using-3d-transformer-based-denoising-diffusion-model Goto Github PK

View Code? Open in Web Editor NEW

This project forked from shaoyanpan/synthetic-ct-generation-from-mri-using-3d-transformer-based-denoising-diffusion-model

0.0 0.0 0.0 5.1 MB

License: MIT License

Python 88.58% Jupyter Notebook 11.42%

synthetic-ct-generation-from-mri-using-3d-transformer-based-denoising-diffusion-model's Introduction

Synthetic CT generation from MRI using 3D transformer-based denoising diffusion model

This is the repository for the paper published in Medical Physics: "Synthetic CT generation from MRI using 3D transformer-based denoising diffusion model". You can also find an early arXiv version at https://arxiv.org/abs/2305.19467

The codes were created based on image-guided diffusion, SwinUnet, and Monai

Notice: Due to the data restriction, we can only provide MATLAB file (so no patient information) with over-smoothed CT volumes. The data we show just to demonstrate how the user should organize their data. The dicom or nii file processing are also included in the Jupyter notebook.

Required packages

The requires packages are in environment.yaml.

Create an environment using Anaconda:

conda env create -f \your directory\environment.yaml


Data organization

The data organization example is shown in folder "MRI_to_CT_brain_for_dosimetric\imagesTr". Or you can see the below screenshots:

MATLAB files: every matlab file can contain a dict has image and label together. So you see you only need two folders: imagesTr for training, imagesTs for testing, and imagesVal for validation. You can change the name but please make sure also change the reading dir in the jupyter notebook Capture

Nii files: one nii file can only contain either image or label. So in this case, you need imagesTr and labelsTr for training, imagesTs and labelsTs for testing, and imagesVal and labelsVal for validation Capture2

Usage

The usage is in the jupyter notebook MC-IDDPM main.ipynb. Including how to build a diffusion process, how to build a network, and how to call the diffusion process to train, and sample new synthetic images. However, we give simple example below:

Create diffusion

from diffusion.Create_diffusion import *
from diffusion.resampler import *

diffusion_steps=1000
learn_sigma=True
timestep_respacing=[50]

# Don't toch these parameters, they are irrelant to the image synthesis
sigma_small=False
class_cond=False
noise_schedule='linear'
use_kl=False
predict_xstart=False
rescale_timesteps=True
rescale_learned_sigmas=True
use_checkpoint=False

diffusion = create_gaussian_diffusion(
    steps=diffusion_steps,
    learn_sigma=learn_sigma,
    sigma_small=sigma_small,
    noise_schedule=noise_schedule,
    use_kl=use_kl,
    predict_xstart=predict_xstart,
    rescale_timesteps=rescale_timesteps,
    rescale_learned_sigmas=rescale_learned_sigmas,
    timestep_respacing=timestep_respacing,
)
schedule_sampler = UniformSampler(diffusion)

Create network

num_channels=64
attention_resolutions="32,16,8"
channel_mult = (1, 2, 3, 4)
num_heads=[4,4,8,16]
window_size = [[4,4,4],[4,4,4],[4,4,2],[4,4,2]]
num_res_blocks = [2,2,2,2]
sample_kernel=([2,2,2],[2,2,1],[2,2,1],[2,2,1]),

attention_ds = []
for res in attention_resolutions.split(","):
    attention_ds.append(int(res))
class_cond = False
use_scale_shift_norm=True
resblock_updown = False
dropout = 0

from network.Diffusion_model_transformer import *
model = SwinVITModel(
          image_size=img_size,
          in_channels=2,
          model_channels=num_channels,
          out_channels=2,
          dims=3,
          sample_kernel = sample_kernel,
          num_res_blocks=num_res_blocks,
          attention_resolutions=tuple(attention_ds),
          dropout=dropout,
          channel_mult=channel_mult,
          num_classes=None,
          use_checkpoint=False,
          use_fp16=False,
          num_heads=num_heads,
          window_size = window_size,
          num_head_channels=64,
          num_heads_upsample=-1,
          use_scale_shift_norm=use_scale_shift_norm,
          resblock_updown=resblock_updown,
          use_new_attention_order=False,
      ).to(device)

Train the diffusion

batch_size = 10
t, weights = schedule_sampler.sample(batch_size, device)
all_loss = diffusion.training_losses(model,target,condition, t)
loss = (all_loss["loss"] * weights).mean()

Testing using MONAI's window-sliding inferencer

img_num = 12
overlap = 0.5
inferer = SlidingWindowInferer(img_size, img_num, overlap=overlap, mode ='constant')
def diffusion_sampling(condition, model):
    sampled_images = diffusion.p_sample_loop(model,(condition.shape[0], 1,
                                                    condition.shape[2], condition.shape[3],condition.shape[4]),
                                                    condition = condition,clip_denoised=True)
    return sampled_images

sampled_images = inferer(condition,diffusion_sampling,model)

Visual examples (CT, MRI and generated CT) (also shown in the paper)

Brain Picture2

Prostate Picture3

synthetic-ct-generation-from-mri-using-3d-transformer-based-denoising-diffusion-model's People

Contributors

shaoyanpan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.