Giter Site home page Giter Site logo

Comments (6)

quoniammm avatar quoniammm commented on May 28, 2024 23

You should change:
loss += criterion(decoder_output[0], target_variable[di]) to
loss += criterion(decoder_output, target_variable[di])

from practical-pytorch.

czs0x55aa avatar czs0x55aa commented on May 28, 2024 13

you need to modify score function in Attention Model.
maybe you can use code follow:

def score(self, hidden, encoder_output):
        if self.method == 'dot':
            energy =torch.dot(hidden.view(-1), encoder_output.view(-1))
        elif self.method == 'general':
            energy = self.attn(encoder_output)
            energy = torch.dot(hidden.view(-1), energy.view(-1))
        elif self.method == 'concat':
            energy = self.attn(torch.cat((hidden, encoder_output), 1))
            energy = torch.dot(self.v.view(-1), energy.view(-1))
        return energy

but this implementation will very slower in GPU. #56

from practical-pytorch.

czs0x55aa avatar czs0x55aa commented on May 28, 2024 10

In pytorch v0.2, it removed implicit flattening for dot. pytorch/pytorch#2313
You can use the following writing

torch.dot(hidden.view(-1), energy.view(-1))

from practical-pytorch.

caozhen-alex avatar caozhen-alex commented on May 28, 2024 3

@czs0x55aa Thank you very much! It seems work. However, it raised another error.


ValueError Traceback (most recent call last)
in ()
8
9 # Run the train function
---> 10 loss = train(input_variable, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
11
12 # Keep track of loss

in train(input_variable, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length)
32 for di in range(target_length):
33 decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
---> 34 loss += criterion(decoder_output[0], target_variable[di])
35 decoder_input = target_variable[di] # Next target is next input
36

~/anaconda/envs/pytorch_nmt3.5/lib/python3.5/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
222 for hook in self._forward_pre_hooks.values():
223 hook(self, input)
--> 224 result = self.forward(*input, **kwargs)
225 for hook in self._forward_hooks.values():
226 hook_result = hook(self, input, result)

~/anaconda/envs/pytorch_nmt3.5/lib/python3.5/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
130 _assert_no_grad(target)
131 return F.nll_loss(input, target, self.weight, self.size_average,
--> 132 self.ignore_index)
133
134

~/anaconda/envs/pytorch_nmt3.5/lib/python3.5/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index)
674 return _functions.thnn.NLLLoss2d.apply(input, target, weight, size_average, ignore_index)
675 else:
--> 676 raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))
677
678

ValueError: Expected 2 or 4 dimensions (got 1)

Btw, I am a beginner of coding. How can I deal with this kind of error raised from source code? Thank you for help!

from practical-pytorch.

caozhen-alex avatar caozhen-alex commented on May 28, 2024

@czs0x55aa Thank you very much for answering. But where should I revise in this case--seq2seq translation.

from practical-pytorch.

czs0x55aa avatar czs0x55aa commented on May 28, 2024

sorry, i'm not facing this issue.
maybe you should to check the tensor size of decoder_output and target_variable

from practical-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.