Giter Site home page Giter Site logo

inpaint's Introduction

Inpainting Codebase

This repo implements a simple PyTorch codebase for training inpainting models with powerful tools including Docker, PyTorchLightning, and Hydra.

Currently, only DFNet is supported. More methods as well as some additional useful utilities for image inpainting will be implemented.

Prerequirements

We use docker to run all experiemnts.

Features

  • PytorchLightning
    • logging (tensorboard, csv)
    • checkpoint
    • DistributedDataParallel
    • mixed-precision
  • Hydra
    • flexible configuration system
    • logging (stream to file, folder structure)
  • Others
    • save sample results

Setup

Build the environment

Build the image for the first time:

python core.py env prepare

Explaination:

  • When you first run this command, you will be asked to give three items:
    1. a project name,
    2. the root folder of your train log,
    3. the root folder of your datasets,
  • then an image is built based on /env/Dockerfile,
  • and at last, a container is launched based on docker-compose.yml

The defualt setting of docker-compose.yml is shown as below, you can modify this setting before building accordingly:

version: "3.9"
services:
    lab:
        container_name: ${PROJECT}
        runtime: nvidia
        build:
            context: env/
            dockerfile: Dockerfile
            args:
                - USER_ID=${UID}
                - GROUP_ID=${GID}
                - USER_NAME=${USER_NAME}
        image: pytorch181_local
        environment:
            - TZ=Asia/Shanghai
            - TORCH_HOME=/data/torch_model
        ipc: host
        hostname: docker
        working_dir: /code
        command: ['sleep', 'infinity']
        volumes:
            - ${CODE_ROOT}:/code
            - ${DATA_ROOT}:/data
            - ${LOG_ROOT}:/outputs

Get into the environment

Simply run:

python core.py env

The default user is the same as the host to avoid permission issues. And of course you can enter the container with root:

python core.py env --root

Modify the environment at anytime

Basiclly, the environment are determined by four items:

  1. /env/Dockerfile defines the logic of building the local docker image. For example, installing packages defined in requirements.txt based on deepbase/pytorch:latest.
  2. Base docker image. From /env/Dockerfile, you can find deepbase/pytorch is the base image. The original Dockerfile of the base image is at deepcodebase/docker. You can change the base image as whatever you like.
  3. /env/requirements.txt defines the python packages you want to install in the local docker image.
  4. /docker-compose.yml defines the setting of running the container. For example, the volumes, timezone, etc.

After changing the settings as you want at anytime, you can rebuild the local image by running:

python core.py env prepare --build

Training

