Giter Site home page Giter Site logo

the blue-score always 0 about knn-box HOT 5 CLOSED

njunlp avatar njunlp commented on May 29, 2024
the blue-score always 0

from knn-box.

Comments (5)

ZhaoQianfeng avatar ZhaoQianfeng commented on May 29, 2024

@oraby8 Hi , I think the problem is with the '--arch' option, you need to make sure it matches your pre-trained nmt model.

I found that your pre-trained model's architecture is different from wmt19 winner model, for example, your model's encoder_embed_dim is 768 and wmt19 winner model's encoder_embed_dim is 1024. So you can't simply use --arch vanilla_knn_mt@transformer_wmt19_de_en when building datastore and inferencing.The correct steps are as follows:

  1. add your arch function at the end of knnbox/common_utils/archs.py.
def my_arch(args):
     args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
     args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768)
     # other options
     ...
     base_architecture(args)
  1. regist your arch at the end of knnbox/models/vanilla_knn_mt.py
@register_model_architecture("vanilla_knn_mt", "vanilla_knn_mt@my_arch")
def transformer_my_arch(args):
    archs.my_arch(args)
  1. use --arch vanilla_knn_mt@my_arch when build_datastore and inference.
CUDA_VISIBLE_DEVICES=1 python $PROJECT_PATH/knnbox-scripts/common/validate.py $DATA_PATH \ --task translation \ 
--path $BASE_MODEL \ --source-lang en --target-lang ar \
--model-overrides "{'eval_bleu': False, 'required_seq_len_multiple':1, 'load_alignments': False}" \ 
--dataset-impl mmap \ --valid-subset train \ --skip-invalid-size-inputs-valid-test \ --max-tokens 2048 \ --bpe fastbpe \
--user-dir $PROJECT_PATH/knnbox/models \ --arch vanilla_knn_mt@my_arch \ --knn-mode build_datastore \ 
--knn-datastore-path $DATASTORE_SAVE_PATH \

And there is no need to specify encoder_embed_dim when inference, because --arch vanilla_knn_mt@my_arch contains this information:

CUDA_VISIBLE_DEVICES=1 python $PROJECT_PATH/knnbox-scripts/common/generate.py $DATA_PATH \ --task translation \
--path $BASE_MODEL \ --dataset-impl mmap \ --beam 4 --lenpen 0.6 --max-len-a 1.2 --max-len-b 10 --source-lang en
--target-lang ar \ --gen-subset test \ --max-tokens 2048 \  --scoring sacrebleu \ --tokenizer moses \ --remove-bpe \ 
--user-dir $PROJECT_PATH/knnbox/models \ --arch vanilla_knn_mt@my_arch \ --knn-mode inference \ 
--knn-datastore-path $DATASTORE_LOAD_PATH \ --knn-k 8 \ --knn-lambda 0.7 \ --knn-temperature 10.0 \

from knn-box.

oraby8 avatar oraby8 commented on May 29, 2024

Hi @ZhaoQianfeng , Thanks for replaying fast however i tried the same changes and it didn't does any effect to the results.
here's the model args. tell me if you do see something strange on it

