Comments (2)
@bot66 See this comment: #136 (comment)
To reiterate:
Our trg
sequence will be something like [<sos>, x1, x2, x3, <eos>]
. When we do trg[:,-1]
the sequence will be [<sos>, x1, x2, x3]
, and our predicted sequence will be [y1, y2, y3, y4]
, where y1
should be x1
, y2
should be x2
, y3
should be x3
and y4
should be <eos>
. The predicted sequence should be a shifted version of the target sequence -- this is because we could the loss of output
against trg[,1:] = [x1, x2, x3, <eos>]
With padding, let's say the target sequence is [<sos>, x1, x2, x3, <eos>, <pad>, <pad>]
, thus trg[:,-1] = [<sos>, x1, x2, x3, <eos>, <pad>]
and thus our predicted sequence is [y1, y2, y3, y4, y5, y6]
. Same as before, but y5 and y6 should be the model predicting <pad>
tokens, so we are calculating the loss of our output
against trg[,1:] = [x1, x2, x3, <eos>, <pad>, <pad>]
. However, because we use nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)
we ignore the loss values over the <pad>
tokens, so we only calculate loss from the output
compared against trg[,1:] = [x1, x2, x3, <eos>]
-- which is the exact same sequence without padding.
I'm not sure exactly what you mean by the masking (feel free to clarify), but the trg_mask
ensures that each element in output
, each prediction yi
, can only "see" target tokens "at or before" it in time. For example, y1
can only "see" the first element in trg[:,-1]
, the shifted target, which is the <sos>
token, and must use only that and the encoded source sequence, enc_src
, to predict x1
. y2
can only see <sos>
and x1
and must use that to predict x2
. When it comes to the end of the sequence token, the decoder can "see" [<sos>, x1, x2, x3]
and can use this (with enc_src
) to predict the <eos>
token. It does the same thing with [<sos>, x1, x2, x3, <eos>]
to predict the <pad>
token -- but as mentioned before, we do not calculate loss over the padding tokens so this doesn't really matter.
You do not need to explicitly mask the <eos>
token. It is just another token you need to predict and should not be treated as a special token, e.g. think of a translation model in production, how would you know when to stop outputting tokens? You have to predict the <eos>
token.
Again, let me know if there's anything that needs clarifying.
from pytorch-seq2seq.
Thanks for your replying ! Ok i get it, when calculate loss the <eos> token predict logit loss will be ignored since the corresponding target is <pad>.
What i do is something like replace the <eos> token to <pad> token when trg input decoder transformer🤣
from pytorch-seq2seq.
Related Issues (20)
- Thank you! HOT 1
- Custom Text Dataset HOT 6
- Question
- torchtext recent version (0.12.0) doesn't support Field, BucketIterator HOT 4
- Question about how to resolve the out of vocabulary problem during encoding and decoding in tutorial 1
- Possible Inaccuracies in training script
- Tutorial 6: [Attention is All You need] Different output at different batch size during Inference
- Question about changing params init from xavier to kaiming
- Transformer ScaledDotProductAttention energy value on 16-bit Precision. HOT 3
- Using pretrained BERT embedding
- Why using tanh function HOT 3
- How do you make this work on android?
- Notebook 1 <eos> problem. HOT 2
- no module named 'torchtext.legacy' HOT 2
- import
- possible opposite explanation of hidden compared to output in notebook #3
- Seq2seq: Input not matching Output (and big thanks)
- How to change seq2seq to graph2seq
- Incorrect German Translation
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-seq2seq.