Comments (8)
added in #264
sorry for the trouble, thanks !
from chainermn.
Sorry for the late response. I will investigate the issue soon.
from chainermn.
@Guriido It seems rather serializer issue in ExponentialShift
rather than ChainerMN's checkpointer, as it just calls trainer serializer (and consequently all owned objects). Could you make sure that the issue is not reproducible without ChainerMN? If so, I'd be happy to test it once minimal reproducible script be provided.
from chainermn.
Sorry for the late answer. I tested many things but couldn't reproduce the issue without ChainerMN (I cannot affirm my test were exhaustive though).
With the following script (a modified version of mnist example ), thanks to the custom trigger, the learning rate is shifted after the second epoch. If I stop the training afterwards (at the fourth epoch for example) and resume the training by running the same script (with the same parameters of course), the learning rate is reset to the initial value ( 0.1
) and not the expected shifted value ( 0.01
)
#!/usr/bin/env python
from __future__ import print_function
import argparse
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
from mpi4py import MPI
import chainermn
from chainer.training import util
class MLP(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__(
# the size of the inputs to each layer will be inferred
l1=L.Linear(784, n_units), # n_in -> n_units
l2=L.Linear(n_units, n_units), # n_units -> n_units
l3=L.Linear(n_units, n_out), # n_units -> n_out
)
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
def main():
parser = argparse.ArgumentParser(description='ChainerMN example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--communicator', type=str,
default='hierarchical', help='Type of communicator')
parser.add_argument('--epoch', '-e', type=int, default=60,
help='Number of sweeps over the dataset to train')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
parser.add_argument('--double_buffering', action='store_true', help='improves speed')
args = parser.parse_args()
# Prepare ChainerMN communicator.
if args.double_buffering:
args.communicator = 'pure_nccl'
comm = chainermn.create_communicator(args.communicator)
device = comm.intra_rank
if comm.mpi_comm.rank == 0:
print('==========================================')
print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size()))
if args.gpu:
print('Using GPUs')
print('Using {} communicator'.format(args.communicator))
print('Num unit: {}'.format(args.unit))
print('Num Minibatch-size: {}'.format(args.batchsize))
print('Num epoch: {}'.format(args.epoch))
print('==========================================')
model = L.Classifier(MLP(args.unit, 10))
if device >= 0:
chainer.cuda.get_device(device).use()
model.to_gpu()
initial_lr = 0.1
# Create a multi node optimizer from a standard Chainer optimizer.
optimizer = chainermn.create_multi_node_optimizer(
chainer.optimizers.MomentumSGD(lr=initial_lr, momentum=0.9), comm, double_buffering=args.double_buffering)
optimizer.setup(model)
# Split and distribute the dataset. Only worker 0 loads the whole dataset.
# Datasets of worker 0 are evenly split and distributed to all workers.
if comm.rank == 0:
train, test = chainer.datasets.get_mnist()
else:
train = None
test = None
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm)
train_iter = chainer.iterators.SerialIterator(train, args.batchsize, shuffle=False)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, device=device)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
# Create a multi node evaluator from a standard Chainer evaluator.
evaluator = extensions.Evaluator(test_iter, model, device=device)
evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
trainer.extend(evaluator)
checkpointer = chainermn.create_multi_node_checkpointer(
name='mnist-example', comm=comm)
checkpointer.maybe_load(trainer, optimizer)
trainer.extend(checkpointer, trigger=(1, 'epoch'))
trainer.extend(extensions.ExponentialShift(
'lr', 0.1), trigger=LRShiftTrigger())
# Some display and output extensions are necessary only for one worker.
# (Otherwise, there would just be repeated outputs.)
if comm.rank == 0:
trainer.extend(extensions.LogReport())
trainer.extend(extensions.observe_lr(), trigger=(1, 'epoch'))
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time', 'lr']))
trainer.extend(extensions.ProgressBar())
trainer.run()
class LRShiftTrigger(object):
"""Trigger invoked on specific epoch defined by ResNet Paper author
Args:
key (str): Key of value.
compare (function): Compare function which takes current best value and
new value and returns whether new value is better than current
best.
trigger: Trigger that decides the comparison interval between current
best value and new value. This must be a tuple in the form of
``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
:class:`~chainer.training.triggers.IntervalTrigger`.
"""
triggers = [2, 50]
def __init__(self):
self._interval_trigger = util.get_trigger((1, 'epoch'))
def __call__(self, trainer):
"""Decides whether the extension should be called on this iteration.
Args:
trainer (~chainer.training.Trainer): Trainer object that this
trigger is associated with. The ``observation`` of this trainer
is used to determine if the trigger should fire.
Returns:
bool: ``True`` if the corresponding extension should be invoked in
this iteration.
"""
if not self._interval_trigger(trainer):
return False
epoch = trainer.updater.epoch
return epoch in LRShiftTrigger.triggers
if __name__ == '__main__':
main()
My environment settings (using Docker):
Ubuntu 16.04
python 3.5.2
Open MPI 2.1.2 with infiniband
mpi4py 3.0.0
chainermn 1.2.0
chainer 4.0.0b3
If you need any other informations, I will be more than happy to help
from chainermn.
@Guriido Thank you for reporting, and for your effort to cut out a reproducible script! I successfully (?) reproduced the bug and will work on it.
from chainermn.
I digged in a little bit deeper, I found optimizer's update rules states are actually not saved. I'll keep this open while chainer/chainer#4749 is open.
from chainermn.
@Guriido I have understood what is going on in your example code.
checkpointer = chainermn.create_multi_node_checkpointer(
name='mnist-example', comm=comm)
checkpointer.maybe_load(trainer, optimizer)
trainer.extend(checkpointer, trigger=(1, 'epoch'))
trainer.extend(extensions.ExponentialShift(
'lr', 0.1), trigger=LRShiftTrigger())
In your code the exponential shift extension is set after loading the snapshot. So the re-loaded trainer has correct learning rate but the later extension injection overwrites the reloaded learning rate with initial 0.1. I'd recommend to put snapshot-related code right before trainer.run()
to avoid any stateful extensions' initialization after checkpoint reload like this:
trainer.extend(extensions.ExponentialShift(
'lr', 0.1), trigger=LRShiftTrigger())
(snip)
checkpointer = chainermn.create_multi_node_checkpointer(
name='mnist-example', comm=comm)
checkpointer.maybe_load(trainer, optimizer)
trainer.extend(checkpointer, trigger=(1, 'epoch'))
trainer.run()
This code correctly worked in my environment.
from chainermn.
Thank you very much for your time and efforts.
I indeed did not expect ExponentialShift extension to affect the trainer status...
Maybe would it be proficient to put a warning or a note about this in the multi_node_checkpointer
documentation?
There is obvously a remark about calling it before trainer.run(), but I think it could avoid troubles for users if there is a mention like "it is recommended to call the load right before running trainer".
from chainermn.
Related Issues (20)
- Don't inicialize global NCCL comm when HOT 2
- Adding allreduce for ndarray HOT 10
- mpirun doesn't exit when exception is thrown in some process HOT 7
- Asynchronous Allreduce HOT 2
- Handle list of dicts in MultiNodeIterator HOT 1
- would you please share hype parameters of GPUs=4 for resnet50 training with us ? HOT 23
- Expose `intra_size`, `inter_rank` and `inter_size` of communicators at readthedocs
- Provide functions for allreduce
- Manual selection for gpus in distributed training HOT 5
- CommunicatorBase.{scatter, allgather} is missing in the document
- Add `force_equal_length` flag to `scatter_dataset` method
- optimizer.setup() created by create_multi_node_optimizer returns an original optimizer HOT 2
- FP16 support HOT 1
- Forcing forkserver spawn earlier HOT 2
- When `in_size=None` is used in `Liner` and it is not used, an error occurs
- NCCL_ERROR_SYSTEM_ERROR: unhandled system error HOT 3
- CUDA streams usage HOT 6
- Non-Blocking Methodology on ChainerMN HOT 3
- Installation should do nothing but omit a warning.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from chainermn.