Giter Site home page Giter Site logo

minichain's Introduction

A tiny library for coding with large language models. Check out the MiniChain Zoo to get a sense of how it works.

Coding

  • Code (math_demo.py): Annotate Python functions that call language models.
@prompt(OpenAI(), template_file="math.pmpt.tpl")
def math_prompt(model, question):
    "Prompt to call GPT with a Jinja template"
    return model(dict(question=question))

@prompt(Python(), template="import math\n{{code}}")
def python(model, code):
    "Prompt to call Python interpreter"
    code = "\n".join(code.strip().split("\n")[1:-1])
    return model(dict(code=code))

def math_demo(question):
    "Chain them together"
    return python(math_prompt(question))
  • Chains (Space): MiniChain builds a graph (think like PyTorch) of all the calls you make for debugging and error handling.

show(math_demo,
     examples=["What is the sum of the powers of 3 (3^i) that are smaller than 100?",
               "What is the sum of the 10 first positive integers?"],
     subprompts=[math_prompt, python],
     out_type="markdown").queue().launch()
...
Question:
A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?
Code:
2 + 2/2

Question:
{{question}}
Code:
  • Installation
pip install minichain
export OPENAI_API_KEY="sk-***"

Examples

This library allows us to implement several popular approaches in a few lines of code.

It supports the current backends.

  • OpenAI (Completions / Embeddings)
  • Hugging Face ๐Ÿค—
  • Google Search
  • Python
  • Manifest-ML (AI21, Cohere, Together)
  • Bash

Why Mini-Chain?

There are several very popular libraries for prompt chaining, notably: LangChain, Promptify, and GPTIndex. These library are useful, but they are extremely large and complex. MiniChain aims to implement the core prompt chaining functionality in a tiny digestable library.

Tutorial

Mini-chain is based on annotating functions as prompts.

image

@prompt(OpenAI())
def color_prompt(model, input):
    return model(f"Answer 'Yes' if this is a color, {input}. Answer:")

Prompt functions act like python functions, except they are lazy to access the result you need to call run().

if color_prompt("blue").run() == "Yes":
    print("It's a color")

Alternatively you can chain prompts together. Prompts are lazy, so if you want to manipulate them you need to add @transform() to your function. For example:

@transform()
def said_yes(input):
    return input == "Yes"

image

@prompt(OpenAI())
def adjective_prompt(model, input):
    return model(f"Give an adjective to describe {input}. Answer:")
adjective = adjective_prompt("rainbow")
if said_yes(color_prompt(adjective)).run():
    print("It's a color")

We also include an argument template_file which assumes model uses template from the Jinja language. This allows us to separate prompt text from the python code.

@prompt(OpenAI(), template_file="math.pmpt.tpl")
def math_prompt(model, question):
    return model(dict(question=question))

Visualization

MiniChain has a built-in prompt visualization system using Gradio. If you construct a function that calls a prompt chain you can visualize it by calling show and launch. This can be done directly in a notebook as well.

show(math_demo,
     examples=["What is the sum of the powers of 3 (3^i) that are smaller than 100?",
              "What is the sum of the 10 first positive integers?"],
     subprompts=[math_prompt, python],
     out_type="markdown").queue().launch()

Memory

MiniChain does not build in an explicit stateful memory class. We recommend implementing it as a queue.

image

Here is a class you might find useful to keep track of responses.

@dataclass
class State:
    memory: List[Tuple[str, str]]
    human_input: str = ""

    def push(self, response: str) -> "State":
        memory = self.memory if len(self.memory) < MEMORY_LIMIT else self.memory[1:]
        return State(memory + [(self.human_input, response)])

See the full Chat example. It keeps track of the last two responses that it has seen.

Tools and agents.

MiniChain does not provide agents or tools. If you want that functionality you can use the tool_num argument of model which allows you to select from multiple different possible backends. It's easy to add new backends of your own (see the GradioExample).

@prompt([Python(), Bash()])
def math_prompt(model, input, lang):
    return model(input, tool_num= 0 if lang == "python" else 1)

Documents and Embeddings

MiniChain does not manage documents and embeddings. We recommend using the Hugging Face Datasets library with built in FAISS indexing.

image

Here is the implementation.

# Load and index a dataset
olympics = datasets.load_from_disk("olympics.data")
olympics.add_faiss_index("embeddings")

@prompt(OpenAIEmbed())
def get_neighbors(model, inp, k):
    embedding = model(inp)
    res = olympics.get_nearest_examples("embeddings", np.array(embedding), k)
    return res.examples["content"]

This creates a K-nearest neighbors (KNN) prompt that looks up the 3 closest documents based on embeddings of the question asked. See the full Retrieval-Augemented QA example.

We recommend creating these embeddings offline using the batch map functionality of the datasets library.

def embed(x):
    emb = openai.Embedding.create(input=x["content"], engine=EMBEDDING_MODEL)
    return {"embeddings": [np.array(emb['data'][i]['embedding'])
                           for i in range(len(emb["data"]))]}
x = dataset.map(embed, batch_size=BATCH_SIZE, batched=True)
x.save_to_disk("olympics.data")

There are other ways to do this such as sqllite or Weaviate.

Typed Prompts

MiniChain can automatically generate a prompt header for you that aims to ensure the output follows a given typed specification. For example, if you run the following code MiniChain will produce prompt that returns a list of Player objects.

class StatType(Enum):
    POINTS = 1
    REBOUNDS = 2
    ASSISTS = 3

@dataclass
class Stat:
    value: int
    stat: StatType

@dataclass
class Player:
    player: str
    stats: List[Stat]


