Comments (5)
I'm also attaching the code
import argparse
import numpy as np
import os
join = os.path.join
import pathlib
import torch.distributed as dist
from skimage import io, color
import torch
from improved_diffusion import dist_util, logger
from improved_diffusion.script_util import (
model_and_diffusion_defaults,
add_dict_to_argparser,
create_model_and_diffusion,
args_to_dict
)
import matplotlib.pyplot as plt
def create_argparser():
defaults = dict(
image_size=256,
batch_size=1,
num_channels=64,
num_res_blocks=3,
num_heads=1,
diffusion_steps=1000,
noise_schedule='linear',
lr=1e-4,
clip_denoised=False,
num_samples=1, # 10000
use_ddim=True,
# timestep_respacing='ddim250',
model_path="",
)
ori = model_and_diffusion_defaults()
# defaults.update(model_and_diffusion_defaults())
ori.update(defaults)
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, ori)
return parser
# def main():
args = create_argparser().parse_args()
logger.log(f"args: {args}")
dist_util.setup_dist()
logger.configure(dir='./log')
code_folder = './'
# data_folder = './datasets' # get_code_and_dataset_folders()
#%% load model
def read_model_and_diffusion(args, model_path):
"""Reads the latest model from the given directory."""
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys()),
)
model.load_state_dict(dist_util.load_state_dict(model_path, map_location="cuda"))
model.to(dist_util.dev())
# if args.use_fp16:
# model.convert_to_fp16()
model.eval()
return model, diffusion
ct_model_path = './work_dir/abdomenCT256/ema_0.9999_480000.pt'
s_model, s_diffusion = read_model_and_diffusion(args, ct_model_path)
mr_model_path = './work_dir/abdomenMR256/ema_0.9999_480000.pt'
t_model, t_diffusion = read_model_and_diffusion(args, mr_model_path)
save_path = './log'
#%% translate image
s_img_path = './demo-img'
names = sorted(os.listdir(s_img_path))
# names = ['ct_ori.png']
def sample2img(sample):
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous().cpu().numpy()[0]
return sample
for name in names:
ct_data = io.imread(join(s_img_path, name))
s_np = ct_data / np.max(ct_data)
s_np = (s_np - 0.5) * 2.0
# s_np = np.repeat(np.expand_dims(s_np, -1), 3, -1)
assert s_np.shape == (256, 256, 3), 'shape error! Current shape' + ct_data.shape
s_np = np.expand_dims(s_np, 0)
source = torch.from_numpy(s_np.astype(np.float32)).permute(0,3,1,2).to('cuda')
# print(f"{source.shape=}")
noise = s_diffusion.ddim_reverse_sample_loop(
s_model, source,
clip_denoised=False,
device=dist_util.dev(),
)
source_recon = s_diffusion.ddim_sample_loop(
s_model, (args.batch_size, 3, args.image_size, args.image_size),
noise=noise,
clip_denoised=False,
device=dist_util.dev(),
)
target = t_diffusion.ddim_sample_loop(
t_model, (args.batch_size, 3, args.image_size, args.image_size),
noise=noise,
clip_denoised=False,
device=dist_util.dev(),
)
#%% plot
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8,8))
images = [ct_data, color.rgb2gray(sample2img(noise)), sample2img(source_recon), sample2img(target)]
titles = ['CT image', 'CT noise encode', \
'CT reconstruction', 'CT2MR']
for i, ax in enumerate(axes.flat):
ax.imshow(images[i], cmap='gray')
ax.set_title(titles[i])
ax.axis('off')
plt.suptitle(name)
plt.savefig(join(save_path, name), dpi=300)
from ddib.
Hi Jun, thanks for your interests in our work, and attempting to validate our method on CT & MR images.
DDIBs translate images via a (regularized) optimal transport process. This is both an advantage and a limitation of our method. Training diffusion models on the two domains, independently, serves to decouple the training process; but the resulting optimal-transport based translation process may not necessarily produce images that you desire.
You can refer to Appendix B of our paper: https://arxiv.org/pdf/2203.08382.pdf, for detailed explanations about the phenomenon you observe. Let us know if you have additional questions!
from ddib.
Hi @suxuann ,
Thanks for your answer very much.
Now I understand the reason for the 2nd question.
Could you please explain the following question a little bit?
Why is cycle consistency (the noise encoding cannot reconstruct the original image) not robust? Base on the proof, it should be robust for different images.
from ddib.
Hi @JunMa11 , I'd like to ask some questions about model training.
If I want to train a new model on my own dataset, just like your . /work_dir/abdomenCT256/ema_0.9999_480000.pt
Could you please tell me how I should prepare the training script?
from ddib.
Hi @JunMa11 ,
I tried to reconstruct the original MR as well, but my reconstruction has a color problem.
I tried to normalize the image to [0,1], but it still can't reconstruct the original image.
Could you please tell me what could be the possible reason?
from ddib.
Related Issues (18)
- Code Release HOT 1
- Training on my own datasets HOT 8
- RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. HOT 1
- Reproduce Cycle Consistency poorly HOT 1
- where is my model? HOT 1
- how to translate image like cyclegan? HOT 2
- About pairing data translation HOT 1
- When do you plan to release the pretrained models on AFHQ?
- When will add color translation experiments|? HOT 2
- Eagerly awaiting the experiment of the color transformation!
- Some confuse of the reproduction results
- The noise image obtained by inverse DDIM is not like a Gaussian distribution! HOT 3
- Instructions on creating my own dataset? HOT 3
- About the DDIM reconstruction error. HOT 1
- Code for Paired Domain Translation
- Translation failed when reproducing with imagent dataset
- how to train this model in own dataset?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ddib.