Giter Site home page Giter Site logo

hilamanor / cycletransformer Goto Github PK

View Code? Open in Web Editor NEW
2.0 1.0 0.0 2.53 MB

Final assignment in the NLP course at the Technion (IEM097215). In this assignment we propose a novel architecture to handle both Text-to-Image translation and Image-to-Text translation tasks on paired data, using a unified architecture of transformers and CNNs and enforcing cycle consistency.

Python 100.00%
python3 python pytorch vgg vgg19 image-generation cyclegan text-generation transformers perceptual-losses

cycletransformer's Introduction

Python 3.8.12 torch torchvision datasets transformers HuggingFace

CycleTransformer

Both Text-to-Image translation and Image-to-Text translation have been an active area of research in the recent past [5,6,7]. Both tasks are difficult and interesting problems to solve: The Image-to-Text task demands that the generated caption will faithfully describe the image, while the Text-to-Image task demands that the generated image should be a faithful visual representation of the given text. Usually only one task is handled at a time, and the methods are tailored for extracting data from one domain and translating it to the other domain.

Recently, some [2,3,4] took inspiration from CycleGAN's use of duality for unpaired data [1], by leveraging the cycle consistency duality for paired data of different domains, such as text and images. Inspired by those papers and recent advancements in deep learning and NLP, in this assignment we propose a novel architecture, CycleTransformer, to handle both Text-to-Image translation and Image-to-Text translation on paired data, using a unified architecture of transformers and CNNs and enforcing cycle consistency.

Table of Contents

Requirements

The code was tested on python v3.8.12 with the following libraries:

Library Version
datasets 1.17.0
matplotlib 3.4.3
numpy 1.21.3
pillow 8.4.0
pytorch 1.10.0+cu111
pytorch-fid 0.2.1
rouge_score 0.0.4
scikit-image 0.18.3
scipy 1.7.1
torchvision 0.11.1+cu111
tqdm 4.63.0
transformers 4.15.0

We recommend using conda to deploy the project:

git glone https://github.com/HilaManor/CycleTransformer.git && cd CycleTransformer
conda create --name CycleTransformer python=3.8.12 pytorch=1.10.0 torchvision=0.11.1 cudatoolkit=11.1 numpy=1.21.3 scikit-image=0.18.3 matplotlib=3.4.3 scipy=1.7.1 pandas=1.3.4 pillow=8.4.0 tqdm -c pytorch -c conda-forge
conda activate CycleTransformer
pip install transformers==4.15.0 datasets==1.17.0 rouge_score==0.0.4 pytorch-fid==0.2.1

IMPORTANT - Fixing Hugginface Bug

The transformers code we've been working with had a bug which didn't allow the use of tensors in the ViT feature extraction method. We had to fix this bug in the library's code to allow complete gradient flow (for the consistency cycle).
This means that for our code to run, until the bug will fixed in the offical repo, you must fix it yourself before running the code.

To fix the bug you should edit feature_extraction_utils.py located in <python_base_folder>/site-packages/transformers/:
line 144 (under the function as_tensor(value), declared in line 142 of transformersv4.15.0):
add:

elif isinstance(value, (list, torch.Tensor)):
    return torch.stack(value)

Repository Structure

├── code - the code for training the CycleTransformer model
└── config - configurations for the CycleTransformer model

Usage Example

Training the model

python main.py --epochs <training_epochs> --val_epochs <validation_every_x_epochs> --config <path_to_yaml_file> [--baseline]

At each validation epoch the validation loss is shown and some images and captions are created from the validation split.
If the optional --baseline is given, will train the baseline models instead.
Use --help for more information on the parameters.

Generating Images and Captions

python main.py --out_dir <path_to_trained_model_dir> [--text <optional_text_prompt>] [--img_path <optional_image_path>] [--amount <amount_of_images_to_generate>]

Generates the images of the test split and generate captions for them, while comparing to the ground truths.
If the optional --text is given, will generate images from that text. The amount of generated images is given by --amount.
If the optional --img_path is given, will generate a text caption for the given image.
Use --help for more information on the parameters.

Pretrained Weights and Organized Dataset

Pretrained weights for the cycle consistent model are hosted here and pretrained weights for the baseline models are hosted here.

