Giter Site home page Giter Site logo

flaggems's Introduction

中文版

Introduction

FlagGems is a high-performance general operator library implemented in OpenAI Triton. It aims to provide a suite of kernel functions to accelerate LLM training and inference.

By registering with the ATen backend of PyTorch, FlagGems facilitates a seamless transition, allowing users to switch to the Triton function library without the need to modify their model code. Users can still utilize the ATen backend as usual while experiencing significant performance enhancement. The Triton language offers benefits in readability, user-friendliness and performance comparable to CUDA. This convenience allows developers to engage in the development of FlagGems with minimal learning investment.

Feature

Automatic Codegen

In FlagGems, we provide automatic code generation that developers can use to conveniently generate pointwise single operators and pointwise fused operators. Automatic code generation can handle various needs such as normal pointwise computations, non-tensor arguments, and specifying output data types.

Normal Pointwise Operator

Decorating the pointwise operator function with pointwise_dynamic can save the manual handling of tensor addressing, tensor read/write, parallel tiling, tensor broadcasting, dynamic dimensions, non-contiguous storage, etc. For example, in the following code, developers only need to describe the computational logic to generate flexible and efficient Triton code.

@pointwise_dynamic
@triton.jit
def abs_func(x):
    return tl.abs(x)

Non-Tensor Argument

By default, pointwise_dynamic treats all parameters as tensors, and by passing a list of boolean values to the parameter is_tensor, developers can specify which parameters are tensors and which are not. Additionally, developers can pass in dtypes to indicate the data types of non-tensor parameters, but this is not required. For example, in the following code, the alpha parameter is defined as a non-tensor floating point number, while the x and y parameters are defined as tensors.

@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float])
@triton.jit
def add_func(x, y, alpha):
    return x + y * alpha

Output Data Type

By default, all output tensors have the same data type as the first input tensor, but it can also be customized by providing a list of data types to the parameter output_dtypes. For example, in the following code, the output tensor type is specified as torch.bool.

@pointwise_dynamic(output_dtypes=[torch.bool])
@triton.jit
def ge(x, y):
    return x > y

Changelog

v1.0

  • support BLAS operators: addmm, bmm, mm
  • support pointwise operators: abs, add, div, dropout, exp, gelu, mul, pow, reciprocal, relu, rsqrt, silu, sub, triu
  • support reduction operators: cumsum, layernorm, mean, softmax

v2.0

  • support BLAS operator: mv, outer
  • support pointwise operators: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid
  • support reduction operators: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm
  • support fused operators: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding

Quick Start

Requirements

  1. Triton >= 2.2.0
  2. PyTorch >= 2.1.2
  3. Transformers >= 4.40.2

Installation

git clone https://github.com/FlagOpen/FlagGems.git
cd FlagGems
pip install .

Usage

Import

  1. Enable permanently

    import flag_gems
    flag_gems.enable()
  2. Enable temporarily

    import flag_gems
    with flag_gems.use_gems():
        pass
  3. Example

    import torch
    import flag_gems
    
    M, N, K = 1024, 1024, 1024
    A = torch.randn((M, K), dtype=torch.float16, device="cuda")
    B = torch.randn((K, N), dtype=torch.float16, device="cuda")
    with flag_gems.use_gems():
        C = torch.mm(A, B)

Execute

  1. Test Operator Accuracy

    • Run reference on cuda
      cd tests
      pytest test_xx_ops.py
    • Run reference on cpu
      cd tests
      pytest test_xx_ops.py --device cpu
  2. Test Model Accuracy

    cd examples
    pytest model_xx_test.py
  3. Test Operator Performance

    • Test CUDA performance
      cd benchmark
      pytest test_xx_perf.py -s
    • Test end-to-end performance
      cd benchmark
      pytest test_xx_perf.py -s --mode cpu
  4. Run tests with logging infomation

    pytest program.py --log-cli-level debug

    Not recommended in performance testing.

Supported Operators

Operators will be implemented according to OperatorList.md.

Supported Models

  • Bert-base-uncased
  • Llama-2-7b

Supported Platforms

Platform float16 float32 bfloat16
Nvidia A100

Performance

The following chart shows the speedup of FlagGems compared with PyTorch ATen library in eager mode. The speedup is calculated by averaging the speedup on each shape, representing the overall performance of the operator.

Operator Speedup

Contributions

If you are interested in contributing to the FlagGems project, please refer to CONTRIBUTING.md. Any contributions would be highly appreciated.

Contact us

If you have any questions about our project, please submit an issue, or contact us through [email protected].

License

The FlagGems project is based on Apache 2.0.

flaggems's People

Contributors

strongspoon avatar jokmingwong avatar bowen12992 avatar iclementine avatar tongxin avatar fatjhon avatar mard1no avatar gwokhiujin avatar phoenixdong avatar pingzhuu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.