Namespace(_name='fconv_wmt_en_de', activation_dropout=0.0, activation_fn='relu', adam_betas='(0.9, 0.998)', adam_eps=1e-08, adaptive_input =False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, all_gather_list_size=16384, amp=False, amp_batch_retries=2, amp_init_scale=128, amp_scale_window=None, arch='vanilla_knn_mt@ my_arch', attention_dropout=0.0, azureml_logging=False, batch_size=None, batch_size_valid=None, best_checkpoint_metric='loss', bf16=False, bpe='fastbpe', broadcast_buffers=False, bucket_cap_ mb=25, build_faiss_index_with_cpu=False, checkpoint_shard_count=1, checkpoint_suffix='', clip_norm=0.1, combine_valid_subsets=None, cpu=False, cpu_offload=False, criterion='label_smoothed_cr oss_entropy', cross_self_attention=False, curriculum=0, data='/mnt/data-2/new_knn_box/knn-box/data-bin/small/en_ar', data_buffer_size=10, dataset_impl='mmap', ddp_backend='pytorch_ddp', ddp_ comm_hook='none', decoder_attention='True', decoder_attention_heads=8, decoder_embed_dim=768, decoder_embed_path=None, decoder_ffn_embed_dim=2048, decoder_input_dim=768, decoder_layerdrop=0, decoder_layers=8, decoder_layers_to_keep=None, decoder_learned_pos=False, decoder_normalize_before=False, decoder_out_embed_dim=512, decoder_output_dim=768, device_id=0, disable_validation= False, distributed_backend='nccl', distributed_init_method=None, distributed_no_spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=1, dropout=0.2, empty_cache_freq= 0, encoder_attention_heads=8, encoder_embed_dim=768, encoder_embed_path=None, encoder_ffn_embed_dim=2048, encoder_layerdrop=0, encoder_layers=8, encoder_layers_to_keep=None, encoder_learned_ pos=False, encoder_normalize_before=False, eos=2, eval_bleu=False, eval_bleu_args='{}', eval_bleu_detok='space', eval_bleu_detok_args='{}', eval_bleu_print_samples=False, eval_bleu_remove_bp e=None, eval_tokenized_bleu=False, fast_stat_sync=False, find_unused_parameters=False, finetune_from_model=None, fix_batches_to_gpus=False, fixed_validation_seed=None, fp16=True, fp16_init_s cale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, fp32_reduce_scatter=False, gen_subset='test', heartbeat_timeout=-1, ignore_prefix_size=0, ignore_unus ed_valid_subsets=False, keep_best_checkpoints=-1, keep_interval_updates=-1, keep_interval_updates_pattern=-1, keep_last_epochs=-1, knn_datastore_path='/mnt/data-2/new_knn_box/knn-box/knnbox- scripts/vanilla-knn-mt/../../datastore/vanilla/small/en_ar', knn_k=8, knn_lambda=0.7, knn_mode='build_datastore', knn_temperature=10, label_smoothing=0.1, layernorm_embedding=False, left_pad _source=False, left_pad_target=False, load_alignments=False, load_checkpoint_on_all_dp_ranks=False, localsgd_frequency=3, log_file=None, log_format=None, log_interval=100, lr=[0.0005], lr_pa tience=0, lr_scheduler='reduce_lr_on_plateau', lr_shrink=0.5, lr_threshold=0.0001, max_epoch=100, max_source_positions=1024, max_target_positions=1024, max_tokens=2048, max_tokens_valid=4096, max_update=0, max_valid_steps=None, maximize_best_checkpoint_metric=False, memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_overrides="{'eval_bleu': False, 'required_seq_len_multiple':1, 'load_alignments': False}", model_parallel_size=1, no_cross_attention=False, no_epoch_checkpoints=False, no_last_checkpoints=False, no_progress_bar=False, no_reshard_after_forward=False, no_save=False, no_save_optimizer_state=False, no_scale_embedding=False, no_seed_provided=False, no_token_positional_embeddings=False, nprocs_per_node=1, num_batch_buckets=0, num_shards=1, num_workers=1, optimizer='adam', optimizer_overrides='{}', pad=1, path='/mnt/data-2/new_knn_box/knn-box/data-bin/small/model/checkpoint_best.pt', patience=10, pipeline_balance=None, pipeline_checkpoint='never', pipeline_chunks=0, pipeline_decoder_balance=None, pipeline_decoder_devices=None, pipeline_devices=None, pipeline_encoder_balance=None, pipeline_encoder_devices=None, pipeline_model_parallel=False, plasma_path='/tmp/plasma', profile=False, quant_noise_pq=0, quant_noise_pq_block_size=8, quant_noise_scalar=0, quantization_config_path=None, report_accuracy=False, required_batch_size_multiple=8, required_seq_len_multiple=1, reset_dataloader=False, reset_logging=False, reset_lr_scheduler=False, reset_meters=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='checkpoints', save_interval=1, save_interval_updates=0, scoring='bleu', seed=1, sentence_avg=False, shard_id=0, share_all_embeddings=False, share_decoder_input_output_embed=False, share_input_output_embed=False, simul_type=None, skip_invalid_size_inputs_valid_test=True, slowmo_algorithm='LocalSGD', slowmo_momentum=None, source_lang='en', stop_min_lr=-1.0, stop_time_hours=0, suppress_crashes=False, target_lang='ar', task='translation', tensorboard_logdir='log_dir', threshold_loss_scale=None, tie_adaptive_weights=False, tokenizer=None, tpu=False, train_subset='train', truncate_source=False, unk=3, update_freq=[1], upsample_primary=-1, use_bmuf=False, use_old_adam=False, use_plasma_view=False, use_sharded_state=False, user_dir='/mnt/data-2/new_knn_box/knn-box/knnbox-scripts/vanilla-knn-mt/../../knnbox/models', valid_subset='train', validate_after_updates=0, validate_interval=1, validate_interval_updates=0, wandb_project=None, warmup_init_lr=-1, warmup_updates=4000, weight_decay=0.0001, write_checkpoints_asynchronously=False, zero_sharding='none')

from knn-box.

ZhaoQianfeng avatar ZhaoQianfeng commented on May 29, 2024

@oraby8 Hello, it seems that the pre-training nmt model you use is convolutional neural machine translation instead of transformer?

Namespace(_name='fconv_wmt_en_de', ...

knn-box toolkit currently does not support the convolutional neural machine translation model, and all k-nearest neighbor machine translation research uses the transformer model. So when training your nmt model, please use --arch transformer. when building datastore and inference, use --arch vanilla_knn_mt@transformer

I am not very sure whether this is the reason why your translation result bleu is 0. If it is convenient, you can share your model checkpoint and data-bin files on Google drive, and I can help you troubleshoot the cause of the failure.

from knn-box.

oraby8 avatar oraby8 commented on May 29, 2024

yes here's the model and data-bin files
https://drive.google.com/file/d/1VmPEfPgrmpNifp2dVziiNCQhdTKKdd31/view?usp=sharing

from knn-box.

ZhaoQianfeng avatar ZhaoQianfeng commented on May 29, 2024

Hello, I downloaded the model file and checked it. The problem is that the pre-trained model you are using is not the transformer architecture, but the convolutional nmt model. You need to train a transformer architecture nmt model, Please refer to the https://github.com/facebookresearch/fairseq/tree/main/examples/translation IWSLT 14 German to English (the Transformer) section.

  1. when training nmt model, use --arch transformer
  2. when building datastore, use --arch vanilla_knn_mt@transformer
  3. when inferencing with knn-mt, use --arch vanilla_knn_mt@transformer

from knn-box.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.