yinboc / prototypical-network-pytorch Goto Github PK
View Code? Open in Web Editor NEWA re-implementation of "Prototypical Networks for Few-shot Learning"
License: MIT License
A re-implementation of "Prototypical Networks for Few-shot Learning"
License: MIT License
Since we use episodes and in each and every episodes different no of classes are involved. Can you help me to find confusion matrix. As with confusion matrix we can have better idea regarding where its failing.
Hello,
Thanks for this implementation.
Recently, I ran this repo with step-size=100 and max-epochs=2000 and got higher performance.
For 5way1shot train, I got 50.30% and for 5way5shot train I got 67.11%
Dataset could not be downloaded. It needs permission.
Hi!
Could you provide the license for the codes? (hopefully MIT :))
Thank you for sharing your code.
If I understood correctly, you don't sample support and query sets randomly unlike the original paper of Protonet. Can this be the reason of overfitting on custom dataset?
data, _ = [_.cuda() for _ in batch]
p = args.shot * args.train_way
data_shot, data_query = data[:p], data[p:]
label = torch.arange(args.train_way).repeat(args.query)
why labels is not from train_loader? It wont affect accuracy?
Dear Sir,
I am trying to run the code on google colab, but it seems to be slow. I wonder whether it is me being impatient or is it due to the way mini_imagenet.py is coded.
I wanted to ask which of the two options is faster?
Dear Chen,
I noticed that you used cross-entropy loss function in your code, but the author of the paper tried to minimize the distance between query and support data who share the same lable, which was used as loss function. Here's his code : https://github.com/jakesnell/prototypical-networks
So, Do you have some doubt about the author's loss function?
Thank you!
怎么能拿到这个训练数据
请问程序有没有训练测试的曲线图,trlog是训练曲线图吗?Does the program have a training test graph? Is trlog a training graph?
Hi
I am trying to create a demo version where single image can be taken and predicted. I am bit new to pytorch coding. and facing few issues.
I modified as below. I have loaded the query image and the 7 different support images into data_query and data_shot respectively.
def load_support():
img_data = []
img_transform = transforms.Compose([
transforms.Resize(84),
transforms.CenterCrop(84),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
for file in os.listdir("/home/abc/data/support/"):
if file.endswith(".jpg"):
path="/home/abc/data/support/"+file
image = img_transform(Image.open(path).convert('RGB'))
img_data.append(image.tolist())
img_data=torch.cuda.FloatTensor(img_data)
return img_data
if __name__ == '__main__':
img_transform = transforms.Compose([
transforms.Resize(84),
transforms.CenterCrop(84),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
model = Convnet().cuda()
model.load_state_dict(torch.load('./save/proto-17-5/max-acc.pth'))
model.eval()
data_shot= load_support()
path="/home/abc/data/val/test-15.jpg"
data_query = img_transform(Image.open(path).convert('RGB'))
data_query=data_query.cuda()
print("support",data_shot)
print("query",data_query)
x = model(data_shot)
print("shape",x.shape)
> logits = euclidean_metric(model(data_query), x)
label = torch.arange(7).repeat(1)
label = label.type(torch.cuda.LongTensor)
the logits = euclidean_metric(model(data_query), x) line is throwing an error saying
Expected 4-dimensional input for 4-dimensional weight 64 3 3, but got 3-dimensional input of size [3, 84, 84] instead.
What is the additional parameter I am missing? Please guide.
Thanks
Hi, I was able to reach 48.2% accuracy for 1-shot on the test set after 200 epochs. Can you please mention the number of epochs that you trained the model to achieve the given 1-shot and 5-shot results. Was there any other architectural changes made for getting these results.
Hi.
I was trying out few shot on custom dataset, My custom data has 17 categories and 6 images per category. created a class similar to MiniImageNet and used in it validation, i.e valset. But it is throwing me an error in line 114 as
loss = F.cross_entropy(logits, label)
File "/home/search_env/lib/python3.6/site-packages/torch/nn/functional.py", line 2056, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/home/search_env/lib/python3.6/site-packages/torch/nn/functional.py", line 1869, in nll_loss
.format(input.size(0), target.size(0)))
ValueError: Expected input batch_size (5) to match target batch_size (75).
Am I missing something? Kindly guide
Thank you very much for sharing your code.
I want to train models with a dataset other than miniImageNet or Omniglot. Could you guide me how to arrange dataset (main folder , sub-folder), and how to train models on a new dataset?
Thanks a lot.
Hi, thanks for your source code. However, when I run the 1-shot code in two GTX 1080 Ti parallel, it still occurs "CUDA error: out of memory", what I can do is to change the numbers of "train-way" and "query", but I am afraid the results would be dropped. So, could you give me some advice when I just have two GTX1080 Ti? And how many of your GPUs when you run this source code? Thanks a lot!
Hi, I have trained the model use mini-imagenet dataset, and use test.py, But I want to use the model to do inference(input one image and get its result), I don't know how to write the demo.py. Can you provide me some advice?
According to my understanding, when I test on single image, I at least provide the number of trained classes' images, it is right?
According to the paper, using 5-way 5-shot to train and 5-way 5-shot to test, the accuracy should be ~65.77%, but I can only get ~43.79%. I don't know why, could you please give me any suggestions?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.