Giter Site home page Giter Site logo

Comments (1)

dingning97 avatar dingning97 commented on September 24, 2024

I used a very old version of MiniGPT4, and the last commit might be "commit-22d8888" on May 1, 2023. You might access those lagacy codes throught this link: https://github.com/Vision-CAIR/MiniGPT-4/tree/22d8888ca2cf0aac862f537e7d22ef5830036808

Place the following script "generate_captions.py" at the root folder of MiniGPT4 and modify as needed
##############################################

"""generate_captions.py"""
import argparse
import random
from time import time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import os
import torchvision
from copy import deepcopy
import datetime
from pathlib import Path

from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION

# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
chat_state = CONV_VISION.copy()
input_prompt = 'Describe this image. Do not say anything that you are not sure.'

def prepare_model(args):
    print('Initializing Chat')
    cfg = Config(args)

    model_config = cfg.model_cfg
    model_config.device_8bit = args.gpu_id
    model_cls = registry.get_model_class(model_config.arch)
    model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

    vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
    vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
    chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
    print('Initialization Finished')
    return chat

def get_single_caption(chat, image, input_prompt):
    img_list = []
    chat_state.messages = []
    image_emb, _ = chat.model.encode_img(image)
    img_list.append(image_emb)
    chat_state.append_message(chat_state.roles[0], "<Img><ImageHere></Img>")
    chat_state.messages[-1][1] = ' '.join([chat_state.messages[-1][1], input_prompt])
    llm_message = chat.answer(
        conv=chat_state, img_list=img_list, num_beams=1, temperature=0.1,
        max_new_tokens=300, max_length=2000, return_text_only=True)
    
    return llm_message

def make_captions(args):
    start, num, bsize = args.start, args.num, args.bsize
    end = start + num - 1
    assert 0 == num % bsize
    assert 0 == start % bsize
    Path(args.save_dir).mkdir(parents=True, exist_ok=True)

    chat_model = prepare_model(args)
    dataset = torchvision.datasets.ImageFolder(
        args.data_path, 
        transform=deepcopy(chat_model.vis_processor)
    )
    all_result = []
    save_fname = f'minigpt4_caption_imagenet_train_{start}_{end}.pth'
    time_start = time()

    print(f'##@@ START = {start}    NUM = {num}')
    for index in range(start, start+num):
        img, _ = dataset.__getitem__(index)
        pred_string = get_single_caption(
            chat_model, img.unsqueeze(0).to(chat_model.device), input_prompt
        )
        print(pred_string)
        all_result.append(pred_string)
        
        if index == len(dataset) - 1:  # handle last batch
            print('this is the last sample')
            save_fname = save_fname.replace(str(end), str(index))
            break
    
    print('##@@ END\n', 'total running time :', str(datetime.timedelta(seconds=int(time()-time_start))))
    save_fname = os.path.join(args.save_dir, save_fname)
    torch.save(all_result, save_fname)
    print('result saved at', save_fname)

def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--data_path", type=str, default='/path/to/imagenet/train')
    parser.add_argument("--save_dir", type=str, default='/cache/output/minigpt4_captions')
    parser.add_argument("--cfg_path", type=str, default='eval_configs/minigpt4_eval.yaml')
    parser.add_argument("--gpu_id", default=0, type=int)
    parser.add_argument('--start',  default=0, type=int)
    parser.add_argument('--num',    default=100, type=int)
    parser.add_argument('--bsize',  default=5, type=int)
    parser.add_argument(
        "--options", nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_args()
    make_captions(args)

from efficient-computing.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.