Train face verification model with ArcFace implemented with PyTorch. Models are intended to be used on mobile devices.
This repo is tested with Python 3.7 and PyTorch 1.13. Other packages' versions can be found in requirements.txt
.
This project use 2 datasets:
- MS1Mv3 (also called MS1M-RetinaFace) for training
- LFW (deep funneled) for validation
- Download and extract the dataset here
- In the config file
configs/ms1mv3_mbf.py
, modify theconfig.root_dir
value with the path to the extracted folder containing the MS1Mv3 dataset.
-
Download and extract the deep funneled version here. In the config file
configs/ms1mv3_mbf.py
, modify theimg_dir
value ofconfig.lfwpair_kwargs
with the path to the extracted folder containing LFW images. -
Download the
pairs.txt
here. In the config fileconfigs/ms1mv3_mbf.py
, modify thepairs_txt_path
value ofconfig.lfwpair_kwargs
with the path to thepairs.txt
file.
Run python train.py
. You can modify the config.batch_size
in the config file to fit your GPU memory if needed.
Then run tensorboard --logdir ./lightning_logs
to view the training progress in realtime.
The best checkpoint will be saved in the corressponding folder in ./lightning_logs/
.
mean_target_cosine
: Training metric - Mean value in each batch of cosine similarity between the image and the target class cluster centertrain_loss
: ArcFace losslfw_auroc
: Validation metric - AUC ROC based on predictions of LFW pair images.pos_mean_score
: Validation metric - Mean similarity score of all matching pair imagesneg_mean_score
: Validation metric - Mean similarity score of all mismatch pair images