Giter Site home page Giter Site logo

prototypical-network-pytorch's Issues

Selecting support and query sets randomly

Thank you for sharing your code.

If I understood correctly, you don't sample support and query sets randomly unlike the original paper of Protonet. Can this be the reason of overfitting on custom dataset?

data, _ = [_.cuda() for _ in batch] 
p = args.shot * args.train_way
data_shot, data_query = data[:p], data[p:]

Cannot get the reported results.

According to the paper, using 5-way 5-shot to train and 5-way 5-shot to test, the accuracy should be ~65.77%, but I can only get ~43.79%. I don't know why, could you please give me any suggestions?

Problem about labels

label = torch.arange(args.train_way).repeat(args.query)
why labels is not from train_loader? It wont affect accuracy?

LICENSE?

Hi!

Could you provide the license for the codes? (hopefully MIT :))

Running on Google Colab

Dear Sir,
I am trying to run the code on google colab, but it seems to be slow. I wonder whether it is me being impatient or is it due to the way mini_imagenet.py is coded.
I wanted to ask which of the two options is faster?

  1. accessing the images directly from folders
  2. converting the dataset into some other format like pickle/hdf5 etc
    And if the second option is better, how do I convert to that?
    Thanks

How train with a new dataset

Thank you very much for sharing your code.

I want to train models with a dataset other than miniImageNet or Omniglot. Could you guide me how to arrange dataset (main folder , sub-folder), and how to train models on a new dataset?

Thanks a lot.

DataSets

怎么能拿到这个训练数据

about the memory

Hi, thanks for your source code. However, when I run the 1-shot code in two GTX 1080 Ti parallel, it still occurs "CUDA error: out of memory", what I can do is to change the numbers of "train-way" and "query", but I am afraid the results would be dropped. So, could you give me some advice when I just have two GTX1080 Ti? And how many of your GPUs when you run this source code? Thanks a lot!

confusion matrix plot for evaluation

Since we use episodes and in each and every episodes different no of classes are involved. Can you help me to find confusion matrix. As with confusion matrix we can have better idea regarding where its failing.

About loss function

Dear Chen,
I noticed that you used cross-entropy loss function in your code, but the author of the paper tried to minimize the distance between query and support data who share the same lable, which was used as loss function. Here's his code : https://github.com/jakesnell/prototypical-networks
So, Do you have some doubt about the author's loss function?
Thank you!

Can you provide demo.py?

Hi, I have trained the model use mini-imagenet dataset, and use test.py, But I want to use the model to do inference(input one image and get its result), I don't know how to write the demo.py. Can you provide me some advice?
According to my understanding, when I test on single image, I at least provide the number of trained classes' images, it is right?

No of epochs to be trained for achieving the given results

Hi, I was able to reach 48.2% accuracy for 1-shot on the test set after 200 epochs. Can you please mention the number of epochs that you trained the model to achieve the given 1-shot and 5-shot results. Was there any other architectural changes made for getting these results.

Single image prediction

Hi
I am trying to create a demo version where single image can be taken and predicted. I am bit new to pytorch coding. and facing few issues.
I modified as below. I have loaded the query image and the 7 different support images into data_query and data_shot respectively.

def load_support():   
    img_data = []
    img_transform = transforms.Compose([
            transforms.Resize(84),
            transforms.CenterCrop(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])      
    for file in os.listdir("/home/abc/data/support/"):
        if file.endswith(".jpg"):
            path="/home/abc/data/support/"+file
            image = img_transform(Image.open(path).convert('RGB')) 
            img_data.append(image.tolist())
    img_data=torch.cuda.FloatTensor(img_data)
    return img_data
if __name__ == '__main__':
    
    img_transform = transforms.Compose([
            transforms.Resize(84),
            transforms.CenterCrop(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
 
    model = Convnet().cuda()
    model.load_state_dict(torch.load('./save/proto-17-5/max-acc.pth'))
    model.eval()

    data_shot= load_support()
    path="/home/abc/data/val/test-15.jpg"
    data_query = img_transform(Image.open(path).convert('RGB'))
    
    data_query=data_query.cuda()
    
    print("support",data_shot)
    print("query",data_query)
   
    x = model(data_shot)
    print("shape",x.shape)
    

> logits = euclidean_metric(model(data_query), x)

    label = torch.arange(7).repeat(1)
    label = label.type(torch.cuda.LongTensor)

the logits = euclidean_metric(model(data_query), x) line is throwing an error saying
Expected 4-dimensional input for 4-dimensional weight 64 3 3, but got 3-dimensional input of size [3, 84, 84] instead.

What is the additional parameter I am missing? Please guide.
Thanks

Custom dataset

Hi.
I was trying out few shot on custom dataset, My custom data has 17 categories and 6 images per category. created a class similar to MiniImageNet and used in it validation, i.e valset. But it is throwing me an error in line 114 as
loss = F.cross_entropy(logits, label)
File "/home/search_env/lib/python3.6/site-packages/torch/nn/functional.py", line 2056, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/home/search_env/lib/python3.6/site-packages/torch/nn/functional.py", line 1869, in nll_loss
.format(input.size(0), target.size(0)))
ValueError: Expected input batch_size (5) to match target batch_size (75).

Am I missing something? Kindly guide

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.