Giter Site home page Giter Site logo

goldandrabbit / cross-stitch-networks-for-multi-task-learning Goto Github PK

View Code? Open in Web Editor NEW

This project forked from helloyide/cross-stitch-networks-for-multi-task-learning

0.0 0.0 0.0 2.91 MB

A Tensorflow implementation of the paper arXiv:1604.03539

License: MIT License

Python 100.00%

cross-stitch-networks-for-multi-task-learning's Introduction

Cross-stitch-Networks-for-Multi-task-Learning

This project is a TensorFlow implementation of a Multi Task Learning method described in the paper Cross-stitch Networks for Multi-task Learning.

Arguments

  • --lr, learning rate
  • --n_epoch, number of epoch
  • --n_batch_size, mini batch size
  • --reg_lambda, L2 regularization lambda
  • --keep_prob, Dropout keep probability
  • --cross_stitch_enabled, Use Cross Stitch or not

Dataset

Fashion-MNIST

Fashion-MNIST is a dataset of Zalando's article images, consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes:

Label Description Label Description
0 T-shirt/top 5 Sandal
1 Trouser 6 Shirt
2 Pullover 7 Sneaker
3 Dress 8 Bag
4 Coat 9 Ankle boot

For multi task learning, I created another label for each image, which is based on the original labels:

Label Original Labels Description
0 5, 7, 9 Shoes
1 3, 6, 8 For Women
2 0, 1, 2, 4 Other

The network will train these two classifiers together.

Network

Without task sharing

As a baseline, a network without cross stitch is built, which simply concats two convolutional neural networks side by side. Each network is for one task, although their parameters are not shared. The final loss function is the sum of two loss functions of sub networks.

Here is an overview of this structure:

Network structure without task sharing

Both sub convolutional neural networks have the same architecture:

Layer Output size filter size / stride
conv1 28x28x32 3x3 / 1
pool_1 14x14x32 2x2 / 2
conv2 14x14x64 3x3 / 1
pool_2 7x7x64 2x2 / 2
fc_3 1024
output 10 or 3 depends on task

With Cross Stitch

Cross Stitch is a transformation applied between layers, it describes the relationship between different tasks with a linear combination of their activations.

linear combination

The network should learn the relationship by itself, in comparison with manually tuning the shared network structure, this end-to-end approach works better.

Here is an overview of this structure:

Network strcture with Cross Stitch

The convolutional sub networks have the same architecture as above. As in paper suggested the cross stitch units are only added after Pool layers and Fully Connected layers.

Training

  • The input images are standardized with z-score.
  • L2 regularization is used for convolution layers and fully connected layers, lambda = 1e-5.
  • Dropout has keep_prob = 0.8
  • Batch normalization is used
  • Weights of sub networks are initialized with He initialization
  • Weights of Cross Stitch are initialized with identity matrix (i.e no sharing between tasks at the beginning)
  • Learning rate is set to a constant value 0.001
  • Trained 30 epochs with batch size = 128

Evaluation

The overall accuracy is calculated by averaging the accuracies of all sub tasks.

With cross stitch transformation it gets more than 1% improvement on test dataset.

Orange: without sharing. Blue: with cross stitch. test accuracy total loss

Result

For Fashion-MNIST new labels are created based on the original labels, so two classification tasks are highly related. I also used this technique to build a gender-age classifier with VGGFace2 dataset, which labels are more independent. In both tests cross stitch improves the accuracy. Although this project only trained with two tasks but it can be extended to more tasks easily.

I didn't pretrain the sub networks as in paper suggested and I also used a different initialization strategy. A better result might be found with more tuning.

cross-stitch-networks-for-multi-task-learning's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.