Giter Site home page Giter Site logo

m-nauta / prototree Goto Github PK

View Code? Open in Web Editor NEW
84.0 1.0 17.0 891 KB

ProtoTrees: Neural Prototype Trees for Interpretable Fine-grained Image Recognition, published at CVPR2021

License: MIT License

Python 100.00%
pytorch computer-vision explainable-ai interpretability interpretable-machine-learning deep-neural-networks explainable-ml fine-grained-visual-categorization explainability interpretable-deep-learning

prototree's People

Contributors

m-nauta avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

prototree's Issues

Model inference question

Hi, thanks for this wonderful work, I was wondering if you could provide a tutorial on how to use trained model to do inference with model.pth, model_state.pth and tree.pkl.

Greatly appreciate the help

Training time

First of all, congrats on the very interesting paper!

I was wondering what architecture you used to train the model, and how much time it took. I'm trying to train on Colab with a GPU and it's taking ~40s per iteration (or ~5h per epoch). Is that the expected training time?

is finetuning avaliable?

Hi, thanks for this awesome paper, I was just wondering if finetuning neuro prototree is possible?

Acc lower than paper claims

Hi @M-Nauta , thank you for your very interesting paper and this very well-written repo. However, when I follow README to reproduce the results on CUB_200_2011, I got 72% and 78% accuracy in two runs. Did I miss anything? How should I fix this?

My final acc is only 76%

