Giter Site home page Giter Site logo

avilash / pytorch-siamese-triplet Goto Github PK

View Code? Open in Web Editor NEW
89.0 2.0 21.0 76 KB

One-Shot Learning with Triplet CNNs in Pytorch

Python 100.00%
pytorch convolutional-neural-networks deep-learning meta-learning one-shot-learning triplet-loss siamese-network pytorch-implmention mnist fashionmnist

pytorch-siamese-triplet's Introduction

Deep metric learning using Triplet network in PyTorch

The following repository contains code for training Triplet Network in Pytorch
Siamese and Triplet networks make use of a similarity metric with the aim of bringing similar images closer in the embedding space while separating non similar ones.
Popular uses of such networks being -

  • Face Verification / Classification
  • Learning deep embeddings for other tasks like classification / detection / segmentation

Paper - Deep metric learning using Triplet network

Installation

Install PyTorch

pip install -r requirements.txt  

Demo

Colab notebook with pretrained weights

Training

python train.py --cuda  

This by default will train on the MNIST dataset

MNIST / FashionMNIST

python train.py --result_dir results --exp_name MNIST_exp1 --cuda --dataset <manist>/<fmnist>  

To create a tSNE visualisation

python tsne.py --ckp <path to model>  

The embeddings and the labels are stored in the experiment folder as a pickle file, and you do not have to run the model everytime you create a visualisation. Just pass the saved embeddings as the --pkl parameter

python tsne.py --pkl <path to stored embeddings>  

Sample tSNE visualisation on MNIST tSNE

Specify the location of the dataset in test.yaml
The directory should have the following structure

+-- root
|   +-- train
|       +-- class1
|           +-- img1.jpg
|           +-- img2.jpg
|           +-- img3.jpg
|       +-- class2
|       +-- class3
|   +-- test
|       +-- class4
|       +-- class5
python train.py --result_dir results --exp_name VGGFace2_exp1 --cuda --epochs 50 --ckp_freq 5 --dataset vggface2 --num_train_samples 32000 --num_test_samples 5000 --train_log_step 50 

Custom Dataset

Specify the location of the dataset in test.yaml
The directory should have the following structure

+-- root
|   +-- train
|       +-- class1
|           +-- img1.jpg
|           +-- img2.jpg
|           +-- img3.jpg
|       +-- class2
|       +-- class3
|   +-- test
|       +-- class4
|       +-- class5
python train.py --result_dir results --exp_name Custom_exp1 --cuda --epochs 50 --ckp_freq 5 --dataset custom --num_train_samples 32000 --num_test_samples 5000 --train_log_step 50 

TODO

  • Train on MNIST / FashionMNIST
  • Train on a public dataset
  • Multi GPU Training
  • Custom Dataset
  • Include popular models - ResneXT / Resnet / VGG / Inception

pytorch-siamese-triplet's People

Contributors

avilash avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch-siamese-triplet's Issues

opencv BGR bug

I've been browsing the code, and it seems that you use opencv.imread() to read the images, which produces BGR output. But you still use the normalisation parameters for rgb: mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]. You can change BGR to RGB using the following: cv2.cvtColor(img, cv2.COLOR_BGR2RGB). Thanks for your work!

In PyTorch's examples, they read with PIL, which reads in RGB by default, but they use the same constants.
Reference: https://pytorch.org/hub/pytorch_vision_alexnet/

Custom dataset

Do you have any code like mnist.getTriplet for custom data?
Thanks!

How to do inference on an image?

Hi,your work is helpful to me.
I have a question about how to do inference on an image?I couldn't find the relevant code.The tsne.py couldn't load the ckpt.

adding tensorboard

the current codebase does not have tensorboard functionalities. it would be great to add that

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.