Comments (10)
Would be nice to have this merged then!
from transformers.
looks like the issue is that torch.manual seed is used by both nn.Dropout and by the data loader.
from transformers.
it seems that the data_seed argument is not used but should be able to set the seed here for the random sampler
since we can only influence Dropout via torch.manual_seed, can I implement a change so that data_seed is used in order to seed the RandomSamper? or will this be not backwards compatible in which case I can add a new argument to do this and deprecate random seed?
from transformers.
Hi @ri938
thanks for this interesting issue, I am not really familiar with the way accelerate
sets the seed for the data sampler. I am also not sure how to do you set both the seed for dropout and the sampler in your code, could you share more details about that ?
from transformers.
So I set the seed on startup to the same value "100" on each device
def set_training_seed(seed):
from transformers import set_seed
set_seed(seed)
this ensures that each devices has the same init of weights before training starts
then I set the seed in the TrainerArguments which gets passed to the Trainer to a constant value "100" too
trainer_args = TrainingArguments(seed=100, **kwargs)`
from transformers.
the seed
torch.manual_seed(x)
is what impacts dropout. It also impacts the RandomSampler.
And therefore there is no way to ensure that dropout masks vary across devices without also breaking the data ordering on each device which requires the same seed to be set.
I would argue this is potentially an issue impacting many training runs for many users. Therefore there should be both a way to avoid this issue and also a warning message or error to prevent people training unaware of it.
from transformers.
it seems that the data_seed argument is not used but should be able to set the seed here for the random sampler
Hi @ri938 you are right, the class variable data_seed
is not used and set_seed
is used for both data sampling and training.
Please refer the discussion in the #31255 issue
from transformers.
Yes, I was suggesting that if we used data_seed for the data sampling then this could be used to fix this issue. But this would break backwards compatibility.
Here is another image to illustrate the problem. When training gpt2 the gradient norms are huge when you use the same seed for each device. But when you vary the seed for each device its more sensible.
from transformers.
This is the workaround I am using to fix this issue
I am adding a callback
class SeedDeviceRandomlyCallback(TrainerCallback):
def on_train_begin(self, args, state, control, **kwargs):
global_rank = int(os.environ['RANK'])
new_seed = args.seed + global_rank
print('Setting torch seed to {} on device {}'.format(new_seed, global_rank))
torch.manual_seed(new_seed)
Because you have to set the seed to be different after get_train_dataloader has been called in order to not break data ordering.
from transformers.
After applying just this one callback. This is a demonstration of how much it improved performance
from transformers.
Related Issues (20)
- Add `bot_token` attribute to `PreTrainedTokenizer` and `PreTrainedTokenizerFast` HOT 1
- Error when using AutoTokenizer to load local files without network
- LLava-Next example is broken HOT 8
- how to remove kv cache? HOT 8
- Phi3SmallForCausalLM missing? HOT 2
- Model loading OOM when using FSDP + QLoRA
- how to generate router_logits in moe models using model.generate()? HOT 4
- gguf dequantize failed HOT 9
- Mismatch of implement and comment of attention_mask in clip encoder? HOT 5
- cannot import get_full_repo_name from huggingface_hub after updating pytorch HOT 6
- Inconsistent special_token addition in EncoderDecoderModel forward pass
- Cannot find the best model after training HOT 1
- MPS support broken for T5 models HOT 1
- Pass `HFQuantizer` to `from_pretrained` kwargs HOT 1
- [i18n-<languageCode>] Translating docs to <languageName> HOT 1
- NumPy 2.0 support HOT 1
- Can I use "attn_implementation" in model config file HOT 4
- Encountering an error while loading a model using state_dict and quantization simultaneously HOT 6
- Fix 'Can't infer missing attention mask on `mps` device' HOT 4
- might be a waste of resources HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.