Giter Site home page Giter Site logo

why cycleconsistency is not robust about ddib HOT 5 OPEN

JunMa11 avatar JunMa11 commented on August 23, 2024
why cycleconsistency is not robust

from ddib.

Comments (5)

JunMa11 avatar JunMa11 commented on August 23, 2024

I'm also attaching the code

import argparse
import numpy as np
import os
join = os.path.join
import pathlib
import torch.distributed as dist
from skimage import io, color
import torch
from improved_diffusion import dist_util, logger
from improved_diffusion.script_util import (
    model_and_diffusion_defaults,
    add_dict_to_argparser,
    create_model_and_diffusion,
    args_to_dict
)
import matplotlib.pyplot as plt

def create_argparser():
    defaults = dict(
        image_size=256,
        batch_size=1,
        num_channels=64,
        num_res_blocks=3,
        num_heads=1,
        diffusion_steps=1000,
        noise_schedule='linear',
        lr=1e-4,
        clip_denoised=False,
        num_samples=1, # 10000
        use_ddim=True,
        # timestep_respacing='ddim250',
        model_path="",
    )
    ori = model_and_diffusion_defaults()
    # defaults.update(model_and_diffusion_defaults())
    ori.update(defaults)
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, ori)
    return parser

# def main():
args = create_argparser().parse_args()

logger.log(f"args: {args}")

dist_util.setup_dist()
logger.configure(dir='./log')

code_folder = './'
# data_folder = './datasets' # get_code_and_dataset_folders()


#%% load model
def read_model_and_diffusion(args, model_path):
    """Reads the latest model from the given directory."""

    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys()),
    )
    model.load_state_dict(dist_util.load_state_dict(model_path, map_location="cuda"))
    model.to(dist_util.dev())
    # if args.use_fp16:
    #     model.convert_to_fp16()
    model.eval()
    return model, diffusion

ct_model_path =  './work_dir/abdomenCT256/ema_0.9999_480000.pt'
s_model, s_diffusion = read_model_and_diffusion(args, ct_model_path)
mr_model_path = './work_dir/abdomenMR256/ema_0.9999_480000.pt'
t_model, t_diffusion = read_model_and_diffusion(args, mr_model_path)
save_path = './log'
#%% translate image
s_img_path = './demo-img'
names = sorted(os.listdir(s_img_path))
# names = ['ct_ori.png']
def sample2img(sample):
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous().cpu().numpy()[0]
    
    return sample

for name in names:
    ct_data = io.imread(join(s_img_path, name))

    s_np = ct_data / np.max(ct_data)
    s_np = (s_np - 0.5) * 2.0
    # s_np = np.repeat(np.expand_dims(s_np, -1), 3, -1)
    assert s_np.shape == (256, 256, 3), 'shape error! Current shape' + ct_data.shape
    s_np = np.expand_dims(s_np, 0)
    
    source = torch.from_numpy(s_np.astype(np.float32)).permute(0,3,1,2).to('cuda')
    # print(f"{source.shape=}")
    noise = s_diffusion.ddim_reverse_sample_loop(
        s_model, source,
        clip_denoised=False,
        device=dist_util.dev(),
    )

    source_recon = s_diffusion.ddim_sample_loop(
        s_model, (args.batch_size, 3, args.image_size, args.image_size),
        noise=noise,
        clip_denoised=False,
        device=dist_util.dev(),
    )

    target = t_diffusion.ddim_sample_loop(
        t_model, (args.batch_size, 3, args.image_size, args.image_size),
        noise=noise,
        clip_denoised=False,
        device=dist_util.dev(),
    )

    #%% plot
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8,8))
    images = [ct_data, color.rgb2gray(sample2img(noise)), sample2img(source_recon), sample2img(target)]
    titles = ['CT image', 'CT noise encode', \
        'CT reconstruction', 'CT2MR']
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i], cmap='gray')
        ax.set_title(titles[i])
        ax.axis('off')
    plt.suptitle(name)

    plt.savefig(join(save_path, name), dpi=300)

from ddib.

suxuann avatar suxuann commented on August 23, 2024

Hi Jun, thanks for your interests in our work, and attempting to validate our method on CT & MR images.

DDIBs translate images via a (regularized) optimal transport process. This is both an advantage and a limitation of our method. Training diffusion models on the two domains, independently, serves to decouple the training process; but the resulting optimal-transport based translation process may not necessarily produce images that you desire.

You can refer to Appendix B of our paper: https://arxiv.org/pdf/2203.08382.pdf, for detailed explanations about the phenomenon you observe. Let us know if you have additional questions!

from ddib.

JunMa11 avatar JunMa11 commented on August 23, 2024

Hi @suxuann ,

Thanks for your answer very much.
Now I understand the reason for the 2nd question.

Could you please explain the following question a little bit?

Why is cycle consistency (the noise encoding cannot reconstruct the original image) not robust? Base on the proof, it should be robust for different images.

from ddib.

leoil avatar leoil commented on August 23, 2024

Hi @JunMa11 , I'd like to ask some questions about model training.
If I want to train a new model on my own dataset, just like your . /work_dir/abdomenCT256/ema_0.9999_480000.pt
Could you please tell me how I should prepare the training script?

from ddib.

yang1173350896 avatar yang1173350896 commented on August 23, 2024

Hi @JunMa11 ,
I tried to reconstruct the original MR as well, but my reconstruction has a color problem.
I tried to normalize the image to [0,1], but it still can't reconstruct the original image.
Could you please tell me what could be the possible reason?
image

from ddib.

Related Issues (18)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.