Comments (1)
Hi, plz refer to #5 (comment)
There are little changes in my codebase, you can take the code below as a reference:
def set_loader(self):
opt = self.opt
dataset_name = opt.dataset
if dataset_name not in dataset_dict.keys():
mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
self.logger.msg_str(f'Dataset does {dataset_name} not exist in dataset_dict,'
f' use default normalizations: mean {str(mean)}, std {str(std)}.')
else:
mean, std = dataset_dict[dataset_name]
normalize = transforms.Normalize(mean=mean, std=std, inplace=True)
train_transform = self.train_transform(normalize)
self.logger.msg_str('set transforms...')
self.logger.msg_str(train_transform)
self.logger.msg_str('set train and unlabeled dataloaders...')
train_loader, labels, train_sampler = self.build_dataloader(
transform=train_transform,
batch_size=opt.batch_size,
shuffle=True,
drop_last=True,
sampler=True,
train=True)
unlabeled_loader, _, unlabeled_sampler = self.build_dataloader(
transform=train_transform,
batch_size=opt.batch_size,
shuffle=True,
drop_last=True,
sampler=True,
train=False)
test_transform = []
if 'imagenet' in dataset_name:
test_transform += [transforms.Resize(256), transforms.CenterCrop(224)]
test_transform += [
transforms.Resize(opt.img_size),
transforms.ToTensor(),
normalize
]
test_transform = transforms.Compose(test_transform)
# self.logger.msg_str('set test dataloaders...')
# test_loader = self.build_dataloader(test_transform, train=False, batch_size=opt.batch_size)[0]
self.logger.msg_str('set memory dataloaders...')
memory_loader = self.build_dataloader(test_transform, train=True, batch_size=opt.batch_size, sampler=True)[0]
# self.test_loader = test_loader
self.train_loader = train_loader
self.memory_loader = memory_loader
self.unlabeled_loader = unlabeled_loader
self.train_sampler = train_sampler
self.unlabeled_sampler = unlabeled_sampler
self.iter_per_epoch = len(train_loader) + len(unlabeled_loader)
self.num_classes = len(np.unique(labels))
self.num_samples = len(labels)
self.indices = torch.zeros(len(self.train_sampler), dtype=torch.long).cuda()
self.num_cluster = self.num_classes if opt.num_cluster is None else opt.num_cluster
self.psedo_labels = torch.zeros((self.num_samples,)).long().cuda()
self.logger.msg_str('load {} images...'.format(self.num_samples))
def fit(self):
opt = self.opt
# training routine
self.progress_bar = tqdm.tqdm(total=self.iter_per_epoch * opt.epochs, disable=(self.rank != 0))
n_iter = self.iter_per_epoch * opt.resume_epoch + 1
self.progress_bar.update(n_iter)
max_iter = opt.epochs * self.iter_per_epoch
while True:
epoch = int(n_iter // self.iter_per_epoch + 1)
self.train_sampler.set_epoch(epoch)
self.unlabeled_sampler.set_epoch(epoch)
for inputs in self.unlabeled_loader:
inputs = convert_to_cuda(inputs)
self.train_unlabeled(inputs, n_iter)
self.progress_bar.refresh()
self.progress_bar.update()
n_iter += 1
apply_kmeans = epoch % opt.reassign == 0
if apply_kmeans:
self.psedo_labeling(n_iter)
self.indices.copy_(torch.Tensor(list(iter(self.train_sampler))))
for inputs in self.train_loader:
inputs = convert_to_cuda(inputs)
self.adjust_learning_rate(n_iter)
self.train(inputs, n_iter)
self.progress_bar.refresh()
self.progress_bar.update()
n_iter += 1
# if epoch % opt.save_freq == 0:
# self.logger.checkpoints(int(epcoch))
# self.test(n_iter)
if n_iter > max_iter:
break
def train(self, inputs, n_iter):
opt = self.opt
images, labels = inputs
self.byol.train()
im_q, im_k = images
_start = ((n_iter - 1) % self.iter_per_epoch - len(self.unlabeled_loader)) * opt.batch_size
indices = self.indices[_start: _start + opt.batch_size]
psedo_labels = self.psedo_labels[indices]
# compute loss
contrastive_loss, cluster_loss_batch, q = self.byol(
im_q, im_k, psedo_labels)
loss = contrastive_loss
if ((n_iter - 1) / self.iter_per_epoch) > opt.warmup_epochs:
loss += cluster_loss_batch * opt.cluster_loss_weight
self.optimizer.zero_grad()
# SGD
loss.backward()
self.optimizer.step()
with torch.no_grad():
q_std = torch.std(q.detach(), dim=0).mean()
outputs = [contrastive_loss, cluster_loss_batch, q_std]
self.logger.msg(outputs, n_iter)
def train_unlabeled(self, inputs, n_iter):
opt = self.opt
images, labels = inputs
self.byol.train()
im_q, im_k = images
# compute loss
unlabeled_contrastive_loss, _, q = self.byol(im_q, im_k, None)
self.optimizer.zero_grad()
# SGD
unlabeled_contrastive_loss.backward()
self.optimizer.step()
with torch.no_grad():
unlabeled_q_std = torch.std(q.detach(), dim=0).mean()
outputs = [unlabeled_contrastive_loss, unlabeled_q_std]
self.logger.msg(outputs, n_iter)
from propos.
Related Issues (16)
- About spherical k-means implementation HOT 2
- About Performance on Imagenetdogs HOT 4
- Could u please share the config file for STL? HOT 2
- Got 65% ACC for BYOL on ImageNetdogs HOT 2
- About cifar10 performance on SimSiam HOT 5
- How to change parameters when num_device =2 HOT 1
- How to train ProPos on the STL-10 dataset?
- what does the memory data loader do in the basic_template.py? HOT 1
- How to obtain the clustering result after training? HOT 3
- how to reproduce pcl HOT 2
- Can't Reproduce Result in CIFAR-20 HOT 14
- How to generate pseudo label for unlabeled data in STL-10 and use them in PSL term? HOT 1
- I can't get the results in the paper using the pretrained models you provided HOT 7
- Results on CIFAR10 with ResNet34 HOT 1
- Training speed of ImageNet-10 HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from propos.