Comments (7)
Hello Rufael, post your code here for the CUBS dataset and I will take a look. Here are some configurations to consider. (1) How many epochs does it run for? (2) I added a vgg6 model, try this model. (3) how are you handling class imbalance in the training data?
from signalpropagation.
hi,
here is the dataset class that I am using,
import os
import math
import torch
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch import nn
from spacy.lang.en import English
from torchvision import transforms
from torchvision.models import inception_v3
from torch.utils.data import Dataset, DataLoader
DATASETS_PATH = '../../data/'
# CUB_200_2011
CUB_200_2011_PATH = os.path.join(DATASETS_PATH, 'cub-200-2011')
# Captions and metadata
CUB_200_2011_METADATA_PATH = os.path.join(CUB_200_2011_PATH, 'metadata.pth')
CUB_200_2011_IMG_ID_TO_CAPS_PATH = os.path.join(CUB_200_2011_PATH, 'cub_200_2011_img_id_to_caps.pth')
# All images
CUB_200_2011_IMGS_64_PATH = os.path.join(CUB_200_2011_PATH, 'imgs_64x64.pth')
CUB_200_2011_IMGS_128_PATH = os.path.join(CUB_200_2011_PATH, 'imgs_128x128.pth')
CUB_200_2011_IMGS_256_PATH = os.path.join(CUB_200_2011_PATH, 'imgs_256x256.pth')
# Training/Validation split images
CUB_200_2011_TRAIN_VAL_IMGS_64_PATH = os.path.join(CUB_200_2011_PATH, 'imgs_train_val_64x64.pth')
CUB_200_2011_TRAIN_VAL_IMGS_128_PATH = os.path.join(CUB_200_2011_PATH, 'imgs_train_val_128x128.pth')
CUB_200_2011_TRAIN_VAL_IMGS_256_PATH = os.path.join(CUB_200_2011_PATH, 'imgs_train_val_256x256.pth')
# Testing split images
CUB_200_2011_TEST_IMGS_64_PATH = os.path.join(CUB_200_2011_PATH, 'imgs_test_64x64.pth')
CUB_200_2011_TEST_IMGS_128_PATH = os.path.join(CUB_200_2011_PATH, 'imgs_test_128x128.pth')
CUB_200_2011_TEST_IMGS_256_PATH = os.path.join(CUB_200_2011_PATH, 'imgs_test_256x256.pth')
# GLOVE word embeddings for CUB 200 2011
CUB_200_2011_GLOVE_PATH = os.path.join(CUB_200_2011_PATH, 'glove_relevant_embeddings.pth')
CUB_200_2011_D_VOCAB = 1750
class CUB_200_2011(Dataset):
"""If should_pad is True, need to also provide a pad_to_length. Padding also adds <START> and <END> tokens to captions."""
def __init__(self, split='all', d_image_size=64, transform=None, should_pad=False, pad_to_length=None, no_start_end=False, **kwargs):
super().__init__()
assert split in ('all', 'train_val', 'test')
assert d_image_size in (64, 128, 256)
if should_pad:
assert pad_to_length >= 3 # <START> foo <END> need at least length 3.
self.split = split
self.d_image_size = d_image_size
self.transform = transform
self.should_pad = should_pad
self.pad_to_length = pad_to_length
self.no_start_end = no_start_end
metadata = torch.load(CUB_200_2011_METADATA_PATH)
# labels
self.img_id_to_class_id = metadata['img_id_to_class_id']
self.class_id_to_class_name = metadata['class_id_to_class_name']
self.class_name_to_class_id = metadata['class_name_to_class_id']
# images
if split == 'all':
self.img_ids = metadata['img_ids']
if d_image_size == 64:
imgs_path = CUB_200_2011_IMGS_64_PATH
elif d_image_size == 128:
imgs_path = CUB_200_2011_IMGS_128_PATH
else:
imgs_path = CUB_200_2011_IMGS_256_PATH
elif split == 'train_val':
self.img_ids = metadata['train_val_img_ids']
if d_image_size == 64:
imgs_path = CUB_200_2011_TRAIN_VAL_IMGS_64_PATH
elif d_image_size == 128:
imgs_path = CUB_200_2011_TRAIN_VAL_IMGS_128_PATH
else:
imgs_path = CUB_200_2011_TRAIN_VAL_IMGS_256_PATH
else:
self.img_ids = metadata['test_img_ids']
if d_image_size == 64:
imgs_path = CUB_200_2011_TEST_IMGS_64_PATH
elif d_image_size == 128:
imgs_path = CUB_200_2011_TEST_IMGS_128_PATH
else:
imgs_path = CUB_200_2011_TEST_IMGS_256_PATH
self.imgs = torch.load(imgs_path)
assert self.imgs.size()[1:] == (3, d_image_size, d_image_size) and self.imgs.dtype == torch.uint8
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img = self.imgs[idx]
if self.transform:
img = self.transform(img)
img_id = self.img_ids[idx]
class_id = self.img_id_to_class_id[img_id]
return img, class_id
def to_one_hot(labels, d_classes):
"""
Args:
labels (Tensor): integer tensor of shape (d_batch, *)
d_classes (int): number of classes
Returns:
(Tensor): float tensor of shape (d_batch, *, d_classes), one hot representation of the labels
"""
return torch.zeros(*labels.size(), d_classes, device=labels.device).scatter_(-1, labels.unsqueeze(-1), 1)
def get_cub_200_2011( should_pad=False, split='train_val', shuffle=True, **kwargs):
transform = transforms.Lambda(lambda x: (x.float() / 255.) * 2. - 1.)
train_set = CUB_200_2011(transform=transform,split=split, should_pad=should_pad)
return train_set
I am only using the images and labels not the captions. I have also added the following to the make_dataset function
if args.dataset == 'cub':
dataset_info = dict(
input_dim = 64,
input_ch = 3,
num_classes = 200,
)
dataset_train = get_cub_200_2011(split='train_val')
dataset_test = get_cub_200_2011(split='test')
Regarding the number of epochs, I actually let it run for 300 epochs but there is no change in the test accuracy. The train accuracy increases significantly. do you think it's a case of over-fitting?
(2) I will try the Vgg6 and let you know if there are any changes.
(3) There is no significant imbalance in the dataset as each class contain relatively the same number of images though it is small.
from signalpropagation.
Hi, I've encountered the same problem on the Imagenette dataset. Is there any solution? I haven't found out the reason yet.
In my case, I ran on vgg16. It works well on CIFAR10 but fails on Imagenette. Is there anything to do with the complexity of the input? Thanks :D
from signalpropagation.
@Lily-Le is the model overfitting for you as it was for rufaelfekadu ? Try vgg6 (added to this repository), and post your results here.
from signalpropagation.
I ran on vgg16:
'vgg16': [ 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
Results for IMAGENETTE:
`Epoch Start: 59
[Info][Train Epoch 59/100][Batch 146/147] [loss 1.6029] [acc 0.4717]
[Sequential] Acc: 0.4531 (0.5458, 5135/9408) Loss: 6.9778 (6.5711)
[BlockConv] Acc: 0.4062 (0.4663, 4387/9408) Loss: 7.4894 (7.1177)
[BlockConv] Acc: 0.4531 (0.4878, 4589/9408) Loss: 7.1538 (6.9224)
[BlockConv] Acc: 0.4844 (0.4935, 4643/9408) Loss: 7.1543 (6.8956)
[BlockConv] Acc: 0.4688 (0.4947, 4654/9408) Loss: 7.1105 (6.8705)
[BlockConv] Acc: 0.4688 (0.4984, 4689/9408) Loss: 7.1243 (6.8609)
[BlockConv] Acc: 0.4531 (0.4933, 4641/9408) Loss: 7.1985 (6.8722)
[BlockConv] Acc: 0.4844 (0.4980, 4685/9408) Loss: 6.9924 (6.8204)
[BlockConv] Acc: 0.4844 (0.4994, 4698/9408) Loss: 7.1181 (6.8262)
[BlockConv] Acc: 0.5000 (0.4961, 4667/9408) Loss: 7.1534 (6.8395)
[BlockConv] Acc: 0.4531 (0.4991, 4696/9408) Loss: 7.1324 (6.8187)
[BlockConv] Acc: 0.4688 (0.4963, 4669/9408) Loss: 7.1760 (6.8172)
[BlockConv] Acc: 0.5000 (0.4971, 4677/9408) Loss: 7.2032 (6.8241)
[BlockLinear] Acc: 0.2969 (0.4134, 3889/9408) Loss: 10.8310 (10.9136)
[BlockLinear] Acc: 0.4375 (0.4439, 4176/9408) Loss: 8.8507 (8.7048)
[Info][Test Epoch 59/100] [loss 1.2826] [acc 0.5839]
[Sequential] Acc: 1.0000 (0.9869, 3853/3904) Loss: 8.3178 (8.2321)
[BlockConv] Acc: 1.0000 (0.9841, 3842/3904) Loss: 8.3178 (8.2473)
[BlockConv] Acc: 1.0000 (0.9869, 3853/3904) Loss: 8.3178 (8.2378)
[BlockConv] Acc: 1.0000 (0.9859, 3849/3904) Loss: 8.3178 (8.2341)
[BlockConv] Acc: 1.0000 (0.9862, 3850/3904) Loss: 8.3178 (8.2315)
[BlockConv] Acc: 1.0000 (0.9864, 3851/3904) Loss: 8.3178 (8.2304)
[BlockConv] Acc: 1.0000 (0.9869, 3853/3904) Loss: 8.3178 (8.2300)
[BlockConv] Acc: 1.0000 (0.9885, 3859/3904) Loss: 8.3178 (8.2265)
[BlockConv] Acc: 1.0000 (0.9887, 3860/3904) Loss: 8.3178 (8.2266)
[BlockConv] Acc: 1.0000 (0.9882, 3858/3904) Loss: 8.3178 (8.2267)
[BlockConv] Acc: 1.0000 (0.9864, 3851/3904) Loss: 8.3178 (8.2273)
[BlockConv] Acc: 1.0000 (0.9869, 3853/3904) Loss: 8.3178 (8.2276)
[BlockConv] Acc: 1.0000 (0.9867, 3852/3904) Loss: 8.3178 (8.2281)
[BlockLinear] Acc: 1.0000 (0.9859, 3849/3904) Loss: 8.3178 (8.2273)
[BlockLinear] Acc: 1.0000 (0.9874, 3855/3904) Loss: 8.3178 (8.2304)
Epoch Start: 60
[Info][Train Epoch 60/100][Batch 146/147] [loss 1.5741] [acc 0.4824]
[Sequential] Acc: 0.6094 (0.5555, 5226/9408) Loss: 6.3034 (6.5557)
[BlockConv] Acc: 0.5156 (0.4700, 4422/9408) Loss: 6.8742 (7.1059)
[BlockConv] Acc: 0.5156 (0.4892, 4602/9408) Loss: 6.4993 (6.8825)
[BlockConv] Acc: 0.5312 (0.4959, 4665/9408) Loss: 6.5132 (6.8680)
[BlockConv] Acc: 0.5312 (0.4968, 4674/9408) Loss: 6.5088 (6.8521)
[BlockConv] Acc: 0.5000 (0.5027, 4729/9408) Loss: 6.4299 (6.8478)
[BlockConv] Acc: 0.4688 (0.4983, 4688/9408) Loss: 6.4779 (6.8464)
[BlockConv] Acc: 0.5156 (0.5018, 4721/9408) Loss: 6.3565 (6.8026)
[BlockConv] Acc: 0.5156 (0.5027, 4729/9408) Loss: 6.4326 (6.7990)
[BlockConv] Acc: 0.5312 (0.5002, 4706/9408) Loss: 6.3938 (6.8050)
[BlockConv] Acc: 0.5625 (0.5044, 4745/9408) Loss: 6.2600 (6.7933)
[BlockConv] Acc: 0.5781 (0.5048, 4749/9408) Loss: 6.2030 (6.7899)
[BlockConv] Acc: 0.5781 (0.5067, 4767/9408) Loss: 6.2406 (6.8012)
[BlockLinear] Acc: 0.4062 (0.4136, 3891/9408) Loss: 11.0399 (10.6748)
[BlockLinear] Acc: 0.5156 (0.4433, 4171/9408) Loss: 7.4861 (8.6191)
[Info][Test Epoch 60/100] [loss 1.2561] [acc 0.5921]
[Sequential] Acc: 1.0000 (0.9869, 3853/3904) Loss: 8.3178 (8.2320)
[BlockConv] Acc: 1.0000 (0.9836, 3840/3904) Loss: 8.3178 (8.2463)
[BlockConv] Acc: 1.0000 (0.9841, 3842/3904) Loss: 8.3178 (8.2376)
[BlockConv] Acc: 1.0000 (0.9846, 3844/3904) Loss: 8.3178 (8.2347)
[BlockConv] Acc: 1.0000 (0.9828, 3837/3904) Loss: 8.3178 (8.2327)
[BlockConv] Acc: 1.0000 (0.9836, 3840/3904) Loss: 8.3178 (8.2323)
[BlockConv] Acc: 1.0000 (0.9836, 3840/3904) Loss: 8.3178 (8.2331)
[BlockConv] Acc: 1.0000 (0.9831, 3838/3904) Loss: 8.3178 (8.2314)
[BlockConv] Acc: 1.0000 (0.9846, 3844/3904) Loss: 8.3178 (8.2315)
[BlockConv] Acc: 1.0000 (0.9844, 3843/3904) Loss: 8.3178 (8.2318)
[BlockConv] Acc: 1.0000 (0.9851, 3846/3904) Loss: 8.3178 (8.2314)
[BlockConv] Acc: 1.0000 (0.9844, 3843/3904) Loss: 8.3178 (8.2322)
[BlockConv] Acc: 1.0000 (0.9846, 3844/3904) Loss: 8.3178 (8.2330)
[BlockLinear] Acc: 1.0000 (0.9846, 3844/3904) Loss: 8.3178 (8.2373)
[BlockLinear] Acc: 1.0000 (0.9854, 3847/3904) Loss: 8.3178 (8.2380)
`
Results for CIFAR10 on the same network(except for input dim from 224 -> 32):
`Epoch Start: 9
[Info][Train Epoch 9/150][Batch 780/781] [loss 2.4259] [acc 0.1013]
[Sequential] Acc: 0.6719 (0.6329, 31634/49984) Loss: 5.6800 (5.9229)
[BlockConv] Acc: 0.7188 (0.6518, 32579/49984) Loss: 5.6133 (5.8158)
[BlockConv] Acc: 0.7344 (0.6805, 34012/49984) Loss: 5.5556 (5.6277)
[BlockConv] Acc: 0.7188 (0.6895, 34465/49984) Loss: 5.4879 (5.5643)
[BlockConv] Acc: 0.7812 (0.7055, 35264/49984) Loss: 5.3193 (5.4678)
[BlockConv] Acc: 0.7812 (0.7107, 35524/49984) Loss: 5.2669 (5.4437)
[BlockConv] Acc: 0.7969 (0.7080, 35389/49984) Loss: 5.2427 (5.4614)
[BlockConv] Acc: 0.8125 (0.7040, 35188/49984) Loss: 5.2359 (5.4811)
[BlockConv] Acc: 0.7812 (0.6898, 34480/49984) Loss: 5.3128 (5.5590)
[BlockConv] Acc: 0.7656 (0.6805, 34014/49984) Loss: 5.4046 (5.6144)
[BlockConv] Acc: 0.7188 (0.6723, 33602/49984) Loss: 5.5831 (5.6514)
[BlockConv] Acc: 0.7188 (0.6630, 33139/49984) Loss: 5.6603 (5.6923)
[BlockConv] Acc: 0.7500 (0.6535, 32663/49984) Loss: 5.6682 (5.7378)
[BlockLinear] Acc: 0.1406 (0.1005, 5024/49984) Loss: 7.2372 (7.1900)
[BlockLinear] Acc: 0.1406 (0.0998, 4989/49984) Loss: 4.4739 (4.1325)
[Info][Test Epoch 9/150] [loss 2.4035] [acc 0.1262]
[Sequential] Acc: 0.5312 (0.5880, 5871/9984) Loss: 6.2647 (6.1128)
[BlockConv] Acc: 0.5938 (0.6298, 6288/9984) Loss: 5.9764 (5.8872)
[BlockConv] Acc: 0.5938 (0.6547, 6537/9984) Loss: 5.8911 (5.7394)
[BlockConv] Acc: 0.6562 (0.6615, 6604/9984) Loss: 5.8488 (5.6886)
[BlockConv] Acc: 0.6562 (0.6844, 6833/9984) Loss: 5.8790 (5.5892)
[BlockConv] Acc: 0.6562 (0.6903, 6892/9984) Loss: 5.8392 (5.5473)
[BlockConv] Acc: 0.6250 (0.6916, 6905/9984) Loss: 5.7869 (5.5506)
[BlockConv] Acc: 0.6250 (0.6868, 6857/9984) Loss: 5.7762 (5.5688)
[BlockConv] Acc: 0.6562 (0.6671, 6660/9984) Loss: 5.8855 (5.6569)
[BlockConv] Acc: 0.5938 (0.6523, 6513/9984) Loss: 5.9305 (5.7337)
[BlockConv] Acc: 0.5781 (0.6520, 6510/9984) Loss: 5.8862 (5.7148)
[BlockConv] Acc: 0.6094 (0.6445, 6435/9984) Loss: 5.9057 (5.7618)
[BlockConv] Acc: 0.6406 (0.6423, 6413/9984) Loss: 5.9459 (5.7829)
[BlockLinear] Acc: 0.0938 (0.1001, 999/9984) Loss: 7.0604 (6.8776)
[BlockLinear] Acc: 0.0938 (0.1001, 999/9984) Loss: 3.6729 (3.6420)
Epoch Start: 10
[Info][Train Epoch 10/150][Batch 780/781] [loss 2.6020] [acc 0.1023]
[Sequential] Acc: 0.6250 (0.6416, 32072/49984) Loss: 5.7659 (5.8761)
[BlockConv] Acc: 0.7031 (0.6567, 32825/49984) Loss: 5.6342 (5.7845)
[BlockConv] Acc: 0.7344 (0.6834, 34157/49984) Loss: 5.4152 (5.5966)
[BlockConv] Acc: 0.7031 (0.6958, 34779/49984) Loss: 5.3959 (5.5221)
[BlockConv] Acc: 0.7188 (0.7132, 35647/49984) Loss: 5.3778 (5.4194)
[BlockConv] Acc: 0.7344 (0.7181, 35894/49984) Loss: 5.3038 (5.3879)
[BlockConv] Acc: 0.7188 (0.7157, 35772/49984) Loss: 5.3032 (5.4017)
[BlockConv] Acc: 0.7188 (0.7120, 35587/49984) Loss: 5.3146 (5.4198)
[BlockConv] Acc: 0.7344 (0.7010, 35038/49984) Loss: 5.3615 (5.4895)
[BlockConv] Acc: 0.7500 (0.6917, 34573/49984) Loss: 5.3883 (5.5411)
[BlockConv] Acc: 0.7188 (0.6794, 33960/49984) Loss: 5.5378 (5.6055)
[BlockConv] Acc: 0.6875 (0.6752, 33748/49984) Loss: 5.5627 (5.6264)
[BlockConv] Acc: 0.6562 (0.6680, 33391/49984) Loss: 5.6148 (5.6683)
[BlockLinear] Acc: 0.0625 (0.1005, 5022/49984) Loss: 6.2389 (7.0155)
[BlockLinear] Acc: 0.0625 (0.0998, 4986/49984) Loss: 3.1281 (4.0298)
[Info][Test Epoch 10/150] [loss 2.3379] [acc 0.1025]
[Sequential] Acc: 0.5000 (0.5921, 5912/9984) Loss: 6.2656 (6.0863)
[BlockConv] Acc: 0.5625 (0.6406, 6396/9984) Loss: 5.9421 (5.8529)
[BlockConv] Acc: 0.5938 (0.6597, 6586/9984) Loss: 5.8152 (5.7241)
[BlockConv] Acc: 0.6094 (0.6693, 6682/9984) Loss: 5.7813 (5.6647)
[BlockConv] Acc: 0.6719 (0.6918, 6907/9984) Loss: 5.6013 (5.5462)
[BlockConv] Acc: 0.7031 (0.6946, 6935/9984) Loss: 5.5856 (5.5179)
[BlockConv] Acc: 0.7344 (0.6976, 6965/9984) Loss: 5.5719 (5.5190)
[BlockConv] Acc: 0.7188 (0.6962, 6951/9984) Loss: 5.5813 (5.5261)
[BlockConv] Acc: 0.6562 (0.6893, 6882/9984) Loss: 5.5825 (5.5544)
[BlockConv] Acc: 0.6562 (0.6855, 6844/9984) Loss: 5.5731 (5.5691)
[BlockConv] Acc: 0.6719 (0.6632, 6621/9984) Loss: 5.6038 (5.7429)
[BlockConv] Acc: 0.6719 (0.6410, 6400/9984) Loss: 5.7375 (5.8567)
[BlockConv] Acc: 0.7188 (0.6460, 6450/9984) Loss: 5.7438 (5.8346)
[BlockLinear] Acc: 0.1094 (0.0999, 997/9984) Loss: 7.0619 (6.4486)
[BlockLinear] Acc: 0.1094 (0.0999, 997/9984) Loss: 3.9911 (3.6814)
Epoch Start: 11
[Info][Train Epoch 11/150][Batch 780/781] [loss 2.4988] [acc 0.0966]
[Sequential] Acc: 0.6250 (0.6508, 32529/49984) Loss: 5.7883 (5.8365)
[BlockConv] Acc: 0.6562 (0.6622, 33100/49984) Loss: 5.6041 (5.7574)
[BlockConv] Acc: 0.7344 (0.6893, 34456/49984) Loss: 5.5137 (5.5632)
[BlockConv] Acc: 0.7500 (0.7034, 35161/49984) Loss: 5.4319 (5.4845)
[BlockConv] Acc: 0.7344 (0.7206, 36018/49984) Loss: 5.2818 (5.3811)
[BlockConv] Acc: 0.7188 (0.7278, 36378/49984) Loss: 5.2343 (5.3480)
[BlockConv] Acc: 0.7344 (0.7263, 36301/49984) Loss: 5.2226 (5.3628)
[BlockConv] Acc: 0.7344 (0.7227, 36123/49984) Loss: 5.2093 (5.3822)
[BlockConv] Acc: 0.7188 (0.7104, 35508/49984) Loss: 5.3145 (5.4582)
[BlockConv] Acc: 0.6719 (0.7003, 35003/49984) Loss: 5.3798 (5.5159)
[BlockConv] Acc: 0.6875 (0.6865, 34314/49984) Loss: 5.4960 (5.5765)
[BlockConv] Acc: 0.6875 (0.6848, 34231/49984) Loss: 5.6059 (5.5833)
[BlockConv] Acc: 0.6875 (0.6777, 33873/49984) Loss: 5.6564 (5.6189)
[BlockLinear] Acc: 0.1094 (0.1033, 5162/49984) Loss: 6.4103 (7.3552)
[BlockLinear] Acc: 0.1094 (0.1023, 5114/49984) Loss: 3.9425 (4.4211)
[Info][Test Epoch 11/150] [loss 2.2822] [acc 0.1458]
[Sequential] Acc: 0.4844 (0.5957, 5947/9984) Loss: 6.3106 (6.0652)
[BlockConv] Acc: 0.6094 (0.6485, 6475/9984) Loss: 5.9859 (5.8226)
[BlockConv] Acc: 0.6094 (0.6655, 6644/9984) Loss: 5.9259 (5.6847)
[BlockConv] Acc: 0.6094 (0.6807, 6796/9984) Loss: 5.9007 (5.6211)
[BlockConv] Acc: 0.6875 (0.7008, 6997/9984) Loss: 5.7632 (5.5251)
[BlockConv] Acc: 0.7031 (0.7065, 7054/9984) Loss: 5.7509 (5.4888)
[BlockConv] Acc: 0.7031 (0.7086, 7075/9984) Loss: 5.7585 (5.4753)
[BlockConv] Acc: 0.6719 (0.7058, 7047/9984) Loss: 5.8051 (5.5021)
[BlockConv] Acc: 0.6719 (0.6929, 6918/9984) Loss: 5.9003 (5.5846)
[BlockConv] Acc: 0.6250 (0.6817, 6806/9984) Loss: 5.9636 (5.6416)
[BlockConv] Acc: 0.6250 (0.6809, 6798/9984) Loss: 5.9051 (5.6140)
[BlockConv] Acc: 0.6562 (0.6821, 6810/9984) Loss: 5.9682 (5.6108)
[BlockConv] Acc: 0.6875 (0.6730, 6719/9984) Loss: 6.0381 (5.6816)
[BlockLinear] Acc: 0.0938 (0.1001, 999/9984) Loss: 6.1299 (5.7286)
[BlockLinear] Acc: 0.0938 (0.1001, 999/9984) Loss: 3.6390 (3.6205)
`
Also, I thought that one possible reason might be the dimension of the projection, so I tried different input_ch : 64, 128, and 256. But the result is always the same, i.e. test loss = 8.3178 for Imagenette on each layer in each epoch.
from signalpropagation.
Case solved. For me, I should change the shuffle param from False to True in the test loader.
test_loader =torch.utils.data.DataLoader(test_data,batch_size=args.batch_size,shuffle=True, **kwargs)
MNIST and CIFAR10 are downloaded directly from torch. The test samples are already shuffled, and therefore the code provided in the repository works well though the shuffle arg in the test loader is False. But I generated test samples from ImageFolder, and in this case, the test labels are the same in one batch. Generally speaking in the test phase, the order of samples has no influence on test accuracy. But it is not the case this time.
By the way, I think it works pretty well on VGG16 and Imagenette. The network has no problem generalizing to complex datasets.
If there's anything wrong, please let me know. Thanks.
from signalpropagation.
@Lily-Le @rufaelfekadu , Since this is solved, I am closing this issue.
from signalpropagation.
Related Issues (3)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from signalpropagation.