Giter Site home page Giter Site logo

ait's Introduction

All in Tokens: Unifying Output Space of Visual Tasks via Soft Token

PWC

By Jia Ning*, Chen Li*, Zheng Zhang*, Zigang Geng, Qi Dai, Kun He, Han Hu

Introduction

AiT is initially described in arxiv, which is a framework to unify the output space of visual tasks. We demonstrate a single unified model that simultaneously handles two typical visual tasks of instance segmentation and depth estimation, which have discrete/fixed-length and continuous/varied-length outputs, respectively. We propose several new techniques that take into account the particularity of visual tasks: 1) Soft tokens. We employ soft tokens to represent the task output. Unlike hard tokens in the common VQ-VAE which are assigned one-hot to discrete codebooks/vocabularies, the soft tokens are assigned softly to the codebook embeddings. Soft tokens can improve the accuracy of both the next token inference and decoding the task output; 2) Mask augmentation. Many visual tasks have corruption, undefined or invalid values in label annotations, i.e., occluded area of depth maps. We show that a mask augmentation technique can greatly benefit these tasks. With these new techniques and other designs, we show that the proposed general-purpose task solver can perform both instance segmentation and depth estimation well. Particularly, we achieve 0.275 RMSE on the specific task of NYUv2 depth estimation, setting a new record on this benchmark.

teaser

Results and Models

Results on COCO instance segmentation

Model
Box AP Mask AP VQ-VAE Model Task-Solver Model
AiT(SwinV2-B) 43.3 34.2 vqvae_insseg.pt model
AiT(SwinV2-B) w/o soft token 43.6 31.1(-3.1) vqvae_insseg.pt model

Results on NYUv2 depth estimation

Model
D1 D2 D3 Abs Rel RMSE Log10 VQ-VAE
Model
Task-Solver
Model
AiT(SwinV2-B) 0.934 0.991 0.998 0.087 0.305 0.037 vqvae_depth.pt model
AiT-P(SwinV2-B) 0.940 0.992 0.998 0.085 0.301 0.036 vqvae_depth.pt model
AiT(SwinV2-B) w/o soft token 0.932 0.991 0.998 0.089 0.318 0.038 vqvae_depth.pt model
AiT(SwinV2-L) 0.949 0.993 0.999 0.079 0.284 0.034 vqvae_depth.pt model
AiT-P(SwinV2-L) 0.954 0.994 0.999 0.076 0.275 0.033 vqvae_depth.pt model

Joint training results on COCO and NYUv2

Model
Box AP Mask AP RMSE VQ-VAE Model Task-Solver
Model
AiT(SwinV2-B) 42.2 34.1 0.310 vqvae_depth.pt/vqvae_insseg.pt model

Usage

Installation

We recommend using pytorch>=1.10, other packages can be found in requirements.txt. To install boundary-iou-api, please using the following command:

git clone https://github.com/bowenc0221/boundary-iou-api && cd boundary-iou-api && pip install -e .

Data/Pre-training model Preparation

  1. Download the NYU Depth V2 dataset, COCO datasets, our preprocess box-cropped binary instance masks, named maskcoco, and organize the data according to the following directory structure:
AiT
├── ait
├── vae
├── data
│   ├── coco
│   │   ├── annotations
│   │   ├── train2017
│   │   ├── val2017
│   │   ├── test2017
│   ├── maskcoco
│   ├── nyu_depth_v2
  1. Create the data links using following commands:
ln -s data ait/data
ln -s data vae/data
  1. Download pre-trained backbone models swin_v2_base_densesimmim.pth and swin_v2_large_densesimmim.pth.

Training

Training VQ-VAE on depth estimation:

cd vae
python -m torch.distributed.launch --nproc_per_node=${N_GPUS} train_depth_vqvae_dist.py  configs/depth/ait_depth_vqvae.py --cfg-options <custom-configs>

Training VQ-VAE on instance segmentation:

cd vae
python -m torch.distributed.launch --nproc_per_node=${N_GPUS} train_insseg_vqvae_dist.py  configs/insseg/ait_insseg_vqvae.py --cfg-options <custom-configs>

Training task-solver on depth estimation:

cd ait

# Train auto-regressive model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_depthonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt # for AR training

# Train parallel model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_parallel_depthonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt # for parallel training

Training task-solver on object detection

cd ait
python -m torch.distributed.launch --nproc_per_node=16 --nnodes=2 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} code/train.py configs/swinv2b_640reso_detonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth

Note: We use the pre-trainined object detection model to initialize the instance segmentation models and joint-training models to save training cost, please download the pre-trained model (ait_det_swinv2b_wodec.pth) before training on instance segmentation and joint training setting.

Training task-solver on instance segmentation

python -m torch.distributed.launch --nproc_per_node=16 code/train.py configs/swinv2b_640reso_inssegonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt load_from=ait_det_swinv2b_wodec.pth

Joint training on instance segmentation and depth estimation

