Giter Site home page Giter Site logo

Comments (13)

yinboc avatar yinboc commented on July 25, 2024

Hi,
It is my understanding that, in the original paper of Prototypical Networks, they used cross-entropy loss in page 3.
Also in their code I saw the -log term in their loss.
So I guess the loss function would be the same?

from prototypical-network-pytorch.

cultivater avatar cultivater commented on July 25, 2024

Hi,
It is my understanding that, in the original paper of Prototypical Networks, they used cross-entropy loss in page 3.
Also in their code I saw the -log term in their loss.
So I guess the loss function would be the same?

Thank you for your reply.
But are you sure the following equation is cross-entropy loss?In my view, the goal of this loss in the page 3 is minimizing the distance between the same pairs and maximumimg the sum of the distance between differernt pairs(Although it has exp()and log()term but I think you can understand what I mean) ,which is different from the cross-entropy loss whose goal is minimizing the distribution between true labels and the predicted one.

default

Besides, I think the code of the author is something different from his paper,too. Just look at this:
default
When the log_softmax is calculated, the 52th line shows that only the log_softmax of the same pairs are gathered as loss, regardless of the mismatching pairs.However, this operation is meet with the page 2 in his paper:

default

That's two points of my confusion.Anyway, I think your cross-entropy loss is more reasonable and outperforms the author's one in his code.

from prototypical-network-pytorch.

yinboc avatar yinboc commented on July 25, 2024

Thanks for your question.
The cross-entropy loss is: -x[class] + \log( \sum_j \exp(x[j]) ), which is the same as the formulation above. You could see torch.nn.CrossEntropyLoss for details.
-log p_{\phi} (y = k | x) is also equivalent, you just need to expand p_{\phi}.
Let me know if I am mistaken.
Best.

from prototypical-network-pytorch.

ujsyehao avatar ujsyehao commented on July 25, 2024

@cyvius96 Hi, I have some problems below:

  1. Paper emphasizes the mean of its support set is the prototype representation of each class but in your code, for example, Choose 2 class, 3 images per class in train phase, so I has 6 images like class1_1, class2_1, class1_2, class2_2, class1_3, class2_3, then divide them into support set(4 images: class1_2, class2_2, class1_3, class2_3) and query set(2 images: class1_1, class2_1), then calculate distance between support set and query set:
    1548385575
    At last, it chooses the max distance as predict class.
    My confusion is that it calculates the distance between the test image to class1-2 and class1-3 image, but class1-2 image and class1-3 image are the same class, according to the paper, we should calculate the mean of its support set which belongs to the same class.
    Thank you in advance!

from prototypical-network-pytorch.

yinboc avatar yinboc commented on July 25, 2024

@cyvius96 Hi, I have some problems below:

  1. Paper emphasizes the mean of its support set is the prototype representation of each class but in your code, for example, Choose 2 class, 3 images per class in train phase, so I has 6 images like class1_1, class2_1, class1_2, class2_2, class1_3, class2_3, then divide them into support set(4 images: class1_2, class2_2, class1_3, class2_3) and query set(2 images: class1_1, class2_1), then calculate distance between support set and query set:
    1548385575
    At last, it chooses the max distance as predict class.
    My confusion is that it calculates the distance between the test image to class1-2 and class1-3 image, but class1-2 image and class1-3 image are the same class, according to the paper, we should calculate the mean of its support set which belongs to the same class.
    Thank you in advance!

https://github.com/cyvius96/prototypical-network-pytorch/blob/a3f8f1e1afd7fcb8cab64ba89268a80790761f88/train.py#L74
I think this line computes the mean of shots for one class.

from prototypical-network-pytorch.

ujsyehao avatar ujsyehao commented on July 25, 2024

@cyvius96 I test this code proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0), it has no effect.
Test code:
241
Test result:
240
I use torch.equal() to test, and it turns out no change.

from prototypical-network-pytorch.

yinboc avatar yinboc commented on July 25, 2024

@cyvius96 I test this code proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0), it has no effect.
Test code:
241
Test result:
240
I use torch.equal() to test, and it turns out no change.

Of course, since it is 30-way 1-shot, and there are 30 prototypes (each with dim 1600) for 30 classes.
If you test n-way 5-shot, there should be n prototypes instead of 5n, images in same class are reduced by mean.

from prototypical-network-pytorch.

ujsyehao avatar ujsyehao commented on July 25, 2024

@cyvius96 Yeah, you are right! Thank you!
I am a rookie in this field, I have two more problems:

  1. From your code, query_variable is always bigger than shot_variable, Is it normal? Can you give me a reasonable explanation or some links to resources?
  2. I have written a demo.py, ex: I train a model which can classify car and truck, I provide 3 images when I do inference(a car image and a truck image as support set, a test images as query set), but I find it performs bad, should I provide more images as support set(then computes the mean of the shot)?

from prototypical-network-pytorch.

yinboc avatar yinboc commented on July 25, 2024

@cyvius96 Yeah, you are right! Thank you!
I am a rookie in this field, I have two more problems:

  1. From your code, query_variable is always bigger than shot_variable, Is it normal? Can you give me a reasonable explanation or some links to resources?
  2. I have written a demo.py, ex: I train a model which can classify car and truck, I provide 3 images when I do inference(a car image and a truck image as support set, a test images as query set), but I find it performs bad, should I provide more images as support set(then computes the mean of the shot)?

You are welcome.

  1. It is the setting proposed by most recent works.
  2. Umm... Maybe?
    Good luck to your research.

from prototypical-network-pytorch.

ujsyehao avatar ujsyehao commented on July 25, 2024

I found the formula in the paper has an error:
242
243
It should be not 1/Nc.

from prototypical-network-pytorch.

ujsyehao avatar ujsyehao commented on July 25, 2024

@cyvius96 Hi, Can you help me check the formula?

from prototypical-network-pytorch.

yinboc avatar yinboc commented on July 25, 2024

I found the formula in the paper has an error:
242
243
It should be not 1/Nc.

Yes, I think it should be 1/N_S.

from prototypical-network-pytorch.

ujsyehao avatar ujsyehao commented on July 25, 2024

@cyvius96 Thank you for your kind.
In paper, author use softmax loss(softmax function + cross entropy loss), but you use cross entropy loss instead, I think there is some difference between softmax loss and cross entropy loss.

from prototypical-network-pytorch.

Related Issues (17)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.