@prompt(OpenAI(), template_file="stats.pmpt.tpl", parser="json")
def stats(model, passage):
    out = model(dict(passage=passage, typ=type_to_prompt(Player)))
    return [Player(**j) for j in out]

Specifically it will provide your template with a string typ that you can use. For this example the string will be of the following form:

You are a highly intelligent and accurate information extraction system. You take passage as input and your task is to find parts of the passage to answer questions.

You need to output a list of JSON encoded values

You need to classify in to the following types for key: "color":

RED
GREEN
BLUE


Only select from the above list, or "Other".โŽ


You need to classify in to the following types for key: "object":โŽ

String



You need to classify in to the following types for key: "explanation":

String

[{ "color" : "color" ,  "object" : "object" ,  "explanation" : "explanation"}, ...]

Make sure every output is exactly seen in the document. Find as many as you can.

This will then be converted to an object automatically for you.

minichain's People

Contributors

gureckis avatar krandiash avatar neilconway avatar nyu-comp-lab-student avatar shijie-wu avatar srush avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

minichain's Issues

Python 3.11 data class error

When using Python 3.11 importing MiniChain causes an error.

mfinlayson  ~/MiniChain $  python -c "import minichain"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/mfinlayson/MiniChain/minichain/__init__.py", line 1, in <module>
    from .backend import (
  File "/Users/mfinlayson/MiniChain/minichain/backend.py", line 309, in <module>
    @dataclass
     ^^^^^^^^^
  File "/Users/mfinlayson/.pyenv/versions/3.11.9/lib/python3.11/dataclasses.py", line 1232, in dataclass
    return wrap(cls)
           ^^^^^^^^^
  File "/Users/mfinlayson/.pyenv/versions/3.11.9/lib/python3.11/dataclasses.py", line 1222, in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mfinlayson/.pyenv/versions/3.11.9/lib/python3.11/dataclasses.py", line 958, in _process_class
    cls_fields.append(_get_field(cls, name, type, kw_only))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mfinlayson/.pyenv/versions/3.11.9/lib/python3.11/dataclasses.py", line 815, in _get_field
    raise ValueError(f'mutable default {type(f.default)} for field '
ValueError: mutable default <class 'minichain.backend.RunLog'> for field run_log is not allowed: use default_factory

This can be fixed by changing RunLog() to field(default_factory=RunLog) in backends.py, though I'm not sure if this would introduce other bugs.

[feature request] untangle `gradio` dependency from minichain "core"

gradio is a fairly heavyweight dependency and may not be necessary in production deployments. currently, import minichain requires gradio to be installed via

from .gradio import GradioConf, show

and

import gradio as gr

would it be possible to untangle gradio from minichain core and make it optional? more generally, would it be possible to build a minichain "core" which additional components then build off of?

error when running examples

I got below error when running math_demo.ipynb, do you know what is wrong here ?

TypeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 gradio = show(math_demo,
2 examples=["What is the sum of the powers of 3 (3^i) that are smaller than 100?",
3 "What is the sum of the 10 first positive integers?",],
4 # "Carla is downloading a 200 GB file. She can download 2 GB/minute, but 40% of the way through the download, the download fails. Then Carla has to restart the download from the beginning. How load did it take her to download the file in minutes?"],
5 subprompts=[math_prompt, python],
6 out_type="json",
7 description=desc,
8 )
9 if name == "main":
10 gradio.launch()

File ~\anaconda3\envs\sicl-llm\lib\site-packages\minichain\gradio.py:316, in show(prompt, examples, subprompts, fields, initial_state, out_type, keys, description, code, css, show_advanced)
314 inputs = list([gr.Textbox(label=f) for f in fields])
315 examples = gr.Examples(examples=examples, inputs=inputs)
--> 316 query_btn = gr.Button(label="Run")
317 constructor = constructor.add_inputs(inputs)
319 with gr.Box():
320 # Intermediate prompt displays

File ~\anaconda3\envs\sicl-llm\lib\site-packages\gradio\component_meta.py:157, in updateable..wrapper(*args, **kwargs)
155 return None
156 else:
--> 157 return fn(self, **kwargs)

TypeError: init() got an unexpected keyword argument 'label'

Graph schedulers?

Something like dask to run graphs in parallel, where safe?

This maybe deserves a separate issue, but what about integrating with a prompt templating language like lmql or guidance

access the rendered template during development/testing

Hi great looking library. Love the simplicity!
I'm playing with things and exploring the examples but I'm sort of curious how I could print out my rendered template prior to running it. For example from this basic example:

@prompt(OpenAI(), template_file="math.pmpt.tpl")
def math_prompt(model, question):
    "Prompt to call GPT with a Jinja template"
    return model(dict(question=question))

I see that a call to
math_prompt('what is 1+1?').run() will execute but was thinking with the lazy loading that something like math_prompt("what is 1+1?").render() would print out the rendered jinja template. I see how could just use jinja for that but in the flow of debugging it seems a useful utility function. Happy to make a PR if you give me a hint or let me know how you would think about this.

<inor update... part of the reason is that the function might itself provide new inputs to the template using hard to predict computations. As a simple example, I look up the date today and add it to the template:

@prompt(Mock(answers=["blah"]), template_file="test.pmpt.tpl")
def date_prompt(model, question):
    "Prompt to call GPT with a Jinja template"
    today = date.today().strftime("%A %B %d, %Y")
    print(model.prompt.template_fill(dict(question=question, today=today)))
    return model(dict(question=question, today=today))

where test.pmpt.tpl is

Todays date is {{today}}
{{question}}

The snippet above worked to print out the final template during the run but maybe could be done prompt or chain itself.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.