hi, when training cnat, I get following error:
if "VQ" in inner_states[GlobalNames.PRI_RET]:
KeyError: 'prior_ret'
detail info:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, *args)
File "/search/CNAT/fairseq/fairseq/distributed_utils.py", line 270, in distributed_main
main(args, **kwargs)
File "/search/CNAT/train.py", line 112, in main
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/contextlib.py", line 74, in inner
return func(*args, **kwds)
File "/search/CNAT/train.py", line 190, in train
log_output = trainer.train_step(samples)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/contextlib.py", line 74, in inner
return func(*args, **kwds)
File "/search/CNAT/fairseq/fairseq/trainer.py", line 486, in train_step
ignore_grad=is_dummy_batch,
File "/search/CNAT/latent_nat/nat_task.py", line 154, in train_step
return super().train_step(sample, model, criterion, optimizer, update_num, ignore_grad)
File "/search/CNAT/fairseq/fairseq/tasks/translation_lev.py", line 178, in train_step
loss, sample_size, logging_output = criterion(model, sample)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/CNAT/latent_nat/awesome_nat_loss.py", line 55, in forward
outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/CNAT/fairseq/fairseq/legacy_distributed_data_parallel.py", line 85, in forward
return self.module(*inputs, **kwargs)
File "/root/anaconda3/envs/torch_fairseq/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/CNAT/latent_nat/cnat.py", line 46, in forward
losses.update(self._compute_vq_loss(inner_states))
File "/search/CNAT/latent_nat/cnat.py", line 57, in _compute_vq_loss
if "VQ" in inner_states[GlobalNames.PRI_RET]:
KeyError: 'prior_ret'
My fully training command is:
python3 train.py ${DATA_BIN}
--user-dir ${USER_DIR}
--save-dir $CHECKPOINT
--ddp-backend=no_c10d
--task nat
--criterion awesome_nat_loss
--arch cnat_wmt14
--self-attn-cls shaw
--block-cls highway
--max-rel-positions 4
--enc-self-attn-cls shaw
--enc-block-cls highway
--share-rel-embeddings
--share-decoder-input-output-embed
--mapping-func interpolate
--mapping-use output
--noise full_mask
--apply-bert-init
--optimizer adam
--lr 0.0007
--lr-scheduler inverse_sqrt
--warmup-updates 10000
--warmup-init-lr 1e-07
--min-lr 1e-09
--weight-decay 0.0
--dropout 0.1
--encoder-learned-pos
--decoder-learned-pos
--pred-length-offset
--length-loss-factor 0.1
--label-smoothing 0.0
--log-interval 100
--fixed-validation-seed 7
--max-tokens 4096
--update-freq 1
--save-interval-updates 500
--keep-best-checkpoints 5
--no-epoch-checkpoints
--keep-interval-updates 5
--max-update 300000
--num-workers 0
--eval-bleu
--eval-bleu-detok moses
--eval-bleu-remove-bpe
--best-checkpoint-metric bleu
--maximize-best-checkpoint-metric
--iter-decode-max-iter 0
--iter-decode-eos-penalty 0
--left-pad-source False
--batch-size-valid 128
--latent-factor 0.5
--num-codes 64
--vq-ema
--crf-cls BCRF
--crf-num-head 4
--latent-layers 5
--vq-schedule-ratio 0.5
--find-unused-parameters