Giter Site home page Giter Site logo

ziyang412 / fromage Goto Github PK

View Code? Open in Web Editor NEW

This project forked from kohjingyu/fromage

0.0 0.0 0.0 40.98 MB

๐Ÿง€ Code and models for the ICML 2023 paper "Grounding Language Models to Images for Multimodal Inputs and Outputs".

Home Page: https://jykoh.com/fromage

License: Apache License 2.0

Python 5.89% Jupyter Notebook 94.11%

fromage's Introduction

Unified Embeddings for Multimodal Retrieval via Frozen LLMs

This repository contains the code used for the paper: Unified Embeddings for Multimodal Retrieval via Frozen LLMs

Ziyang Wang, Heba Elfardy, Markus Dreyer, Kevin Small, Mohit Bansal.

Overview

In this work, We present Unified Embeddings for Multimodal Retrieval (UNIMUR), a simple but effective approach that embeds multimodal inputs and retrieves visual and textual outputs via frozen Large Language Models (LLMs). Specifically, UNIMUR jointly retrieves multimodal outputs via a unified multimodal embedding and applies dual alignment training to account for both visual and textual semantics. Thus, unlike previous approaches, UNIMUR significantly reduces LLMโ€™s modality bias towards generating text-only outputs. Meanwhile, the proposed unified multimodal embedding mitigates the inconsistency between visual and textual outputs and provides coherent multimodal outputs. Furthermore, benefiting from the joint training of visual and textual semantics, UNIMUR also achieves strong image/text retrieval ability.

teaser image

Setup

Install Dependencies

  1. (Optional) Set up a new conda environment, and install the required libraries:
conda create -n unimur python=3.8
conda activate unimur
  1. Install the required libraries
pip install -r requirements.txt
  1. Add the unimur library to PYTHONPATH:
export PYTHONPATH=$PYTHONPATH:/home/path/to/unimur/

Pretrained Checkpoints

The pruned UNIMUR model weights (linear layers and [RET] embedding) are small and are included in this Git repo. They will be in the unimur_model/ folder after cloning. The checkpoint and model config in unimur_model/ reproduce the results reported in our paper.

Precomputed Embeddings For Image Retrieval

We follow FROMAGe which leverages the visual embedding of CC3M images for retrieval. Please follow their instructions, download the files, and place cc3m_embeddings.pkl into the unimur_model/ directory.

Training

Preparing CC3M

UNIMUR is trained on the Conceptual Captions dataset (main results trained on CC3M). After following the instructions on the website to download the captions and images, format them into a .tsv file as follows:

caption image
A picture of a dog  dog.png
Tree  tree.png

where each line contains the caption followed by the filename of the image files. Save these .tsv files into the dataset/ folder (the default names expected are cc3m_train.tsv and cc3m_val.tsv). The repo contains two placeholder files, and you will have to replace them with the appropriate data.

The corresponding image files should be saved in the data/ directory. The directory can be changed with the --image-dir runtime flag.

Training UNIMUR

After preparing dataset as detailed above, you can start a new training job with the following command line flag:

randport=$(shuf -i8000-9999 -n1)  # Generate a random port number
python -u main.py \
    --dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl' \
    --multiprocessing-distributed --world-size 1 --rank 0 \
    --dataset=cc3m  --val-dataset=cc3m \
    --opt-version='facebook/opt-6.7b' --visual-model='openai/clip-vit-large-patch14' \
    --exp_name='unimur_exp' --image-dir='image_data/'  --log-base-dir='exp_log/' \
    --batch-size=120  --val-batch-size=80  --learning-rate=0.0003 --precision='bf16'  --print-freq=100

On 4 V100-32GB GPU, the model converges within 16 hours (with a batch size of 120). You may also have to disable NCCL P2P with export NCCL_P2P_DISABLE=1 if you run into issues.

Pruning Model Weights

As UNIMUR only consists of a few pretrained linear layers and the [RET] embedding, we can discard most of the pretrained weights to save on disk space. If you have trained a new model, and wish to do so, you can use unimur/prune_model_ckpt.py to prune the model weights. We used the same script to create the weights in the unimur_model directory.

Evaluation

Preparing Evaluation Datasets

We evaluate our model on Visual Dialog and MMDialog dataset.

For the VisDial dataset (val split), please download the validation annotation from here and the raw image from here and Extract everything to the VisualDialog folder.

For the MMDialog dataset (test split), please follow the download instructions here.

Evaluating on multiple tasks

We provide an evaluation script to reproduce our results on dialogue-to-image retrieval in Visual Dialog (results of Table 3 of our paper). Please change the dataset/checkpoint directory to your configuration.

python evals/eval_visdial.py

Similarly, we provide scripts to reproduce the multimodal response retrieval results on MMDialog (presented in Table 1 of our paper). To speed up the evaluation process (MMDialog applies different candidate image pools for each dialog), we first extract the visual embedding of all test images using the extract_vis_emb.py (configure your own data/model path). Then, we use the eval_mmdialog.py to reproduce the multimodal response retrieval results (configure the path to your own data/model/saved embedding).

python unimur/extract_vis_emb.py

python evals/eval_mmdialog.py

License

fromage's People

Contributors

kohjingyu avatar ziyang412 avatar vishaal27 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.