Giter Site home page Giter Site logo

ds2's Introduction

DS2 with prompt tuning and prefix tuning

Paper link: https://arxiv.org/abs/2203.01552

Original DS2 repo: https://github.com/jshin49/ds2 Original Prefix tuning repo: https://github.com/XiangLi1999/PrefixTuning

(add more explanation to clearly show how to run the code)

How to use the code

  1. Installing the directory as pip will resolve all path issues
pip install -e .
pip install -r requirements.txt # requires python 3.8

# transformer version 3.2.0 throws type error when loading the tokenizer and model.
pip install transformers==4.11.0
  1. Get MWOZ data (for 2.0 change all 2.1 to 2.0)

For 2.0

python scripts/create_data_mwoz.py --mwz_ver=2.0 --main_dir=data_mwoz_2.0 --target_path=data_mwoz_2.0/mwz

For 2.1

python scripts/create_data_mwoz.py --mwz_ver=2.1 --main_dir=data_mwoz_2.1 --target_path=data_mwoz_2.1/mwz
  1. Training and Inference - Cross-domain(CD)

We have uploaded the T5 pre-trained on Dialogue Summarization model on HuggingFace Model Hub at https://huggingface.co/jaynlp/t5-large-samsum. Now you can choose between BART and T5 as such:

  • model_name=bart and model_checkpoint=Salesforce/bart-large-xsum-samsum
  • model_name=t5 and model_checkpoint=jaynlp/t5-large-samsum
  • if you want Multi-GPU mode
    • add --GPU={the number of gpu} in command line
    • set accelerator=ddp in trainer(train_ds2.py)

2.1 Pre-training(target_domain = attraction)

CUDA_VISIBLE_DEVICES={gpu} python ds2/scripts/train_ds2.py \
    --dev_batch_size=8 \
    --test_batch_size=8 \
    --train_batch_size=2 \
    --n_epochs=100 \
    --num_beams=1 \
    --test_num_beams=1 \
    --val_check_interval=1.0 \
    --fewshot=0.01 \
    --grad_acc_steps=1 \
    --model_name=bart \
    --model_checkpoint=Salesforce/bart-large-xsum-samsum \
    --except_domain=attraction \
    --mode=finetune \
    --exp_name=bart-CD-1-Attr-pre \
    --seed=577 \
    --version=2.1

2.2 Fine-tune

CUDA_VISIBLE_DEVICES={gpu} python ds2/scripts/train_ds2.py \
    --dev_batch_size=8 \
    --test_batch_size=8 \
    --train_batch_size=2 \
    --n_epochs=100 \
    --num_beams=1 \
    --test_num_beams=1 \
    --val_check_interval=1.0 \
    --fewshot=0.01 \
    --grad_acc_steps=1 \
    --model_name=bart \
    --model_checkpoint=Salesforce/bart-large-xsum-samsum \
    --only_domain=attraction \
    --mode=finetune \
    --load_pretrained={bart-CD-1-Attr-pre/ckpt_path} \
    --exp_name=bart-CD-1-Attr \
    --seed=577 \
    --version=2.1
  1. Training and Inference - Multi-domain(MD)
CUDA_VISIBLE_DEVICES={gpu} python ds2/scripts/train_ds2.py \
    --dev_batch_size=8 \
    --test_batch_size=8 \
    --train_batch_size=2 \
    --n_epochs=100 \
    --num_beams=1 \
    --test_num_beams=1 \
    --val_check_interval=1.0 \
    --fewshot=0.01 \
    --grad_acc_steps=1 \
    --model_name=bart \
    --model_checkpoint=Salesforce/bart-large-xsum-samsum \
    --mode=finetune \
    --exp_name=bart-MD-1 \
    --seed=577 \
    --version=2.1
  1. Training and Inference - Cross-task(CT)
CUDA_VISIBLE_DEVICES={gpu} python ds2/scripts/train_ds2.py \
    --dev_batch_size=8 \
    --test_batch_size=8 \
    --train_batch_size=2 \
    --n_epochs=100 \
    --num_beams=1 \
    --test_num_beams=1  \
    --val_check_interval=1.0 \
    --fewshot=0.01 \
    --grad_acc_steps=1 \
    --model_name=bart \
    --model_checkpoint=Salesforce/bart-large-xsum-samsum \
    --mode=finetune  \
    --only_domain=attraction \
    --exp_name=bart-CT-Attr-1 \
    --seed=577 \
    --version=2.1

ds2's People

Contributors

jshin49 avatar cwyoon-99 avatar gykim-intellius avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.