Giter Site home page Giter Site logo

guglielmocamporese / hands-segmentation-pytorch Goto Github PK

View Code? Open in Web Editor NEW
93.0 2.0 12.0 2.69 MB

A repo for training and finetuning models for hands segmentation.

Python 92.25% Shell 7.75%
hands-segmentation gtea egocentric-videos egohands-dataset hands-over-face-dataset ego-youtube-hands-dataset hands-datasets transfer-learning semantic-segmentation

hands-segmentation-pytorch's Introduction

Hands Segmentation in PyTorch - A Plug and Play Model

If you need hands segmentations for your project, you are in the correct place!

DOI

If you use the code of this repo and you find this project useful, 
please consider to give a star ⭐!

If you use this repo for your project please cite this project using:

@article{camporese2021HandsSeg,
  title   = "Hands Segmentation is All You Need",
  author  = "Camporese, Guglielmo",
  journal = "https://github.com/guglielmocamporese",
  year    = "2021",
  url     = "https://github.com/guglielmocamporese/hands-segmentation-pytorch"
}

Updates

Direct Usage form Torch Hub

# Imports
import torch
import torch.hub

# Create the model
model = torch.hub.load(
    repo_or_dir='guglielmocamporese/hands-segmentation-pytorch', 
    model='hand_segmentor', 
    pretrained=True
)

# Inference
model.eval()
img_rnd = torch.randn(1, 3, 256, 256) # [B, C, H, W]
preds = model(img_rnd).argmax(1) # [B, H, W]

Results on the Validation and Test Datasets

Predictions on some test images

alt text

Table

Dataset Partition mIoU
EgoYouTubeHands Validation 0.818
EgoYouTubeHands Test 0.816
EgoHands Validation 0.919
EgoHands Test 0.920
HandOverFace Validation 0.814
HandOverFace Test 0.768
GTEA Validation 0.960
GTEA Test 0.949

What you can do with this code

This code provides:

  • A plug and play pretrained model for hand segmentation, either usable directly from torch hub (see the Direct Usage form Torch Hub section) or usable cloning this repo,
  • A collection of 4 different datasets for hands segmentation (see the Datasets section for more details), that can be used for train a hands segmentation model,
  • the scripts for training and evaluating a hand segmentation model (see the Train and Test sections),
  • the scripts for finetuning my pre-trained model, that you can download (see the Model section), for hand segmentation on a custom dataset (see the Finetune section),
  • the scripts for computing hands segmentation maps on unseen (your) custom data, using my pre-trained (or your) model (see the Predict From a Custom Dataset section).

Install Locally

Once you have cloned the repo, all the commands below should be runned inside the main project folder hands folder:

# Clone the repo
$ git clone https://github.com/guglielmocamporese/hands-segmentation-pytorch.git hands

# Go inside the project folder
$ cd hands

To run the code you need to have conda installed (version >= 4.9.2).

Furthermore, all the requirements for running the code are specified in the environment.yml file and can be installed with:

# Install the conda env
$ conda env create --file environment.yml

# Activate the conda env
$ conda activate hands

Datasets

I set up a script scripts/download_datasets.sh that downloads and prepares all the datasets described below into the DATA_BASE_PATH folder, specified in the script itself.

In this project I considered the following datasets for training the model:

  • EgoHands [link]

    • 4800 labeled frames (100 labeled frames from 48 different videos),

    • each frame is 720x1280,

    • 1.3 GB of zip file,

  • EgoYouTubeHands (EYTH) [link]

    • 774 labeled frames,

    • each frame is 216x384,

    • 17 MB of tar.gz file,

  • GTEA (with GTEA GAZE PLUS) [link]

    • 1067 labeled frames,

    • each frame of GTEA is 405x720, each frame of GTEA GAZE PLUS is 720x960,

    • 250 MB of zip file,

  • HandOverFace (HOF) [link]

    • 180 labeled frames,

    • each frame is 384x216,

    • 41 MB of tar.gz file.

Model

I used the PyTorch implementation of DeepLabV3 with ResNet50 backbone. In particular I trained the model for hands segmentation starting from the pretrained DeepLabV3 on COCO train2017.

We provide the code for downloading our model checkpoint:

# Download our pre-trained model
$ ./scripts/download_model_checkpoint.sh

This will download the checkpoint checkpoint.ckpt inside the checkpoint folder.

Predict From a Custom Dataset

With this code you can do inference and compute the predictions starting from a set of custom images, you just have to specify the folder that contains the images in the variable data_base_path in the scripts/predict.sh script.

Each prediction computed from the image path/to/image.jpg will be saved at path/to/image.jpg.png.

You can find an example of a script used for predicting at scripts/predict.sh. I also reported it here:

python main.py \
	--mode predict \
	--data_base_path 'test_images' \
	--model_checkpoint "checkpoint/checkpoint.ckpt" \
	--model_pretrained

Finetuning

