Giter Site home page Giter Site logo

huanglk / lightning-transformers Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lightning-universe/lightning-transformers

0.0 0.0 0.0 764 KB

Flexible components pairing ๐Ÿค— Transformers with Pytorch Lightning

Home Page: https://lightning-transformers.readthedocs.io

License: Apache License 2.0

Python 99.59% Makefile 0.41%

lightning-transformers's Introduction

Flexible components pairing ๐Ÿค— Transformers with Pytorch Lightning


Docs โ€ข Community


Installation

pip install lightning-transformers
From Source
git clone https://github.com/PyTorchLightning/lightning-transformers.git
cd lightning-transformers
pip install .

What is Lightning-Transformers

Lightning Transformers provides LightningModules, LightningDataModules and Strategies to use ๐Ÿค— Transformers with the PyTorch Lightning Trainer.

Quick Recipes

Train bert-base-cased on the CARER emotion dataset using the Text Classification task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="bert-base-cased"
)
dm = TextClassificationDataModule(
    batch_size=1,
    dataset_name="emotion",
    max_length=512,
    tokenizer=tokenizer,
)
model = TextClassificationTransformer(
    pretrained_model_name_or_path="bert-base-cased", num_labels=dm.num_classes
)

trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Train a pre-trained mt5-base backbone on the WMT16 dataset using the Translation task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.translation import (
    TranslationTransformer,
    WMT16TranslationDataModule,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="google/mt5-base"
)
model = TranslationTransformer(
    pretrained_model_name_or_path="google/mt5-base",
    n_gram=4,
    smooth=False,
    val_target_max_length=142,
    num_beams=None,
    compute_generate_metrics=True,
)
dm = WMT16TranslationDataModule(
    # WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
    dataset_config_name="ro-en",
    source_language="en",
    target_language="ro",
    max_source_length=128,
    max_target_length=128,
    padding="max_length",
    tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Lightning Transformers supports a bunch of ๐Ÿค— tasks and datasets. See the documentation.

Billion Parameter Model Support

Big Model Inference

It's really easy to enable large model support for the pre-built LightningModule ๐Ÿค— tasks.

Below is an example to enable automatic model partitioning (across CPU/GPU and even leveraging disk space) to run text generation using a 6B parameter model.

import torch
from accelerate import init_empty_weights
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.language_modeling import (
    LanguageModelingTransformer,
)

with init_empty_weights():
    model = LanguageModelingTransformer(
        pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
        tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"),
        low_cpu_mem_usage=True,
        device_map="auto",  # automatically partitions the model based on the available hardware.
    )

output = model.generate("Hello, my name is", device=torch.device("cuda"))
print(model.tokenizer.decode(output[0].tolist()))

For more information see Big Transformers Model Inference.

Big Model Training with DeepSpeed

Below is an example of how you can train a 6B parameter transformer model using Lightning Transformers and DeepSpeed.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.language_modeling import (
    LanguageModelingDataModule,
    LanguageModelingTransformer,
)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="gpt2")

model = LanguageModelingTransformer(
    pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
    tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"),
    deepspeed_sharding=True,  # defer initialization of the model to shard/load pre-train weights
)

dm = LanguageModelingDataModule(
    batch_size=1,
    dataset_name="wikitext",
    dataset_config_name="wikitext-2-raw-v1",
    tokenizer=tokenizer,
)
trainer = pl.Trainer(
    accelerator="gpu",
    devices="auto",
    strategy="deepspeed_stage_3",
    precision=16,
    max_epochs=1,
)

trainer.fit(model, dm)

For more information see DeepSpeed Training with Big Transformers Models or the Model Parallelism documentation.

Contribute

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

Community

For help or questions, join our huge community on Slack!

lightning-transformers's People

Contributors

seannaren avatar borda avatar carmocca avatar pre-commit-ci[bot] avatar tchaton avatar dependabot[bot] avatar mathemusician avatar kaushikb11 avatar tanmoyio avatar edenlightning avatar williamfalcon avatar taeminlee avatar zippeurfou avatar mariomeissner avatar ethanwharris avatar akihironitta avatar benschreiber avatar espoirmur avatar maksym-taranukhin avatar bzantium avatar aribornstein avatar rr-28023 avatar rohitgr7 avatar ghomashudson avatar bofenghuang avatar karthikrangasai avatar junnyu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.