Comments (13)
cc @muellerzr and @younesbelkada for sft + trainer save
from transformers.
Hey, the error says:
ValueError: [UserWarning('do_sample is set to False. However, temperature is set to 0.9 -- this flag is only used in sample-based generation modes. You should set do_sample=True or unset temperature.'), UserWarning('do_sample is set to False. However, top_p is set to 0.6 -- this flag is only used in sample-based generation modes. You should set do_sample=True or unset top_p.')]**
Are you sure you are running the example unmodified?
from transformers.
yes. unmodified. and i also run it twice and the second time it works. and i run the third time , then this error comes again.
from transformers.
Mmm that's a bit weird, it runs fine twice so maybe a cache thing ?
from transformers.
I'm facing a similar issue
from transformers.
In my experience, it's a default warning for generation tasks when using arguments do_sample=False and also setting arguments such as temp/top_p.
The warning is given as temp/top_p arguments will be ignored(only used if do_sample=True) as sampling is set to False.
This itself should not throw an error.
from transformers.
here is my full sft training code:
import logging
import os,sys
from contextlib import nullcontext
import argparse
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from datasets import load_dataset
import datasets
from tqdm.rich import tqdm
from transformers import AutoTokenizer, TrainingArguments
from trl import (
ModelConfig,
RichProgressCallback,
SFTTrainer,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
sys.path.append("/home/chenyanan/trl")
from prompts import *
from datasets import Dataset
tqdm.pandas()
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
if __name__ == "__main__":
parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments
print(args)
print('-' * 50)
print(training_args)
print('-' * 50)
print(model_config)
print('-' * 50)
# print('cuda===>', os.environ['CUDA_VISIBLE_DEVICES'])
if model_config.use_peft:
use_peft = 'peft'
else:
use_peft = 'nopeft'
output_dir = "/home/chenyanan/trl/" + model_config.model_name_or_path.split('/')[-1]+ "_sft_{}".format(training_args.output_dir)
print('output_dir===>', output_dir)
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
print("torch_dtype===>", torch_dtype)
if model_config.use_peft:
quantization_config = None
else:
quantization_config = get_quantization_config(model_config)
print("quantization_config===>", quantization_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
ds_dic = {'prompt':[], 'completion':[]}
# if args_.ds == 'alpaca':
# ds_raw = load_dataset('tatsu-lab/alpaca')['train']
# ds_raw_shuffle = ds_raw.shuffle(seed=777)
# ds100 = ds_raw_shuffle.select(list(range(args_.samplecnt)))
# for ii in ds100:
# ds_dic['prompt'].append(ii['instruction'] + ' ' + ii['input'])
# ds_dic['completion'].append(ii['output'])
# train_dataset = load_dataset("json", data_files="/home/chenyanan/trl/train_ft.jsonl", split="train")
# train_dataset = datasets.load_from_disk("/vol/chenyanan/trl/TravelPlanner_datasets/TravelPlanner_train/")['train']
# elif args_.ds == 'tp':
# ds_raw = load_dataset('osunlp/TravelPlanner', 'train')['train']
ds_raw = datasets.load_from_disk("/home/chenyanan/TravelPlanner_datasets/TravelPlanner_train/")['train']
for ii in ds_raw:
prompt = planner_agent_prompt.format(text=ii['reference_information'], query=ii['query'])
if len(tokenizer.encode(prompt)) > 10000:
continue
ds_dic['prompt'].append(prompt)
ds_dic['completion'].append(parse_from_json_plan(eval(ii['annotated_plan'])[-1]))
ds = Dataset.from_dict(ds_dic)
print("ds rows:", ds.num_rows)
if training_args.do_eval:
print("do eval")
ds_ = ds.train_test_split(test_size=0.2)
train_ds = ds_['train']
eval_ds = ds_["test"]
else:
print("do not eval")
train_ds = ds
eval_ds = None
# print('dataset_text_field===>', args.dataset_text_field)
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {output_dir}")
)
################
# Training
################
with init_context:
trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset= train_ds,
eval_dataset= eval_ds,
#dataset_text_field=args.dataset_text_field,
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
packing=args.packing,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
trainer.train()
with save_context:
trainer.save_model(output_dir)
launch script:
accelerate launch \
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/sft_tp.py \
--model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2" \
--report_to="wandb" \
--learning_rate=4e-5 \
--per_device_train_batch_size=1 \
--gradient_accumulation_steps=4 \
--output_dir="tp_deepspeed_epoch10" \
--logging_steps=1 \
--num_train_epochs=5 \
--save_strategy "no" \
--lr_scheduler_type "constant" \
--max_steps=-1 \
--gradient_checkpointing \
--logging_strategy "epoch" \
--bf16 True \
--packing False \
--do_eval False \
--overwrite_output_dir True \
--evaluation_strategy 'no' \
--attn_implementation 'flash_attention_2' \
--max_seq_length 10000
from transformers.
deepspeed_zero3.yaml config:
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 4
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: true
main_process_port: 29533
gpu_ids: 0,1,2,3
from transformers.
the most weird thing is that, after changing **num_train_epochs**
from 50 to 5(or 2), this error disappears completely.
Everything else are left unchanged.
from transformers.
the error is triggered by file transformers/generation/configuration_utils.py line 663
from transformers.
Hi @yananchen1989
Hmm this is very strange and hard to repro, I checked the generation config of mistral model and neither top_p or temparature are set. SFTTrainer which inherits from trainer do not change the generation config during training at all, to be on the safe zone, can you try to set a dummy generation config to the model before calling trainer.save ? Something like:
trainer.model.generation_config = GenerationConfig(temperature=None, top_p=None)
from transformers.
sure. thanks @younesbelkada
with save_context:
trainer.model.generation_config = transformers.GenerationConfig(temperature=None, top_p=None)
trainer.save_model(output_dir)
testing this block.
from transformers.
seems that it works. thanks.
from transformers.
Related Issues (20)
- Mixtral past_key_values and output_router_logits incompatible HOT 1
- Disable Progress Bar? HOT 1
- Meet problems when I use the file src/transformers/models/llama/convert_llama_weights_to_hf.py to transfer LlaMa-7B HOT 2
- [DOCS] - Model outputs of RecurrentGemmaCausalLM doesn't align with the documentation HOT 1
- [Batched Whisper] ValueError on input mel features HOT 3
- use_reentrant=False can't be set properly HOT 6
- Bug: InformerModel, decoder_input torch.cat size of tensor mismatch error otherwise HOT 7
- BitsNBytes 4 bit quantization error message typo and logical errors in error message handling HOT 3
- train_new_from_iterator does not properly modify the tokenizer's postprocessor's ids when using a Sequence postprocessor
- recent version of Transformers seems to mess with forward/__call__. Breaks patching loss function HOT 3
- TypeError: 'list' object is not callable || Resume from checkpoint HOT 3
- Failed to import transformers.models.vit.feature_extraction_vit because of the following error (look up to see its traceback): No module named 'ml_dtypes._custom_floats' HOT 1
- TokenClassificationPipeline support is_split_into_words tokeniser parameter HOT 2
- Implement kv cache sparsity like H2O with attention score HOT 1
- BART generate with min_new_tokens exceeds maximum length HOT 4
- Convert Helsinki-NLP model to huggingface
- Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained HOT 3
- Grounding DINO missing custom kernels HOT 2
- For multiple GPUs: torch.cuda.empty_cache() stuck forever
- Issues occuring during parallel evaluation (using Trainer.evaluate)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.