Giter Site home page Giter Site logo

adaptive-knn-mt's Introduction

Adaptive kNN-MT

Code for our ACL 2021 paper "Adaptive Nearest Neighbor Machine Translation". Please cite our paper if you find this repository helpful in your research:

@inproceedings{Zheng2021AdaptiveNN,
  title={Adaptive Nearest Neighbor Machine Translation},
  author={Xin Zheng and Zhirui Zhang and Junliang Guo and Shujian Huang and Boxing Chen and Weihua Luo and Jiajun Chen},
  year={2021}
}

This project implements our Adaptive kNN-MT as well as Vanilla kNN-MT. The implementation is build upon fairseq, and heavily inspired by knn-lm, many thanks to the authors for making their code avaliable.

Requirements and Installation

  • pytorch version >= 1.5.0
  • python version >= 3.6
  • faiss-gpu >= 1.6.5
  • pytorch_scatter = 2.0.5
  • 1.19.0 <= numpy < 1.20.0

You can install this project by

pip install --editable ./

Instructions

We use an example to show how to use our codes.

Pre-trained Model and Data

The pre-trained translation model can be downloaded from this site. We use the De->En Single Model for all experiments.

The raw data can be downloaded in this site, and you should preprocess them with moses toolkits and the bpe-codes provided by pre-trained model. For convenience, We also provide pre-processed data.

Create Datastore

This script will create datastore (includes key.npy and val.npy) for the data.

DSTORE_SIZE=3613350
MODEL_PATH=/path/to/pretrained_model_path
DATA_PATH=/path/to/fairseq_preprocessed_data_path
DATASTORE_PATH=/path/to/save_datastore
PROJECT_PATH=/path/to/ada_knnmt

mkdir -p $DATASTORE_PATH

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/save_datastore.py $DATA_PATH \
    --dataset-impl mmap \
    --task translation \
    --valid-subset train \
    --path $MODEL_PATH \
    --max-tokens 4096 --skip-invalid-size-inputs-valid-test \
    --decoder-embed-dim 1024 --dstore-fp16 --dstore-size $DSTORE_SIZE --dstore-mmap $DATASTORE_PATH
 
# 4096 and 1024 depend on your device and model separately

The DSTORE_SIZE depends on the num of tokens of target language train data. You can get it by two ways:

  • find it in preprocess.log file, which is created by fairseq-process and in data binary folder.
  • calculate wc -l + wc -w of raw data file.

The datastore sizes we used in our paper are listed as below:

IT Medical koran Law
3613350 6903320 524400 19070000

Build Faiss Index

This script will build faiss index for keys, which is used for fast knn search. when the knn_index is build, you can remove keys.npy to save the hard disk space.

PROJECT_PATH=/path/to/ada_knnmt
DSTORE_PATH=/path/to/saved_datastore
DSTORE_SIZE=3613350

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/train_datastore_gpu.py \
  --dstore_mmap $DSTORE_PATH \
  --dstore_size $DSTORE_SIZE \
  --dstore_fp16 \
  --faiss_index ${DSTORE_PATH}/knn_index \
  --ncentroids 4096 \
  --probe 32 \
  --dimension 1024

Train Adaptive kNN-MT Model


DSTORE_SIZE=3613350
DATA_PATH=/path/to/fairseq_preprocessed_data_path
PROJECT_PATH=/path/to/ada_knnmt
MODEL_PATH=/path/to/pretrained_model_path
DATASTORE_PATH=/path/to/saved_datastore

max_k_grid=(4 8 16 32)
batch_size_grid=(32 32 32 32)
update_freq_grid=(1 1 1 1)
valid_batch_size_grid=(32 32 32 32)

for idx in ${!max_k_grid[*]}
do

  MODEL_RECORD_PATH=/path/to/save/model/train-hid32-maxk${max_k_grid[$idx]}
  TRAINING_RECORD_PATH=/path/to/save/tensorboard/train-hid32-maxk${max_k_grid[$idx]}
  mkdir -p "$TRAINING_RECORD_PATH"

  CUDA_VISIBLE_DEVICES=0 python \
  $PROJECT_PATH/fairseq_cli/train.py \
  $DATA_PATH \
  --log-interval 100 --log-format simple \
  --arch transformer_wmt19_de_en_with_datastore \
  --tensorboard-logdir "$TRAINING_RECORD_PATH" \
  --save-dir "$MODEL_RECORD_PATH" --restore-file "$MODEL_PATH" \
  --reset-dataloader --reset-lr-scheduler --reset-meters --reset-optimizer \
  --validate-interval-updates 100 --save-interval-updates 100 --keep-interval-updates 1 --max-update 5000 --validate-after-updates 1000 \
  --save-interval 10000 --validate-interval 100 \
  --keep-best-checkpoints 1 --no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
  --train-subset valid --valid-subset valid --source-lang de --target-lang en \
  --criterion label_smoothed_cross_entropy --label-smoothing 0.001 \
  --max-source-positions 1024 --max-target-positions 1024 \
  --batch-size "${batch_size_grid[$idx]}" --update-freq "${update_freq_grid[$idx]}" --batch-size-valid "${valid_batch_size_grid[$idx]}" \
  --task translation \
  --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --min-lr 3e-05 --lr 0.0003 --clip-norm 1.0 \
  --lr-scheduler reduce_lr_on_plateau --lr-patience 5 --lr-shrink 0.5 \
  --patience 30 --max-epoch 500 \
  --load-knn-datastore --dstore-filename $DATASTORE_PATH --use-knn-datastore \
  --dstore-fp16 --dstore-size $DSTORE_SIZE --probe 32 \
  --knn-sim-func do_not_recomp_l2 \
  --use-gpu-to-search --move-dstore-to-mem --no-load-keys \
  --knn-lambda-type trainable --knn-temperature-type fix --knn-temperature-value 10 --only-train-knn-parameter \
  --knn-k-type trainable --k-lambda-net-hid-size 32 --k-lambda-net-dropout-rate 0.0 --max-k "${max_k_grid[$idx]}" --k "${max_k_grid[$idx]}" \
  --label-count-as-feature
