Giter Site home page Giter Site logo

mtan's Introduction

Multi-Task Attention Network (MTAN)

This repository contains the source code of Multi-Task Attention Network (MTAN) and baselines from the paper, End-to-End Multi-Task Learning with Attention, introduced by Shikun Liu, Edward Johns, and Andrew Davison.

Experiments

Image-to-Image Predictions (One-to-Many)

Under the folder im2im_pred, we have provided our proposed network alongside with all the baselines on NYUv2 dataset presented in the paper. All the models were written in pytorch. So please first make sure you have pytorch 1.0 framework or above installed in your machine.

Download our pre-processed NYUv2 dataset here which we evaluated in the paper. We use the pre-computed ground-truth normals from here. The raw 13-class NYUv2 dataset can be directly downloaded in this repo with segmentation labels defined in this repo.

I am sorry that I am not able to provide the raw pre-processing code due to an unexpected computer crash.

Update - Jun 2019: I have now released the pre-processing CityScapes dataset with 2, 7, and 19-class semantic labels (see the paper for more details) and (inverse) depth labels. Download [256x512, 2.42GB] version here and [128x256, 651MB] version here.

Update - Oct 2019: For pytorch 1.2 users: The mIoU evaluation method has now been updated to avoid "zeros issue" from computing binary masks. Also, to correctly run the code, please move the scheduler.step() after calling the optimizer.step(), e.g. one line before the last performance printing step to fit the updated pytorch requirements. See more in the official pytorch documentation here.

All the models (files) are built with SegNet and described in the following table:

File Name Type Flags Comments
model_segnet_single.py Single task, dataroot standard single task learning
model_segnet_stan.py Single task, dataroot our approach whilst applied on one task
model_segnet_split.py Multi weight, dataroot, temp, type multi-task learning baseline in which the shared network splits at the last layer (also known as hard-parameter sharing)
model_segnet_dense.py Multi weight, dataroot, temp multi-task learning baseline in which each task has its own paramter space (also known as soft-paramter sharing)
model_segnet_cross.py Multi weight, dataroot, temp our implementation of the Cross Stitch Network
model_segnet_mtan.py Multi weight, dataroot, temp our approach

For each flag, it represents

Flag Name Usage Comments
task pick one task to train: semantic (semantic segmentation, depth-wise cross-entropy loss), depth (depth estimation, l1 norm loss) or normal (normal prediction, cos-similarity loss) only available in single-task learning
dataroot directory root for NYUv2 dataset just put under the folder im2im_pred to avoid any concerns :D
weight weighting options for multi-task learning: equal (direct summation of all task losses), DWA (our proposal), uncert (our implementation of the Weight Uncertainty Method) only available in multi-task learning
temp hyper-parameter temperature in DWA weighting option to determine the softness of task weighting
type different versions of multi-task baseline split: standard, deep, wide only available in the baseline split

To run any model, cd im2im_pred/ and run python MODEL_NAME.py --FLAG_NAME 'FLAG_OPTION'.

Visual Decathlon Challenge (Many-to-Many)

We have also provided source code for the recently proposed Visual Decathlon Challenge for which we build MTAN based on Wide Residual Network from the implementation here.

To run the code, first download the dataset and devkit at the official Visual Decathlon Challenge website here and put it in the folder visual_decathlon. Then, put decathlon_mean_std.pickle into the folder of the downloaded dataset decathlon-1.0-data.

Finally, run python model_wrn_mtan.py for training python model_wrn_eval.py --dataset 'imagenet' and 'notimagenet' for evaluation and python coco_results.py for COCO format for online evaluation.

Other Notices

  1. The provided code is highly optimised for readability. If you find any unusual behaviour, please post an issue or directly contact my email below.
  2. Training the provided code will result slightly different performances (depending on the type of task) than the reported numbers in the paper for image-to-image prediction tasks. But, the rankings stay the same. If you want to compare any models in the paper for image-to-image prediction tasks, please re-run the model directly with your own training strategies (learning rate, optimiser, etc) and keep all training strategies consistent to ensure fairness. To compare results in Visual Decathlon Challenge, you may directly check out the results in the paper. To compare with your own research, please build your multi-task network with the same backbone architecture: SegNet for image-to-image tasks, and Wide Residual Network for the Visual Decathlon Challenge.
  3. From my personal experience, designing a better architecture is usually more helpful (and easier) than finding a better task weighting in multi-task learning.

Citation

If you found this code/work to be useful in your own research, please considering citing the following:

@inproceedings{liu2019end,
  title={End-to-End Multi-task Learning with Attention},
  author={Liu, Shikun and Johns, Edward and Davison, Andrew J},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  pages={1871--1880},
  year={2019}
}

Contact

If you have any questions, please contact [email protected].

mtan's People

Contributors

lorenmt avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.