Giter Site home page Giter Site logo

dyishiou / collm Goto Github PK

View Code? Open in Web Editor NEW

This project forked from zyang1580/collm

0.0 0.0 0.0 85.59 MB

The implementation for the work "CoLLM: Integrating Collaborative Embeddings into Large Language Models for Recommendation".

License: BSD 3-Clause "New" or "Revised" License

Python 60.20% Jupyter Notebook 39.80%

collm's Introduction

CoLLM: Integrating Collaborative Embeddings into Large Language Models for Recommendation

Yang Zhang, Fuli Feng, Jizhi Zhang, Keqin Bao, Qifan Wang and Xiangnan He.

University of Science and Technology of China

This repository is constructed based on MiniGPT-4!

Introduction

We introduce CoLLM, a novel method that effectively integrates collaborative information into LLMs by harnessing the capability of external traditional models to capture the information. Similar to existing approaches (e.g., TALLRec), CoLLM starts by converting recommendation data into language prompts (prompt construction), which are then encoded and inputted into an LLM to generate recommendations (hybrid encoding and LLM prediction). We have specific designs for incorporating collaborative information:

  • When constructing prompts, we add user/item ID fields in addition to text descriptions to represent collaborative information.

  • When encoding prompts, alongside the LLMs’ tokenization and embedding for encoding textual information, we employ a conventional collaborative model to generate user/item representations that capture collaborative information, and map them into the token embedding space of the LLM, which are achieved by the CIE module in the figure.

For training, we take a two-step tuning method:

  • Tuning the LoRA Module with the text-only input.

  • Tuning the CIE module with both the text and user/item ID data.

overview

Getting Started

Installation

1. Prepare the code and the environment

Git clone our repository, creating a python environment and ativate it via the following command

git clone https://github.com/zyang1580/CoLLM.git
cd CoLLM
conda env create -f environment.yml
conda activate minigpt4

Code Structure:

├──minigpt4: Core code of CoLLM, following the structure of MiniGPT-4.
    ├── models: Defines our CoLLM model architecture.
    ├── datasets: Defines dataset classes.
    ├── task: A overall task class, defining the used model and datasets, training epoch and evaluation.
    ├── runners: A runner class to train and evaluate a model based on a task.
    ├── common: Commonly used functions.
├──dataset: Dataset pre-processing.
├──prompt: Used prompts.
├──train_configs: Training configuration files, setting hyperparameters.
├──train_collm_xx.py CoLLM training file.
├──baseline_train_xx.py: Baseline training file.

2. Prepare the pretrained Vicuna weights

The current version of CoLLM is built on the v0 versoin of Vicuna-7B. Please refer to Mini-GPT4's instruction here to prepare the Vicuna weights. The final weights would be in a single folder in a structure similar to the following:

vicuna_weights
├── config.json
├── generation_config.json
├── pytorch_model.bin.index.json
├── pytorch_model-00001-of-00003.bin
...   

Then, set the path to the vicuna weight in the "llama_model" field of a traing config file, e.g., here for CoLLM-MF.

3. Prepare the Datasets

You can process the data yourself using the code provided in the ./dataset directory. Alternatively, you can download our pre-processed data from here.

Training

The training of CoLLM contains two stages:

1. LoRA Tuning

To endow the cold-start recommendation capabilities of LLM, our initial focus is on fine-tuning the LoRA module to learn recommendation tasks independently of collaborative information. That is, we solely utilize the text-only segment of the prompt (e.g., "prompt_/tallrec_movie.txt") to generate predictions and minimize prediction errors for tuning the LoRA module to learning recommendation.

When implementing, you need to set the hyper-parameters in the training config file (e.g., train_configs/collm_pretrain_mf_ood.yaml) as follows:

- freeze_rec: True # freeze the collaborative rec model
- freeze_proj: True  # freeze the mapping function or not
- freeze_lora: False # tuning the LoRA module
- prompt_path: "prompt_/tallrec_movie.txt" # use the prompt without the user/item IDs
- ckpt: None # without pretrained LoRA and CIE (LoRA), you can aslo directly delete this hypee-parameter 
- evaluate:False #set training

To launch the first stage training, run the following command. In our experiments, we use 2 A100.

CUDA_VISIBLE_DEVICES=6,7 WORLD_SIZE=2 nohup torchrun --nproc-per-node 2 --master_port=11139 train_collm_mf_din.py  --cfg-path=train_configs/collm_pretrain_mf_ood.yaml > /log.out &

Our CoLLM checkpoint for this stage training can be downloaded at here (7B). Note that, the model obtained in this stage can be thought as a version of TALLRec.

2. CIE Tuning

In this step, we tune the CIE module while keeping all other components frozen The objective of this tuning step is to enable the CIE module to learn how to extract and map collaborative information effectively for LLM usage in recommendations. To achieve this, we utilize prompts containing user/item IDs to generate predictions and tune the CIE model to minimize prediction errors. When implementing, you need set the hyper-parameters in the training config file as follows:

- freeze_rec: True # freeze the collaborative rec model
- freeze_proj: False  # tuning the mapping function or not
- freeze_lora: True # freeze the LoRA module
- pretrained_path: pretrained_collab_model_path # pretrained collab. model path 
- evaluate:False #set training
- prompt_path: "prompt_/collm_movie.txt" # use the prompt with the user/item IDs
- ckpt: step1_checkpoint_path # with pretrained LoRA

Then run the same command to the stage 1. Our final CoLLM checkpoint can be downloaded at here (7B).

** Notes: By default, at this stage, we utilize a pretrained collaborative model and focus solely on tuning the mapping module (an MLP). Alternatively, you have the option to fine-tune both the complete CIE module (mapping + collaborative model), with or without prior pretraining of the collaborative model. This approach might yield enhanced performance, as outlined in Section 5.3.2 of our paper. This can be achieved by controlling the above two hyper-parameters: freeze_rec and pretrained_path.

Evaluation

Set the hyper-parameters in the training config file as follows:

- ckpt: your_checkpoint_path # trained model path
- evaluate: True # only evaluate

Then run the same command to the first stage training.

Acknowledgement

  • MiniGPT4 Our repository is built upon MiniGPT-4!
  • Vicuna The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!

If you're using CoLLM code in your research or applications, please cite our paper using this BibTeX:

@article{zhang2023collm,
  title={CoLLM: Integrating Collaborative Embeddings into Large Language Models for Recommendation},
  author={Zhang, Yang and Feng, Fuli and Zhang, Jizhi and Bao, Keqin and Wang, Qifan and He, Xiangnan},
  journal={arXiv preprint arXiv:2310.19488},
  year={2023}
}

License

This repository is under BSD 3-Clause License. Many codes are based on MiniGPT-4 with BSD 3-Clause License here, which is build upon Lavis with BSD 3-Clause License here.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.