Giter Site home page Giter Site logo

findalexli / mllm-dpo Goto Github PK

View Code? Open in Web Editor NEW
6.0 1.0 0.0 49.77 MB

Repo associated with the paper Multi-modal preference alignment remedies regression of visual instruction tuning on language model

Python 6.92% Jupyter Notebook 91.88% Shell 0.94% HTML 0.10% JavaScript 0.13% CSS 0.02%

mllm-dpo's Introduction

This repo contains the code and the data for the following paper:

@misc{li2024multimodal,
    title={Multi-modal preference alignment remedies regression of visual instruction tuning on language model},
    author={Shengzhi Li and Rongyu Lin and Shichao Pei},
    year={2024},
    eprint={2402.10884},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

[Arxiv paper] [GitHub] [Data] [Model] [Data]

Developers: Shengzhi Li (TIFIN.AI), Rongyu Lin (KAUST), Shichao Pei (University of Massachusetts Boston) Affiliations: TIFIN, KAUST, University of Massachusetts Boston Contact Information: [email protected], [email protected], [email protected]

Contents

Introduction

This guide provides step-by-step instructions for fine-tuning using the alignment methods and evaluating the LLaVA model, specifically focusing on visual instruction tuning using SciGraphQA and LRV-instruct datasets.

Installation

  1. Unzip the repository:

  2. Set up the environment:

    conda create -n llava python=3.10 -y
    conda activate llava
    pip install --upgrade pip
    pip install -e .
  3. Install packages for training:

    pip install -e ".[train]"
    pip install flash-attn --no-build-isolation

Data Preparation

  1. Download datasets and images:

    The images for LRC-Instruct shall be downloaded by: gdown https://drive.google.com/uc?id=1k9MNV-ImEV9BYEOeLEIb4uGEUZjd3QbM

The images for SciGraphQA can be downloaded by: https://huggingface.co/datasets/alexshengzhili/SciGraphQA-295K-train/resolve/main/img.zip?download=true 2. Organize the images in ./playground/data:

```
playground/
└── data/
    ├── scigraphqa/
    │   └── images/
    └── lrv_instruct/
        └── images/
```
  1. For DPO, please see playground/data/dpo_inference0104.with_logpllava-v1.5-13b_2024-02-03.json
  2. For non-DPO data, we also provide each of the alignment method (SteerLM, Rejection Sampling and Standard SFT) in the data folder such as playground/data/rejection_sampling.json playground/data/standard_sft.json playground/data/steerlm.json

Training

  1. Use scripts/v1/finetune_dpo.sh for DPO experiments
  2. Use scripts/v1/finetune_steer.sh for non-DPO experiments,

Evaluation

  1. Use the provided evaluation scripts under scripts/v1_5/eval/ to assess the performance of your fine-tuned model on various benchmarks. Ensure that you follow the guidelines for using greedy decoding to ensure consistency with real-time outputs.

We thank the authors of LLaVA, Vicuna for which the origional state of this repo is based on

mllm-dpo's People

Contributors

findalexli avatar

Stargazers

 avatar JerExJs avatar  avatar Wenhao Chai avatar @choucaicai avatar  avatar

Watchers

 avatar

mllm-dpo's Issues

Loading Dataset returns error from huggingface

running this:

from datasets import load_dataset
dataset = load_dataset("alexshengzhili/mllm-dpo",split='train[0:1]',trust_remote_code=True)

returns this error:

ArrowInvalid Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/json/json.py in _generate_tables(self, files)
121 try:
--> 122 pa_table = paj.read_json(
123 io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size)

17 frames
ArrowInvalid: JSON parse error: Column() changed from object to array in row 0

During handling of the above exception, another exception occurred:

ArrowTypeError Traceback (most recent call last)
ArrowTypeError: Expected bytes, got a 'int' object

The above exception was the direct cause of the following exception:

DatasetGenerationError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/datasets/builder.py in _prepare_split_single(self, gen_kwargs, fpath, file_format, max_shard_size, job_id)
2036 if isinstance(e, DatasetGenerationError):
2037 raise
-> 2038 raise DatasetGenerationError("An error occurred while generating the dataset") from e
2039
2040 yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths)

DatasetGenerationError: An error occurred while generating the dataset

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.