Comments (13)
Hi,
It is my understanding that, in the original paper of Prototypical Networks, they used cross-entropy loss in page 3.
Also in their code I saw the -log term in their loss.
So I guess the loss function would be the same?
from prototypical-network-pytorch.
Hi,
It is my understanding that, in the original paper of Prototypical Networks, they used cross-entropy loss in page 3.
Also in their code I saw the -log term in their loss.
So I guess the loss function would be the same?
Thank you for your reply.
But are you sure the following equation is cross-entropy loss?In my view, the goal of this loss in the page 3 is minimizing the distance between the same pairs and maximumimg the sum of the distance between differernt pairs(Although it has exp()and log()term but I think you can understand what I mean) ,which is different from the cross-entropy loss whose goal is minimizing the distribution between true labels and the predicted one.
Besides, I think the code of the author is something different from his paper,too. Just look at this:
When the log_softmax is calculated, the 52th line shows that only the log_softmax of the same pairs are gathered as loss, regardless of the mismatching pairs.However, this operation is meet with the page 2 in his paper:
That's two points of my confusion.Anyway, I think your cross-entropy loss is more reasonable and outperforms the author's one in his code.
from prototypical-network-pytorch.
Thanks for your question.
The cross-entropy loss is: -x[class] + \log( \sum_j \exp(x[j]) ), which is the same as the formulation above. You could see torch.nn.CrossEntropyLoss for details.
-log p_{\phi} (y = k | x) is also equivalent, you just need to expand p_{\phi}.
Let me know if I am mistaken.
Best.
from prototypical-network-pytorch.
@cyvius96 Hi, I have some problems below:
- Paper emphasizes the mean of its support set is the prototype representation of each class but in your code, for example, Choose 2 class, 3 images per class in train phase, so I has 6 images like class1_1, class2_1, class1_2, class2_2, class1_3, class2_3, then divide them into support set(4 images: class1_2, class2_2, class1_3, class2_3) and query set(2 images: class1_1, class2_1), then calculate distance between support set and query set:
At last, it chooses the max distance as predict class.
My confusion is that it calculates the distance between the test image to class1-2 and class1-3 image, but class1-2 image and class1-3 image are the same class, according to the paper, we should calculate the mean of its support set which belongs to the same class.
Thank you in advance!
from prototypical-network-pytorch.
@cyvius96 Hi, I have some problems below:
- Paper emphasizes the mean of its support set is the prototype representation of each class but in your code, for example, Choose 2 class, 3 images per class in train phase, so I has 6 images like class1_1, class2_1, class1_2, class2_2, class1_3, class2_3, then divide them into support set(4 images: class1_2, class2_2, class1_3, class2_3) and query set(2 images: class1_1, class2_1), then calculate distance between support set and query set:
At last, it chooses the max distance as predict class.
My confusion is that it calculates the distance between the test image to class1-2 and class1-3 image, but class1-2 image and class1-3 image are the same class, according to the paper, we should calculate the mean of its support set which belongs to the same class.
Thank you in advance!
https://github.com/cyvius96/prototypical-network-pytorch/blob/a3f8f1e1afd7fcb8cab64ba89268a80790761f88/train.py#L74
I think this line computes the mean of shots for one class.
from prototypical-network-pytorch.
@cyvius96 I test this code proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
, it has no effect.
Test code:
Test result:
I use torch.equal()
to test, and it turns out no change.
from prototypical-network-pytorch.
@cyvius96 I test this code
proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
, it has no effect.
Test code:
Test result:
I usetorch.equal()
to test, and it turns out no change.
Of course, since it is 30-way 1-shot, and there are 30 prototypes (each with dim 1600) for 30 classes.
If you test n-way 5-shot, there should be n prototypes instead of 5n, images in same class are reduced by mean.
from prototypical-network-pytorch.
@cyvius96 Yeah, you are right! Thank you!
I am a rookie in this field, I have two more problems:
- From your code, query_variable is always bigger than shot_variable, Is it normal? Can you give me a reasonable explanation or some links to resources?
- I have written a demo.py, ex: I train a model which can classify car and truck, I provide 3 images when I do inference(a car image and a truck image as support set, a test images as query set), but I find it performs bad, should I provide more images as support set(then computes the mean of the shot)?
from prototypical-network-pytorch.
@cyvius96 Yeah, you are right! Thank you!
I am a rookie in this field, I have two more problems:
- From your code, query_variable is always bigger than shot_variable, Is it normal? Can you give me a reasonable explanation or some links to resources?
- I have written a demo.py, ex: I train a model which can classify car and truck, I provide 3 images when I do inference(a car image and a truck image as support set, a test images as query set), but I find it performs bad, should I provide more images as support set(then computes the mean of the shot)?
You are welcome.
- It is the setting proposed by most recent works.
- Umm... Maybe?
Good luck to your research.
from prototypical-network-pytorch.
I found the formula in the paper has an error:
It should be not 1/Nc.
from prototypical-network-pytorch.
@cyvius96 Hi, Can you help me check the formula?
from prototypical-network-pytorch.
I found the formula in the paper has an error:
It should be not 1/Nc.
Yes, I think it should be 1/N_S.
from prototypical-network-pytorch.
@cyvius96 Thank you for your kind.
In paper, author use softmax loss(softmax function + cross entropy loss), but you use cross entropy loss instead, I think there is some difference between softmax loss and cross entropy loss.
from prototypical-network-pytorch.
Related Issues (17)
- How to get the mini-imageNet dataset? HOT 2
- Custom dataset HOT 4
- Single image prediction HOT 1
- How train with a new dataset HOT 1
- Selecting support and query sets randomly HOT 1
- 训练、测试曲线图training graph,testing graph HOT 2
- confusion matrix plot for evaluation
- Dataset could not be downloaded. HOT 1
- Cannot get the reported results. HOT 1
- No of epochs to be trained for achieving the given results HOT 2
- Can you provide demo.py? HOT 1
- about the memory HOT 5
- Change step-size and max-epochs to get higher performance HOT 1
- LICENSE? HOT 1
- Running on Google Colab HOT 1
- DataSets HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from prototypical-network-pytorch.