Giter Site home page Giter Site logo

Comments (2)

bentrevett avatar bentrevett commented on May 23, 2024 3

@bot66 See this comment: #136 (comment)

To reiterate:

Our trg sequence will be something like [<sos>, x1, x2, x3, <eos>]. When we do trg[:,-1] the sequence will be [<sos>, x1, x2, x3], and our predicted sequence will be [y1, y2, y3, y4], where y1 should be x1, y2 should be x2, y3 should be x3 and y4 should be <eos>. The predicted sequence should be a shifted version of the target sequence -- this is because we could the loss of output against trg[,1:] = [x1, x2, x3, <eos>]

With padding, let's say the target sequence is [<sos>, x1, x2, x3, <eos>, <pad>, <pad>], thus trg[:,-1] = [<sos>, x1, x2, x3, <eos>, <pad>] and thus our predicted sequence is [y1, y2, y3, y4, y5, y6]. Same as before, but y5 and y6 should be the model predicting <pad> tokens, so we are calculating the loss of our output against trg[,1:] = [x1, x2, x3, <eos>, <pad>, <pad>]. However, because we use nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX) we ignore the loss values over the <pad> tokens, so we only calculate loss from the output compared against trg[,1:] = [x1, x2, x3, <eos>] -- which is the exact same sequence without padding.

I'm not sure exactly what you mean by the masking (feel free to clarify), but the trg_mask ensures that each element in output, each prediction yi, can only "see" target tokens "at or before" it in time. For example, y1 can only "see" the first element in trg[:,-1], the shifted target, which is the <sos> token, and must use only that and the encoded source sequence, enc_src, to predict x1. y2 can only see <sos> and x1 and must use that to predict x2. When it comes to the end of the sequence token, the decoder can "see" [<sos>, x1, x2, x3] and can use this (with enc_src) to predict the <eos> token. It does the same thing with [<sos>, x1, x2, x3, <eos>] to predict the <pad> token -- but as mentioned before, we do not calculate loss over the padding tokens so this doesn't really matter.

You do not need to explicitly mask the <eos> token. It is just another token you need to predict and should not be treated as a special token, e.g. think of a translation model in production, how would you know when to stop outputting tokens? You have to predict the <eos> token.

Again, let me know if there's anything that needs clarifying.

from pytorch-seq2seq.

bot66 avatar bot66 commented on May 23, 2024

Thanks for your replying ! Ok i get it, when calculate loss the <eos> token predict logit loss will be ignored since the corresponding target is <pad>.

What i do is something like replace the <eos> token to <pad> token when trg input decoder transformer🤣

from pytorch-seq2seq.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.