Giter Site home page Giter Site logo

mlforhealth / s2sd Goto Github PK

View Code? Open in Web Editor NEW
41.0 4.0 6.0 1022 KB

(ICML 2021) Implementation for S2SD - Simultaneous Similarity-based Self-Distillation for Deep Metric Learning. Paper Link: https://arxiv.org/abs/2009.08348

License: MIT License

Shell 2.62% Python 97.38%
deep-metric-learning metric-learning deep-learning pytorch cub200-2011 cars196 stanford-online-products image-retrieval zero-shot-learning

s2sd's Introduction

Authors: Karsten Roth, Timo Milbich, Bjoern Ommer, Joseph Paul Cohen, Marzyeh Ghassemi
Contact: [email protected]

Suggestions are always welcome!


What is S2SD about?

Simultaneous Similarity-based Self-Distillation, or S2SD for short, is a cheap and easy to use extension for any Deep Metric Learning application:

setup

It learns high-dimensional, better generalizing context during training which it distills into a base embedding space, while also tackling the dimensional bottleneck between feature representations and final embeddings. This significantly improves performance at no additional cost at test time and very little changes in training time. It is especially efficient in boosting very low-dimensional embedding spaces, and easily boosts strong baseline objective to achieve state-of-the-art.


What can I find in this repo?

The base code in this repo is based on this repository, which includes implementations and training protocols for various DML objectives. We have integrated S2SD into this repository in the folder criteria under s2sd.py. Other than that, the majority of the code is the same (with the exception of a checkpointing mechanism related to the --checkpoint flag).


How to use S2SD

For general usage instructions for this repo, we refer to the original repository.

To use S2SD with any DML objective, simply run

python main.py --dataset cub200 --source $datapath --n_epochs 150 --log_online --project <your_project_name> --group <run_name_which_groups_seeds>
--seed 0 --gpu $gpu --bs 112 --loss s2sd --loss_distill_source multisimilarity --loss_distill_target multisimilarity
--loss_distill_T 1 --loss_distill_w 50 --loss_distill_net_dims 512 1024 1536 2048 --loss_distill_match_feats
--loss_distill_feat_w 50 --arch resnet50_frozen_normalize --embed_dim 128 --loss_distill_max_feat_iter 1000

Besides basic parameters, S2SD introduces the following new parameters:

--loss s2sd

--loss_s2sd_source <criterion>

--loss_s2sd_target <criterion>

--loss_s2sd_T <temperature>

--loss_s2sd_w <distillation weights>

--loss_s2sd_target_dims <dimensions of target branches>

--loss_s2sd_pool_aggr >> Flag; Use global max- & avg. pooling in aux. branches

--loss_s2sd_feat_distill >> Flag; apply feature space distillation

--loss_s2sd_feat_distill_w <feature distillation weight>

--loss_s2sd_feat_distill_delay <activation delay of feature distillation

Citations

If you use this repository, please make sure to cite the S2SD paper:

@misc{roth2020s2sd,
    title={S2SD: Simultaneous Similarity-based Self-Distillation for Deep Metric Learning},
    author={Karsten Roth and Timo Milbich and Björn Ommer and Joseph Paul Cohen and Marzyeh Ghassemi},
    year={2020},
    eprint={2009.08348},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

If you use the base code in this repo (and thus of this repo), please make sure to also cite the original paper:

@misc{roth2020revisiting,
    title={Revisiting Training Strategies and Generalization Performance in Deep Metric Learning},
    author={Karsten Roth and Timo Milbich and Samarth Sinha and Prateek Gupta and Björn Ommer and Joseph Paul Cohen},
    year={2020},
    eprint={2002.08473},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

s2sd's People

Contributors

confusezius avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

s2sd's Issues

The Kullback-Leibler divergence

I found the Kullback-Leibler divergence in this paper is different from the standard form. There is one more operation (a log-operation on A). Why do it this way?

S2SD/criteria/s2sd.py

Lines 133 to 135 in 1ef26e4

log_p_A = F.log_softmax(A/self.T, dim=-1)
p_B = F.softmax(B/self.T, dim=-1)
kl_div = F.kl_div(log_p_A, p_B, reduction='sum') * (self.T**2) / A.shape[0]

What's more, in previous KD methods, logits or features of teacher (target) network is in the first position of kl_div function, and student (source) is in the second. But in this paper, the order is reverse. Why?

kl_divs.append(self.kl_div(source_smat, target_smat.detach()))

Question about sampling process of "Sharing Matters for Generalization in Deep Metric Learning"

Dear authors,

Thanks for the serial awesome works on metric learning and the released code.

I am very interested in the idea of learning sharing features across classes in [1]. And I noticed that this work has been accpeted by TPAMI, Congratulations!

The sharing features are obtained through constructing various class triplets in Sec. 3.2. But I have a question about the online sampling process: In sub-section "Online sampling of inter-class triplets t∗", the triplets sayed to be sampled based on analytical distance distribution, but I am curious that if 【both positive and negative samples are sampled based on the same distribution】or 【only negative samples are sampled based on the distance distribution and postive samples are randomly sampled】?

PS:the question is raised in this reposince because I do not find the open-resouced code of [1], thanks.

[1] Sharing Matters for Generalization in Deep Metric Learning

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.