Giter Site home page Giter Site logo

hfsoftmax's Introduction

Accelerated Training for Massive Classification via Dynamic Class Selection (HF-Softmax) pdf

Training

  1. Install PyTorch. (Better to install the latest master from source)
  2. Follow the instruction of InsightFace and download training data.
  3. Decode the data(.rec) to images and generate training/validation list.
python tools/rec2img.py --in-folder xxx --out-folder yyy
  1. Try normal training. It uses torch.nn.DataParallel(multi-thread) for parallel.
sh scripts/train.sh dataset_path
  1. Try sampled training. It uses one GPU for training and default sampling number is 1000.
python paramserver/paramserver.py
sh scripts/train_hf.sh dataset_path

Distributed Training

For distributed training, there is one process on each GPU.

Some backends are provided for PyTroch Distributed training. If you want to use nccl as backend for distributed training, please follow the instructions to install NCCL2.

You can test your distributed setting by executing

sh scripts/test_distributed.sh

When NCCL2 is installed, you should re-compile PyTorch from source.

python setup.py clean install

In our case, we use libnccl2=2.2.13-1+cuda9.0 libnccl-dev=2.2.13-1+cuda9.0 and the master of PyTorch 0.5.0a0+e31ab99

Hashing Forest

We use Annoy to approximate the hashing forest. You can adjust sample_num, ntrees and interval to balance performance and cost.

Parameter Sever

Parameter server is decoupled with PyTorch. A client is developed to communicate with the server. Other platforms can integrate the parameter server via the communication API. Currently, it only supports syncronized SGD updater.

Evaluation

./scripts/eval.sh xxx.pth.tar dataset_path outputs

It uses torch.nn.DataParallel to extract features and saves it as .npy. The features will subsequently be used to perform the verification test.

If you use distributed training, set strict=False during feature extraction.

Note that the bin file from InsightFace, lfw.bin for example, is pickled by Python2. It cannot be processed by Python 3.0+. You can either use Python2 for evaluation or re-pickle the bin file by Python3 first.

Citation

Please cite the following paper if you use this repository in your reseach.

@inproceedings{zhang2018accelerated,
  title     = {Accelerated Training for Massive Classification via Dynamic Class Selection},
  author    = {Xingcheng Zhang and Lei Yang and Junjie Yan and Dahua Lin},
  booktitle = {AAAI},
  year      = {2018},
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.