Eval Epoch 95: 100% 91/91 [01:49<00:00,  1.21s/it, Batch [91/91], Acc: 0.559]
Train Epoch 96: 100% 469/469 [05:52<00:00,  1.33it/s, Batch [469/469], Loss: 0.3
Eval Epoch 96: 100% 91/91 [01:51<00:00,  1.22s/it, Batch [91/91], Acc: 0.618]
Train Epoch 97: 100% 469/469 [05:55<00:00,  1.32it/s, Batch [469/469], Loss: 0.2
Eval Epoch 97: 100% 91/91 [01:47<00:00,  1.18s/it, Batch [91/91], Acc: 0.529]
Train Epoch 98: 100% 469/469 [05:52<00:00,  1.33it/s, Batch [469/469], Loss: 0.3
Eval Epoch 98: 100% 91/91 [01:48<00:00,  1.19s/it, Batch [91/91], Acc: 0.588]
Train Epoch 99: 100% 469/469 [05:56<00:00,  1.32it/s, Batch [469/469], Loss: 0.4
Eval Epoch 99: 100% 91/91 [01:46<00:00,  1.17s/it, Batch [91/91], Acc: 0.618]
Train Epoch 100: 100% 469/469 [05:53<00:00,  1.33it/s, Batch [469/469], Loss: 0.
Eval Epoch 100: 100% 91/91 [01:47<00:00,  1.18s/it, Batch [91/91], Acc: 0.647]
Eval Epoch pruned: 100% 91/91 [01:39<00:00,  1.10s/it, Batch [91/91], Acc: 0.647
Projection: 100% 375/375 [02:51<00:00,  2.19it/s, Batch: 375/375]
Eval Epoch pruned_and_projected: 100% 91/91 [01:36<00:00,  1.06s/it, Batch [91/9
Eval Epoch pruned_and_projected: 100% 91/91 [01:33<00:00,  1.03s/it, Batch [91/9
Eval Epoch pruned_and_projected: 100% 91/91 [02:16<00:00,  1.50s/it, Batch [91/9
Fidelity: 100% 91/91 [02:47<00:00,  1.84s/it, Batch [91/91]]
```bash
In the overview table
| 85   | 0.764411 | 0.927753346 | 0.315042642 |
| ---- | -------- | ----------- | ----------- |
| 86   | 0.766483 | 0.929271055 | 0.312246734 |
| 87   | 0.768208 | 0.929870736 | 0.310522149 |
| 88   | 0.77028  | 0.928586235 | 0.310088791 |
| 89   | 0.76631  | 0.929519071 | 0.306573199 |
| 90   | 0.763721 | 0.929437633 | 0.307344964 |
| 91   | 0.764066 | 0.929333985 | 0.301229446 |
| 92   | 0.765447 | 0.929681947 | 0.299217476 |
| 93   | 0.765274 | 0.930585169 | 0.298738088 |
| 94   | 0.768381 | 0.929467247 | 0.298132296 |
| 95   | 0.765965 | 0.929104478 | 0.296845926 |
| 96   | 0.76631  | 0.929918858 | 0.295910276 |
| 97   | 0.764066 | 0.930070629 | 0.295259658 |
| 98   | 0.763548 | 0.92906746  | 0.296624792 |
| 99   | 0.764411 | 0.929082267 | 0.296345691 |
| 100  | 0.761823 | 0.929915156 | 0.294409203 |

Using proto tree for multi label classification

Thank you for your interesting work.
I was wondering if this work can be applied in multi label classification setting, say in the domain of medical image analysis.
If not, can you let me what are the short comings that stop it from being applied there?

Accuracy of non-iNat Networks

Hi,

I was able to recreate the accuracy of ResNet50 pre-trained on iNat using the suggestions I found in the closed issues (i.e., around 82%), but when I substitute the network for another (say VGG-16) I get an accuracy of around 11%. Or ResNet50 pre-trained on ImageNet gets around 62%, and ResNet18 gets around 30%.

I'm just wondering if that's normal? Or are the other hyperparameters which need to be changed to boost accuracy on other networks? Thank you if you have time to offer suggestions.

So e.g., I may use this

python main_tree.py --epochs 150 --log_dir ./runs/protoree_cub --dataset CUB-200-2011 --lr 0.001 --lr_block 0.001 --lr_net 1e-5 --num_features 256 --depth 9 --net vgg16 --freeze_epochs 30 --milestones 60,80,100,120,140

Accuracy of Prototree ensemble 5

Thanks for your excellent work!
When i reproduce the Prototree ensemble 5 result, my accuracy is 83, lower than 87.2 in article. Here is my run command: python main_ensemble.py --epochs 100 --log_dir ./runs/protoree_cub --dataset CUB-200-2011 --lr 0.001 --lr_block 0.001 --lr_net 1e-5 --num_features 256 --depth 9 --net resnet50_inat --freeze_epochs 30 --milestones 60,70,80,90,100 --nr_trees_ensemble 5
How can I solve that?

Model can not be loaded

First of all, thanks a lot for your interesting work! I recreated the training process very well. However, there seems to be a problem loading the model. When I use a new file to load the model (use tree.load() or tree.load_state_dict()), it seems that the loaded tree cannot do well on the original training dataset. Did I do something wrong, could you help me fix this problem?
load_code

result

Upsample issue multiple bounding boxes

Hi Meike,

Since I am not a collaborator on this repository, I am unable to push any changes. The problem for upsample.py where too large bounding boxes are drawn for a prototype image can be fixed by replacing lines 56-58 with the following code:
# save the highly activated patch
masked_similarity_map = np.zeros(similarity_map.shape)
prototype_index = prototype_info['patch_ix']
masked_similarity_map[prototype_index // 7, prototype_index % 7] = 1

Instead of finding the maximum value of the latent heatmap, the project_info is used for finding the location of the patch/prototype. The problem with the previous code was that the heatmap can contain multiple maxima.

Kind regards,
Guido

connect to swin transformer

Hi, thanks for this wonderful work, I was wondering if the author could provide for code for using swin transformer as feature extractor, thanks, deeply appericiate the support.

Projection of prototypes

Hi thanks for the great repo, it was really easy to reproduce.

I've just a general question about the projection step, you seem to use the bounding boxes to do projection right? Like you crop the images to just the box around each bird as the data for the loader used for projection.

Why is that? Does it not work as well with the full images? Or is there a technical reason? Thanks in advance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.