An example of script used for finetuning the model is reported in scripts/finetune.sh and reported here:

python main.py \
	--mode train \
	--epochs 10 \
	--batch_size 16 \
	--gpus 1 \
	--datasets 'eyth eh hof gtea' \
	--height 256 \
	--width 256 \
	--data_base_path 'data' \
	--model_checkpoint 'checkpoint/checkpoint.ckpt'
	--model_pretrained

Train

An example of script used for training the model is reported in scripts/train.sh and reported here:

python main.py \
	--mode train \
	--epochs 50 \
	--batch_size 16 \
	--gpus 1 \
	--datasets 'eyth eh hof gtea' \
	--height 256 \
	--width 256 \
	--data_base_path 'data' \
	--model_pretrained

Test

An example of script used for testing the model is reported in scripts/test.sh and reported here:

python main.py \
	--mode test \
	--data_base_path "data" \
	--model_pretrained \
	--model_checkpoint "checkpoint/checkpoint.ckpt"

Extra

Working with Grayscale Inputs

Working with Grayscale Inputs

If you need to work with grayscale images, you just need to:

  • Add the in_channels=1 argument to either the training, validation, test, predict command,
  • Use the model_checkpoint="checkpoints/checkpoint-grayscale.ckpt" as the model checkpoint.

Results with grayscale inputs

The model that uses grayscale inputs has been trained with all the datasets available above in this project, with all the images converted from RGB to grayscale.

Predictions on some test images

alt text

Here you can find the results on the validation and test sets using the grayscale model evaluated on the grayscale datasets.

Dataset Input Partition mIoU
EgoYouTubeHands Grayscale Validation 78.49%
EgoYouTubeHands Grayscale Test 79.36%
EgoHands Grayscale Validation 90.31%
EgoHands Grayscale Test 90.32%
HandOverFace Grayscale Validation 81.98%
HandOverFace Grayscale Test 74.50%
GTEA Grayscale Validation 94.89%
GTEA Grayscale Test 94.01%
Working with RGBD Inputs

Working with RGBD Inputs

If you need to work with grayscale images, you just need to add the in_channels=4 argument to either the training, validation, test, predict command.

However, at this time there are no available models in this project pre-trained on RGBD datasets.

hands-segmentation-pytorch's People

Contributors

guglielmocamporese avatar nconn711 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

hands-segmentation-pytorch's Issues

Model download problem

Hi. I downloaded it the way you did

Download our pre-trained model

$ ./scripts/download_model_checkpoint.sh
Unable to download the file, could you please provide a link to Google cloud disk for me to download
thank you

What datasets are used for training?

Hello, I find your pretrained model really works on many cases. I wonder how the pretrained model provided in the torch hub is trainied? Specifically, what datasets are used? Do you use all the four datasets (EgoYouTubeHands, EgoHands, GTEA, HandOverFace) for training?

Problem with download_model_checkpoint.sh

Running main.py after running download_model_checkpoint.sh, results in an error message "_pickle.UnpicklingError: invalid load key, '<'. ".
It looks like the script is downloading the HTML file of a page that displays a warning that the file is too large to run a virus scan, instead of the actual model file.
Downloading the model file directly from the URL works as expected.

repo usage

Hi,

More of a comment than an issue! Used this repo in our winning solution (3rd place) accepted at CVPRW 2022, arXiv and full paper. We cited the repo in this way.

Guglielmo Camporese. Hands Segmentation in PyTorch - A Plug and Play Model, 2021. Available at https://github.com/guglielmocamporese/hands-segmentation-pytorch.

Consider adding them in the repo. Thanks!

prediction not maintaining the input resolution

Hello,
After running the main.py with predict, I've encountered a new problem where it produces an additional image with a black mask.

Could you kindly provide guidance on any additional steps or configurations necessary to ensure that the script generates the outputs accurately?

Sure, I'll be happy to give a star to the project repo. Keep up the great work!

Thank you once again for your invaluable support. Your assistance in resolving this matter would be immensely appreciated.

Predict black images

Hello, why the model predicts a black image?
frame_0011 jpg
If I pass as input a jpg image:
frame_0011

Same behaviour with other images.

I downloaded the pretrained model using these lines of code you provided in the README:

# Imports
import torch
import torch.hub

# Create the model
model = torch.hub.load(
    repo_or_dir='guglielmocamporese/hands-segmentation-pytorch',
    model='hand_segmentor',
    pretrained=True
)

Then I created the dir "checkpoint" on your project directory and put the model there.
Finally, I executed the main passing the predict parameters:

python main.py \
	--mode predict \
	--data_base_path './test_images' \
	--model_checkpoint "./checkpoint/checkpoint.ckpt" \
	--model_pretrained

Note that I'm using Windows, that's why I didn't run the sh(s) you provided.

Thanks for your time!

Can you share your environment.yaml file?

I'm trying to use your work but I don't know what packages I have to install in the conda environment.
Could you share the environment.yaml file?

Thanks in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.