Giter Site home page Giter Site logo

Comments (13)

ArthurZucker avatar ArthurZucker commented on May 22, 2024 1

cc @muellerzr and @younesbelkada for sft + trainer save

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024

Hey, the error says:

ValueError: [UserWarning('do_sample is set to False. However, temperature is set to 0.9 -- this flag is only used in sample-based generation modes. You should set do_sample=True or unset temperature.'), UserWarning('do_sample is set to False. However, top_p is set to 0.6 -- this flag is only used in sample-based generation modes. You should set do_sample=True or unset top_p.')]**

Are you sure you are running the example unmodified?

from transformers.

yananchen1989 avatar yananchen1989 commented on May 22, 2024

yes. unmodified. and i also run it twice and the second time it works. and i run the third time , then this error comes again.

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024

Mmm that's a bit weird, it runs fine twice so maybe a cache thing ?

from transformers.

yaysummeriscoming avatar yaysummeriscoming commented on May 22, 2024

I'm facing a similar issue

from transformers.

saxenarohit avatar saxenarohit commented on May 22, 2024

In my experience, it's a default warning for generation tasks when using arguments do_sample=False and also setting arguments such as temp/top_p.
The warning is given as temp/top_p arguments will be ignored(only used if do_sample=True) as sampling is set to False.

This itself should not throw an error.

from transformers.

yananchen1989 avatar yananchen1989 commented on May 22, 2024

here is my full sft training code:

import logging
import os,sys
from contextlib import nullcontext
import argparse

TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)

from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser

if TRL_USE_RICH:
    init_zero_verbose()
    FORMAT = "%(message)s"

    from rich.console import Console
    from rich.logging import RichHandler

import torch
from datasets import load_dataset
import datasets

from tqdm.rich import tqdm
from transformers import AutoTokenizer, TrainingArguments

from trl import (
    ModelConfig,
    RichProgressCallback,
    SFTTrainer,
    get_peft_config,
    get_quantization_config,
    get_kbit_device_map,
)
sys.path.append("/home/chenyanan/trl")
from prompts import *
from datasets import Dataset
tqdm.pandas()

if TRL_USE_RICH:
    logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)


