Giter Site home page Giter Site logo

facebookresearch / xformers Goto Github PK

View Code? Open in Web Editor NEW
8.3K 76.0 585.0 41.2 MB

Hackable and optimized Transformers building blocks, supporting a composable construction.

Home Page: https://facebookresearch.github.io/xformers/

License: Other

Python 49.53% Shell 0.13% C++ 36.11% Cuda 11.43% C 2.70% CMake 0.10%

xformers's Introduction

Install with conda Downloads License Open in Colab
CircleCI Codecov black
PRs welcome


xFormers - Toolbox to Accelerate Research on Transformers

xFormers is:

  • Customizable building blocks: Independent/customizable building blocks that can be used without boilerplate code. The components are domain-agnostic and xFormers is used by researchers in vision, NLP and more.
  • Research first: xFormers contains bleeding-edge components, that are not yet available in mainstream libraries like PyTorch.
  • Built with efficiency in mind: Because speed of iteration matters, components are as fast and memory-efficient as possible. xFormers contains its own CUDA kernels, but dispatches to other libraries when relevant.

Installing xFormers

# (python 3.10/3.11 only)
conda install xformers -c xformers
  • (RECOMMENDED, linux & win) Install latest stable with pip: Requires PyTorch 2.4.1
# [linux only] cuda 11.8 version
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
# [linux only] cuda 12.1 version
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
# [linux & win] cuda 12.4 version
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu124
# [linux only] (EXPERIMENTAL) rocm 6.1 version
pip3 install -U xformers --index-url https://download.pytorch.org/whl/rocm6.1
  • Development binaries:
# Use either conda or pip, same requirements as for the stable version above
conda install xformers -c xformers/label/dev
pip install --pre -U xformers
  • Install from source: If you want to use with another version of PyTorch for instance (including nightly-releases)
# (Optional) Makes the build much faster
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
# (this can take dozens of minutes)

Benchmarks

Memory-efficient MHA Benchmarks for ViTS Setup: A100 on f16, measured total time for a forward+backward pass

Note that this is exact attention, not an approximation, just by calling xformers.ops.memory_efficient_attention

More benchmarks

xFormers provides many components, and more benchmarks are available in BENCHMARKS.md.

(Optional) Testing the installation

This command will provide information on an xFormers installation, and what kernels are built/available:

python -m xformers.info

Using xFormers

Key Features

  1. Optimized building blocks, beyond PyTorch primitives
    1. Memory-efficient exact attention - up to 10x faster
    2. sparse attention
    3. block-sparse attention
    4. fused softmax
    5. fused linear layer
    6. fused layer norm
    7. fused dropout(activation(x+bias))
    8. fused SwiGLU

Install troubleshooting

  • NVCC and the current CUDA runtime match. Depending on your setup, you may be able to change the CUDA runtime with module unload cuda; module load cuda/xx.x, possibly also nvcc
  • the version of GCC that you're using matches the current NVCC capabilities
  • the TORCH_CUDA_ARCH_LIST env variable is set to the architectures that you want to support. A suggested setup (slow to build but comprehensive) is export TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.2;7.5;8.0;8.6"
  • If the build from source OOMs, it's possible to reduce the parallelism of ninja with MAX_JOBS (eg MAX_JOBS=2)
  • If you encounter UnsatisfiableError when installing with conda, make sure you have PyTorch installed in your conda environment, and that your setup (PyTorch version, cuda version, python version, OS) match an existing binary for xFormers

License

xFormers has a BSD-style license, as found in the LICENSE file.

Citing xFormers

If you use xFormers in your publication, please cite it by using the following BibTeX entry.

