Giter Site home page Giter Site logo

rl4f's Introduction

RL4F: Generating Natural Language Feedback with Reinforcement Learning

Code repository for the ACL 2023 paper "RL4F: Generating Natural Language Feedback with Reinforcement Learning". Afra Feyza Akyurek, Ekin Akyurek, Aman Madaan, Ashwin Kalyan, Peter Clark, Derry Wijaya, Niket Tandon. Check out the project page for a brief introduction.

This codebase is primarily based on the RL4LMs repository. We provide custom data classes and reward functions to implement RL4F.

Installation

git clone https://github.com/feyzaakyurek/rl4f.git
cd rl4f
pip install -e .

Introduction to the Codebase

rl4lms/data_pools/custom_text_generation_pools.py: This file contains custom dataset loading classes. Make sure to specify the correct data paths in respective classes.

scripts/training/task_configs/: Yaml files containing configs are stored under this path. This is where we specify the training and evaluation arguments, model and output paths.

rl4f_scripts/: A sample sh script for supervised critique generation.

openai_key: If running RL4F with one of OpenAI models, you need to place your API key in a file and specify the path in config. Give a path to this key in your yaml file. Note that RL4F runs using openai API might incur significant charges.

wandb_key: We track our runs using wandb, specify your API key here which is used in the sh script.

Running Experiments

Datasets

Download the pretrained checkpoints and data from this link.

All scripts can be found under rl4f_scripts. For example, check out the rl4f_scripts/run_alphabetize_sup.sh script for warm-starting a pretrained T5-large for supervised critique generation for alphabetization. Alternatively, you can load the released checkpoint from the above drive link. For PPO training, specify the checkpoint at scripts/training/task_configs/alphabetize/t5large_ppo_on_supervised.yaml and run rl4f_scripts/run_alphabetize_ppo.sh.

FAQ

If you are receiving an error about your torch installation not supporting sm_86, try uninstalling torch and reinstalling with conda using the cudatoolkit that matches your environment. E.g.

conda install pytorch==1.11.0 cudatoolkit=11.3 -c pytorch

rl4f's People

Contributors

feyzaakyurek avatar

Stargazers

embneural avatar 呜呼 avatar  avatar Kyeongpil Kang avatar Quyet avatar  avatar  avatar Zheng Yuan avatar Longhui Yu avatar Jiuzhouh avatar Yoon, Seungje avatar Sahand Rezaei-Shoshtari avatar  avatar  avatar Jeff Carpenter avatar Jiaxin Wen avatar muhtasham avatar Dongfang Li avatar Szymon avatar  avatar  avatar 西西嘛呦 avatar Yoshinari Fujinuma avatar  avatar L.JIE avatar Sandalots avatar Altriasjy Pendragon avatar 爱可可-爱生活 avatar Yizhe Zhang avatar  avatar Qinyuan Cheng avatar seven8827 avatar Zhibin Gou avatar Jokie avatar  avatar Iftitahu Ni'mah avatar llasqweweq avatar  avatar science avatar  avatar  avatar  avatar Ethan, Wenjun Hou avatar Mohammad Reza Taesiri avatar Thomas Wood avatar Yotam avatar Gary Gege avatar  avatar Akinori Nakajima avatar  avatar  avatar Dustin Rush avatar Koda avatar zhangcheng avatar  avatar  avatar

Watchers

Kostas Georgiou avatar  avatar Yotam avatar

rl4f's Issues

parallel support

Hi rl4f authors,

I have a follow-up question regarding your support for tuning recent llm such as vicuna. Those models usually requires 8xA100 gpu support. I wonder if your codebase has parallel training part that could do this (e.g., using the way of "torchrun --nproc_per_node=8 --master_port=20001 demo.py" and setting model's device somehow).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.