Giter Site home page Giter Site logo

Comments (10)

ArthurZucker avatar ArthurZucker commented on July 24, 2024 2

Would be nice to have this merged then!

from transformers.

ri938 avatar ri938 commented on July 24, 2024

looks like the issue is that torch.manual seed is used by both nn.Dropout and by the data loader.

from transformers.

ri938 avatar ri938 commented on July 24, 2024

it seems that the data_seed argument is not used but should be able to set the seed here for the random sampler

since we can only influence Dropout via torch.manual_seed, can I implement a change so that data_seed is used in order to seed the RandomSamper? or will this be not backwards compatible in which case I can add a new argument to do this and deprecate random seed?

from transformers.

younesbelkada avatar younesbelkada commented on July 24, 2024

Hi @ri938
thanks for this interesting issue, I am not really familiar with the way accelerate sets the seed for the data sampler. I am also not sure how to do you set both the seed for dropout and the sampler in your code, could you share more details about that ?

from transformers.

ri938 avatar ri938 commented on July 24, 2024

So I set the seed on startup to the same value "100" on each device

def set_training_seed(seed):
    from transformers import set_seed
    set_seed(seed)

this ensures that each devices has the same init of weights before training starts

then I set the seed in the TrainerArguments which gets passed to the Trainer to a constant value "100" too

trainer_args = TrainingArguments(seed=100, **kwargs)`

from transformers.

ri938 avatar ri938 commented on July 24, 2024

the seed

torch.manual_seed(x)

is what impacts dropout. It also impacts the RandomSampler.

And therefore there is no way to ensure that dropout masks vary across devices without also breaking the data ordering on each device which requires the same seed to be set.

I would argue this is potentially an issue impacting many training runs for many users. Therefore there should be both a way to avoid this issue and also a warning message or error to prevent people training unaware of it.

from transformers.

RUFFY-369 avatar RUFFY-369 commented on July 24, 2024

it seems that the data_seed argument is not used but should be able to set the seed here for the random sampler

Hi @ri938 you are right, the class variable data_seed is not used and set_seed is used for both data sampling and training.
Please refer the discussion in the #31255 issue

from transformers.

ri938 avatar ri938 commented on July 24, 2024

Yes, I was suggesting that if we used data_seed for the data sampling then this could be used to fix this issue. But this would break backwards compatibility.

Here is another image to illustrate the problem. When training gpt2 the gradient norms are huge when you use the same seed for each device. But when you vary the seed for each device its more sensible.

W B Chart 12_06_2024, 22_02_30 (1)

from transformers.

ri938 avatar ri938 commented on July 24, 2024

This is the workaround I am using to fix this issue

I am adding a callback

class SeedDeviceRandomlyCallback(TrainerCallback):

    def on_train_begin(self, args, state, control, **kwargs):
        global_rank = int(os.environ['RANK'])
        new_seed = args.seed + global_rank
        print('Setting torch seed to {} on device {}'.format(new_seed, global_rank))
        torch.manual_seed(new_seed)

Because you have to set the seed to be different after get_train_dataloader has been called in order to not break data ordering.

from transformers.

ri938 avatar ri938 commented on July 24, 2024

After applying just this one callback. This is a demonstration of how much it improved performance

W B Chart 14_06_2024, 17_51_16

W B Chart 14_06_2024, 17_51_30

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.