The dataset, organized as we used it (without classes splits) is available for simplicity here.

Model

CycleTransformer model is comprised of Text-to-Image and Image-to-Text parts.
The Text-to-Image model is comprised of distill BERT for text embedding. We concatenate a random noise vector sampled from the standard normal distributed to this embedding and then feed it to an image generator model. The Text-to-Image model is trained using perceptaul and reconstruction losses.
The Image-to-text model is an encoder decoder structure composed of distill DeiT model for features extractor and a GPT2 for text generation. This model is trained using language modelling loss.
Read our short paper for more detailes about the model.

model

Team

Hila Manor and Matan Kleiner

Examples

image

Comparison between the baseline models and the cycle consistent model on more generation results, for the Text-to-Image task and the Image-to-Test task. The baseline models generations are more varied in shape and display more coherent colors. The generated sentences are also more diverse in the language used and the different elements of the flower are described with greater detail.

image

Comparison between the results of the baseline and the cycle-consistent Image-to-Text models, for custom images on the Image-to-Text task. All the image were found online and are under creative commons license. The first three flowers (Gilboa Iris, Nazareth Iris and Mountain Tulip, respectively) are not part of the original dataset flower species. The last flower (Daffodil) is a specie present in the original dataset. The caption created by the baseline model are more accurate and do not include colors that don't appear in the input image. The cycle consistent model also creates good captions but in the case of the Nazareth Iris and the Daffodil, it mentions that the flower's color is purple which is not the case.

image

Comparison between the results of the baseline and the cycle-consistent Text-to-Image models, for custom text prompts on the Text-to-Image task. The first two sentences are simple, describing a single flower with one prominent color. In both cases, both models create a blob of the specified color in the center of the image, where the baseline model's blob's shape is a bit more flower-like. The next two sentences are more complicated, one of them describes more than one flower and the other describes 3 different-colored parts of the flower. The first sentence created a similar image response for both models. The seconds sentence caused the baseline model to generate a blob that merges two of the described colors whereas the cycle consistent model generated a colorful blob, without relating it to the mentioned colors. The last two sentences are the most complex ones. Both models generate similar results, while the the baseline model's results are a bit more pleasing to the eye.

References

  1. Jun-Yan Zhu, Taesung Park, Phillip Isola, and Alexei A Efros. 2017. Unpaired image-to-image translation using cycle-consistent adversarial networks, In Proceedings of the IEEE international conference on computer vision, pages 2223–2232.
  2. Mohammad R. Alam, Nicole A. Isoda, Mitch C. Manzanares, Anthony C. Delgado, and Antonius F. Panggabean. 2021. TextCycleGAN: cyclical-generative adversarial networks for image captioning, In Artificial Intelligence and Machine Learning for Multi-Domain Operations Applications III, volume 11746, pages 213 – 220. International Society for Optics and Photonics, SPIE
  3. Satya Krishna Gorti and Jeremy Ma. 2018. Text-to-image-to-text translation using cycle consistent adversarial networks, arXiv preprint, arXiv:1808.04538
  4. Keisuke Hagiwara, Yusuke Mukuta, and Tatsuya Harada. 2019. End-to-end learning using cycle consistency for image-to-caption transformations, arXiv preprint, arXiv:1903.10118
  5. Xiujun Li, Xi Yin, Chunyuan Li, Pengchuan Zhang, Xiaowei Hu, Lei Zhang, Lijuan Wang, Houdong Hu, Li Dong, Furu Wei, et al. 2020. Oscar: Object semantics aligned pre-training for vision-language tasks, In European Conference on Computer Vision, pages 121–137. Springer.
  6. Alex Nichol, Prafulla Dhariwal, Aditya Ramesh, Pranav Shyam, Pamela Mishkin, Bob McGrew, Ilya Sutskever, and Mark Chen. 2021. Glide: Towards photorealistic image generation and editing with text-guided diffusion models, arXiv preprint, arXiv:2112.10741
  7. Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, and Ilya Sutskever. 2021. Zero-shot text-to-image generation, In International Conference on Machine Learning, pages 8821–8831. PMLR.

cycletransformer's People

Contributors

hilamanor avatar matankleiner avatar

Stargazers

 avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.