Giter Site home page Giter Site logo

Comments (6)

XiangLi1999 avatar XiangLi1999 commented on September 15, 2024 1

I think 60%-70% makes sense!

Great question: the speed gains in prefix-tuning happens because you don't have to update as many parameters that's stored in the optimizer (aka fewer trainable parameters), but backprop is still required all the way to the bottom Transformer layer. One thought experiment that could explain this is as follows: imagine when you only train the last one layer of a Transformer model, then both number of trainable parameter and the required number of backprop layer reduced (you only need to backprop one layer, since you are not interested in the gradients of first couple layers). However, if you only train the first layer of the Transformer, then you need backprop all the way, despite the same number of trainable parameters.

Based on the first layer v.s. last layer analogy, let's go back to prefix-tuning. We tune all activation layers, and therefore we need to backprop all the way back to the first layer, so backprop time is not reduced. The only reduced computation is that we don't need to do as much updates.

Let me know if this makes sense.

from prefixtuning.

Timothyxxx avatar Timothyxxx commented on September 15, 2024

Great thank for your analysis! I assume for the same reasons too233. thx again!

from prefixtuning.

lrongzheni avatar lrongzheni commented on September 15, 2024

What's your GPU hardware environment, a piece of gpu can train? thx~@tianbao Xie

from prefixtuning.

Timothyxxx avatar Timothyxxx commented on September 15, 2024

Of course, it depends on the model, I think 11GB memory is enough for e2e dataset in GPT2.

from prefixtuning.

lrongzheni avatar lrongzheni commented on September 15, 2024

When trying to train in GPT2, the bellow problem trouble me. Can you help me to fix it?thx~

python train_e2e.py --optim_prefix yes --preseqlen 5 --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101
webnlg_models/webnlgprefixtune_y_5_act_cat_b=5-e=5_d=0.0_u=no_lr=5e-05_w=0.0_s=101_r=n_m=512_o=1_o=1
python run_language_modeling.py --output_dir=webnlg_models/webnlgprefixtune_y_5_act_cat_b=5-e=5_d=0.0_u=no_lr=5e-05_w=0.0_s=101_r=n_m=512_o=1_o=1 --model_type=gpt2 --model_name_or_path=gpt2-medium --tokenizer_name=gpt2-medium --per_device_train_batch_size 5 --per_device_eval_batch_size 5 --save_steps 500000 --num_train_epochs 5 --do_train --train_data_file=/u/scr/xlisali/WebNLG/webnlg-dataset/webnlg_challenge_2017/train.json --do_eval --line_by_line --save_total_limit 1 --overwrite_output_dir --task_mode webnlg --eval_data_file=/u/scr/xlisali/WebNLG/webnlg-dataset/webnlg_challenge_2017/dev.json --tuning_mode prefixtune --logging_dir webnlg_models/runs/webnlgprefixtune_y_5_act_cat_b=5-e=5_d=0.0_u=no_lr=5e-05_w=0.0_s=101_r=n_m=512_o=1_o=1 --train_embs no --optim_prefix yes --preseqlen 5 --prefix_mode activation --format_mode cat --gradient_accumulation_steps 1 --learning_rate 5e-05 --weight_decay 0.0 --seed 101 --disable_tqdm --mid_dim 512 --init_random no --use_dropout no --prefix_dropout 0.0 --objective_mode 1 --evaluate_during_training --eval_steps 5000 --cache_dir /u/scr/xlisali/contrast_LM/transformers/examples/control/gpt2-medium-s3
/data/lirongzhen/PrefixTuning/transformers/src/transformers/init.py
Traceback (most recent call last):
File "/data/lirongzhen/PrefixTuning/gpt2/run_language_modeling.py", line 1159, in
main()
File "/data/lirongzhen/PrefixTuning/gpt2/run_language_modeling.py", line 498, in main
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
File "/data/lirongzhen/PrefixTuning/transformers/src/transformers/hf_argparser.py", line 40, in init
self._add_dataclass_arguments(dtype)
File "/data/lirongzhen/PrefixTuning/transformers/src/transformers/hf_argparser.py", line 72, in _add_dataclass_arguments
elif hasattr(field.type, "origin") and issubclass(field.type.origin, List):
File "/data/anaconda3/envs/PrefixTuning/lib/python3.9/typing.py", line 847, in subclasscheck
return issubclass(cls, self.origin)
TypeError: issubclass() arg 1 must be a class

from prefixtuning.

Timothyxxx avatar Timothyxxx commented on September 15, 2024

Sorry for forgetting to close this issue, thanks again!

from prefixtuning.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.