Comments (9)
The issue with the loss spikes in the sentiment control notebook was that sometimes only a few new tokens would be generated (1-2) and this would cause the loss to spike. Not sure, yet, where exactly this behaviour comes from but we now know where to look: we can actively generate short sequences and investigate what causes the loss explosion.
from trl.
@younesbelkada your idea makes sense.
Some follow ups:
- @lvwerra what experiment setup was this? I'd love to dig further.
- what does a clip frac of .55 mean, is that half of the value samples are clipped in the PPO update? Or am I off by a factor of 100?
Below is musings on PPO stability:
- Thread from stable baselines, suggests entropy coefficient was way too high (different domain than RLHF)
(will add more if I find it)
The more I look, there is surely some numerical instability in the loss computation at that step (NaN), which is impressive it recovers from. I'm thinking about what is the right intermediate values to log (maybe optionally). Can we do something that if there is a NaN or a big loss value, we dump a bunch of values to the logger? I am sure we will see things like this when doing more RLHF.
3. How should we configure the logger for a rich researchy-approach (lots of unknowns).
from trl.
I also experienced the spike loss in my case. I'm using the seq2seq t5 model as the backbone. The model is initialized with a supervised finetuned model. I find that the spike loss comes from steps that have a negative advantage and an extremely high ratio r(\theta). This falls in the situation 6 in the figure below.
In my case, removing pg_losses1 and only keeping the clipped pg_losses2 can help restrict the ratio and stabilize the loss. I didn't train the model from scratch, so the clip fraction is low (less than 3%). But this is a problem if the clip fraction is too high and most of the loss is clipped. It's not a general solution though, just some findings from my case.
from trl.
One idea could be that we don't mask out the logits corresponding to padding tokens when computing the loss, it is something I am having a look in #100 - But I am not sure here if this is really the rootcause of this
from trl.
Yeah, so something weird is going one with a simultaneous large drop in entropy, clip fraction, etc. Can we log the model outputs at that step? Is there any chance the model output gets stuck on something?
from trl.
I also observed a spike in policy loss when running sentiment-control example, and I initially thought it's because of some strange samples or high variance in positive logits.
And I found this : pipeline doesn't always output 'POSITIVE' logit at 1 index.
and in the notebook, output[1]['score']
is considered as a positive logit and fed into the PPOTrainer. I guess this causes unstable training because reward signal is not valid. Am I making sense?
btw, I didn't realize this and run several experiments with changed reward definitions (that uses both positive and negative logits) and reward_mean wasn't increasing as training goes on.
I'll report further experiment results at #120
from trl.
I corrected parsing pipeline output and loss spike still remains in sentiment-control notebook example.
so there may be another reaseon for this unstability.
from trl.
Thanks @DaehanKim, yes there is an issue besides the order of the logits. I tracked it down to some changes done in #80 (no spikes at the beginning of the PR and spikes at time of merge) and I started tracking the issue down in #126. I'll report as well here if I figure it out!
from trl.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
from trl.
Related Issues (20)
- PPOTrainer Appears to incorrectly handle `pad_token_id` HOT 1
- Prompt format clarification for ORPO
- TrlParser not working
- excessive RAM usage with quantized base model and LORA with SFTTrainer HOT 3
- ValueError: The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models. HOT 1
- Warning message in `DataCollatorForCompletionOnlyLM` is misleading when only `response_template` is missing in the batch
- `RichProgressCallback` would break model evaluation and prediction HOT 1
- Multi-GPU Training with DPO Full Parameter Stucks HOT 1
- UserWarning: Could not find a config file
- concatenated_forward when self.ref_model is not provided HOT 1
- which model should i choose if i wanna try DPO algorithm?
- PPOTrainer behavior with `device_map = "auto"` HOT 1
- None ref_model in ppo train HOT 1
- FSDP/ZeRO3 Support for QLoRA in DPO?
- Use `SFTTrainer` for completion-only model without `DataCollatorForCompletionOnlyLM`
- `disable_dropout` not used in KTOTrainer
- [question] how to apply model parallism to solve cuda memory error HOT 6
- misleading warning message HOT 1
- How should I set the SFT label? HOT 1
- dpo cli command error HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from trl.