Giter Site home page Giter Site logo

pytorch.repmet_lu3a's Introduction

Magnet Loss and RepMet in PyTorch

NOTE: THIS PROJECT IS ON HOLD WHILE I WORK ON THE DETECTION PIPELINE IN MXNET

This takes a lot from the Tensorflow Magnet Loss code: pumpikano/tf-magnet-loss

Magnet Loss

Figure 3 from paper

"Metric Learning with Adaptive Density Discrimination" introduced a new optimization objective for distance metric learning called Magnet Loss that, unlike related losses, operates on entire neighborhoods in the representation space and adaptively defines the similarity that is being optimized to account for the changing representations of the training data.

RepMet

Figure 2 from paper

"RepMet: Representative-based Metric Learning for Classification and One-shot Object Detection" extends upon magnet loss by storing the centroid as representations that are learnable, rather than just statically calculated every now and again with k-means.

Implementation

NOTE: Currently only classification, RepMet's detection functionality coming soon :)

Tested with python 3.6 + pytorch 0.4 + cuda 9.1

See train.py for training the model, please ensure your path is set in configs.py.

RepMet.v2

There are two versions of the RepMet loss implemented in this code, as the original authors suggested this modification from the original paper.

Version 1: As in the original paper it uses the closest (R*) representative in the numerator and disregards same class representatives in the denominator:

eq repmetv1

where:

eq repmetv1b

Version 2: Sums the distance of all representatives of the same class in the numerator, and doesn't disregard any in the denominator.

eq repmetv2

MyLoss

Testing my own loss as the RepMet's don't make sense to me yet... (never reach 0 as denom is always greater than numerator). My loss is defined as:

eq mylossv1

Although it works better using squared euclidean like the other losses... hmm

eq mylossv2

Takes max distance of embeddings with their same-class clusters/reps/modes plus the alpha margin, and subtracts this from every distance for all embeddings. We only sum the J clusters/reps that correspond to the classes seen in the batch (which is different from what repmet seems to do [have to check with authors], but same as magnet [as the means are taken]).

Datasets

Currently works on MNIST, working on getting the implementation to work with Oxford Flowers 102 and Stanford Dogs at the moment.

Evaluation

During training a number of accuracies are calculated:

1. Batch Accuracy: In the batch what % of the batch samples are correctly assigned to their cluster (for magnet loss this is how close to their in-batch mean).

2. Simple Accuracy: Assign a sample x to its closest training cluster.

eq simple

3. Magnet Accuracy: Use eq. 6 from Magnet Loss Paper. Take min(L, n_clusters) closest clusters to a sample x, then take the sum of the distances of the same classes for each class and take the min (or max after exp).

eq 6 ML

where sig2 is the avg of all sig2 in training. Is equivalent to Simple Accuracy when n_clusters < L and k=1.

4. RepMet Accuracy: Use eq. 6 from RepMet Loss Paper. Equivalent to Magnet Accuracy however takes all clusters (n_clusters) into consideration not just top L. Also doesn't normalise into probability distribution before the arg max.

eq 6 RM

where sig2 is set to 0.5 as in training. Is equivalent to Simple Accuracy when k=1.

5. Unsupervised Accuracy: Run K-Means on set (ie. don't use trained clusters) and then greedily assign classes to clusters based on the class of samples that fall in that cluster.

Test Error can be considered 1-a (1 minus these accuracies)

Results (Coming Soon)

These results are calculated with evaluate.py and reference the accuracy calculations above.

After 1000 iterations with pretrained ResNet18 with last fc layer replaced with a 1024 embedding layer, M=12, D=4, K=3.

Oxford Flowers 102

Simple Error Magnet Error RepMet Error Unsup. Error
Loss Train (Test) Train (Test) Train (Test) Train (Test)
Magnet Loss 00.00 (00.00) 00.00 (00.00) 00.00 (00.00) 00.00 (00.00)
RepMet.v1 00.00 (00.00) 00.00 (00.00) 00.00 (00.00) 00.00 (00.00)
RepMet.v2 00.00 (00.00) 00.00 (00.00) 00.00 (00.00) 00.00 (00.00)
RepMet.v3 00.00 (00.00) 00.00 (00.00) 00.00 (00.00) 00.00 (00.00)
My Loss 00.00 (00.00) 00.00 (00.00) 00.00 (00.00) 00.00 (00.00)

pytorch.repmet_lu3a's People

Contributors

trellixvulnteam avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.