Giter Site home page Giter Site logo

Spikes in PPO policy loss about trl HOT 9 CLOSED

huggingface avatar huggingface commented on July 30, 2024
Spikes in PPO policy loss

from trl.

Comments (9)

lvwerra avatar lvwerra commented on July 30, 2024 2

The issue with the loss spikes in the sentiment control notebook was that sometimes only a few new tokens would be generated (1-2) and this would cause the loss to spike. Not sure, yet, where exactly this behaviour comes from but we now know where to look: we can actively generate short sequences and investigate what causes the loss explosion.

from trl.

natolambert avatar natolambert commented on July 30, 2024 1

@younesbelkada your idea makes sense.

Some follow ups:

  1. @lvwerra what experiment setup was this? I'd love to dig further.
  2. what does a clip frac of .55 mean, is that half of the value samples are clipped in the PPO update? Or am I off by a factor of 100?

Below is musings on PPO stability:

  • Thread from stable baselines, suggests entropy coefficient was way too high (different domain than RLHF)
    (will add more if I find it)

The more I look, there is surely some numerical instability in the loss computation at that step (NaN), which is impressive it recovers from. I'm thinking about what is the right intermediate values to log (maybe optionally). Can we do something that if there is a NaN or a big loss value, we dump a bunch of values to the logger? I am sure we will see things like this when doing more RLHF.
3. How should we configure the logger for a rich researchy-approach (lots of unknowns).

from trl.

tengxiaoliu avatar tengxiaoliu commented on July 30, 2024 1

I also experienced the spike loss in my case. I'm using the seq2seq t5 model as the backbone. The model is initialized with a supervised finetuned model. I find that the spike loss comes from steps that have a negative advantage and an extremely high ratio r(\theta). This falls in the situation 6 in the figure below.
image

In my case, removing pg_losses1 and only keeping the clipped pg_losses2 can help restrict the ratio and stabilize the loss. I didn't train the model from scratch, so the clip fraction is low (less than 3%). But this is a problem if the clip fraction is too high and most of the loss is clipped. It's not a general solution though, just some findings from my case.

from trl.

younesbelkada avatar younesbelkada commented on July 30, 2024

One idea could be that we don't mask out the logits corresponding to padding tokens when computing the loss, it is something I am having a look in #100 - But I am not sure here if this is really the rootcause of this

from trl.

natolambert avatar natolambert commented on July 30, 2024

Yeah, so something weird is going one with a simultaneous large drop in entropy, clip fraction, etc. Can we log the model outputs at that step? Is there any chance the model output gets stuck on something?

from trl.

DaehanKim avatar DaehanKim commented on July 30, 2024

I also observed a spike in policy loss when running sentiment-control example, and I initially thought it's because of some strange samples or high variance in positive logits.

And I found this : pipeline doesn't always output 'POSITIVE' logit at 1 index.
순서바뀜

and in the notebook, output[1]['score'] is considered as a positive logit and fed into the PPOTrainer. I guess this causes unstable training because reward signal is not valid. Am I making sense?

btw, I didn't realize this and run several experiments with changed reward definitions (that uses both positive and negative logits) and reward_mean wasn't increasing as training goes on.
image

I'll report further experiment results at #120

from trl.

DaehanKim avatar DaehanKim commented on July 30, 2024

I corrected parsing pipeline output and loss spike still remains in sentiment-control notebook example.
so there may be another reaseon for this unstability.

image

from trl.

lvwerra avatar lvwerra commented on July 30, 2024

Thanks @DaehanKim, yes there is an issue besides the order of the logits. I tracked it down to some changes done in #80 (no spikes at the beginning of the PR and spikes at time of merge) and I started tracking the issue down in #126. I'll report as well here if I figure it out!

from trl.

github-actions avatar github-actions commented on July 30, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

from trl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.