python -m torch.distributed.launch --nproc_per_node=16 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} code/train.py configs/swinv2b_640reso_joint.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt load_from=ait_det_swinv2b_wodec.pth

Inference

Evaluate on depth estimation

cd ait

# Evaluating auto-regressive model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_depthonly.py  --cfg-options model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval <model_checkpiont>

# Evaluating parallele model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_parallel_depthonly.py  --cfg-options model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval <model_checkpiont>

Evaluate on instance segmentation

cd ait

python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_640reso_inssegonly.py --cfg-options model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt --eval <model_checkpiont>

Evaluate on both depth estimation and instance segmentation

cd ait

python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_640reso_joint.py --cfg-options model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval <model_checkpiont>

Citation

@article{ning2023all,
  title={All in Tokens: Unifying Output Space of Visual Tasks via Soft Token},
  author={Ning, Jia and Li, Chen and Zhang, Zheng and Geng, Zigang and Dai, Qi and He, Kun and Hu, Han},
  journal={arXiv preprint arXiv:2301.02229},
  year={2023}
}

ait's People

Contributors

ancientmooner avatar hust-nj avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ait's Issues

Small typo

Hi, great work! I just noticed a small typo :
In the inference section of the readme, the supposedly <model_checkpoint> is written <model_checkpiont>

Some problem with visualizing the depth of pred and gt.

Thanks for your work. I meet some problems with visualizing the depth of pred and gt. Here is the location to visualize them in

for pred_d, depth_gt in results:
pred_crop, gt_crop = cropping_img(pred_d, depth_gt)
computed_result = eval_depth(pred_crop, gt_crop)

    for pred_d, depth_gt in results:
        '''visualize 'pred_d'''
        pred_crop, gt_crop = cropping_img(pred_d, depth_gt)
         ''' After reshaping, visualize 'pred_crop, gt_crop'''
        computed_result = eval_depth(pred_crop, gt_crop)

this is cmd:
CUDA_VISIBLE_DEVICES=5,6,7 python -m torch.distributed.launch --nproc_per_node=3 code/train.py configs/swinv2b_480reso_depthonly.py --cfg-options model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval ait_joint_swinv2b.pth

However, the results of pred_d,pred_crop and gt_crop are very similar. The results of them are like this picture[The picture is almost white]
Screenshot 2023-05-14 at 7 39 59 PM

Training time

Hi, interesting work! Can you share the approximate time to train the VQVAE and the task solver on both tasks? Thanks!

Ask for models and data

The model weights and data you released are inaccessible. Can you please make these weights and data publicly available again? Very much looking forward to your response!

'PublicAccessNotPermitted' when download the checkpoints

Hi, thank you for the excellent work!
I come across some troubles when I download the checkpoints using wget, it raises an error 'PublicAccessNotPermitted'. I would like to know how to download them properly, especially the pre-trained backbone models.
Thank you in advance!

Single Image Inference

How can i perform inferencing with my custom set of images? What changes do I need to do for data pre processing? Do I need to change val dict under data in AiT/ait/configs/swinv2b_480reso_depthonly.py ?

Error(s) in loading state_dict for VQVAE

Thank you for your nice work!
However, after training VA-VAE on depth estimation, I tried to train task-solver on depth estimation, the following error comes out:

Error(s) in loading state_dict for VQVAE:
        Missing key(s) in state_dict: "encoder.0.weight", "encoder.0.bias", "encoder.2.weight", "encoder.2.bias", "encoder.4.weight", "encoder.4.bias", "encoder.6.weight", "encoder.6.bias", "encoder.8.weight", "encoder.8.bias", "encoder.10.net.0.weight", "encoder.10.net.0.bias", "encoder.10.net.2.weight", "encoder.10.net.2.bias", "encoder.10.net.4.weight", "encoder.10.net.4.bias", "encoder.11.net.0.weight", "encoder.11.net.0.bias", "encoder.11.net.2.weight", "encoder.11.net.2.bias", "encoder.11.net.4.weight", "encoder.11.net.4.bias", "encoder.12.weight", "encoder.12.bias", "decoder.0.weight", "decoder.0.bias", "decoder.2.net.0.weight", "decoder.2.net.0.bias", "decoder.2.net.2.weight", "decoder.2.net.2.bias", "decoder.2.net.4.weight", "decoder.2.net.4.bias", "decoder.3.net.0.weight", "decoder.3.net.0.bias", "decoder.3.net.2.weight", "decoder.3.net.2.bias", "decoder.3.net.4.weight", "decoder.3.net.4.bias", "decoder.4.weight", "decoder.4.bias", "decoder.6.weight", "decoder.6.bias", "decoder.8.weight", "decoder.8.bias", "decoder.10.weight", "decoder.10.bias", "decoder.12.weight", "decoder.12.bias", "decoder.14.weight", "decoder.14.bias", "_vq_vae._embedding", "_vq_vae._ema_cluster_size", "_vq_vae._ema_w".

