Giter Site home page Giter Site logo

llava-next's Issues

Training/Finetunning code please

Hi, Dear author:
It seems the llava-next is really insightful exploreing work. Please kindly release the training and inference code asap, thank you very much.

cannot import name 'LlavaLlamaForCausalLM' from 'llava.model'

I'm getting this error in example code of llava-next

ImportError                               Traceback (most recent call last)
Cell In[5], line 1
----> 1 from llava.model.builder import load_pretrained_model
      2 from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
      3 from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX

File /opt/conda/lib/python3.10/site-packages/llava/model/builder.py:23
     21 import torch
     22 from llava.model import *
---> 23 from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
     24 from llava.utils import rank0_print
     27 def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, **kwargs):

File /opt/conda/lib/python3.10/site-packages/llava/__init__.py:1
----> 1 from .model import LlavaLlamaForCausalLM

ImportError: cannot import name 'LlavaLlamaForCausalLM' from 'llava.model' (/opt/conda/lib/python3.10/site-packages/llava/model/__init__.py)

Add prompt format and sample inference code to HF model repos

Currently, the tokenizer_config is the same as the Llama 3 model, which isn't instructive as to how to pass in images.

Adding a very short snippet of code outlining how to load and inference the model would be a great addition. Same for the video repos.

Ideally inference could be done with either AutoModelForCausalLM or a LlavaLlama model (although I guess that has be created as the LLaVA NeXT Llama 3 model differs?)

Only output [1, 2] tokens for 'lmms-lab/LLaVA-NeXT-Video-7B-DPO' video demo inference

the output of output_ids is tensor([[1, 2]], device='cuda:0')
Other output of the demo script is:

Question: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER:
Please provide a detailed description of the video, focusing on the main subjects, their actions, and the background scenes ASSISTANT:

Response:

LaVA-NeXT-video Training Hyperparameters

Hello Developers,

Thank you for your outstanding work. Could you please provide the Training Hyperparameters used during the training of the LLaVA-NeXT-video and LaVA-NeXT-video-dpo model?

LLaMA3-8B video inference

As llama3-llava-next-8b and LLaVA-NeXT-Video-7B-DPO seem to have the same interface, is it possible to make llama3-llava-next-8b process multiple frames of one video per single forward?

Basically, I don't get the idea of what makes LLaVA-NeXT-Video a processor for multiple frames and can't find the related code. According to the blog post processing patches versus frames is the difference, so the initial question arises.

Example doesn't work

Copy pasted the example from here: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/inference/docs/LLaVA-NeXT.md

Get this:

