Giter Site home page Giter Site logo

shaochenze / cokd Goto Github PK

View Code? Open in Web Editor NEW

This project forked from ictnlp/cokd

0.0 0.0 0.0 16.38 MB

Code for ACL 2022 main conference paper "Overcoming Catastrophic Forgetting beyond Continual Learning: Balanced Training for Neural Machine Translation".

License: Other

Python 4.12% C++ 0.01% Perl 0.02% Shell 0.01% Lua 0.02% JavaScript 95.83%

cokd's Introduction

Overcoming Catastrophic Forgetting beyond Continual Learning: Balanced Training for Neural Machine Translation

This repository contains the source code for our ACL 2022 paper Overcoming Catastrophic Forgetting beyond Continual Learning: Balanced Training for Neural Machine Translation pdf. Our method is implemented based on the open-source toolkit fairseq. We mainly modified train.py and cokd_loss.py.

Requirements

This system has been tested in the following environment.

  • Python version = 3.8
  • Pytorch version = 1.7

Replicate the TED results

Pre-processing

We use the tokenized TED dataset released by VOLT, which can be downloaded from here and pre-processed into subword units by prepare-ted-bilingual.sh.

We provide the pre-processed TED En-Es dataset in this repository. First, process the data into the fairseq format.

TEXT=./data
python preprocess.py --source-lang en --target-lang es \
        --trainpref $TEXT/es-en.train \
        --validpref $TEXT/es-en.valid \
        --testpref $TEXT/es-en.test \
        --destdir data-bin/tedbpe10kenes \
        --nwordssrc 10240 --joined-dictionary  --workers 16

Training

To train the Transformer baseline, run the following command.

data_dir=data-bin/tedbpe10kenes
save_dir=output/enes_base

python train.py $data_dir \
    --fp16 --dropout 0.3  --save-dir $save_dir \
    --arch transformer_wmt_en_de --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
    --lr 0.0007 --min-lr 1e-09 \
    --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --max-tokens 4096 --update-freq 1\
    --no-progress-bar --log-format json --log-interval 100 --save-interval-updates 1000 \
    --max-update 18000 --keep-interval-updates 10 --no-epoch-checkpoints
    
python scripts/average_checkpoints.py --inputs $save_dir \
 --num-update-checkpoints 5  --output $save_dir/average-model.pt

To train the COKD model, run the following command.

data_dir=data-bin/tedbpe10kenes
save_dir=output/enes_cokd

python train.py $data_dir \
    --fp16 --dropout 0.2 --kd-alpha 0.95 --num-teachers 1 --save-dir $save_dir \
    --arch transformer_wmt_en_de --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
    --lr 0.0007 --min-lr 1e-09 \
    --weight-decay 0.0 --criterion cokd_loss --label-smoothing 0.1 --max-tokens 4096 --update-freq 1\
    --no-progress-bar --log-format json --log-interval 100 --save-interval-updates 1000 \
    --max-update 18000 --keep-interval-updates 10 --no-epoch-checkpoints
    
python scripts/average_checkpoints.py --inputs $save_dir \
 --num-update-checkpoints 5  --output $save_dir/average-model.pt

The above commands assume 8 GPUs on the machine. When the number of GPUs is different, adapt --update-freq to make sure that the batch size is 32K.

Inference

Run the following command for inference.

python generate.py data-bin/tedbpe10kenes  --path output/enes_cokd/average-model.pt --gen-subset test --beam 5 --batch-size 100 --remove-bpe --lenpen 1 > out
# because fairseq's output is unordered, we need to recover its order
grep ^H out | cut -f1,3- | cut -c3- | sort -k1n | cut -f2- > pred.es
sed -r 's/(@@ )|(@@ ?$)//g' data/es-en.test.es > ref.es
perl multi-bleu.perl ref.es < pred.es

The expected BLEU scores are 40.86 for the Transformer baseline and 42.50 for the COKD model.

Citation

If you find the resources in this repository useful, please cite as:

@inproceedings{cokd,
  title = {Overcoming Catastrophic Forgetting beyond Continual Learning: Balanced Training for Neural Machine Translation},
  author= {Chenze Shao and
               Yang Feng},
  booktitle = {Proceedings of ACL 2022},
  year = {2022},
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.