Data Prepartion

  1. Image data: any image data you like. e.g. Places2, ImageNet, etc. Place your dataset into your DATAROOT in your local machine (mapped to docker's /data). For example: DATAROOT/places2 is used for training by default.
  2. Masks: you can download and use free-form-mask. Decompress the file and place mask under DATAROOT.
  3. By default, inside the environment, you need to have places2 and mask under /data.
  4. If you use other datasets, remember to modify the settings especially the data location under conf/dataset.

Running

After entering the environment, you can launch training. Example training commands:

python train.py
python train.py mode=run pl_trainer.gpus=\'3,4\' logging.wandb.notes="tune model"
python train.py +experiment=k80 mode=run logging.wandb.tags='[k80]'

This project use wandb for logging by default, it will prompt if you run training the first time:

wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice:

just follow the steps and wandb is convenient and easy use. If you wan't to use tensorboard instead, just add the flag when running:


python train.py logging=tensorboard

Reading Suggestions

Reading the offical documents of Hydra and PyTorchLightning to know more about:

  • Hydra: Very powerful and convenient configuration system and more.
  • PyTorchLightning: You almost only need to write codes for models and data. Say goodbye to massive codes for pipelines, mixed precision, logging, etc.

Results

Results of DFNet

Training on Places2 with 20 epochs.

License

Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.

inpaint's People

Contributors

hughplay avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

chenzhaiyu

inpaint's Issues

ValueError: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch

Dear sir,
I appreciate your work but I got this error even I followed all the steps

╭─────────────────────────────────── Config ───────────────────────────────────╮
│ 1 model: │
│ 2 target: model.dfnet.DFNet │
│ 3 c_img: 3 │
│ 4 c_mask: 1 │
│ 5 c_alpha: 3 │
│ 6 mode: nearest │
│ 7 norm: batch │
│ 8 act_en: relu │
│ 9 act_de: leaky_relu │
│ 10 en_ksize: │
│ 11 - 7 │
│ 12 - 5 │
│ 13 - 5 │
│ 14 - 3 │
│ 15 - 3 │
│ 16 - 3 │
│ 17 - 3 │
│ 18 - 3 │
│ 19 de_ksize: │
│ 20 - 3 │
│ 21 - 3 │
│ 22 - 3 │
│ 23 - 3 │
│ 24 - 3 │
│ 25 - 3 │
│ 26 - 3 │
│ 27 - 3 │
│ 28 blend_layers: │
│ 29 - 0 │
│ 30 - 1 │
│ 31 - 2 │
│ 32 - 3 │
│ 33 - 4 │
│ 34 - 5 │
│ 35 loss: │
│ 36 target: loss.InpaintLoss │
│ 37 c_img: 3 │
│ 38 w_l1: 6.0 │
│ 39 w_percep: 0.1 │
│ 40 w_style: 240.0 │
│ 41 w_tv: 0.1 │
│ 42 structure_layers: │
│ 43 - 0 │
│ 44 - 1 │
│ 45 - 2 │
│ 46 - 3 │
│ 47 - 4 │
│ 48 - 5 │
│ 49 texture_layers: │
│ 50 - 0 │
│ 51 - 1 │
│ 52 - 2 │
│ 53 optim: │
│ 54 target: torch.optim.AdamW │
│ 55 lr: 0.002 │
│ 56 scheduler: │
│ 57 target: torch.optim.lr_scheduler.CosineAnnealingLR │
│ 58 T_max: 20 │
│ 59 dataset: │
│ 60 target: dataset.inpaint.InpaintDataModule │
│ 61 name: places2 │
│ 62 data: │
│ 63 train_dir: /data/places2/data_large │
│ 64 val_dir: /data/places2/data_large │
│ 65 test_dir: /data/places2/test_large │
│ 66 mask: │
│ 67 train_dir: /data/mask │
│ 68 val_dir: /data/mask │
│ 69 test_dir: /data/mask │
│ 70 batch_size: 32 │
│ 71 num_workers: 4 │
│ 72 pin_memory: true │
│ 73 pipeline: │
│ 74 target: pipeline.inpainter.LitInpainter │
│ 75 callbacks: │
│ 76 checkpoint: │
│ 77 target: pytorch_lightning.callbacks.ModelCheckpoint │
│ 78 dirpath: log/checkpoints │
│ 79 filename: '{epoch}-{step}-{val_loss:.2f}' │
│ 80 save_last: true │
│ 81 monitor: val_loss │
│ 82 save_top_k: 2 │
│ 83 verbose: true │
│ 84 mode: min │
│ 85 lr_monitor: │
│ 86 target: pytorch_lightning.callbacks.LearningRateMonitor │
│ 87 mode: debug │
│ 88 pl_trainer: │
│ 89 accelerator: ddp │
│ 90 precision: 16 │
│ 91 amp_backend: native │
│ 92 amp_level: O2 │
│ 93 profiler: simple │
│ 94 weights_summary: top │
│ 95 deterministic: true │
│ 96 max_epochs: 20 │
│ 97 gpus: -1 │
│ 98 default_root_dir: ./ │
│ 99 val_save: │
│ 100 n_image_per_batch: 8 │
│ 101 n_save: 11 │
│ 102 run_test: false │
│ 103 seed: 2021 │
│ 104 find_unused_parameters: true │
│ 105 │
╰──────────────────────────────────────────────────────────────────────────────╯

[2021-08-12 22:08:09,938][pytorch_lightning.utilities.seed][INFO] - Global seed set to 2021
[2021-08-12 22:08:11,047][pytorch_lightning.utilities.distributed][INFO] - GPU available: True, used: True
[2021-08-12 22:08:11,047][pytorch_lightning.utilities.distributed][INFO] - TPU available: None, using: 0 TPU cores
[2021-08-12 22:08:11,047][pytorch_lightning.trainer.connectors.accelerator_connector][INFO] - Using native 16bit precision.
[2021-08-12 22:08:11,053][pytorch_lightning.accelerators.gpu][INFO] - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[2021-08-12 22:08:11,054][pytorch_lightning.loggers.tensorboard][WARNING] - Missing logger folder: ./lightning_logs
[2021-08-12 22:08:11,056][pytorch_lightning.utilities.seed][INFO] - Global seed set to 2021
[2021-08-12 22:08:11,056][pytorch_lightning.plugins.training_type.ddp][INFO] - initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
[2021-08-12 22:08:11,057][root][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0
[2021-08-12 22:08:12,914][pytorch_lightning.core.lightning][INFO] -
| Name | Type | Params

0 | model | DFNet | 32.9 M
1 | loss | InpaintLoss | 1.7 M

32.9 M Trainable params
1.7 M Non-trainable params
34.6 M Total params
138.464 Total estimated model params size (MB)
Error executing job with overrides: []
Traceback (most recent call last):
File "train.py", line 43, in
main()
File "/usr/local/lib/python3.8/site-packages/hydra/main.py", line 49, in decorated_main
_run_hydra(
File "/usr/local/lib/python3.8/site-packages/hydra/_internal/utils.py", line 367, in _run_hydra
run_and_report(
File "/usr/local/lib/python3.8/site-packages/hydra/_internal/utils.py", line 214, in run_and_report
raise ex
File "/usr/local/lib/python3.8/site-packages/hydra/_internal/utils.py", line 211, in run_and_report
return func()
File "/usr/local/lib/python3.8/site-packages/hydra/_internal/utils.py", line 368, in
lambda: hydra.run(
File "/usr/local/lib/python3.8/site-packages/hydra/_internal/hydra.py", line 110, in run
_ = ret.return_value
File "/usr/local/lib/python3.8/site-packages/hydra/core/utils.py", line 233, in return_value
raise self._return_value
File "/usr/local/lib/python3.8/site-packages/hydra/core/utils.py", line 160, in run_job
ret.return_value = task_function(task_cfg)
File "train.py", line 36, in main
trainer.fit(model, datamodule)
File "/usr/local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 499, in fit
self.dispatch()
File "/usr/local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 546, in dispatch
self.accelerator.start_training(self)
File "/usr/local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 73, in start_training
self.training_type_plugin.start_training(trainer)
File "/usr/local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 114, in start_training
self._results = trainer.run_train()
File "/usr/local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 607, in run_train
self.run_sanity_check(self.lightning_module)
File "/usr/local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 846, in run_sanity_check
self.reset_val_dataloader(ref_model)
File "/usr/local/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py", line 364, in reset_val_dataloader
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
File "/usr/local/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py", line 325, in _reset_eval_dataloader
num_batches = len(dataloader) if has_len(dataloader) else float('inf')
File "/usr/local/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py", line 33, in has_len
raise ValueError('Dataloader returned 0 length. Please make sure that it returns at least 1 batch')
ValueError: Dataloader returned 0 length. Please make sure that it returns at least 1 batch

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.