@Misc{xFormers2022,
  author =       {Benjamin Lefaudeux and Francisco Massa and Diana Liskovich and Wenhan Xiong and Vittorio Caggiano and Sean Naren and Min Xu and Jieru Hu and Marta Tintore and Susan Zhang and Patrick Labatut and Daniel Haziza and Luca Wehrstedt and Jeremy Reizenstein and Grigory Sizov},
  title =        {xFormers: A modular and hackable Transformer modelling library},
  howpublished = {\url{https://github.com/facebookresearch/xformers}},
  year =         {2022}
}

Credits

The following repositories are used in xFormers, either in close to original form or as an inspiration:

xformers's People

Contributors

abdbarho avatar artkorenev avatar blefaudeux avatar bottler avatar clashluke avatar danthe3rd avatar dianaml0 avatar drisspg avatar eltociear avatar erip avatar fmassa avatar jianyuh avatar jieru-hu avatar kashif avatar kit1980 avatar lvaleriu avatar lw avatar mpu avatar nottombrown avatar optyang avatar patricklabatut avatar pradeep90 avatar qianfengz avatar scxiao avatar sgrigory avatar spdraptor avatar tenpercent avatar yuanandonly avatar zhiqwang avatar zyan0 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

xformers's Issues

switch Triton parts on/off depending on the torch AMP context

๐Ÿš€ Feature

There are a couple of fused parts which are faster than pytorch in fp16, but slower in fp32 for now. By enabling them we risk users running the library in fp32 and finding it slow when compared to pytorch. We could either check the input or the torch AMP context, and fallback to pytorch if we're in a known perf hole

Motivation

Things should be fast by default, regardless of what the user do

Support Tensor parallel or pipeline parallel out of the box

๐Ÿš€ Feature

Support tensor // or model // as a built in feature, through Fairscale ? cc @min-xu-ai @anj-s @suchenzang @VitaliyLi @iyerr3

Motivation

This is typically extra work for users, but squarely in the model space. xFormers could invest in a generic engineering effort here

Pitch

Ideally, either a factory option or a dedicated model wrapper, and make this mostly transparent to users

Alternatives

Not doing that

Additional context

Discussed with @suchenzang

Issues when using blocksparse to build simple local block attentions

โ“ Questions and Help

@blefaudeux @dianaml0

  1. I found the doc here is a bit confusing: it passed a layout of size [SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] but the code comments here says the layout should have size [HEAD, SEQ, SEQ]. The layout I passed to build the local block attention is
            blocks = seq_len // block_size
            layout = torch.eye(blocks)

Is this the right way to do this?

  1. What are the environment requirements for Triton to work? I got this error triton-lang/triton#322 with CUDA 11.1 on A100. I tried to update my cuda toolkit to 11.3, 11.4.2 but both did not work for me

Thanks in advance!

[OpenFold] Custom parts / AlphaFold

๐Ÿš€ Feature

Checkout the discussion here, OpenFold is missing a couple of custom/optimized parts in PyTorch and maybe that xformers could host some of these

Motivation

  • Transformer related
  • Custom parts, not handled out of the box by the PyTorch primitives in a fast enough fashion
  • Possibly good candidates for Triton, Functorch or Sputnik, but would need more reading..

Pitch

Implement some of these primitives in xformers

Alternatives

Well, do nothing

Additional context

AlphaFold2 supplementary data: https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf

ViT example broke with the model_factory refactor

๐Ÿ› Bug

Breaks on the config, we should add a quick run on that one also to make sure that this does not happen

Command

python3 xformers/benchmarks/benchmark_vit_timm.py

Callstack

Global seed set to 42
Traceback (most recent call last):
  File "/home/lefaudeux/Git/xformers/xformers/benchmarks/benchmark_vit_timm.py", line 364, in <module>
    lm = VisionTransformer(
  File "/home/lefaudeux/Git/xformers/xformers/benchmarks/benchmark_vit_timm.py", line 188, in __init__
    config = xFormerConfig(xformer_config)
  File "/home/lefaudeux/Git/xformers/xformers/factory/model_factory.py", line 40, in __init__
    if config["block_type"] == "encoder":
KeyError: 'block_type'

Dropout uses the *memory address* of seeds instead of reading seeds from memory

๐Ÿ› Bug

From reading the code for k_dropout_fw and k_dropout_bw, it seems to me that the seeds are never read from memory and the code simply uses the memory address of the seed.
For example:

    rand1, rand2, rand3, rand4 = tl.randint4x(seed.to(tl.int32), rand_offsets)

Here seed is still a memory address and not an integer.

As a result, when k_dropout_fw is passed in two identical seed tensors with different memory addresses, it produces different results.

To Reproduce

Setting the Pytorch seed should produce the same seed used in dropout, and should produce the same dropout mask.
However, that's not the case

import torch
from xformers.triton.dropout import dropout

x = torch.randn(3, 5, device='cuda')
print(x)

torch.manual_seed(0)
torch.cuda.manual_seed(0)
print(dropout(x, 0.5))

torch.manual_seed(0)
torch.cuda.manual_seed(0)
print(dropout(x, 0.5))
 tensor([[ 0.4821, -1.6949,  0.8196,  1.9093, -1.0018],
        [ 0.4030, -1.5175, -0.3187, -0.0959,  2.7204],
        [ 1.0645, -0.1254,  0.3978, -2.9882,  0.2232]], device='cuda:0')

tensor([[ 0.9642,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.8059, -3.0350,  0.0000, -0.1918,  5.4409],
        [ 0.0000, -0.2507,  0.7955,  0.0000,  0.0000]], device='cuda:0')

tensor([[ 0.9642, -3.3897,  0.0000,  3.8186, -2.0037],
        [ 0.0000,  0.0000, -0.6374,  0.0000,  5.4409],
        [ 2.1290, -0.2507,  0.7955,  0.0000,  0.4464]], device='cuda:0')
  • PyTorch Version (e.g., 1.0): 1.10.1
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): pip install -e . (from master)
  • Python version: 3.8
  • CUDA/cuDNN version: 11.3
  • GPU models and configuration: V100

Benchmark with Timm

๐Ÿš€ Feature

Timm is the SOTA in the field for vision transformers, it would be a good touchstone to check where a vanilla xformers speed lands

Motivation

speed is important to be compelling

Blocksparse crashes in the encoder benchmark

๐Ÿ› Bug

Running python3 xformers/benchmarks/benchmark_encoder.py --activations relu --plot -emb 256 -bs 32 -heads 16 -mlp MLP -a blocksparse with latest Triton crashes on an illegal memory access.

The dedicated blocksparse matmul benchmark runs, same for the unit test

Expected behavior

Well, not crashing

Environment

PyTorch version: 1.9.1+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8 (64-bit runtime)
Python platform: Linux-5.4.0-52-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: 
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 450.80.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.9.1+cu111
[pip3] torch-tb-profiler==0.2.1
[pip3] torchaudio==0.9.1
[pip3] torchvision==0.10.1+cu111
[conda] numpy                     1.19.5                   pypi_0    pypi
[conda] pytorch-sphinx-theme      0.0.24                   pypi_0    pypi
[conda] torch                     1.9.1+cu111              pypi_0    pypi
[conda] torch-tb-profiler         0.2.1                    pypi_0    pypi
[conda] torchaudio                0.9.1                    pypi_0    pypi
[conda] torchvision               0.10.1+cu111             pypi_0    pypi

Additional context

Looks like something which can happen depending on the layout, not super clear why

[perf] merging computation across triton kernels

Is it possible to reasonably merge our matmul, activation, normalization, attention and soon-to-be-added GLU kernels into one large kernel without writing a fused kernel?

If we're able to do that, I propose adding triton kernels for RevNet so that we can compile the entire model instead of just parts of it. Unless anyone else wants to, I could work on implementing that.

[microGPT] Crash on sampling "sample not on the right device"

๐Ÿ› Bug

Maybe something which changed in lightning, but the sampling in the end of microGPT fails on the sample not being on the right device now

Command

python3 example/microGPT.py

Additional context

Callstack

Traceback (most recent call last):
  File "/home/lefaudeux/Git/xformers/examples/microGPT.py", line 325, in <module>
    y = sample(model, x, steps=1000, temperature=1.0, sample=True, top_k=10)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/lefaudeux/Git/xformers/examples/microGPT.py", line 246, in sample
    logits = model(x_cond)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lefaudeux/Git/xformers/examples/microGPT.py", line 158, in forward
    prediction = self.model(src)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lefaudeux/Git/xformers/xformers/factory/model_factory.py", line 159, in forward
    memory = encoder(memory, input_mask=encoder_input_mask)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lefaudeux/Git/xformers/xformers/factory/block_factory.py", line 315, in forward
    x = self.wrap_ff(x)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lefaudeux/Git/xformers/xformers/components/residual.py", line 48, in forward
    return inputs[0] + self.layer(*inputs, *args, **kwargs)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lefaudeux/Git/xformers/xformers/components/residual.py", line 69, in forward
    return self.sublayer(*x_norm, *args, **kwargs)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lefaudeux/Git/xformers/xformers/components/feedforward/fused_mlp.py", line 64, in forward
    return self.mlp(inputs)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lefaudeux/Git/xformers/xformers/triton/dropout.py", line 156, in forward
    return _dropout.apply(x, p, self.bias, self.activation, self.activation_grad)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py", line 94, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/home/lefaudeux/Git/xformers/xformers/triton/dropout.py", line 48, in forward
    k_dropout_fw[grid](
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/triton/code_gen.py", line 676, in __call__
    return self.kernel(*wargs, **kwargs, grid=self.grid)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/triton/code_gen.py", line 724, in __call__
    return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
  File "/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.9/site-packages/triton/code_gen.py", line 583, in __call__
    raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
ValueError: Arguments at index [0, 1, 2, 3] are on the wrong device. Only CUDA is supported at the moment`

feat(triton): InplaceNorm + InstanceNorm

I'd love to run LayerNorm in place and ideally also add InstanceNorm (by extracting the core normalization from LayerNorm) as HomebrewNLP is currently using a slow PyTorch-level implementation with a correct backward pass.

While we're at it, optionally fusing GLU and GLUv2 (gelu(f(x)) * g(x) + gelu(h(x))) with various activation functions and normalization might give another speed boost.

To add this myself, I'd need to fully understand triton's pointers and how to access the output instead of input in your LayerNorm implementation. Could you help me with that? or would you instead implement this yourself? Is this even in the scope of xformers?

Can't pickle with xformers after 81bc427

๐Ÿ› Bug

I'm running into pickling errors after commit 81bc427

Pickling is needed for some DDP contexts.

Pseudocode

import pickle

class ViTB(nn.Module):
    def __init__(self):
        super().__init__()
        self.xformer = xFormer.from_config(xFormerConfig([self.model_config]))  # eg: Generic ViT-B conf

pickle.dumps(ViTB())

I haven't bisected, but on head (093224e at this time) this results in:

AttributeError: Can't pickle local object '_init_from_params.<locals>.init_method'

Expected behavior

Model pickles correctly (needed for certain DDP contexts).

Environment

  • PyTorch Version (e.g., 1.0): 1.10.2 (tried earlier versions, same issue).
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, source): conda

`scaled_dot_product_attention` shouldn't have a `causal` flag

The causal flag in

causal: bool = False,
was added to enable fast triton kernels for softmax.

I believe we shouldn't have this flag, as it makes it look like we can combine attn_mask and causal, which doesn't look like it's the case.
Plus, the behavior of causal is embedded inside attn_mask already.

IMO we should remove causal, and instead have a custom class UpperTriangularMatrix which can be passed as an attn_mask, which gets then dispatched to the efficient causal softmax

[feat] Make the batch dimension broadcastable

๐Ÿš€ Feature

Support different key/query/value batch sizes, if 1 or the same value

Motivation

(popped up when working on salina) PyTorch MHA handles key/query/value with 1 to B batch sizes, will broadcast if in need, We donยดt do that and hard crash instead, so the users have to align dimensions from the outside

Pitch

Alternatives

Additional context

[bug] LRA config was broken by the recent mask changes

๐Ÿ› Bug

Broken LRA setup with Nystrom
-> should be fixed
-> could we make this part of CI somehow ?

Command

python3 run_tasks.py --attention nystrom --task listops --config code/config.json --world_size 1 --tb_dir logs/tb/banana --checkpoint_dir logs/banana

asserts on the initial tracing phase (when computing the flops) for nystrom

Expected behavior

Runs normally

Additional context

Stacktrace
-- Process 0 terminated with the following error: Traceback (most recent call last): File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap fn(i, *args) File "/private/home/lefaudeux/git/xformers/xformers/benchmarks/LRA/run_tasks.py", line 272, in benchmark model = build_model(args, config) File "/private/home/lefaudeux/git/xformers/xformers/benchmarks/LRA/run_tasks.py", line 80, in build_model args.logger.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/fvcore/nn/jit_analysis.py", line 247, in total stats = self._analyze() File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/fvcore/nn/jit_analysis.py", line 550, in _analyze graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/fvcore/nn/jit_analysis.py", line 175, in _get_scoped_trace_graph graph, _ = _get_trace_graph(module, inputs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/jit/_trace.py", line 1160, in _get_trace_graph outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward graph, out = torch._C._create_graph_by_tracing( File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper outs.append(self.inner(*trace_inputs)) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl result = forward_call(*input, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward result = self.forward(*input, **kwargs) File "/private/home/lefaudeux/git/xformers/xformers/factory/model_factory.py", line 161, in forward x = encoders(x, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl result = forward_call(*input, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward result = self.forward(*input, **kwargs) File "/private/home/lefaudeux/git/xformers/xformers/components/reversible.py", line 138, in forward return _ReversibleFunction.apply(x, self.blocks, block_kwargs) File "/private/home/lefaudeux/git/xformers/xformers/components/reversible.py", line 114, in forward x = block(x, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl result = forward_call(*input, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward result = self.forward(*input, **kwargs) File "/private/home/lefaudeux/git/xformers/xformers/components/reversible.py", line 65, in forward y1 = x1 + self.f(x2, record_rng=self.training, **f_args) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl result = forward_call(*input, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward result = self.forward(*input, **kwargs) File "/private/home/lefaudeux/git/xformers/xformers/components/reversible.py", line 41, in forward return self.net(*args, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl result = forward_call(*input, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward result = self.forward(*input, **kwargs) File "/private/home/lefaudeux/git/xformers/xformers/components/residual.py", line 78, in forward return self.sublayer(*x_norm, *args, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl result = forward_call(*input, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward result = self.forward(*input, **kwargs) File "/private/home/lefaudeux/git/xformers/xformers/components/multi_head_dispatch.py", line 307, in forward y = self.attention(q=q, k=k, v=v, att_mask=att_mask) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl result = forward_call(*input, **kwargs) File "/private/home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward result = self.forward(*input, **kwargs) File "/private/home/lefaudeux/git/xformers/xformers/components/attention/nystrom.py", line 171, in forward assert att_mask.dtype == torch.bool AssertionError

integrate with Lightning ecosystem CI

Hello and so happy to see you use Pytorch-Lightning! ๐ŸŽ‰
Just wondering if you already heard about quite the new Pytorch Lightning (PL) ecosystem CI where we would like to invite you to... You can check out our blog post about it: Stay Ahead of Breaking Changes with the New Lightning Ecosystem CI โšก
As you use PL framework for your cool project, we would like to enhance your experience and offer you safe updates to our future releases. At this moment, you run tests with a particular PL version, but it may accidentally happen that the next version will be incompatible with your project... ๐Ÿ˜• We do not intend to change anything on our project side, but still here we have a solution - ecosystem CI with testing both - your and our latest development head we can find it very early and prevent releasing eventually bad version... ๐Ÿ‘

What is needed to do?

What will you get?

  • scheduled nightly testing configured for development/stable versions
  • slack notification if something went wrong to investigate
  • testing also on multi-GPU machine as our gift to you ๐Ÿฐ

cc: @Borda

[feat] patch torch and huggingface for improved performance in existing applications

I'd love it if we could patch existing functions like torch.nn.functional.softmax with our faster xformers/triton implementation. This would allow users of pre-defined models from HuggingFace to simply call xformers.patch() and run twice as fast.
Do you think that'd be possible, or would we have to rewrite all APIs first to match PyTorch's?

I'd propose writing a patch() function which replaces the appropriate PyTorch modules with our implementation at run-time. Alternative implementations are more than welcome, as my proposal wouldn't allow jumping into the source code of the new PyTorch function, rendering this change invisible to the user.

`scaled_dot_product_attention` shouldn't slice `attn_mask`

In

if att_mask.ndim == 2:
if not att_mask.is_sparse:
att_mask = att_mask[:seq, :seq]
else:
logging.warning(
"Mismatching attention mask and sequence length. On the fly correction but this will be slow"
)
# Loosing sparsity on purpose,
# expectation is that moving back and forth dense/sparse will negate the speedup
att_mask = att_mask.to_dense().squeeze(0)[:seq, :seq]
else:
assert (
not att_mask.is_sparse
), "Sparse masks with a batch dimension are not supported for now"
att_mask = att_mask[:, :seq, :seq]

the core function scale_dot_product_attention was made to accept attn_masks which are not of the correct shape.

While this looks like a convenience function, IMO this shouldn't be handled at such a core level as scaled_dot_product_attention.
Indeed, this behavior can lead to silent bugs (specially if the current slicing strategy doesn't work, e.g., for images), imposes serious performance penalties (by casting sparse tensors to dense tensors), and as such should be handled at a higher level.

Happy to discuss alternatives

[follow up] Check Nystrom + causal

๐Ÿ› Possible Bug

  • Nystrom does not pass the test in #104
  • with same #104, the nystrom specific test does not pass if causal is set (+ Nans everywhere)

causal was not part of the original paper, it could be that something got lost in translation. Some reference here

[feat] Compositional Attention

๐Ÿš€ Feature

Intriguing paper, keep the softmax(QKt) and V untangled, in that retrievals (*V_i in the vanilla attention) can have a look at all the searchs, that is it can be evaluated against all the softmax(QKt)_j, on a per head basis ("heads" become how many searchs and and many retrieval you support, possibly different)

Motivation

Interesting take for some tasks, does not seem life changing for classical MLM but seems very relevant to reasoning or vision related tasks

Pitch

Implement this, see how it goes in something like Dino ?

Alternatives

Not doing it

Additional context

Paper
Reference implementation

Error when running example

๐Ÿ› Bug

Running the code from here yields an error.

Command

To Reproduce

Execute code in example

Traceback (most recent call last):
  File "train.py", line 112, in <module>
    config = xFormerConfig(my_config)
  File "/Users/erip/Code/xformers/xformers/factory/model_factory.py", line 40, in __init__
    if config["block_type"] == "encoder":
KeyError: 'block_type'

Steps to reproduce the behavior:

  1. Paste code from example into test.py
  2. python test.py
  3. profit

Expected behavior

Environment

Please copy and paste the output from the
environment collection script from PyTorch
(or fill out the checklist below manually).

You can run the script with:

# For security purposes, please check the contents of collect_env.py before running it.
python -m torch.utils.collect_env

PyTorch version: 1.10.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 10.15.7 (x86_64)
GCC version: Could not collect
Clang version: 12.0.0 (clang-1200.0.32.29)
CMake version: version 3.22.1
Libc version: N/A

Python version: 3.7.11 (default, Jul 27 2021, 07:03:16) [Clang 10.0.0 ] (64-bit runtime)
Python platform: Darwin-19.6.0-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.4
[pip3] torch==1.10.1
[pip3] torchtext==0.11.1
[conda] blas 1.0 mkl
[conda] cpuonly 2.0 0 pytorch
[conda] mkl 2021.4.0 hecd8cb5_637
[conda] mypy-extensions 0.4.3 pypi_0 pypi
[conda] numpy 1.21.4 pypi_0 pypi
[conda] pytorch 1.10.1 py3.7_0 pytorch
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] torchtext 0.11.1 py37 pytorch

Additional context

[LRA] Use key masking instead of attention mask

๐Ÿ› Bug

Right now the LRA implementation uses attention masking (see) for the MLM task, which is probably wrong for a couple of attentions (would need investigation). Key masking would probably make more sense

[improvement] lower Favor+causal memory consumption

๐Ÿš€ Feature

Lower favor+causal memory consumption

Motivation

Using a lot of memory for an approximation kind of defeats the purpose..

Pitch

would make favor more useable for NLP, not sure how much of a priority this is

Alternatives

Not doing anything, there are other options for causal approximations (Nystrom for instance)

Additional context

The previous implementation (prior to #104) was more memory efficient but not trainable, since the variables were modified in place

Does xFormers support weight tying?

โ“ Questions and Help

Tying weights between encoder and decoder embedding layers is often useful for convergence and task performance. Is there a mechanism for sharing weights between the encoder and decoder?

Building blocksparse kernels

โ“ Questions and Help

It looks like sputnik kernels are not built when running the setup script.

I build the source with

ฮป ~/dev/xformers: python setup.py build develop

I get the following error

ฮป ~/dev/xformers: python xformers/benchmarks/benchmark_triton_blocksparse.py
WARNING:root:Unsupported device, Triton code generation may fail
M=128, N=128, K=128 - pytorch - sdd - 16:  - 0.12TFlops
Traceback (most recent call last):
  File "/home/domluna/dev/xformers/xformers/benchmarks/benchmark_triton_blocksparse.py", line 142, in <module>
    bench_matmul(torch.float16, shapes)
  File "/home/domluna/dev/xformers/xformers/benchmarks/benchmark_triton_blocksparse.py", line 116, in bench_matmul
    ms = triton.testing.do_bench(lambda: testcase.function())[0]
  File "/home/domluna/anaconda3/lib/python3.9/site-packages/triton/testing.py", line 131, in do_bench
    fn()
  File "/home/domluna/dev/xformers/xformers/benchmarks/benchmark_triton_blocksparse.py", line 116, in <lambda>
    ms = triton.testing.do_bench(lambda: testcase.function())[0]
  File "/home/domluna/dev/xformers/xformers/benchmarks/benchmark_triton_blocksparse.py", line 97, in sparse_step
    return _matmul_with_mask(a_cs, b_cs, sparse_cs_mask)
  File "/home/domluna/dev/xformers/xformers/components/attention/core.py", line 79, in _matmul_with_mask
    return mask.matmul_with_mask(a, b)
  File "/home/domluna/dev/xformers/xformers/components/attention/_sputnik_sparse.py", line 255, in matmul_with_mask
    out = _sddmm.apply(
  File "/home/domluna/dev/xformers/xformers/components/attention/_sputnik_sparse.py", line 94, in forward
    out = _sddmm_func(a, b, row_indices, row_offsets, column_indices)
  File "/home/domluna/dev/xformers/xformers/components/attention/_sputnik_sparse.py", line 61, in _sddmm_func
    return torch.ops.xformers.sddmm_sputnik(
NotImplementedError: Could not run 'xformers::sddmm_sputnik' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'xformers::sddmm_sputnik' is only available for these backends: [CPU, BackendSelect, Python, Named, Conjugate, Negative, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, Tracer, UNKNOWN_TENSOR_TYPE_ID, Autocast, Batched, VmapMode].

CPU: registered at /home/domluna/dev/xformers/xformers/components/attention/csrc/cpu/sddmm.cpp:100 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/PythonFallbackKernel.cpp:47 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
Negative: registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
AutogradOther: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
AutogradXLA: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/VariableFallbackKernel.cpp:51 [backend fallback]
AutogradLazy: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/VariableFallbackKernel.cpp:55 [backend fallback]
AutogradXPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradMLC: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/core/VariableFallbackKernel.cpp:59 [backend fallback]
Tracer: registered at /opt/conda/conda-bld/pytorch_1634272204863/work/torch/csrc/autograd/TraceTypeManual.cpp:291 [backend fallback]
UNKNOWN_TENSOR_TYPE_ID: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/autocast_mode.cpp:466 [backend fallback]
Autocast: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/autocast_mode.cpp:305 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

Even though there's the warning about code generation possibly failing it seems to work fine when the ops are not sparse. Triton kernels showing improvements:

ฮป ~/dev/xformers: python xformers/benchmarks/benchmark_triton_softmax.py
WARNING:root:Unsupported device, Triton code generation may fail
 ------------- Type: torch.float16 -------------
| Units: GB/s                                    |B=8, M=384, K=128   |B=8, M=784, K=512   |B=4, M=1024, K=768  |B=4, M=2048, K=1024 |B=2, M=2048, K=2048 |B=2, M=2048, K=4096 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw                                    |178.7               |332.6               |293.2               |335.9               |303.5               |314.1               |319.1               |307.0               |
|triton - fw                                     |136.5               |348.4               |343.1               |374.8               |374.0               |380.3               |385.5               |378.1               |
|triton - causal - fw                            |128.0               |328.5               |339.9               |364.1               |327.0               |390.1               |411.7               |360.4               |
|pytorch - log - fw                              |195.8               |348.4               |337.5               |367.1               |365.7               |376.6               |383.3               |351.4               |
|triton - log - fw                               |150.3               |351.2               |352.3               |375.0               |376.5               |381.9               |386.4               |381.0               |

QApplication: invalid style override passed, ignoring it.
 ------------- Type: torch.float32 -------------
| Units: GB/s                                    |B=8, M=384, K=128   |B=8, M=784, K=512   |B=4, M=1024, K=768  |B=4, M=2048, K=1024 |B=2, M=2048, K=2048 |B=2, M=2048, K=4096 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw                                    |307.2               |361.1               |364.9               |368.2               |368.1               |372.4               |379.0               |246.4               |
|triton - fw                                     |265.7               |371.0               |374.5               |383.8               |381.4               |387.8               |385.2               |379.7               |
|triton - causal - fw                            |256.0               |417.0               |438.1               |427.2               |479.3               |586.7               |492.8               |646.7               |
|pytorch - log - fw                              |310.1               |364.3               |366.6               |366.3               |379.4               |382.4               |386.8               |219.1               |
|triton - log - fw                               |275.4               |371.8               |377.4               |384.6               |381.4               |387.8               |385.5               |379.6               |

 ------------- Type: torch.float16 -------------
| Units: GB/s                                    |B=8, M=384, K=128   |B=8, M=784, K=512   |B=4, M=1024, K=768  |B=4, M=2048, K=1024 |B=2, M=2048, K=2048 |B=2, M=2048, K=4096 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw+bw                                 |77.2                |100.4               |98.3                |106.6               |104.0               |106.9               |108.4               |105.0               |
|triton - fw+bw                                  |73.8                |128.2               |127.6               |136.6               |136.6               |139.7               |142.2               |140.8               |
|triton - causal - fw+bw                         |72.9                |131.8               |132.1               |141.2               |143.7               |159.9               |156.8               |156.7               |
|pytorch - log - fw+bw                           |85.3                |125.4               |124.1               |135.8               |134.1               |138.6               |140.8               |134.3               |
|triton - log - fw+bw                            |75.9                |128.4               |127.9               |136.5               |136.6               |140.0               |142.3               |140.8               |

 ------------- Type: torch.float32 -------------
| Units: GB/s                                    |B=8, M=384, K=128   |B=8, M=784, K=512   |B=4, M=1024, K=768  |B=4, M=2048, K=1024 |B=2, M=2048, K=2048 |B=2, M=2048, K=4096 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw+bw                                 |126.4               |107.0               |106.8               |110.3               |109.8               |110.8               |111.8               |98.0                |
|triton - fw+bw                                  |130.8               |136.3               |135.7               |141.1               |140.9               |142.5               |142.9               |125.7               |
|triton - causal - fw+bw                         |128.8               |146.2               |147.4               |149.9               |159.9               |174.9               |162.4               |54.1                |
|pytorch - log - fw+bw                           |145.1               |134.9               |134.5               |139.4               |139.4               |141.1               |142.1               |116.7               |
|triton - log - fw+bw                            |132.0               |136.4               |135.8               |141.1               |140.9               |142.4               |142.8               |125.2               |

AttentionMask is not scriptable

๐Ÿ› Bug

AttentionMask and any model composed of AttentionMask cannot be scripted.

Command

To Reproduce

Steps to reproduce the behavior:

  1. Run code as below
>>> import torch
>>> from xformers.components.attention import AttentionMask
>>> bool_mask = torch.rand((256, 256)) > 0.5
>>> additive_mask = AttentionMask.from_bool(bool_mask)
>>> torch.jit.script(additive_mask)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/erip/opt/miniconda3/envs/traduis-dev/lib/python3.7/site-packages/torch/jit/_script.py", line 1318, in script
    return torch.jit._recursive.create_script_class(obj)
  File "/Users/erip/opt/miniconda3/envs/traduis-dev/lib/python3.7/site-packages/torch/jit/_recursive.py", line 419, in create_script_class
    _compile_and_register_class(type(obj), rcb, qualified_class_name)
  File "/Users/erip/opt/miniconda3/envs/traduis-dev/lib/python3.7/site-packages/torch/jit/_recursive.py", line 44, in _compile_and_register_class
    script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
RuntimeError:
Expression of type apply cannot be used in a type expression:
  File "/Users/erip/Code/xformers/xformers/components/attention/attention_mask.py", line 128
    def __add__(self, other):
        assert isinstance(other, type(self))
                                 ~~~~~~~~~ <--- HERE
        return AttentionMask(self.values + other.values, is_causal=False)

Expected behavior

AttentionMask should be a JIT-able module.

Environment

Please copy and paste the output from the
environment collection script from PyTorch
(or fill out the checklist below manually).

You can run the script with:

# For security purposes, please check the contents of collect_env.py before running it.
python -m torch.utils.collect_env

Collecting environment information...
PyTorch version: 1.10.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 10.15.7 (x86_64)
GCC version: Could not collect
Clang version: 12.0.0 (clang-1200.0.32.29)
CMake version: version 3.22.1
Libc version: N/A

Python version: 3.7.11 (default, Jul 27 2021, 07:03:16) [Clang 10.0.0 ] (64-bit runtime)
Python platform: Darwin-19.6.0-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.4
[pip3] torch==1.10.1
[pip3] torchtext==0.11.1
[conda] blas 1.0 mkl
[conda] cpuonly 2.0 0 pytorch
[conda] mkl 2021.4.0 hecd8cb5_637
[conda] mypy-extensions 0.4.3 pypi_0 pypi
[conda] numpy 1.21.4 pypi_0 pypi
[conda] pytorch 1.10.1 py3.7_0 pytorch
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] torchtext 0.11.1 py37 pytorch

Additional context

Add LRA to CI

๐Ÿš€ Feature

Capturing idea from @blefaudeux here. Add LRA to the CI since some bugs are only showing up through manual runs of it. Should be a smaller version of the full LRA.

Motivation

LRA has caught many bugs, could be useful to add it to the continuous integration testing.

Pitch

LRA with a smaller dataset.

Alternatives

Limit tasks and attention variants to a subset. Run a benchmark in parallel not necessarily linked to CI

Additional context

#145 (comment)

Pre-commit broken on seed isort

๐Ÿ› Bug

I just enabled the pre-commit hooks in my environment following the docs, and they seem to raise an error in an unrelated part of the codebase.

This is the error message I get after a git commit:

Trim Trailing Whitespace.................................................Passed
Check python ast.........................................................Passed
Check for merge conflicts................................................Passed
Don't commit to branch...................................................Passed
Check for added large files..............................................Passed
Fix End of Files.........................................................Passed
black....................................................................Passed
flake8...................................................................Passed
seed isort known_third_party.............................................Failed
- hook id: seed-isort-config
- exit code: 1

Traceback (most recent call last):
  File "/private/home/fmassa/.cache/pre-commit/repohjnsozm1/py_env-python3/bin/seed-isort-config", line 8, in <module>
    sys.exit(main())
  File "/private/home/fmassa/.cache/pre-commit/repohjnsozm1/py_env-python3/lib/python3.8/site-packages/seed_isort_config.py", line 112, in main
    third_party = third_party_imports(filenames, appdirs)
  File "/private/home/fmassa/.cache/pre-commit/repohjnsozm1/py_env-python3/lib/python3.8/site-packages/seed_isort_config.py", line 60, in third_party_imports
    visitor.visit(ast.parse(f.read(), filename=filename))
  File "/private/home/fmassa/.conda/envs/xformers/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/private/home/fmassa/.conda/envs/xformers/lib/python3.8/ast.py", line 379, in generic_visit
    self.visit(item)
  File "/private/home/fmassa/.conda/envs/xformers/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/private/home/fmassa/.cache/pre-commit/repohjnsozm1/py_env-python3/lib/python3.8/site-packages/seed_isort_config.py", line 42, in visit_Import
    self._maybe_append_name(name.name)
  File "/private/home/fmassa/.cache/pre-commit/repohjnsozm1/py_env-python3/lib/python3.8/site-packages/seed_isort_config.py", line 35, in _maybe_append_name
    imp_type = classify_import(name, self.appdirs)
  File "/private/home/fmassa/.cache/pre-commit/repohjnsozm1/py_env-python3/lib/python3.8/site-packages/aspy/refactor_imports/classify.py", line 135, in classify_import
    found, module_path, is_builtin = _get_module_info(
  File "/private/home/fmassa/.cache/pre-commit/repohjnsozm1/py_env-python3/lib/python3.8/site-packages/aspy/refactor_imports/classify.py", line 103, in _get_module_info
    assert spec.submodule_search_locations is not None
AssertionError

isort....................................................................Passed
mypy.....................................................................Passed

Support shape types for Tensors.

๐Ÿš€ Feature

Add shape type stubs for Tensor functions. Annotate some of the xformers code with shape types so that future users can add new ones easily. Set up IDE support and iterate on the UX based on user feedback.

Motivation

Shape types allow us to see the shape of a Tensor as it is transformed by various functions. This can help understand existing code, debug problems, and catch shape mismatches without running the entire program.

For example, it can catch broadcasting errors like torch.randn(2, 3) + torch.randn(4, 5).

This requires annotating Tensor variables with their initial shapes. More docs to come.

Pitch

We should be able to see the shape of a Tensor variable as we code (after adding annotations to the function signature).

Alternatives

The existing alternatives are to print(my_tensor.shape), try it out in a notebook, or run the program end-to-end.

Additional context

PEP 646: Variadic types: https://www.python.org/dev/peps/pep-0646/

More docs to come.

[perf] Auto causal <> sparse

๐Ÿš€ Feature

When applicable, automatically use sparse or blocksparse for causal attention. Right now this requires that people use them explicitly, even if the causal flag is passed, which means that a lot of people could miss the possible optimization.

Motivation

Free perf on the table, can be a little complex to handle all cases, but would make sense to do it directly in xFormers.

Pitch

Sort out the applicable cases first, and in that case defer to the sparse or blocksparse when ScaledDotProduct is used with the causal flag

Alternatives

Warn users of this possible optim

Additional context

Sputnik in Xformers

โ“ Questions and Help

Hello! I'm the author of Sputnik and noticed that you've forked it! I'm interested to understand your use cases. If you have any questions about the library I'm happy to help!

[feat] PyPI package + badges + links + HOWTO

๐Ÿ› Bug

Publish the PyPI package + update all the above

Command

To Reproduce

Steps to reproduce the behavior:

Expected behavior

Environment

Please copy and paste the output from the
environment collection script from PyTorch
(or fill out the checklist below manually).

You can run the script with:

# For security purposes, please check the contents of collect_env.py before running it.
python -m torch.utils.collect_env
  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

[block factory] Handle attentions which return one or several extra context tensors

๐Ÿš€ Feature

Luna has an extra "context" path, I think that several other attentions do something similar (like the attentions which try to keep a long term memory), it would be nice to handle this possibility in the block_factory for encoders or decoders.

Motivation

Support attentions which pass a context along, or similar

Pitch

Make it possible for attentions to return several items, not just the attention value. Silently consume them if the next layer does not need them

Alternatives

Keep rigid blocks

Additional context

Luna,

[feat] Add a fast implementation of Rabe and Staats algorigthm (mem efficient attention) on GPU

๐Ÿš€ Feature

Implement https://arxiv.org/pdf/2112.05682v2.pdf using Triton

Motivation

There are existing implementations in Pytorch, but they re bound to be a little slow. It s actually not that much work to write that down in Triton, give it a shot. Given the FW speed (should be similar to normal attention, without the memory) and the expected BW speed (about 60% of the vanilla attention), feels like a compromise that many would use

Pitch

The required kernel is actually not that far from some of the kernels that we already have, at least for the FW. The chunk strategy proposed by the paper is actually fairly classic in that field, nothing out of the ordinary (see for instance), so it's bound to be pretty fast if correctly implemented.

Alternatives

At least support a pure pytorch variant in xformers ?

Encoder decoder arch doesnt work when sequence lengths are different

๐Ÿ› Bug

I get an error when the sequence lengths to the encoder and decoder are different, e.g. in the code snippet below:

Command

EMB = 384
SEQ_ENC = 128
SEQ_DEC = 64
BATCH = 16
VOCAB = 64

my_config = [
    # A list of the encoder or decoder blocks which constitute the Transformer.
    # Note that a sequence of different encoder blocks can be used, same for decoders
    {
        "reversible": False,  # Optionally make these layers reversible, to save memory
            "block_type": "encoder",
            "num_layers": 3,  # Optional, this means that this config will repeat N times
            "dim_model": EMB,
            "layer_norm_style": "pre",  # Optional, pre/post
            "position_encoding_config": {
                "name": "vocab",  # whatever position encodinhg makes sense
                "seq_len": SEQ_ENC,
                "vocab_size": VOCAB,
            },
            "multi_head_config": {
                "num_heads": 4,
                "residual_dropout": 0,
                "attention": {
                    "name": "linformer",  # whatever attention mechanism
                    "dropout": 0,
                    "causal": False,
                    "seq_len": SEQ_ENC,
                },
            },
            "feedforward_config": {
                "name": "MLP",
                "dropout": 0,
                "activation": "relu",
                "hidden_layer_multiplier": 4,
            },
        },
    {
        "reversible": False,  # Optionally make these layers reversible, to save memory

            "block_type": "decoder",
            "num_layers": 3,  # Optional, this means that this config will repeat N times
            "dim_model": EMB,
            "layer_norm_style": "pre",  # Optional, pre/post
            "position_encoding_config": {
                "name": "vocab",  # whatever position encodinhg makes sense
                "seq_len": SEQ_DEC,
                "vocab_size": VOCAB,
            },
            "multi_head_config_masked": {
                "num_heads": 4,
                "residual_dropout": 0,
                "attention": {
                    "name": "nystrom",  # whatever attention mechanism
                    "dropout": 0,
                    "causal": True,
                    "seq_len": SEQ_DEC,
                },
            },
            "multi_head_config_cross": {
                "num_heads": 4,
                "residual_dropout": 0,
                "attention": {
                    "name": "favor",  # whatever attention mechanism
                    "dropout": 0,
                    "causal": True,
                    "seq_len": SEQ_DEC,
                },
            },
            "feedforward_config": {
                "name": "MLP",
                "dropout": 0,
                "activation": "relu",
                "hidden_layer_multiplier": 4,
            },
        },
]

# This part of xFormers is entirely type checked and needs a config object,
# could be changed in the future
config = xFormerConfig(my_config)
model = xFormer.from_config(config)

#  Test out with dummy inputs
src = (torch.rand((BATCH, SEQ_ENC)) * VOCAB).abs().to(torch.int)
tgt = (torch.rand((BATCH, SEQ_DEC)) * VOCAB).abs().to(torch.int)
y = model(src=src, tgt=tgt)

print(y.shape)

Expected behavior

torch.Size([16, 64, 384])

however, I get:

RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [64, 128, 96, 96]->[64, 128, 96, 96] [64, 64, 96]->[64, 64, 1, 96]

Favor breaks with fp16

๐Ÿ› Bug

The loss nans immediately

Command

python3 microGPT.py when precision is set to 16

Environment

Unrelated

Add conda package

๐Ÿš€ Feature

Motivation

Packaging pytorch-based packages can be a bit messy because of their reliance on specific CUDA toolkits. Additionally, packages with native code are often messy and rely on a client compiler which can be painful (e.g., on Windows). In the case of xformers, wheels are built but having a condafied package allows for conda-only builds to consume xformers without relying on pip.

Pitch

Add a conda package which either lands in the conda-forge or pytorch channel.

Alternatives

Continue with a wheel-only distribution. Conda users can still pip install xformers which isn't ideal, but not awful.

Additional context

N/A

Logo doesn't appear on documentation sub-pages

๐Ÿ› Bug

Currently, the xFormers logo only appears on the main docs page and the what_is_xformers page which is present in the same directory as it, but not on the other sub-pages. I was wondering whether setting the Sphinx option html_logo in the conf.py file would fix this.

Would be happy to make a PR for this, let me know what you think.

[perf] Test Triton to fuse/harden some operations enhancement

๐Ÿš€ Feature

Investigate "hardening" typical Transformer-related operations using Triton. This issue is a rollover from Fairinternal

Motivation

Get faster, possibly fused, blocks

Pitch

Triton makes it relatively easy to consolidate some typical DL primitives into single ad-hoc kernels

Alternatives

Use pytorch primitives

Additional context

Ongoing work for some time

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.