Giter Site home page Giter Site logo

Comments (4)

TaoHuang2018 avatar TaoHuang2018 commented on July 21, 2024

I am sorry for my late reply. For the training of Neighbor2Neighbor in the SIDD dataset:

  1. The code for data preparation is in the newly uploaded file dataset_tool_raw.py. It would be better to consider the camera device difference in the SIDD Medium Set. However, as the device info is not provided in the SIDD validation set, we simply ignore the device difference for training set preparation.
  2. The training code for the SIDD dataset is similar to that for the ImageNet validation set, except some operations for raw images. To mention, the training scheme for the SIDD dataset is also similar, i.e., the number of epoch is 100 and the learning rate decay schedule is the same, except we use a smaller $gamma=1$. Here is some code for processing raw images.
def space_to_depth(x, block_size):
    n, c, h, w = x.size()
    unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
    return unfolded_x.view(n, c * block_size**2, h // block_size,
                           w // block_size)

def depth_to_space(x, block_size):
    return torch.nn.functional.pixel_shuffle(x, block_size)

def generate_mask_pair(img):
    # prepare masks (N x C x H/2 x W/2)
    n, c, h, w = img.shape
    mask1 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
                        dtype=torch.bool,
                        device=img.device)
    mask2 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
                        dtype=torch.bool,
                        device=img.device)
    # prepare random mask pairs
    idx_pair = torch.tensor(
        [[0, 1], [0, 2], [1, 3], [2, 3], [1, 0], [2, 0], [3, 1], [3, 2]],
        dtype=torch.int64,
        device=img.device)
    rd_idx = torch.zeros(size=(n * h // 2 * w // 2, ),
                         dtype=torch.int64,
                         device=img.device)
    torch.randint(low=0,
                  high=8,
                  size=(n * h // 2 * w // 2, ),
                  generator=get_generator(),
                  out=rd_idx)
    rd_pair_idx = idx_pair[rd_idx]
    rd_pair_idx += torch.arange(start=0,
                                end=n * h // 2 * w // 2 * 4,
                                step=4,
                                dtype=torch.int64,
                                device=img.device).reshape(-1, 1)
    # get masks
    mask1[rd_pair_idx[:, 0]] = 1
    mask2[rd_pair_idx[:, 1]] = 1
    return mask1, mask2

def generate_subimages(img, mask):
    n, c, h, w = img.shape
    subimage = torch.zeros(n,
                           c,
                           h // 2,
                           w // 2,
                           dtype=img.dtype,
                           layout=img.layout,
                           device=img.device)
    # per channel
    for i in range(c):
        img_per_channel = space_to_depth(img[:, i:i + 1, :, :], block_size=2)
        img_per_channel = img_per_channel.permute(0, 2, 3, 1).reshape(-1)
        subimage[:, i:i + 1, :, :] = img_per_channel[mask].reshape(
            n, h // 2, w // 2, 1).permute(0, 3, 1, 2)
    return subimage

class DataLoader_SIDD_Medium_Raw(data.Dataset):
    def __init__(self, data_dir):
        super(DataLoader_SIDD_Medium_Raw, self).__init__()
        self.data_dir = data_dir
        # get images path
        self.train_fns = glob.glob(os.path.join(self.data_dir, "*"))
        self.train_fns.sort()
        print('fetch {} samples for training'.format(len(self.train_fns)))
    def __getitem__(self, index):
        # fetch image
        fn = self.train_fns[index]
        im = loadmat(fn)["x"]
        im = im[np.newaxis, :, :]
        im = torch.from_numpy(im)
        return im
    def __len__(self):
        return len(self.train_fns)

def get_SIDD_validation(dataset_dir):
    val_data_dict = loadmat(
        os.path.join(dataset_dir, "ValidationNoisyBlocksRaw.mat"))
    val_data_noisy = val_data_dict['ValidationNoisyBlocksRaw']
    val_data_dict = loadmat(
        os.path.join(dataset_dir, 'ValidationGtBlocksRaw.mat'))
    val_data_gt = val_data_dict['ValidationGtBlocksRaw']
    num_img, num_block, _, _ = val_data_gt.shape
    return num_img, num_block, val_data_noisy, val_data_gt

from neighbor2neighbor.

madfff avatar madfff commented on July 21, 2024

I still hava a little question. Is it right to generate subimages on the packed 4-channel raw images?
Thank you for your patience.

from neighbor2neighbor.

TaoHuang2018 avatar TaoHuang2018 commented on July 21, 2024

yes, on the packed 4-channel raw images.

from neighbor2neighbor.

zejinwang avatar zejinwang commented on July 21, 2024

I still hava a little question. Is it right to generate subimages on the packed 4-channel raw images? Thank you for your patience.

Hello, I have also encountered a problem that cannot be reproduced on SIDD raw-RGB as high as 51.06dB. I implemented it directly on the source code provided by the author, and psnr can only reach 46.7dB. This is my code:

for epoch in range(1, opt.n_epoch + 1):
cnt = 0

for param_group in optimizer.param_groups:
    current_lr = param_group['lr']
print("LearningRate of Epoch {} = {}".format(epoch, current_lr))

network.train()
for iteration, noisy in enumerate(TrainingLoader):
    st = time.time()
    noisy = noisy.cuda()
    # pack raw data
    noisy = space_to_depth(noisy, 2)

    optimizer.zero_grad()

    mask1, mask2 = generate_mask_pair(noisy)
    noisy_sub1 = generate_subimages(noisy, mask1)
    noisy_sub2 = generate_subimages(noisy, mask2)
    with torch.no_grad():
        noisy_denoised = network(noisy)
    noisy_sub1_denoised = generate_subimages(noisy_denoised, mask1)
    noisy_sub2_denoised = generate_subimages(noisy_denoised, mask2)

    noisy_output = network(noisy_sub1)
    noisy_target = noisy_sub2
    Lambda = epoch / opt.n_epoch * opt.increase_ratio
    diff = noisy_output - noisy_target
    exp_diff = noisy_sub1_denoised - noisy_sub2_denoised

    loss1 = torch.mean(diff**2)
    loss2 = Lambda * torch.mean((diff - exp_diff)**2)
    loss_all = opt.Lambda1 * loss1 + opt.Lambda2 * loss2

    loss_all.backward()
    optimizer.step()
    print(
        '{:04d} {:05d} Loss1={:.6f}, Lambda={}, Loss2={:.6f}, Loss_Full={:.6f}, Time={:.4f}'
        .format(epoch, iteration, np.mean(loss1.item()), Lambda,
                np.mean(loss2.item()), np.mean(loss_all.item()),
                time.time() - st))

scheduler.step()

if epoch % opt.n_snapshot == 0 or epoch == opt.n_epoch:
    network.eval()
    # save checkpoint
    checkpoint(network, epoch, "model")
    # validation
    save_model_path = os.path.join(opt.save_model_path, opt.log_name,
                                   systime)
    validation_path = os.path.join(save_model_path, "validation")
    os.makedirs(validation_path, exist_ok=True)
    np.random.seed(101)

    for valid_name, valid_data in valid_dict.items():
        psnr_result = []
        ssim_result = []
        num_img, num_block, valid_noisy, valid_gt = valid_data
        for idx in range(num_img):
            for idy in range(num_block):
                im = valid_gt[idx, idy][:, :, np.newaxis]
                noisy_im = valid_noisy[idx, idy][:, :, np.newaxis]

                origin255 = im.copy() * 255.0
                origin255 = origin255.astype(np.uint8)
                noisy255 = noisy_im.copy() * 255.0
                noisy255 = noisy255.astype(np.uint8)
                # padding to square
                H = noisy_im.shape[0]
                W = noisy_im.shape[1]
                val_size = (max(H, W) + 31) // 32 * 32
                noisy_im = np.pad(
                    noisy_im,
                    [[0, val_size - H], [0, val_size - W], [0, 0]],
                    'reflect')
                transformer = transforms.Compose([transforms.ToTensor()])
                noisy_im = transformer(noisy_im)
                noisy_im = torch.unsqueeze(noisy_im, 0)
                noisy_im = noisy_im.cuda()
                # pack raw data
                noisy_im = space_to_depth(noisy_im, block_size=2)
                with torch.no_grad():
                    prediction = network(noisy_im)
                    # unpack raw data
                    prediction = depth_to_space(prediction, block_size=2)
                    prediction = prediction[:, :, :H, :W]
                prediction = prediction.permute(0, 2, 3, 1)
                prediction = prediction.cpu().data.clamp(0, 1).numpy()
                prediction = prediction.squeeze(0)
                pred255 = np.clip(prediction * 255.0 + 0.5, 0,
                                    255).astype(np.uint8)
                # calculate psnr
                cur_psnr = calculate_psnr(origin255.astype(np.float32),
                                            pred255.astype(np.float32))
                psnr_result.append(cur_psnr)
                cur_ssim = calculate_ssim(origin255.astype(np.float32),
                                            pred255.astype(np.float32))
                ssim_result.append(cur_ssim)

                # visualization
                save_path = os.path.join(
                    validation_path,
                    "{}_{:03d}-{:03d}-{:03d}_clean.png".format(
                        valid_name, idx, idy, epoch))
                Image.fromarray(origin255.squeeze()).save(save_path)
                save_path = os.path.join(
                    validation_path,
                    "{}_{:03d}-{:03d}-{:03d}_noisy.png".format(
                        valid_name, idx, idy, epoch))
                Image.fromarray(noisy255.squeeze()).save(save_path)

                save_path = os.path.join(
                    validation_path,
                    "{}_{:03d}-{:03d}-{:03d}_denoised.png".format(
                        valid_name, idx, idy, epoch))
                Image.fromarray(pred255.squeeze()).save(save_path)

        psnr_result = np.array(psnr_result)
        avg_psnr = np.mean(psnr_result)
        avg_ssim = np.mean(ssim_result)
        log_path = os.path.join(validation_path,
                                "A_log_{}.csv".format(valid_name))
        with open(log_path, "a") as f:
            f.writelines("{},{},{}\n".format(epoch, avg_psnr, avg_ssim))

from neighbor2neighbor.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.