Comments (7)
from transformers.
Hm, maybe I misunderstand the problem. My understanding is that what we are focused on is that when the Trainer is loading from a checkpoint, it calls skip_first_batches to skip past the beginning of the dataset until the DataLoader iterator is pointing to where it was at that checkpoint.
And for an IterableDataset, the way this is done under the hood is that it has to manually loop over the items to iterate. And StatefulDataLoader may solve this problem by allowing one to call load_state_dict somewhere in the Trainer while loading the checkpoint, and writing the StatefulDataLoader's state dict to the checkpoint.
This process involves returning self-defined classes like DataLoaderShard to handle cases involving distributed data dispatch.
Yes, it seems like DataLoaderShard and DataLoaderDispatcher are created in the prepare_data_loader function and skip_first_batches function in the accelerate library. These classes are both subclasses of DataLoader, so likely need to be modified or copied to extend from StatefulDataLoader
So IIUC, it seems maybe the implementation of this feature would involve the following steps?
- In the
accelerate
library, add either refactor DataLoaderShard and DataLoaderDispatcher to compose or add new variants that inherit from a StatefulDataLoader. - In the
Trainer
class, allow dropping in StatefulDataLoader instead of a regular DataLoader - Also in the
Trainer
class, support loading and saving the state_dict to and from the checkpoint
Thanks for point this out. I still might not be understanding correctly. Maybe it's a lot more complicated than this.
from transformers.
Thank you for your responses @byi8220 @muellerzr.
Yes, I agree with you that if we properly manage the states of dataloaders in the Trainer, we no longer need to use the accelerate skip_first_batches
option for recovery.
As a workaround, I bypass accelerate to prepare my dataloaders by hacking the Trainer class to support stateful ones:
class Trainer(transformers.Trainer):
def get_train_dataloader(self) -> DPAwareDataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
logger.info(f"Split the dataset for the node at rank {self.args.process_index} / {self.args.world_size}.")
train_dataset = HuggingFaceDataset(self.train_dataset,
self.tokenizer,
self.args.context_length,
self.args.process_index,
self.args.world_size)
loader = DPAwareDataLoader(rank=self.args.process_index,
dataset=train_dataset,
batch_size=self.args.train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
persistent_workers=self.args.dataloader_persistent_workers)
data_callback = DataCallback(loader)
self.add_callback(data_callback)
return loader
The DPAwareDataLoader is borrowed from torchtitan
's impls. This pkg is also developing similar ideas. Then making use of self-defined callbacks to save/load states
class DataCallback(TrainerCallback, ExportableState):
def __init__(self, loader):
self.loader = loader
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
output_dir = None
if isinstance(args.resume_from_checkpoint, bool):
if args.resume_from_checkpoint:
output_dir = get_last_checkpoint(args.output_dir)
elif args.resume_from_checkpoint is not None:
output_dir = args.resume_from_checkpoint
if output_dir is not None:
if args.world_size <= 1:
data_state_pth = os.path.join(output_dir, "data_state.json")
else:
data_state_pth = os.path.join(output_dir, f"data_state_{args.process_index}.json")
with open(data_state_pth, "r") as f:
self.loader.load_state_dict(json.loads(f.read()))
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
if args.world_size <= 1:
data_state_pth = os.path.join(output_dir, "data_state.json")
else:
data_state_pth = os.path.join(output_dir, f"data_state_{args.process_index}.json")
with open(data_state_pth, "w") as f:
f.write(json.dumps(self.state(), indent=2, sort_keys=True) + "\n")
def state(self) -> dict:
return self.loader.state_dict()
skip_first_batches
is ignored by --ignore_data_skip
.
I performed some minimal unit tests, and the states were successfully recovered without perceiving any delays.
This approach can be extremely useful when performing online tokenization with IterableDataset.
Some people have conducted benchmarks and observed even faster speeds than pre-tokenization in https://github.com/XinDongol/on-the-fly-tokenization-profiling.
I've tried using stateful loaders with the above ugly hacking code to train the mamba model on the subsets of 627B Slimpajama data, reducing the total training time from ~173h to ~170h.
This could also save ~3TB of disk space compared to pre-tokenized map-style data.
So I'm really looking forward to your official impls, very happy to hear about any progress :D
from transformers.
Hey, just giving my 2 cents since unless I'm missing something, this seems extremely simple to implement since StatefulDataLoader
is a drop-in replacement for DataLoader. (I.e. literally just replace DataLoader construction w/ StatefulDataLoader construction in trainer.py
)
If it's simple enough I could probably take a shot at implementing it?
The only caveat is it seems torchdata.stateful_dataloader
is a beta feature only available in the nightly
build. Does it make sense to officially support unreleased features?
from transformers.
@byi8220 Hi, as I can see, the hf Trainer uses the accelerate library internally to prepare the dataloader. This process involves returning self-defined classes like DataLoaderShard
to handle cases involving distributed data dispatch. I think it might be challenging to directly combine the Trainer with StatefulDataLoader without delving into the intricate details of the Trainer implementation.
from transformers.
Correct, we need to:
- Support the
StatefulDataLoader
inaccelerate
and use it as an optional alternative in theDataLoaderConfiguration
- Then we can move it to the
Trainer
!
from transformers.
Makes sense. It also seems like there's a related issue raised in accelerate
: huggingface/accelerate#2859
Regarding using it in the trainer
, it feels a bit awkward. IIUC, the desired behavior is that if a StatefulDataLoader is being used, and loading from a checkpoint, then it should not call skip_first_batches
at all, unless you are passing in the state dict and checkpoints to that function as well. But imo it feels like skip_first_batches
and "restore from checkpoint" are two separate concepts.
from transformers.
Related Issues (20)
- There is no problem in the development environment, but the build deployment reports an error no available backend found. ERR: [webgpu] HOT 2
- [whisper] transcription is different from hf & openai HOT 2
- Running on Multiple GPU with DeepSpeed. Error: Model was not initialized with Zero-3 despite being configured for Deepspeed Zero-3. Please re-initialize your model via Model.from_pretrained or Model.from_config after creating your TrainingArguments! HOT 5
- StoppingCriteria for Repetition HOT 3
- Request to remove hard dependencies that require CUDA and OpenAI python packages to be downloaded. HOT 4
- Trainer.model.push_to_hub() does not allow a private repository flag HOT 2
- Error in _prepare_generated_length HOT 7
- More precise `inputs_embeds` input logic and tests
- Unrecognized keys in `rope_scaling` for 'rope_type'='dynamic': {'type'} HOT 2
- AttributeError: 'T5Stack' object has no attribute 'layer' when using t5_model.prune_heads(heads_to_prune)
- Documentation of `SinkCache` has bug in example code HOT 2
- torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(Logger) HOT 11
- Local repos not correctly registering for pipelines
- Auto model & pipeline for image-text-to-image-text models HOT 1
- Cannot use local files for AutoModelForVision2Seq when using BLIP3 HOT 4
- XLMRobertaTokenizer false description for build_inputs_with_special_tokens function
- Add method to get loaded adapter names.
- whisper & flash_attention_2 & reduce overhead results in error HOT 2
- Bug in WhisperTokenizer batch_decode, when set `skip_special_tokens=True` for FlaxWhisper model output
- Some casualLM models don't get position_ids in their forward pass. HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.