Giter Site home page Giter Site logo

Where is Expectation about pytorch-sgns HOT 5 OPEN

theeluwin avatar theeluwin commented on July 4, 2024
Where is Expectation

from pytorch-sgns.

Comments (5)

msummerfield avatar msummerfield commented on July 4, 2024 4

Thanks for publishing this repo. I have found it very useful in improving the performance of my own word2vec implementation.

On this particular issue, I note that mean(), sum(), log(), and sigmoid() are all continuous and monotonically-increasing functions of their inputs. Thus, barring any issues with numerical stability, minimising:

-(t.bmm(ovectors, ivectors).squeeze().sigmoid().log().mean(1) + t.bmm(nvectors, ivectors).squeeze().sigmoid().log().view(-1, context_size, self.n_negs).sum(2).mean(1))

is equivalent to minimising:

-(t.bmm(ovectors, ivectors).mean() + t.bmm(nvectors, ivectors).mean())

Given that all of the vectors start out 'small', and do not become excessively large in the course of training, numerical stability does not seem to be an issue. (Indeed, if there were some problem with stability I suspect it might arise anyway, since each function is computed successively, rather than using some numerically-stable compound implementation in the manner of nn.NLLLoss().) So, even though the loss function is 'wrong', it has the same argmin as the 'correct' loss function, which is all we really care about.

The same argument should apply to the 'improved' computation. Ultimately, the order of application of mean() and sum() operations makes no difference to the location of the minimum of the loss function. So, in terms of the optimisation, all you are doing is increasing the number of nwords. But so long as you have 'enough' negative samples, you should be fine - as Mikolov et al. say in their paper, 'we are free to simplify NCE as long as the vector representations retain their quality.'

I have tried the above simplification in my own implementation. It seems to work, in the sense that it converges and produces a sensible-looking result, although I have not done any testing to check that it produces the same embeddings (all else being equal). However, the speed-up is hardly worth it - about 5% for small vectors, e.g. around 25 elements, but with much smaller relative benefits for larger vectors, since the matrix multiplications dominate the computation time. The advantage of retaining the log() and sigmoid() functions is that the magnitude of the loss function is about the same, regardless of the parameters of the model, rather than being, e.g., roughly proportional to the vector size.

Incidentally, as far as I can tell from the original word2vec code (https://github.com/dav/word2vec/blob/9b8b58001ba5d58babe1c62327a8501b62cd6179/src/word2vec.c#L529) they use a fixed number of negative samples (just five by default), and it looks like they compute the sigmoid() function (by table lookup), but not the log().

from pytorch-sgns.

theeluwin avatar theeluwin commented on July 4, 2024

That would give more precise loss value, however, I'm afraid that applying this might consume too much time for training. I quickly wrote a sample code of your idea.
With appropriate self.n_samples on the model, change the line below https://github.com/theeluwin/pytorch-sgns/blob/master/model.py#L66

if self.weights is not None:
    nwords = t.multinomial(self.weights, batch_size * context_size * self.n_negs * self.n_samples, replacement=True).view(batch_size, -1)
else:
    nwords = FT(batch_size, context_size * self.n_negs * self.n_samples).uniform_(0, self.vocab_size - 1).long()
ivectors = self.embedding.forward_i(iword).unsqueeze(2)
ovectors = self.embedding.forward_o(owords)
nvectors = self.embedding.forward_o(nwords).neg()
oloss = t.bmm(ovectors, ivectors).squeeze().sigmoid().log().mean(1)
nloss = t.bmm(nvectors, ivectors).squeeze().sigmoid().log().view(-1, context_size, self.n_negs, self.n_samples).mean(3).sum(2).mean(1)

from pytorch-sgns.

zetyquickly avatar zetyquickly commented on July 4, 2024

I am completely agree with you that will make training slower. My collegue and me concluded that your approach also leads to convergence. But I wish there would be thoughts about this formula and about implementation in the README or in code. Thank you

from pytorch-sgns.

theeluwin avatar theeluwin commented on July 4, 2024

If you have some good idea for implementation, please go ahead for PR. Otherwise, I'll close this issue. Thank you.

from pytorch-sgns.

theeluwin avatar theeluwin commented on July 4, 2024

@msummerfield Thanks for the detailed feedback! Awesome.
Idea of using the 'faster' loss looks meaningful. The main reason I retained all the details is that the overall loss remains mathematically correct (e.g., loss = 1 means the prediction accuracy is 36.7%) but yes the long-calculation might suffer from numerical issues, since I've never cared about those.
Again, thanks for the awesome feedback. It inspired me in many ways, so I hope it could be posted in some other spaces (like a blog) too.

from pytorch-sgns.

Related Issues (14)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.