done

The batch size and update-freq should be adjust by yourself depends on your gpu.

Inference with Adaptive kNN-MT

DSTORE_SIZE=3613350
MODEL_PATH=/path/to/trained_model

DATASTORE_PATH=/path/to/datastore
DATA_PATH=/path/to/data
PROJECT_PATH=/path/to/ada_knnmt

OUTPUT_PATH=/path/to/save_output_result

mkdir -p "$OUTPUT_PATH"

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/experimental_generate.py $DATA_PATH \
    --gen-subset test\
    --path "$MODEL_PATH" --arch transformer_wmt19_de_en_with_datastore \
    --beam 4 --lenpen 0.6 --max-len-a 1.2 --max-len-b 10 --source-lang de --target-lang en \
    --scoring sacrebleu \
    --batch-size 32 \
    --tokenizer moses --remove-bpe \
    --model-overrides "{'load_knn_datastore': True, 'use_knn_datastore': True,
    'dstore_filename': '$DATASTORE_PATH', 'dstore_size': $DSTORE_SIZE, 'dstore_fp16': True, 'probe': 32,
    'knn_sim_func': 'do_not_recomp_l2', 'use_gpu_to_search': True, 'move_dstore_to_mem': True, 'no_load_keys': True,
    'knn_temperature_type': 'fix', 'knn_temperature_value': 10,}" \
    | tee "$OUTPUT_PATH"/generate.txt

grep ^S "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/src
grep ^T "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/ref
grep ^H "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp
grep ^D "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp.detok

base NMT inference

We also provide scripts to do NMT and vanilla kNN-MT inference

MODEL_PATH=/path/to/pretrained_model_path/
DATA_PATH=/path/to/fairseq_preprocessed_path/
DATASTORE_PATH=/path/to/saved_datastore/
PROJECT_PATH=/path/to/knnmt/

mkdir -p $OUTPUT_PATH

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/fairseq_cli/generate.py $DATA_PATH\
    --gen-subset test \
    --path $MODEL_PATH \
    --beam 4 --lenpen 0.6 --max-len-a 1.2 --max-len-b 10 --source-lang de --target-lang en \
    --scoring sacrebleu \
    --max-tokens 4096 \
    --tokenizer moses --remove-bpe | tee $OUTPUT_PATH/generate.txt

grep ^S "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/src
grep ^T "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/ref
grep ^H "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp
grep ^D "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp.detok

Vanilla kNN-MT inference

DSTORE_SIZE=3613350
MODEL_PATH=/path/to/pre_trained_model

DATASTORE_PATH=/path/to/datastore
DATA_PATH=/path/to/data
PROJECT_PATH=/path/to/ada_knnmt

OUTPUT_PATH=/path/to/save_output_result

mkdir -p "$OUTPUT_PATH"

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/experimental_generate.py $DATA_PATH \
    --gen-subset test\
    --path $MODEL_PATH --arch transformer_wmt19_de_en_with_datastore \
    --beam 4 --lenpen 0.6 --max-len-a 1.2 --max-len-b 10 --source-lang de --target-lang en \
    --scoring sacrebleu \
    --batch-size 32 \
    --tokenizer moses --remove-bpe \
    --model-overrides "{'load_knn_datastore': True, 'use_knn_datastore': True,
    'dstore_filename': '$DATASTORE_PATH', 'dstore_size': $DSTORE_SIZE, 'dstore_fp16': True, 'k': 8, 'probe': 32,
    'knn_sim_func': 'do_not_recomp_l2', 'use_gpu_to_search': True, 'move_dstore_to_mem': True, 'no_load_keys': True,
    'knn_lambda_type': 'fix', 'knn_lambda_value': 0.7, 'knn_temperature_type': 'fix', 'knn_temperature_value': 10,
     }" \
    | tee "$OUTPUT_PATH"/generate.txt

grep ^S "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/src
grep ^T "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/ref
grep ^H "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp
grep ^D "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp.detok

We recommend you to use below hyper-parameters to replicate the good vanilla knn-mt results. And note that for our adaptive-knn-mt, we set the temperature as same as below.

IT Medical Law Koran
k 8 4 4 16
lambda 0.7 0.8 0.8 0.8
temperature 10 10 10 100

adaptive-knn-mt's People

Contributors

zhengxxn avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.