Giter Site home page Giter Site logo

Comments (8)

Guriido avatar Guriido commented on June 30, 2024 1

added in #264
sorry for the trouble, thanks !

from chainermn.

keisukefukuda avatar keisukefukuda commented on June 30, 2024

Sorry for the late response. I will investigate the issue soon.

from chainermn.

kuenishi avatar kuenishi commented on June 30, 2024

@Guriido It seems rather serializer issue in ExponentialShiftrather than ChainerMN's checkpointer, as it just calls trainer serializer (and consequently all owned objects). Could you make sure that the issue is not reproducible without ChainerMN? If so, I'd be happy to test it once minimal reproducible script be provided.

from chainermn.

Guriido avatar Guriido commented on June 30, 2024

Sorry for the late answer. I tested many things but couldn't reproduce the issue without ChainerMN (I cannot affirm my test were exhaustive though).

With the following script (a modified version of mnist example ), thanks to the custom trigger, the learning rate is shifted after the second epoch. If I stop the training afterwards (at the fourth epoch for example) and resume the training by running the same script (with the same parameters of course), the learning rate is reset to the initial value ( 0.1 ) and not the expected shifted value ( 0.01 )

#!/usr/bin/env python
from __future__ import print_function

import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
from mpi4py import MPI
import chainermn
from chainer.training import util


class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__(
            # the size of the inputs to each layer will be inferred
            l1=L.Linear(784, n_units),  # n_in -> n_units
            l2=L.Linear(n_units, n_units),  # n_units -> n_units
            l3=L.Linear(n_units, n_out),  # n_units -> n_out
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


def main():
    parser = argparse.ArgumentParser(description='ChainerMN example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--communicator', type=str,
                        default='hierarchical', help='Type of communicator')
    parser.add_argument('--epoch', '-e', type=int, default=60,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--double_buffering', action='store_true', help='improves speed')
    args = parser.parse_args()

    # Prepare ChainerMN communicator.
    if args.double_buffering:
        args.communicator = 'pure_nccl'

    comm = chainermn.create_communicator(args.communicator)
    device = comm.intra_rank

    if comm.mpi_comm.rank == 0:
        print('==========================================')
        print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size()))
        if args.gpu:
            print('Using GPUs')
        print('Using {} communicator'.format(args.communicator))
        print('Num unit: {}'.format(args.unit))
        print('Num Minibatch-size: {}'.format(args.batchsize))
        print('Num epoch: {}'.format(args.epoch))
        print('==========================================')

    model = L.Classifier(MLP(args.unit, 10))
    if device >= 0:
        chainer.cuda.get_device(device).use()
        model.to_gpu()

    initial_lr = 0.1

    # Create a multi node optimizer from a standard Chainer optimizer.
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.MomentumSGD(lr=initial_lr, momentum=0.9), comm, double_buffering=args.double_buffering)
    optimizer.setup(model)

    # Split and distribute the dataset. Only worker 0 loads the whole dataset.
    # Datasets of worker 0 are evenly split and distributed to all workers.
    if comm.rank == 0:
        train, test = chainer.datasets.get_mnist()
    else:
        train = None
        test = None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    test = chainermn.scatter_dataset(test, comm)

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize, shuffle=False)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

    updater = training.StandardUpdater(train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Create a multi node evaluator from a standard Chainer evaluator.
    evaluator = extensions.Evaluator(test_iter, model, device=device)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    trainer.extend(evaluator)

    checkpointer = chainermn.create_multi_node_checkpointer(
        name='mnist-example', comm=comm)
    checkpointer.maybe_load(trainer, optimizer)
    trainer.extend(checkpointer, trigger=(1, 'epoch'))

    trainer.extend(extensions.ExponentialShift(
        'lr', 0.1), trigger=LRShiftTrigger())

    # Some display and output extensions are necessary only for one worker.
    # (Otherwise, there would just be repeated outputs.)
    if comm.rank == 0:
        trainer.extend(extensions.LogReport())
        trainer.extend(extensions.observe_lr(), trigger=(1, 'epoch'))
        trainer.extend(extensions.PrintReport(
            ['epoch', 'main/loss', 'validation/main/loss',
             'main/accuracy', 'validation/main/accuracy', 'elapsed_time', 'lr']))
        trainer.extend(extensions.ProgressBar())

    trainer.run()


class LRShiftTrigger(object):

    """Trigger invoked on specific epoch defined by ResNet Paper author

    Args:
        key (str): Key of value.
        compare (function): Compare function which takes current best value and
            new value and returns whether new value is better than current
            best.
        trigger: Trigger that decides the comparison interval between current
            best value and new value. This must be a tuple in the form of
            ``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
            :class:`~chainer.training.triggers.IntervalTrigger`.

    """

    triggers = [2, 50]

    def __init__(self):
        self._interval_trigger = util.get_trigger((1, 'epoch'))

    def __call__(self, trainer):
        """Decides whether the extension should be called on this iteration.

        Args:
            trainer (~chainer.training.Trainer): Trainer object that this
                trigger is associated with. The ``observation`` of this trainer
                is used to determine if the trigger should fire.

        Returns:
            bool: ``True`` if the corresponding extension should be invoked in
            this iteration.

        """

        if not self._interval_trigger(trainer):
            return False

        epoch = trainer.updater.epoch
        return epoch in LRShiftTrigger.triggers


if __name__ == '__main__':
    main()

My environment settings (using Docker):
Ubuntu 16.04
python 3.5.2
Open MPI 2.1.2 with infiniband
mpi4py 3.0.0
chainermn 1.2.0
chainer 4.0.0b3

If you need any other informations, I will be more than happy to help

from chainermn.

kuenishi avatar kuenishi commented on June 30, 2024

@Guriido Thank you for reporting, and for your effort to cut out a reproducible script! I successfully (?) reproduced the bug and will work on it.

from chainermn.

kuenishi avatar kuenishi commented on June 30, 2024

I digged in a little bit deeper, I found optimizer's update rules states are actually not saved. I'll keep this open while chainer/chainer#4749 is open.

from chainermn.

kuenishi avatar kuenishi commented on June 30, 2024

@Guriido I have understood what is going on in your example code.

    checkpointer = chainermn.create_multi_node_checkpointer(
        name='mnist-example', comm=comm)
    checkpointer.maybe_load(trainer, optimizer)
    trainer.extend(checkpointer, trigger=(1, 'epoch'))

    trainer.extend(extensions.ExponentialShift(
        'lr', 0.1), trigger=LRShiftTrigger())

In your code the exponential shift extension is set after loading the snapshot. So the re-loaded trainer has correct learning rate but the later extension injection overwrites the reloaded learning rate with initial 0.1. I'd recommend to put snapshot-related code right before trainer.run() to avoid any stateful extensions' initialization after checkpoint reload like this:

    trainer.extend(extensions.ExponentialShift(
        'lr', 0.1), trigger=LRShiftTrigger())

(snip)

    checkpointer = chainermn.create_multi_node_checkpointer(
        name='mnist-example', comm=comm)
    checkpointer.maybe_load(trainer, optimizer)
    trainer.extend(checkpointer, trigger=(1, 'epoch'))

    trainer.run()

This code correctly worked in my environment.

from chainermn.

Guriido avatar Guriido commented on June 30, 2024

Thank you very much for your time and efforts.
I indeed did not expect ExponentialShift extension to affect the trainer status...

Maybe would it be proficient to put a warning or a note about this in the multi_node_checkpointer documentation?
There is obvously a remark about calling it before trainer.run(), but I think it could avoid troubles for users if there is a mention like "it is recommended to call the load right before running trainer".

from chainermn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.