Comments (4)
I am sorry for my late reply. For the training of Neighbor2Neighbor in the SIDD dataset:
- The code for data preparation is in the newly uploaded file
dataset_tool_raw.py
. It would be better to consider the camera device difference in the SIDD Medium Set. However, as the device info is not provided in the SIDD validation set, we simply ignore the device difference for training set preparation. - The training code for the SIDD dataset is similar to that for the ImageNet validation set, except some operations for raw images. To mention, the training scheme for the SIDD dataset is also similar, i.e., the number of epoch is 100 and the learning rate decay schedule is the same, except we use a smaller
$gamma=1$ . Here is some code for processing raw images.
def space_to_depth(x, block_size):
n, c, h, w = x.size()
unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
return unfolded_x.view(n, c * block_size**2, h // block_size,
w // block_size)
def depth_to_space(x, block_size):
return torch.nn.functional.pixel_shuffle(x, block_size)
def generate_mask_pair(img):
# prepare masks (N x C x H/2 x W/2)
n, c, h, w = img.shape
mask1 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
dtype=torch.bool,
device=img.device)
mask2 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
dtype=torch.bool,
device=img.device)
# prepare random mask pairs
idx_pair = torch.tensor(
[[0, 1], [0, 2], [1, 3], [2, 3], [1, 0], [2, 0], [3, 1], [3, 2]],
dtype=torch.int64,
device=img.device)
rd_idx = torch.zeros(size=(n * h // 2 * w // 2, ),
dtype=torch.int64,
device=img.device)
torch.randint(low=0,
high=8,
size=(n * h // 2 * w // 2, ),
generator=get_generator(),
out=rd_idx)
rd_pair_idx = idx_pair[rd_idx]
rd_pair_idx += torch.arange(start=0,
end=n * h // 2 * w // 2 * 4,
step=4,
dtype=torch.int64,
device=img.device).reshape(-1, 1)
# get masks
mask1[rd_pair_idx[:, 0]] = 1
mask2[rd_pair_idx[:, 1]] = 1
return mask1, mask2
def generate_subimages(img, mask):
n, c, h, w = img.shape
subimage = torch.zeros(n,
c,
h // 2,
w // 2,
dtype=img.dtype,
layout=img.layout,
device=img.device)
# per channel
for i in range(c):
img_per_channel = space_to_depth(img[:, i:i + 1, :, :], block_size=2)
img_per_channel = img_per_channel.permute(0, 2, 3, 1).reshape(-1)
subimage[:, i:i + 1, :, :] = img_per_channel[mask].reshape(
n, h // 2, w // 2, 1).permute(0, 3, 1, 2)
return subimage
class DataLoader_SIDD_Medium_Raw(data.Dataset):
def __init__(self, data_dir):
super(DataLoader_SIDD_Medium_Raw, self).__init__()
self.data_dir = data_dir
# get images path
self.train_fns = glob.glob(os.path.join(self.data_dir, "*"))
self.train_fns.sort()
print('fetch {} samples for training'.format(len(self.train_fns)))
def __getitem__(self, index):
# fetch image
fn = self.train_fns[index]
im = loadmat(fn)["x"]
im = im[np.newaxis, :, :]
im = torch.from_numpy(im)
return im
def __len__(self):
return len(self.train_fns)
def get_SIDD_validation(dataset_dir):
val_data_dict = loadmat(
os.path.join(dataset_dir, "ValidationNoisyBlocksRaw.mat"))
val_data_noisy = val_data_dict['ValidationNoisyBlocksRaw']
val_data_dict = loadmat(
os.path.join(dataset_dir, 'ValidationGtBlocksRaw.mat'))
val_data_gt = val_data_dict['ValidationGtBlocksRaw']
num_img, num_block, _, _ = val_data_gt.shape
return num_img, num_block, val_data_noisy, val_data_gt
from neighbor2neighbor.
I still hava a little question. Is it right to generate subimages on the packed 4-channel raw images?
Thank you for your patience.
from neighbor2neighbor.
yes, on the packed 4-channel raw images.
from neighbor2neighbor.
I still hava a little question. Is it right to generate subimages on the packed 4-channel raw images? Thank you for your patience.
Hello, I have also encountered a problem that cannot be reproduced on SIDD raw-RGB as high as 51.06dB. I implemented it directly on the source code provided by the author, and psnr can only reach 46.7dB. This is my code:
for epoch in range(1, opt.n_epoch + 1):
cnt = 0
for param_group in optimizer.param_groups:
current_lr = param_group['lr']
print("LearningRate of Epoch {} = {}".format(epoch, current_lr))
network.train()
for iteration, noisy in enumerate(TrainingLoader):
st = time.time()
noisy = noisy.cuda()
# pack raw data
noisy = space_to_depth(noisy, 2)
optimizer.zero_grad()
mask1, mask2 = generate_mask_pair(noisy)
noisy_sub1 = generate_subimages(noisy, mask1)
noisy_sub2 = generate_subimages(noisy, mask2)
with torch.no_grad():
noisy_denoised = network(noisy)
noisy_sub1_denoised = generate_subimages(noisy_denoised, mask1)
noisy_sub2_denoised = generate_subimages(noisy_denoised, mask2)
noisy_output = network(noisy_sub1)
noisy_target = noisy_sub2
Lambda = epoch / opt.n_epoch * opt.increase_ratio
diff = noisy_output - noisy_target
exp_diff = noisy_sub1_denoised - noisy_sub2_denoised
loss1 = torch.mean(diff**2)
loss2 = Lambda * torch.mean((diff - exp_diff)**2)
loss_all = opt.Lambda1 * loss1 + opt.Lambda2 * loss2
loss_all.backward()
optimizer.step()
print(
'{:04d} {:05d} Loss1={:.6f}, Lambda={}, Loss2={:.6f}, Loss_Full={:.6f}, Time={:.4f}'
.format(epoch, iteration, np.mean(loss1.item()), Lambda,
np.mean(loss2.item()), np.mean(loss_all.item()),
time.time() - st))
scheduler.step()
if epoch % opt.n_snapshot == 0 or epoch == opt.n_epoch:
network.eval()
# save checkpoint
checkpoint(network, epoch, "model")
# validation
save_model_path = os.path.join(opt.save_model_path, opt.log_name,
systime)
validation_path = os.path.join(save_model_path, "validation")
os.makedirs(validation_path, exist_ok=True)
np.random.seed(101)
for valid_name, valid_data in valid_dict.items():
psnr_result = []
ssim_result = []
num_img, num_block, valid_noisy, valid_gt = valid_data
for idx in range(num_img):
for idy in range(num_block):
im = valid_gt[idx, idy][:, :, np.newaxis]
noisy_im = valid_noisy[idx, idy][:, :, np.newaxis]
origin255 = im.copy() * 255.0
origin255 = origin255.astype(np.uint8)
noisy255 = noisy_im.copy() * 255.0
noisy255 = noisy255.astype(np.uint8)
# padding to square
H = noisy_im.shape[0]
W = noisy_im.shape[1]
val_size = (max(H, W) + 31) // 32 * 32
noisy_im = np.pad(
noisy_im,
[[0, val_size - H], [0, val_size - W], [0, 0]],
'reflect')
transformer = transforms.Compose([transforms.ToTensor()])
noisy_im = transformer(noisy_im)
noisy_im = torch.unsqueeze(noisy_im, 0)
noisy_im = noisy_im.cuda()
# pack raw data
noisy_im = space_to_depth(noisy_im, block_size=2)
with torch.no_grad():
prediction = network(noisy_im)
# unpack raw data
prediction = depth_to_space(prediction, block_size=2)
prediction = prediction[:, :, :H, :W]
prediction = prediction.permute(0, 2, 3, 1)
prediction = prediction.cpu().data.clamp(0, 1).numpy()
prediction = prediction.squeeze(0)
pred255 = np.clip(prediction * 255.0 + 0.5, 0,
255).astype(np.uint8)
# calculate psnr
cur_psnr = calculate_psnr(origin255.astype(np.float32),
pred255.astype(np.float32))
psnr_result.append(cur_psnr)
cur_ssim = calculate_ssim(origin255.astype(np.float32),
pred255.astype(np.float32))
ssim_result.append(cur_ssim)
# visualization
save_path = os.path.join(
validation_path,
"{}_{:03d}-{:03d}-{:03d}_clean.png".format(
valid_name, idx, idy, epoch))
Image.fromarray(origin255.squeeze()).save(save_path)
save_path = os.path.join(
validation_path,
"{}_{:03d}-{:03d}-{:03d}_noisy.png".format(
valid_name, idx, idy, epoch))
Image.fromarray(noisy255.squeeze()).save(save_path)
save_path = os.path.join(
validation_path,
"{}_{:03d}-{:03d}-{:03d}_denoised.png".format(
valid_name, idx, idy, epoch))
Image.fromarray(pred255.squeeze()).save(save_path)
psnr_result = np.array(psnr_result)
avg_psnr = np.mean(psnr_result)
avg_ssim = np.mean(ssim_result)
log_path = os.path.join(validation_path,
"A_log_{}.csv".format(valid_name))
with open(log_path, "a") as f:
f.writelines("{},{},{}\n".format(epoch, avg_psnr, avg_ssim))
from neighbor2neighbor.
Related Issues (20)
- UNet in Experiments HOT 3
- when will your code be uploaded? HOT 1
- Question about the disabling gradients operations HOT 4
- Can you release your supplementary material? HOT 1
- About Reproduction HOT 2
- In evaluation stage, is σ selected randomly from [5,50] or fixed at 25? HOT 1
- The results on the sidd raw-RGB validation set cannot be reproduced as high as 51.06dB HOT 3
- code HOT 6
- Will you release the pretrained model on SIDD? And is there a RRG based model for illustration? HOT 5
- The Problem about Generating Training Dataset HOT 1
- Padding in val
- Problem in denoising poisson noise for DBSN HOT 1
- Training Log Files
- Some experimental results inaccurate due to a problem in training code HOT 2
- Hi, thank you for nice work, could you release your test file?
- Estimation in local images
- What does the `increase_ratio` mean?
- Applying to a general image restoration
- Training script throws error. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from neighbor2neighbor.