Giter Site home page Giter Site logo

sage's Introduction

Stable Attribute Group Editing for Reliable Few-shot Image Generation


Overall framework of SAGE.

Description

Official implementation of SAGE for few-shot image generation. Our code is modified from pSp. We also provide the diffusion verison of SAGE

Getting Started

Prerequisites

  • Linux
  • NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
  • Python 3

Installation

  • Clone this repo:
git clone https://github.com/UniBester/SAGE.git
cd SAGE
  • Dependencies:
    We recommend running this repository using Anaconda. All dependencies for defining the environment are provided in environment/environment.yaml.

Pretrained pSp

Here, we use pSp to find the latent code of real images in the latent domain of a pretrained StyleGAN generator. Follow the instructions to train a pSp model firsly. Or you can also directly download the pSp pre-trained models we provide.

Training

Preparing your Data

  • You should first download the Animal Faces / Flowers / VggFaces / NABirds and organize the file structure as follows:

    └── data_root
        ├── train                      
        |   ├── cate-id_sample-id.jpg                # train-img
        |   └── ...                                  # ...
        └── valid                      
            ├── cate-id_sample-id.jpg                # valid-img
            └── ...                                  # ...
    

    Here, we provide organized Animal Faces dataset as an example:

    └── data_root
      ├── train                      
      |   ├── n02085620_25.JPEG_238_24_392_167.jpg              
      |   └── ...                                
      └── valid                      
          ├── n02093754_14.JPEG_80_18_239_163.jpg           
          └── ...                                             
    
  • Currently, we provide support for numerous datasets.

    • Refer to configs/paths_config.py to define the necessary data paths and model paths for training and evaluation.
    • Refer to configs/transforms_config.py for the transforms defined for each dataset.
    • Finally, refer to configs/data_configs.py for the data paths for the train and valid sets as well as the transforms.
  • If you wish to experiment with your own dataset, you can simply make the necessary adjustments in

    1. data_configs.py to define your data paths.
    2. transforms_configs.py to define your own data transforms.

Get Class Embedding

To train SAGE, the class embedding of each category in both train and test split should be get first by using tools/get_class_embedding.py.

python tools/get_class_embedding.py \
--class_embedding_path=/path/to/save/classs/embeddings \
--psp_checkpoint_path=/path/to/pretrained/pSp/checkpoint \
--train_data_path=/path/to/training/data \
--test_batch_size=4 \
--test_workers=4

Training pSp

The main training script can be found in tools/train.py.
Intermediate training results are saved to opts.exp_dir. This includes checkpoints, train outputs, and test outputs.
Additionally, if you have tensorboard installed, you can visualize tensorboard logs in opts.exp_dir/logs.

Training the pSp Encoder

#set GPUs to use.
export CUDA_VISIBLE_DEVICES=0,1,2,3

#begin training.
python -m torch.distributed.launch \
--nproc_per_node=4 \
tools/train.py \
--dataset_type=af_encode \
--exp_dir=/path/to/experiment/output \
--workers=8 \
--batch_size=8 \
--valid_batch_size=8 \
--valid_workers=8 \
--val_interval=2500 \
--save_interval=5000 \
--start_from_latent_avg \
--l2_lambda=1 \
--sparse_lambda=0.005 \
--orthogonal_lambda=0.0005 \
--A_length=100 \
--psp_checkpoint_path=/path/to/pretrained/pSp/checkpoint \
--class_embedding_path=/path/to/class/embeddings 

Testing

Inference

For 1-shot generation, you should put all your test data under one folder:

└── test_data                    
    ├── img1.jpg                # test-img
    ├── img2.jpg                                  
    └── ...                     

Then, you can use tools/inference_1_shot.py to apply the model on a set of images.
For example,

python tools/inference_1_shot.py \
--output_path=/path/to/output \
--checkpoint_path=/path/to/checkpoint \
--test_data_path=/path/to/test/input \
--train/data_path=/path/to/training/data \
--class_embedding_path=/path/to/classs/embeddings \
--n_distribution_path=/path/to/save/n/distribution \
--test_batch_size=4 \
--test_workers=4 \
--n_images=5 \
--alpha=1 \
--t=10 \
--n_similar_cates=30\
--beta=0.005

For 3-shot generation, you should put all your imgs from one category under one folder:

└── data_root
  ├── sample1                      
  |   ├── img1.jpg 
  |   ├── img2.jpg              
  |   └── img3.jpg    
  ├── sample2                      
  |   ├── img1.jpg 
  |   ├── img2.jpg              
  |   └── img3.jpg                             
  └── ...                       

Then, you can use tools/inference_3_shot.py to apply the model on a set of images.
For example,

python tools/inference_3_shot.py \
--output_path=/path/to/output \
--checkpoint_path=/path/to/checkpoint \
--test_data_path=/path/to/test/input \
--train/data_path=/path/to/training/data \
--class_embedding_path=/path/to/classs/embeddings \
--n_distribution_path=/path/to/save/n/distribution \
--test_batch_size=4 \
--test_workers=4 \
--n_images=5 \
--alpha=1 \
--t=10 \
--n_similar_cates=30\
--beta=0.005

sage's People

Contributors

unibester avatar

Stargazers

Hao Qi avatar ChunLiang Wu avatar Yisher_Xia avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.