Giter Site home page Giter Site logo

Comments (2)

thomwolf avatar thomwolf commented on May 1, 2024 1

Hi,

Thanks for the feedback, it's always interesting to compare the various possible ways to train the model indeed.

The most likely cause for (2) is that MRPC is a small dataset and the model shows a high variance in the results depending on the initialization of the weights for example (see the original BERT repo on that also). The distributed and multi-gpu setups probably do not use the random generators in the exact same order which lead to different initializations.

You can have an intuition of that by training with different seeds, you will see there is easily a 10% variation in the final accuracy...

If you can do that, a better way to compare the results would thus be to take something like 10 different seeds for each training condition and compare the mean and standard deviation of the results.

from transformers.

llidev avatar llidev commented on May 1, 2024

Thanks for your feedback!

After some investigations, it looks like t_total is not set properly for distributed training in BertAdam. The actual t_total per distributed worker should be divided by the worker count.

I have included the following fix in my PR #58

    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=t_total)

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.