Loading checkpoint shards: 100%|█████████████████████████████| 4/4 [00:02<00:00,  1.50it/s]
Traceback (most recent call last):
  File "/home/user/Programs/LLaVA/test.py", line 36, in <module>
    cont = model.generate(
  File "/home/user/Programs/LLaVA/env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/Programs/LLaVA/llava/model/language_model/llava_llama.py", line 113, in generate
    (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
  File "/home/user/Programs/LLaVA/llava/model/llava_arch.py", line 199, in prepare_inputs_labels_for_multimodal
    raise ValueError("vision_tower_image_size is not found in the vision tower.")
ValueError: vision_tower_image_size is not found in the vision tower.

git log shows me at:

commit 9496e2d4667a220a3d70c268bb30bf9ed2ced724 (HEAD, next/inference)
Author: ZhangYuanhan-AI <[email protected]>
Date:   Fri May 10 05:24:02 2024 +0000

    Fix file paths and scaling factor calculation in video_demo.py and video_demo.sh

Data to evaluate Video Detailed Description.

Hi,

I'm deeply inspired by your great work!

Could you please provide some information on the data used to evaluate the detailed captioning ability of the model (not the evaluation script, but the videos and annotations)?

I really appreciate any help you can provide.

poor quality output for qwen 72b

server:

export CUDA_VISIBLE_DEVICES="3,4,5,6"
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30010 --host="0.0.0.0" --tp-size=4 --random-seed=1234 --context-length=32768 &> 72b.log &

client:

"""
Usage:
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
# Installing latest sglang.

# Endpoint Service CLI:
# python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4

python3 http_qwen_llava_test.py

Output:
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
"""

import argparse
import asyncio
import json
import time
import copy

import aiohttp
import requests

from llava.conversation import (
    default_conversation,
    conv_templates,
    SeparatorStyle,
    conv_llava_llama_3,
    conv_qwen,
)


async def send_request(url, data, delay=0):
    await asyncio.sleep(delay)
    async with aiohttp.ClientSession() as session:
        async with session.post(url, json=data) as resp:
            output = await resp.json()
    return output


async def test_concurrent(args):
    url = f"{args.host}:{args.port}"

    prompt = "<image>\nPlease generate caption towards this image."
    conv_template = copy.deepcopy(conv_qwen)
    conv_template.append_message(role="user", message=prompt)
    prompt_with_template = conv_template.get_prompt()
    response = []
    for i in range(1):
        response.append(
            send_request(
                url + "/generate",
                {
                    "text": prompt_with_template,
                    "image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
                    "sampling_params": {
                        "max_new_tokens": 1024,
                        "temperature": 0,
                        "top_p": 1.0,
                        "presence_penalty": 2,
                        "frequency_penalty": 2,
                        "stop": "<|im_end|>",
                    },
                },
            )
        )

    rets = await asyncio.gather(*response)
    for ret in rets:
        print(ret["text"])


def test_streaming(args):
    url = f"{args.host}:{args.port}"
    prompt = "<image>\nGive detailed information."
    conv_template = copy.deepcopy(conv_qwen)
    conv_template.append_message(role="user", message=prompt)
    prompt_with_template = conv_template.get_prompt()
    pload = {
        "text": prompt_with_template,
        "sampling_params": {
            "max_new_tokens": 1024,
            "temperature": 0,
            "top_p": 1.0,
            "presence_penalty": 2,
            "frequency_penalty": 2,
            "stop": "<|im_end|>",
        },
        #"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
        "image_data": "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.png",
        "stream": True,
    }
    response = requests.post(
        url + "/generate",
        json=pload,
        stream=True,
    )

    prev = 0
    for chunk in response.iter_lines(decode_unicode=False):
        chunk = chunk.decode("utf-8")
        if chunk and chunk.startswith("data:"):
            if chunk == "data: [DONE]":
                break
            data = json.loads(chunk[5:].strip("\n"))
            output = data["text"].strip()
            print(output[prev:], end="", flush=True)
            prev = len(output)
    print("")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="http://0.0.0.0")
    parser.add_argument("--port", type=int, default=80)
    args = parser.parse_args()
    # asyncio.run(test_concurrent(args))
    test_streaming(args)

just gives:

Big Ben

No matter how I prompt, the output is extremely terse even if accurate.

I changed the image, but otherwise this is the default script from sglang: https://github.com/sgl-project/sglang/blob/main/examples/usage/llava/http_qwen_llava_test.py

If I try increasing temperature to 0.5, I get no response at all and it just fails:

INFO:     172.16.0.42:27134 - "POST /generate HTTP/1.1" 200 OK
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/responses.py", line 265, in __call__
    await wrap(partial(self.listen_for_disconnect, receive))
  File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/responses.py", line 261, in wrap
    await func()
  File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/responses.py", line 238, in listen_for_disconnect
    message = await receive()
  File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/uvicorn/protocols/http/httptools_impl.py", line 568, in receive
    await self.message_event.wait()
  File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/asyncio/locks.py", line 214, in wait
    await fut
asyncio.exceptions.CancelledError: Cancelled by cancel scope 7dcc7e79ada0

During handling of the above exception, another exception occurred:

  + Exception Group Traceback (most recent call last):
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/uvicorn/protocols/http/httptools_impl.py", line 411, in run_asgi
  |     result = await app(  # type: ignore[func-returns-value]
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 69, in __call__
  |     return await self.app(scope, receive, send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/fastapi/applications.py", line 1054, in __call__
  |     await super().__call__(scope, receive, send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/applications.py", line 123, in __call__
  |     await self.middleware_stack(scope, receive, send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/middleware/errors.py", line 186, in __call__
  |     raise exc
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/middleware/errors.py", line 164, in __call__
  |     await self.app(scope, receive, _send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 65, in __call__
  |     await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
  |     raise exc
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
  |     await app(scope, receive, sender)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/routing.py", line 756, in __call__
  |     await self.middleware_stack(scope, receive, send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/routing.py", line 776, in app
  |     await route.handle(scope, receive, send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/routing.py", line 297, in handle
  |     await self.app(scope, receive, send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/routing.py", line 77, in app
  |     await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
  |     raise exc
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
  |     await app(scope, receive, sender)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/routing.py", line 75, in app
  |     await response(scope, receive, send)
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/responses.py", line 258, in __call__
  |     async with anyio.create_task_group() as task_group:
  |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 678, in __aexit__
  |     raise BaseExceptionGroup(
  | exceptiongroup.ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
  +-+---------------- 1 ----------------
  +-+---------------- 1 ----------------
    | Traceback (most recent call last):
    |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/responses.py", line 261, in wrap
    |     await func()
    |   File "/home/ubuntu/miniconda3/envs/sglang/lib/python3.10/site-packages/starlette/responses.py", line 250, in stream_response
    |     async for chunk in self.body_iterator:
    |   File "/home/ubuntu/sglang/python/sglang/srt/server.py", line 89, in stream_results
    |     async for out in tokenizer_manager.generate_request(obj, request):
    |   File "/home/ubuntu/sglang/python/sglang/srt/managers/tokenizer_manager.py", line 143, in generate_request
    |     pixel_values, image_hash, image_size = await self.get_pixel_values(
    | TypeError: cannot unpack non-iterable NoneType object
    +------------------------------------

I don't understand what it means by the TypeError just because of temperature=0.5, very odd. Is that sglang's fault?

How to load the largest models?

Loading lmms-lab/llava-next-72b and lmms-lab/llava-next-110b with the device_map='auto' does not seems to work and results in NotImplementedError: Cannot copy out of meta tensor; no data!, even though I am trying to load the model on 8-40GB GPUs. Is there a minimal example to do inference on the largest models?

training code

Hello, I am trying to find the training code, but it seems like there is just inference code.

Can you please point to the training code?

The infer script does not exist in the project

Thanks for your work. When I ran the inference script 'video_detail_description_eval_shard.sh' .I didn't find llava/eval/evaluate_benchmark_video_detail_description.py. I wonder if it's because I'm missing some steps?

what the "32K" means?

I wonder what the "32K" signifies when using the "lmms-lab/LLaVA-NeXT-Video-7B-32K" checkpoint.

LLaVA-NeXT-Video-34B-DPO 模型无法启动,报错

用的 LLaVA-NeXT-Video-34B-DPO 这个模型,
https://github.com/LLaVA-VL/LLaVA-NeXT git checkout video_inference这个分支的代码,为什么我的测试起不来,报错(
1、stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 屏蔽后报 2、
2、llava_arch.LlavaMetaForCausalLM.prepare_inputs_labels_for_multimodal---->concat_images = torch.cat(images_list, dim=0),
输出为None

有没有把这个模型安排起来的,,,LLaVA-NeXT-Video-34B-DPO用到了clip-vit-large-patch14-336(这个模型),
是不是代码有点问题。

sglang support?

The announcement blog post indicates inference can be done with sglang, but attempting to load the 7b model with the sglang backend:

python -m sglang.launch_server --model-path ~/models/lmms-lab_LLaVA-NeXT-Video-7B-DPO --port 30000

Results in this key error:

/home/user/sglang/venv/lib/python3.11/site-packages/transformers/models/llava/configuration_llava.py:103: FutureWarning: The `vocab_size` argument is deprecated and will be removed in v4.42, since it can be inferred from the `text_config`. Passing this argument has no effect
  warnings.warn(
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
/home/user/sglang/venv/lib/python3.11/site-packages/transformers/models/llava/configuration_llava.py:143: FutureWarning: The `vocab_size` attribute is deprecated and will be removed in v4.42, Please use `text_config.vocab_size` instead.
  warnings.warn(
Rank 0: load weight begin.
/home/user/sglang/venv/lib/python3.11/site-packages/transformers/models/llava/configuration_llava.py:143: FutureWarning: The `vocab_size` attribute is deprecated and will be removed in v4.42, Please use `text_config.vocab_size` instead.
  warnings.warn(
/home/user/sglang/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/home/user/sglang/venv/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Process Process-1:
router init state: Traceback (most recent call last):
  File "/home/user/sglang/venv/lib/python3.11/site-packages/sglang/srt/managers/router/manager.py", line 68, in start_router_process
    model_client = ModelRpcClient(server_args, port_args)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/sglang/venv/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 619, in __init__
    self.model_server.exposed_init_model(0, server_args, port_args)
  File "/home/user/sglang/venv/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 70, in exposed_init_model
    self.model_runner = ModelRunner(
                        ^^^^^^^^^^^^
  File "/home/user/sglang/venv/lib/python3.11/site-packages/sglang/srt/managers/router/model_runner.py", line 287, in __init__
    self.load_model()
  File "/home/user/sglang/venv/lib/python3.11/site-packages/sglang/srt/managers/router/model_runner.py", line 326, in load_model
    model.load_weights(
  File "/home/user/sglang/venv/lib/python3.11/site-packages/sglang/srt/models/llava.py", line 285, in load_weights
    param = params_dict[name]
            ~~~~~~~~~~~^^^^^^
KeyError: 'model.vision_resampler.mm_projector.0.bias'

Why does running the following code keep downloading files from huggingface?

from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle

from PIL import Image
import requests
import copy
import torch

pretrained = "llama3-llava-next-8b"
model_name = "llava_llama3"
device = "cuda"
device_map = "auto"
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args

model.eval()
model.tie_weights()

image = Image.open("2.jpeg")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]

conv_template = "llava_llama_3" # Make sure you use correct chat template for different models
question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()

input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]

cont = model.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=256,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)

The image shows a radar chart, also known as a spider chart or a web chart, which is a type of graph used to display multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "LLaVA-Bench," "SEED-Bench," "MMBench-CN," "MMBench," "TextVQA," "VizWiz," "GQA," "BLIP-2," "InstructBLIP," "Owen-VL-Chat," and "LLaVA-1.5." These labels suggest that the chart is comparing the performance of different models or systems across various benchmarks or tasks, such as machine translation, visual question answering, and text-based question answering.\n\nThe chart is color-coded, with each color representing a different model or system. The points on the chart are connected to form a polygon, which shows the relative performance of each model across the different benchmarks. The closer the point is to the outer edge of the

The error is reported as follows
raise EnvironmentError(
OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like meta-llama/Meta-Llama-3-8B-Instruct is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

How to deploy this model via API?

How do we deploy this model via API? Can I deploy it on vLLM or lmdeploy? I can't find any example to run this with HuggingFace transformers.

I want to deploy 72b and 110b model

Request for NExTQA Dataset Evaluation Prompt and More Results on Challenging Datasets for Fair Comparison

To my knowledge, the videos in NExTQA dataset are relatively short, with an average video length of 44 seconds, and there is a noted static bias[1] in the ActivityNet QA dataset. Could you present further results on more demanding datasets for fair comparison, such as EgoSchema[2]? Additionally, Could I request that you supply the evaluation prompt for the NeXTQA dataset?

[1] Lei, Jie et al. “Revealing Single Frame Bias for Video-and-Language Learning.” ArXiv abs/2206.03428 (2022): n. pag.
[2] Mangalam, Karttikeya et al. “EgoSchema: A Diagnostic Benchmark for Very Long-form Video Language Understanding.” ArXiv abs/2308.09126 (2023): n. pag.

Finetuning Scripts

Thanks for your work on this! When will fine-tuning scripts be made available?

中文ocr效果极差

1.5阶段不是加入了中文ocr数据么,为什么识别中文依旧没有任何效果

What is the conv-mode for LLaVA-NeXT-Video-7B-32K

I tried the following conv-mode:
vicuna_v1
--conv-mode mistral_direct
Llava_llama_2
llama_2
mistral_instruct

and encounter the error as below:
AttributeError: 'LlavaMistralConfig' object has no attribute 'attention_bias'

Conversation.copy() does not work well

When do copy() of llama_v3 version Conversation, tokenizer and other attributes are not copied, causing an error at self.tokenizer.apply_chat_template() in get_prompt().

No Attribute 'apply_chat_template'

Do anyone encounter this error when running llama3-llava-next-8b?

File "/LLaVA-NeXT/llava/conversation.py", line 107, in get_prompt return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True) AttributeError: 'str' object has no attribute 'apply_chat_template'

I am so confused that I've already use the latest version of transformers.

LLama-3 conversation template

In the file conversation.py, the Llama-3 chat is given by the line 107
self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=False)
which means the token <|start_header_id|> and <|end_header_id|> will be inserted automatically by the chat template of the tokenizer. However the token <|start_header_id|> is also in the roles as well (line 353)
roles=("<|start_header_id|>user", "<|start_header_id|>assistant"),
So the token <|start_header_id|> will be duplicated in the output like this:
<|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n....<|eot_id|><|start_header_id|><|start_header_id|>assistant<|end_header_id|>\n\n...
Is this the correct behavior?

Problem running LLaVA-NeXT-Video-34B-DPO

Dear authors,

thank you for your great work. I have tested the LLaVA-NeXT-Video-7B-DPO on various videos and it show very excellent results. But when i try to run the 34B-DPO, i encountered following error:

Traceback (most recent call last): File "/mnt/qb/work/ponsmoll/pba178/project/LLaVA-NeXT/batch.py", line 151, in <module> run_inference() File "/mnt/qb/work/ponsmoll/pba178/project/LLaVA-NeXT/batch.py", line 133, in run_inference output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria]) File "/mnt/qb/work/ponsmoll/pba178/.conda/llavan/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/mnt/qb/work/ponsmoll/pba178/project/LLaVA-NeXT/llavavid/model/language_model/llava_llama.py", line 120, in generate return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) File "/mnt/qb/work/ponsmoll/pba178/.conda/llavan/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/mnt/qb/work/ponsmoll/pba178/.conda/llavan/lib/python3.10/site-packages/transformers/generation/utils.py", line 1576, in generate result = self._sample( File "/mnt/qb/work/ponsmoll/pba178/.conda/llavan/lib/python3.10/site-packages/transformers/generation/utils.py", line 2760, in _sample unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) File "/mnt/qb/work/ponsmoll/pba178/.conda/llavan/lib/python3.10/site-packages/transformers/generation/stopping_criteria.py", line 137, in __call__ is_done = is_done | criteria(input_ids, scores, **kwargs) File "/mnt/qb/work/ponsmoll/pba178/project/LLaVA-NeXT/llavavid/mm_utils.py", line 245, in __call__ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) File "/mnt/qb/work/ponsmoll/pba178/project/LLaVA-NeXT/llavavid/mm_utils.py", line 234, in call_for_batch if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.