How can I solve it? Thank you.

denorm twice in eval_coco.py

Hello! I find that /vae/utils/eval_coco.py denorm the reconstruction image twice in line 45.

if hasattr(vae, 'get_codebook_indices'):
                code = vae.get_codebook_indices(mask)
                remask = vae.decode(code)[0, 0, :, :].cpu().numpy() * 0.5 + 0.5 # why denorm here?

because in class func decode, the attr use_norm is True, so decode will denorm the image, but the code denorm after decodeing.
I will try to investigate the effect when evaluating.

Unable to evaluate the results

Hello,

I am trying to run these models to evaluate the results, however I am not able to do that due to errors at runtime.

The best "result" I could get is by with this Dockerfile (at the root of the project):

FROM nvidia/cuda:11.4.3-cudnn8-devel-ubuntu18.04

ARG DEBIAN_FRONTEND=noninteractive
ENV TZ=Etc/UTC
ENV LC_ALL=C.UTF-8
ENV LANG=C.UTF-8

# Install system dependencies
RUN apt-get update && \
    apt-get install -y \
    git \
    wget \
    python3-pip \
    python3-dev \
    python3-opencv \
    python3-six

RUN python3 -m pip install --upgrade pip

RUN pip3 install setuptools openmim

# Install PyTorch and torchvision
RUN pip3 install torch torchvision torchaudio -f https://download.pytorch.org/whl/cu111/torch_stable.html
RUN python3 -m pip install h5py albumentations tensorboardX gdown scipy

RUN python3 -m mim install mmcv

# Upgrade pip

WORKDIR /

RUN wget http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat -O nyu_depth_v2_labeled.mat

RUN git clone https://github.com/vinvino02/GLPDepth.git --depth 1

RUN mv GLPDepth/code/utils/logging.py GLPDepth/code/utils/glp_depth_logging.py


# Set the working directory
WORKDIR /app


RUN python3 ../GLPDepth/code/utils/extract_official_train_test_set_from_mat.py ../nyu_depth_v2_labeled.mat ../GLPDepth/datasets/splits.mat ./data/nyu_depth_v2/official_splits/


# RUN ln -s data ait/data


COPY requirements.txt requirements.txt

RUN python3 -m pip install -r requirements.txt

COPY . .

RUN rm -rf .git

Built the Dockerfile with:

sudo docker build -t mde . -f Dockerfile

And run with:

sudo docker run --name mde-test --gpus all --ipc=host -it --rm -v $(pwd):/app mde

Finally running the evaluation command. For example:

cd ait
python3 -m torch.distributed.launch --nproc_per_node=1 code/train.py configs/swinv2b_480reso_parallel_depthonly.py  --cfg-options model.task_heads.depth.vae_cfg.pretrained=../models/vqvae_depth_2bp.pt --eval ../models/ait_depth_swinv2b_parallel.pth

In this way, the inference process is launched, eventually an anonymous error happen:

eval task depth
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 654/654, 2.5 task/s, elapsed: 262s, ETA:     0sERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -9) local_rank: 0 (pid: 34) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/launch.py", line 193, in <module>
    main()
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/launch.py", line 189, in main
    launch(args)
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/launch.py", line 174, in launch
    run(args)
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/run.py", line 713, in run
    )(*cmd_args)
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/launcher/api.py", line 261, in launch_agent
    failures=result.failures,
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
===================================================
code/train.py FAILED
---------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
---------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-08-26_03:01:18
  host      : f50427e7ad50
  rank      : 0 (local_rank: 0)
  exitcode  : -9 (pid: 34)
  error_file: <N/A>
  traceback : Signal 9 (SIGKILL) received by PID 34
===================================================

Are the authors able to provide the versions of all the software they are using? In particular:

  • Linux version and distribution
  • CUDA version
  • Python version
  • Packages version (in the requirements, some versions are missing)
  • Any other relevant information about

Thanks.

Swin-S and Swin-Ti weights

Thank you for releasing your code! I am wondering if you happen to have any pre-trained checkpoints for Swin-S and Swin-Ti? or even just the ImageNet-1k weights. The ImageNet-1k pre-trained weights would be more preferable, as I can't seem to find these released anywhere with matching sizes.

Thanks!

train/visualize on single GPU

Hello!
I am trying to evaluate it by one GPU,but found a lot of errors.
I am new in these,do you have the code for a single GPU?
Best wishes

There is a bug in dataset maybe. Might cause over-fitting maybe.

Thanks for yours sharing.

    transform = [
        A.Crop(x_min=41, y_min=0, x_max=601, y_max=480),
        A.HorizontalFlip(),
        A.RandomCrop(crop_size[0], crop_size[1]),
    ]

In dataset./nyudepthv2.py , i found you cropped image to (480,480)[fixed region], after that a randomcrop is used.
Maybe albumentations could change the transform sequence?
I am not sure.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.