Giter Site home page Giter Site logo

bert-jam_d4x2's Introduction

Introduction

This repository contains the code for BERT-JAM, which is adapted from the bertnmt repository.

Requirements and Installation

  • PyTorch version: 1.2
  • Python version: 3.7
  • Versions of other packages are shown in the version.txt file

Installing from source

To install fairseq from source and develop locally:

cd bertnmt
pip install --editable .

Getting Started

Data Preparation

First, download the bert model files and put them under the ./pretrained directory. The folder structure should look like this:

bertnmt
|---bert
|---data-bin
|---docs
|---examples
|---fairseq
|---fairseq-cli
|---my
|---pretrained
|   |---bert-base-german-uncased
|   |   |---config.json
|   |   |---pytorch_model.bin
|   |   |---vocab.txt
|---save
|---scripts
|---test

The scripts for pre-precessing the data are under the ./examples/translation/script/ directory. For example, run the following code to pre-process the iwslt'14 De_En data.

cd ./examples/translation/
bash script/prepare-iwslt14.de2en.sh
cd iwslt14.tokenized.de-en
bash ../script/makedataforbert.sh de

Then preprocess data as in Fairseq:

src=de
tgt=en
TEXT=examples/translation/iwslt14.tokenized.de-en
python preprocess.py --source-lang $src --target-lang $tgt \
  --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
  --destdir $DATADIR/iwslt14_de_en/  --joined-dictionary \
  --bert-model-name pretrained/bert-base-german-uncased

Training

The model is trained following the three-phase optimization strategy. Use the fairseq scripts the train the model. The following scripts show how to train the model for the iwslt14 De-En dataset. For the first phase:

BERT=bert-base-german-uncased
src=de
tgt=en
model=bt_glu_joint
ARCH=${model}_iwslt_de_en
DATAPATH=data-bin/iwslt14.tokenized.$src-$tgt
SAVE=save/${model}.iwslt14.$src-$tgt.$BERT.
mkdir -p $SAVE
python train.py $DATAPATH \
-a $ARCH --optimizer adam --lr 0.0005 -s $src -t $tgt --label-smoothing 0.1 \
--dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --warmup-updates 4000 --warmup-init-lr '1e-07' --keep-last-epochs 10 \
--adam-betas '(0.9,0.98)' --save-dir $SAVE --share-all-embeddings   \
--encoder-bert-dropout --encoder-bert-dropout-ratio 0.5 \
--bert-model-name pretrained/$BERT \
--user-dir my --no-progress-bar --max-epoch 40 --fp16 \
--ddp-backend=no_c10d \
| tee -a $SAVE/training.log

For the second phase:

cp $SAVE/checkpoint_last.pt $SAVE/checkpoint_nmt.pt
python train.py $DATAPATH \
-a $ARCH --optimizer adam --lr 0.0005 -s $src -t $tgt --label-smoothing 0.1 \
--dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --warmup-updates 4000 --warmup-init-lr '1e-07' --keep-last-epochs 10 \
--adam-betas '(0.9,0.98)' --save-dir $SAVE --share-all-embeddings   \
--encoder-bert-dropout --encoder-bert-dropout-ratio 0.5 \
--bert-model-name pretrained/$BERT \
--user-dir my --no-progress-bar --max-epoch 50 --fp16 \
--ddp-backend=no_c10d \
--adjust-layer-weights \
--warmup-from-nmt \
| tee -a $SAVE/adjust.log

For the third phase:

cp $SAVE/checkpoint_last.pt $SAVE/checkpoint_nmt.pt
python train.py $DATAPATH \
-a $ARCH --optimizer adam --lr 0.0005 -s $src -t $tgt --label-smoothing 0.1 \
--dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --warmup-updates 4000 --warmup-init-lr '1e-07' --keep-last-epochs 10 \
--adam-betas '(0.9,0.98)' --save-dir $SAVE --share-all-embeddings   \
--encoder-bert-dropout --encoder-bert-dropout-ratio 0.5 \
--bert-model-name pretrained/$BERT \
--user-dir my --no-progress-bar --max-epoch 60 --fp16 \
--ddp-backend=no_c10d \
--adjust-layer-weights \
--finetune-bert \
--warmup-from-nmt \
| tee -a $SAVE/finetune.log

Generation

We generate on the test data split using the fairseq script. Different scripts are used to evaluate using different metrics.

For the tasks that uses the multi-bleu script:

python scripts/average_checkpoints.py --inputs $SAVE \
    --num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt"

CUDA_VISIBLE_DEVICES=0 fairseq-generate $DATAPATH \
    --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 64 --beam 5 --remove-bpe \
    --lenpen 1 --gen-subset test --quiet --user-dir my  \
    --bert-model-name pretrained/$BERT

For the tasks that additionally perform compound split:

python scripts/average_checkpoints.py --inputs $SAVE \
    --num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt"

CUDA_VISIBLE_DEVICES=0 fairseq-generate $DATAPATH \
    --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 64 --beam 4 --remove-bpe \
    --lenpen 0.6 --gen-subset test --user-dir my  \
    --bert-model-name pretrained/$BERT > ${SAVE}/gen.txt

source scripts/compound_split_bleu.sh ${SAVE}/gen.txt

For the tasks that report sacreBLEU scores:

python scripts/average_checkpoints.py --inputs $SAVE \
    --num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt"

CUDA_VISIBLE_DEVICES=0 fairseq-generate $DATAPATH \
    --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 64 --beam 5 --remove-bpe \
    --lenpen 1 --gen-subset test --user-dir my  \
    --bert-model-name pretrained/$BERT > ${SAVE}/gen.txt

source scripts/calc_sacrebleu.sh $src $tgt $SAVE/gen.txt

Trained Models

Model Files
IWSLT'14 De-En iwslt14_de_en.tar.gz (Extration Code: a5yh)
WMT'14 En-De wmt14_en_de.tar.gz.00 (Extration Code: pegt)
wmt14_en_de.tar.gz.01 (Extration Code: o49a)

bert-jam_d4x2's People

Contributors

hollowfire avatar trellixvulnteam avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.