Giter Site home page Giter Site logo

goat's Introduction

๐Ÿ Goat: Fine-tuned LLaMA Outperforms GPT-4 on Arithmetic Tasks

[Paper] | [Adapter Weights] | [Dataset] | [Colab]

Demo

  1. Addition
Alt text Alt text
2. Subtraction
Alt text Alt text
3. Multiplication
Alt text Alt text
4. Division
Alt text Alt text

Local Setup

git clone https://github.com/liutiedong/goat.git 
cd goat
pip install -r requirements.txt

Dataset (dataset.ipynb)

Run dataset.ipynb to generate dataset.json file, or download from HuggingFace dataset tiedong/goat (https://huggingface.co/datasets/tiedong/goat). Each instance in the dataset contains

  • instruction: human instruction created by inserting an arithmetic expression to a randomly chosen template and adding some natural language noises. It serves as prompts to be fed to the model for instruction-finetuning.
  • input: a randomly generated arithmetic expression. It can be used to replace 'instruction' for training when we want to focus on arithmetic and avoid the influence of natural language.
  • output: the target output for the model to learn. It contains CoTs for multi-digit multiplication and division.
  • answer: direct numerical answer to the arithmetic task. It can be used to test learnability of various sub-tasks.

Example:

{
    "instruction": "What is 94140209+73?",
    "input": "94140209 + 73",
    "output": "94140209 + 73 = 94140282",
    "answer": "94140282"
},
{
    "instruction": "Compute 8432862 - 659016175?",
    "input": "8432862 - 659016175",
    "output": "8432862 - 659016175 = -650583313",
    "answer": "-650583313"
},
{
    "instruction": "Calculate 37 times 3066",
    "input": "37 * 3066",
    "output": "37 * 3066 = 3066 * (30 + 7) = 3066 * 30 + 3066 * 7 = 91980 + 21462 = 113442",
    "answer": "113442"
},
{
    "instruction": "Determine the numerical value of 5697/47.",
    "input": "5697 / 47",
    "output": "5697 - 47 * 100 = 5697 - 4700 = 997\n997 - 47 * 20 = 997 - 940 = 57\n57 - 47 * 1 = 57 - 47 = 10\nTherefore, 5697 / 47 = 121 R 10",
    "answer": "121 R 10"
},

Feel free to modify dataset.ipynb to create your own data.

It is good to start with a simple sub-task, say 8-digit by 8-digit addition,

pairs = [(random.randint(10**7, 10**8), random.randint(10**7, 10**8)) for k in range(100000)]

It only takes less than 2 hours of finetuning to achieve near-perfect accuracy (100000 training samples on A10 GPU).

Template (goat.json)

template.txt contains several hundred natural language instructions. Instructions that are more commonly used are duplicated more times to increase their chances of being sampled. Instructions that are generated using ChatGPT are listed behind without duplication. Note that some instructions may not be coherent or grammatical correct after inserting arithmetic expressions, but it should not be a problem if we do not train on input.

To add more instructions for training, put new instructions in template.txt under templates folder. Then run python convert_txt_to_json.py to convert to goat.json file, which is used by dataset.ipynb to generate dataset for fine-tuning.

Training (finetune.py)

Example usage:

python finetune.py \
    --base_model 'decapoda-research/llama-7b-hf' \
    --data_path 'dataset.json' \
    --output_dir './weights'

We train our model using the following command:

python finetune.py \
    --base_model 'decapoda-research/llama-7b-hf' \
    --data_path 'dataset.json' \
    --output_dir './weights' \
    --batch_size 128 \
    --micro_batch_size 16 \
    --num_epochs 1 \
    --learning_rate 1e-4 \
    --cutoff_len 512 \
    --val_set_size 0 \
    --lora_r 64 \
    --lora_alpha 64 \
    --lora_dropout 0.05 \
    --lora_target_modules '[q_proj,v_proj,k_proj,o_proj]' \

Inference (app.py)

This file downloads LoRA weights from HuggingFace tiedong/goat-lora-7b, and runs a Gradio interface for inference.

Example usage:

python app.py \
    --base_model 'decapoda-research/llama-7b-hf' \
    --lora_weights 'tiedong/goat-lora-7b'

Alternatively, host your own Goat gradio demo directly in Colab with this notebook.

Citation

@article{liu2023goat,
  title={Goat: Fine-tuned LLaMA Outperforms GPT-4 on Arithmetic Tasks},
  author={Liu, Tiedong and Low, Bryan Kian Hsiang},
  journal={arXiv preprint arXiv:2305.14201},
  year={2023}
}

Acknowledgements

Our implementation is mainly based on Alpaca-LoRA.

goat's People

Contributors

liutiedong avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

goat's Issues

License File

Hi!

Thanks a lot for releasing this code and data.

I saw that the data on HF is under Apache 2.0 license but I couldn't see a license file for the code.

Could you add one?

Thanks a lot!

right way to prompt the goat model , also request for chat interface

@liutiedong what is the right way to prompt the model.
say , i have the following code , how would you edit it to fix the prompting :

def generate_prompt_with_history(text,history,tokenizer,max_length=2048):
    prompt = "The following is a conversation between a human and an AI assistant named Goat. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n[|Human|]Hello!\n[|AI|]Hi!"   
    history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0],x[1]) for x in history]
    history.append("\n[|Human|]{}\n[|AI|]".format(text))
    history_text = ""
    flag = False
    for x in history[::-1]:
        if tokenizer(prompt+history_text+x, return_tensors="pt")['input_ids'].size(-1) <= max_length:
            history_text = x + history_text
            flag = True
        else:
            break
    if flag:
        return  prompt+history_text,tokenizer(prompt+history_text, return_tensors="pt")
    else:
        return None


def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
    for stop_word in stop_words:
        if s.endswith(stop_word):
            return True
        for i in range(1, len(stop_word)):
            if s.endswith(stop_word[:i]):
                return True
    return False

With the above code I'm getting hallucinations , can you suggest a fix.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.