hongyi-zhang / mixup Goto Github PK
View Code? Open in Web Editor NEWImplementation of the mixup training method
License: BSD 3-Clause "New" or "Revised" License
Implementation of the mixup training method
License: BSD 3-Clause "New" or "Revised" License
Thanks for a very interesting paper!! I was trying out your code and was not able to get the results posted in paper but got the results shown on GitHub page.
CIFAR10 PreAct ResNet-18 results of 3.9% are reported in the paper. But the GitHub page gives it 4.24%. Are there any other setting to change like epochs, weight decay etc?
Thanks
I ran the command as provided in readme:
python easy_mixup.py --sess my_session_1 --seed 11111
Since, I am using pytorch >= 1.0
I had to fix the loss calculation:
train_loss += loss.data[0]
to train_loss += loss.data.item()
and test_loss += loss.data[0]
to test_loss += loss.data.item()
The maximum accuracy I got was 93.90% which translates to an error of 6.1
What am I doing wrong here?
在每个batch中使用mixup方法,然后返回mixup后的样本,那么请问下这个batch中的样本数量是没有变吗,那整个训练集的样本数量是没有变吗?
After 100 epoch lr *= 0.1 on every epoch so quickly it become zero.
After 150 epoch lr *= 0.01 on every epoch.
if epoch >= 100:
lr /= 10
if epoch >= 150:
lr /= 10
Ubuntu system restarts at loss.backward(). I am using Ubuntu 16.04 and cuda 9.0. Could assistance be provided to resolve this issue? Much appreciated.
I've tried to run the code inside colab. however, it shows the following Error:
Do I need to have any specific things in my colab? Thank you
0%| | 10/20001 [00:00<08:23, 39.67it/s]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-2-51480497a079>](https://localhost:8080/#) in <module>
156 'samples/example_z=%d_%s_%1.1f_%06d.pt' %
157 (n_latent, shape, mixup, iteration))
--> 158 plot(plot_real, plot_fake, mixup, iteration)
[<ipython-input-2-51480497a079>](https://localhost:8080/#) in plot(x, y, mixup, iteration)
51 plt.xlim(*lims)
52 plt.ylim(*lims)
---> 53 plt.tight_layout(0, 0, 0)
54 plt.show()
55 plt.savefig("images/example_z=%d_%s_%1.1f_%06d.png" %
TypeError: tight_layout() takes 0 positional arguments but 3 were given
HI I have tried to implement your code in my experiment on UCF101.
but I did't get any improvement and always get the zero correct.
It is a little strange.
my code
``
def train_1epoch(self):
print('==> Epoch:[{0}/{1}][training stage]'.format(self.epoch, self.nb_epochs))
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter() #switch to train mode
self.model.train()
end = time.time()
train_loss = AverageMeter()
total = 0
correct = 0
# mini-batch training
progress = tqdm(self.train_loader)
for i, (data_dict,label) in enumerate(progress):
# measure data loading time
data_time.update(time.time() - end)
# generate mixed inputs, two one-hot label vectors and mixing coefficient
# transfer the label into one-hot Encoder
# label = torch.zeros(label.shape[0], 101).scatter_(1, label.reshape(-1, 1), 1).cuda()
# print(label.shape[0])
label = label.cuda()
# compute output
output = Variable(torch.zeros(len(data_dict['img1']),101).float()).cuda()
# print(len(data_dict['img1'])
for i in range(len(data_dict)):
key = 'img'+str(i)
input_var = (data_dict[key]).cuda()
# generate mixed inputs, two one-hot label vectors and mixing coefficient
input_var, label_a, label_b, lam = mixup_data(input_var, label, args.alpha, True)
input_var, label_a, label_b = Variable(input_var), Variable(label_a), Variable(label_b)
output += self.model(input_var)
criterion = self.criterion
loss = mixup_criterion(criterion, output, label_a, label_b, lam)
# print(label_a.argmax(dim=1).data,label_b.argmax(dim=1).data)
# loss = loss_func(criterion, output)
# print(args.alpha, lam)
# compute gradient and do SGD step
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# measure accuracy and record loss
# train_loss += loss.data[0]
train_loss.update(loss.data[0], data_dict[key].size(0))
# print(loss.data[0])
_, predicted = torch.max(output.data, 1)
total += label.size(0)
# print(label.size(0))
correct += lam * predicted.eq(label_a.data).cpu().sum() + (1 - lam) * predicted.eq(label_b.data).cpu().sum()
# print(predicted.eq(label_a.argmax(dim=1).data).cpu().sum())
# print(label_a.data)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
info = {'Epoch':[self.epoch],
'Batch Time':[round(batch_time.avg,3)],
'Data Time':[round(data_time.avg,3)],
'Loss':[round(train_loss.avg,5)],
'correct':[round(correct,4)],
'Prec@1':[round(correct/total,4)],
'Prec@5':[round(correct/total,4)],
'lr': self.optimizer.param_groups[0]['lr'],
'weight-decay': args.decay
}
record_info(info, 'record/spatial/rgb_train.csv','train')
``
I did't know where is wrong....
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.