Giter Site home page Giter Site logo

leesinliang / microgpt Goto Github PK

View Code? Open in Web Editor NEW
33.0 2.0 8.0 69.57 MB

Implementation of GPT from scratch. Design to be lightweight and easy to modify.

Home Page: https://microgpt.streamlit.app

License: MIT License

Python 99.65% Shell 0.35%
gpt gpt-2-text-generation pytorch top-k-sampling top-p-sampling transformer transformer-decoder gpt-scratch deep-learning natural-language-processing

microgpt's People

Contributors

imgbotapp avatar leesinliang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

microgpt's Issues

Mislabeled file in inference.py

In inference.py line #11 "tokenizer/tokenizer16384_v2.json", it should be just "tokenizer/tokenizer.json".
Having looked at the other repo for deploy; deploy repo calls for "16384"

Corrected code:

from dataclasses import dataclass
from gpt import GPT
from transformers import GPT2TokenizerFast
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# modify the parameters here
max_length = 512
model_path = "models/microGPT.pth"
tokenizer_path = "tokenizer/tokenizer.json"
n_tokens = 1000
temperature = 0.8
top_k = 0
top_p = 0.9

tokenizer = GPT2TokenizerFast(tokenizer_file=tokenizer_path)

@dataclass
class GPTConfig:
    n_embd = 768
    vocab_size = len(tokenizer.get_vocab())
    max_length = 512
    n_head = 8
    n_layer = 8
    dropout = 0.0
    training = True
    pad_token = tokenizer.convert_tokens_to_ids('[PAD]')
    
config = GPTConfig
model = GPT(config)

model_stat = torch.load(model_path)
model.load_state_dict(model_stat["model_state_dict"])
model = model.to(device)

# If you train on the original dataset that the model is trained (minipile https://arxiv.org/abs/2304.08442), the model can generate code, stories, dialogues... etc
context = '''Marlene: Good afternoon Houston division, I am so excited to be here with you talking about an exciting quarter for our division. We are so excited to introduce someone who is here with us for the first time. Rachel Ross!
Rachel: Thank you Marlene. In March, I assumed the role of Vice President of Merchandising for the Houston Division. I came from the Michigan Division so the heat and humidity has been quite a change, but being with this division’s team has been so amazing.
Marlene: Rachel, we are glad to have you here and excited about all of the energy you have already brought to the team. First let’s hear from our Division Controller, Akin Akanni, about how we did financially in the Houston Division this quarter.
Akin: Thanks guys, He spoke of how rare it is to receive the amazing level service that he provided in other stores. Thank you Brent for giving our customers highly satisfying service.  We are so proud to have you on our Houston team.
Marlene and Mike: Way to go Brent!
'''
context = torch.tensor(tokenizer.encode(context), dtype=torch.long, device=device).reshape(1, -1).to(device)
print(
    tokenizer.decode(
        model.generate(
            context, max_tokens_generate=n_tokens, top_k=top_k, top_p=top_p, temperature=temperature
        ).tolist()
    )
)

Missing import

"import os" missing from /datasets/prepare_dataset.py

import os
from datasets import load_dataset
from transformers import GPT2TokenizerFast
from tqdm import tqdm
import numpy as np

tokenizer_path="tokenizer/tokenizer.json"
dataset_dir="datasets/"

tokenizer = GPT2TokenizerFast(
    tokenizer_file=tokenizer_path,
    pad_token="[PAD]",
    padding_side="left",
)

dataset = load_dataset("JeanKaddour/minipile")

def process(data):
    inpt = tokenizer(data['text'])
    inpt['input_ids'].append(tokenizer.eos_token_id)
    out = {'input_ids': inpt['input_ids'], 'len': len(inpt['input_ids'])}
    return out

dataset = dataset.map(
        process,
        remove_columns=['text'],
        num_proc=8,
)

for split, dset in dataset.items():
    arr_len = np.sum(dset['len'], dtype=np.uint64)
    filename = os.path.join(dataset_dir, f'{split}.bin')
    dtype = np.uint16
    arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
    total_batches = 500

    idx = 0
    for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
        batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
        arr_batch = np.concatenate(batch['input_ids'])
        arr[idx : idx + len(arr_batch)] = arr_batch
        idx += len(arr_batch)
    arr.flush()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.