Giter Site home page Giter Site logo

xydxdy / triplet-loss-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from alfonmedela/triplet-loss-pytorch

0.0 0.0 0.0 183 KB

Highly efficient PyTorch version of the Semi-hard Triplet loss ⚡️

Home Page: https://alfonsomedela.com

License: Apache License 2.0

Python 100.00%

triplet-loss-pytorch's Introduction

Triplet SemiHardLoss

PyTorch semi hard triplet loss. Based on tensorflow addons version that can be found here. There is no need to create a siamese architecture with this implementation, it is as simple as following main_train_triplet.py cnn creation process!

The triplet loss is a great choice for classification problems with N_CLASSES >> N_SAMPLES_PER_CLASS. For example, face recognition problems.

The CNN architecture we use with triplet loss needs to be cut off before the classification layer. In addition, a L2 normalization layer has to be added.

Results on MNIST

I tested the triplet loss on the MNIST dataset. We can't compare directly to TF addons as I didn't run the experiment but this could be interesting from the point of view of performance. Here are the training logs if you want to compare results. Accuracy is not relevant and shouldn't be there as we are not training a classification model.

Phase 1

First we train last layer and batch normalization layers, getting close to 0.079 validation loss.

Phase 2

Finally, unfreezing all the layers it is possible to get close to 0.05 with enough training and hyperparmeter tuning.

Test

In order to test, there are two interesting options, training a classification model on top of the embeddings and plotting the train and test embeddings to see if same categories cluster together. The following figure contains the original 10,000 validation samples.

TSNE

We get an accuracy around 99.3% on validation by training a Linear SVM or a simple kNN. This repository is not focused on maximizing this accuracy by tweaking data augmentation, arquitecture and hyperparameters but on providing an effective implementation of triplet loss in torch. For more info on the state-of-the-art results on MNIST check out this amazing kaggle discussion.

Contact me with any question: [email protected] | alfonsomedela.com

Watch my latest TEDx talk: The medicine of the future

Foo

Donations ₿

BTC Wallet: 1DswCAGmXYQ4u2EWVJWitySM7Xo7SH4Wdf

IMPORTANT

If you're using fastai library, it will return an error when predicting the embeddings with learn.predict. It internally knows that your data has N classes and if the embedding vector has M dimensions, beeing M>N, and the predicted highest value is larger than N, that class does not exist and returns an error. So either create your prediction function or make a simple modification of the source code that will modify self.classes list length.

triplet-loss-pytorch's People

Contributors

alfonmedela avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.