if __name__ == "__main__":

    parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
    args, training_args, model_config = parser.parse_args_and_config()
    # https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments
    print(args)
    print('-' * 50)
    print(training_args)
    print('-' * 50)
    print(model_config)
    print('-' * 50)


    # print('cuda===>', os.environ['CUDA_VISIBLE_DEVICES'])

    if model_config.use_peft:
        use_peft = 'peft'
    else:
        use_peft = 'nopeft'

    output_dir = "/home/chenyanan/trl/" + model_config.model_name_or_path.split('/')[-1]+ "_sft_{}".format(training_args.output_dir)
    print('output_dir===>', output_dir)

    # Force use our print callback
    if TRL_USE_RICH:
        training_args.disable_tqdm = True
        console = Console()

    ################
    # Model & Tokenizer
    ################
    torch_dtype = (
        model_config.torch_dtype
        if model_config.torch_dtype in ["auto", None]
        else getattr(torch, model_config.torch_dtype)
    )
    print("torch_dtype===>", torch_dtype)
    if model_config.use_peft:
        quantization_config = None 
    else:
        quantization_config = get_quantization_config(model_config)
    print("quantization_config===>", quantization_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        trust_remote_code=model_config.trust_remote_code,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token

    ################
    # Dataset
    ################
    ds_dic = {'prompt':[], 'completion':[]}

    # if args_.ds == 'alpaca':
    # ds_raw  = load_dataset('tatsu-lab/alpaca')['train']
    # ds_raw_shuffle = ds_raw.shuffle(seed=777)
    # ds100 = ds_raw_shuffle.select(list(range(args_.samplecnt)))
    # for ii in ds100:
    #     ds_dic['prompt'].append(ii['instruction'] + ' ' + ii['input'])
    #     ds_dic['completion'].append(ii['output'])

    # train_dataset = load_dataset("json", data_files="/home/chenyanan/trl/train_ft.jsonl", split="train")
    # train_dataset = datasets.load_from_disk("/vol/chenyanan/trl/TravelPlanner_datasets/TravelPlanner_train/")['train']

    # elif args_.ds == 'tp':
    # ds_raw  = load_dataset('osunlp/TravelPlanner', 'train')['train']
    ds_raw = datasets.load_from_disk("/home/chenyanan/TravelPlanner_datasets/TravelPlanner_train/")['train']

    for ii in ds_raw:
        prompt = planner_agent_prompt.format(text=ii['reference_information'], query=ii['query'])
        if len(tokenizer.encode(prompt)) > 10000:
            continue
        ds_dic['prompt'].append(prompt)
        ds_dic['completion'].append(parse_from_json_plan(eval(ii['annotated_plan'])[-1]))

    ds = Dataset.from_dict(ds_dic)
    print("ds rows:", ds.num_rows)

    if training_args.do_eval:
        print("do eval")
        ds_ = ds.train_test_split(test_size=0.2)
        train_ds = ds_['train']
        eval_ds = ds_["test"]
    else:
        print("do not eval")
        train_ds = ds 
        eval_ds = None
    # print('dataset_text_field===>', args.dataset_text_field)
    ################
    # Optional rich context managers
    ###############
    init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
    save_context = (
        nullcontext()
        if not TRL_USE_RICH
        else console.status(f"[bold green]Training completed! Saving the model to {output_dir}")
    )

    ################
    # Training
    ################
    with init_context:
        trainer = SFTTrainer(
            model=model_config.model_name_or_path,
            model_init_kwargs=model_kwargs,
            args=training_args,
            train_dataset= train_ds,
            eval_dataset= eval_ds,
            #dataset_text_field=args.dataset_text_field,
            max_seq_length=args.max_seq_length,
            tokenizer=tokenizer,
            packing=args.packing,
            peft_config=get_peft_config(model_config),
            callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
        )

    trainer.train()

    with save_context:
        trainer.save_model(output_dir)

launch script:


accelerate launch \
    --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/sft_tp.py \
    --model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2"   \
    --report_to="wandb" \
    --learning_rate=4e-5 \
    --per_device_train_batch_size=1 \
    --gradient_accumulation_steps=4 \
    --output_dir="tp_deepspeed_epoch10" \
    --logging_steps=1 \
    --num_train_epochs=5 \
    --save_strategy "no" \
    --lr_scheduler_type "constant" \
    --max_steps=-1 \
    --gradient_checkpointing \
    --logging_strategy  "epoch" \
    --bf16 True \
    --packing False \
    --do_eval False \
    --overwrite_output_dir True \
    --evaluation_strategy 'no' \
    --attn_implementation 'flash_attention_2' \
    --max_seq_length 10000 

from transformers.

yananchen1989 avatar yananchen1989 commented on May 22, 2024

deepspeed_zero3.yaml config:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  gradient_accumulation_steps: 4
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: true
main_process_port: 29533
gpu_ids: 0,1,2,3

from transformers.

yananchen1989 avatar yananchen1989 commented on May 22, 2024

the most weird thing is that, after changing **num_train_epochs** from 50 to 5(or 2), this error disappears completely.
Everything else are left unchanged.

from transformers.

yananchen1989 avatar yananchen1989 commented on May 22, 2024

the error is triggered by file transformers/generation/configuration_utils.py line 663

from transformers.

younesbelkada avatar younesbelkada commented on May 22, 2024

Hi @yananchen1989
Hmm this is very strange and hard to repro, I checked the generation config of mistral model and neither top_p or temparature are set. SFTTrainer which inherits from trainer do not change the generation config during training at all, to be on the safe zone, can you try to set a dummy generation config to the model before calling trainer.save ? Something like:

trainer.model.generation_config = GenerationConfig(temperature=None, top_p=None)

from transformers.

yananchen1989 avatar yananchen1989 commented on May 22, 2024

sure. thanks @younesbelkada

  with save_context:
      trainer.model.generation_config = transformers.GenerationConfig(temperature=None, top_p=None)
      trainer.save_model(output_dir)

testing this block.

from transformers.

yananchen1989 avatar yananchen1989 commented on May 22, 2024

seems that it works. thanks.

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.