Giter Site home page Giter Site logo

alphatensor's Introduction

AlphaTensor

This is code accompanying the publication

Fawzi, A. et al. Discovering faster matrix multiplication algorithms with reinforcement learning. Nature 610 (2022)

There are 4 independent directories:

  • algorithms contains algorithms discovered by AlphaTensor, represented as factorizations of matrix multiplication tensors, and a Colab showing how to load these.

  • benchmarking contains a script that can be used to measure the actual speed of matrix multiplication algorithms on an NVIDIA V100 GPU.

  • nonequivalence contains 14,236 nonequivalent algorithms discovered by AlphaTensor for the same matrix multiplication problem (multiplying 4x4 matrices), and a Colab that verifies their nonequivalence.

  • recombination contains the code we used to decompose larger matrix multiplication tensors by recombining factorizations of smaller ones.

Installation

  • algorithms: No installation required.

  • benchmarking: See README in the subdirectory.

  • nonequivalence: No installation required.

  • recombination: A machine with Python 3 installed is required. The required dependencies (numpy and absl-py) can be installed by executing pip3 install -r alphatensor/recombination/requirements.txt.

Usage

  • algorithms: The notebook explore_factorizations.ipynb can be opened via Open In Colab. When running the code, you will be asked to upload a file containing the factorizations. Please select either of the compressed NumPy files factorizations_r.npz (containing algoritms in standard arithmetic) or factorizations_f2.npz (algorithms in arithmetic modulo 2).

  • benchmarking: See README in the subdirectory, and Supplement D of the paper.

  • nonequivalence: The notebook inspect_factorizations_notebook.ipynb can be opened via Open In Colab. When running the code, you will be asked to upload a file. Please select the compressed NumPy file alphatensor_14236_factorizations.npz. This will upload the factorizations found by AlphaTensor, and then compute invariants certifying that they are all nonequivalent. For more details, see Supplement B of the paper.

  • recombination: Execute python3 -m alphatensor.recombination.example on the command line, from the parent directory that contains the alphatensor repository as a subdirectory. For more details, see Supplement H of the paper.

Citing this work

If you use the code or data in this package, please cite:

@Article{AlphaTensor2022,
  author  = {Fawzi, Alhussein and Balog, Matej and Huang, Aja and Hubert, Thomas and Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Ruiz, Francisco J. R. and Schrittwieser, Julian and Swirszcz, Grzegorz and Silver, David and Hassabis, Demis and Kohli, Pushmeet},
  journal = {Nature},
  title   = {Discovering faster matrix multiplication algorithms with reinforcement learning},
  year    = {2022},
  volume  = {610},
  number  = {7930},
  pages   = {47--53},
  doi     = {10.1038/s41586-022-05172-4}
}

License and disclaimer

Copyright 2022 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.

This is not an official Google product.

alphatensor's People

Contributors

matejbalog avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

alphatensor's Issues

Conversion of algorithms into Rust

Hi deepmind team,

This is truly amazing work! Great job and thank you so much for publishing these findings. As a developer I'm keen to make use of these algorithmic improvements. I've started to port the algorithms to Rust here: https://github.com/drbh/simd-alphatensor-rs by programmatically generating the code - however I'm not sure how to best implement/share these improvements. As the creators of such work do you have any suggestions for packaging these into a library? Are there any implementation considerations that should be taken into account when building a library around this work?

Thanks again!!

Difference between categories

what are the differences between these directories like if I want to use faster matrix multiplication between two matrices which directory should I use?

4x4x4 algorithm for real arithmetic?

From your article, I was hoping to find an algorithm for the 4x4x4 case with 47 products.
However, the article presents a mod 2 algorithm. In your factorizations collection of real arithmetic algorithms, the 4x4x4 case has 49 products. Recursive application of the 2x2x2 Strassen algorithm also requires 49 products.

Do you actually have an algorithm which outperforms recursive Strassen for the general 4x4x4 case?

Equations in text form for your convenience

Here is some code to generate the equations in text form so no one has to type out the characters from the JPGs in the paper. The indices are slightly different than in the paper (might have missed a transpose somewhere), but the result comes out correct, so it is probably fine. Note that it prints over 7 MB of data. It might be a good idea to pipe the output to a file first before looking at it: python3 print_equations.py > equations.txt

The individual equations could be simplified a bit more (e.q. by using SymPy) but I wanted to keep the dependencies to a minimum.

Make sure to run the code in the same directory as the files factorizations_f2.npz and factorizations_r.npz.

print_equations.py

import numpy as np
from ast import literal_eval as make_tuple

np.random.seed(0)

"""
The *.npz files contain a dict with keys like "(2,3,4)" and values containing
a list of matrices U, V and W. For example, for the 2-by-2 times 2-by-2 case,
we have the following matrices:

U =
[[ 0  1  1  0  1  1  0]
 [ 0  0 -1  1  0  0  0]
 [ 1  1  1  0  1  0  0]
 [-1 -1 -1  0  0  0  1]]

V =
[[0 0 0 0 1 1 0]
 [1 1 0 0 1 0 1]
 [0 1 1 1 1 0 0]
 [0 1 1 0 1 0 1]]

W =
[[ 0  0  0  1  0  1  0]
 [ 0 -1  0  0  1 -1 -1]
 [-1  1 -1 -1  0  0  0]
 [ 1  0  0  0  0  0  1]]

Each column of U is multiplied with the vectorized matrix A.
Likewise, Each column of V is multiplied with the vectorized matrix B.
The resulting vectors are multiplied pointwise and their product is
multiplied with W, which forms the entries of the product matrix C = A times B.
Also see the function `multiply` below.
"""

# There are two factorizations, one for useful numbers and one for mod 2 math.
for filename, mod in [
    ("factorizations_r.npz", None),
    ("factorizations_f2.npz", 2),
]:
    # Load the factorizations. Note that allow_pickle=True allows arbitrary
    # code execution. A JSON file would have been a better format choice
    # since nothing here is stored in NumPy format anyway.
    factorizations = dict(np.load(filename, allow_pickle=True))

    # Test each factorization
    for key, UVW in factorizations.items():
        U, V, W = map(np.array, UVW)

        m, k, n = make_tuple(key)

        print(f"\nMultiply {m}-by-{k} matrix A with {k}-by-{n} matrix B")
        if mod is not None:
            print(f"using mod {mod} arithmetic")
        print()

        # Check that shapes are correct
        assert m * k == U.shape[0]
        assert k * n == V.shape[0]
        assert m * n == W.shape[0]
        assert U.shape[1] == V.shape[1]
        assert U.shape[1] == W.shape[1]

        # Generate two random matrices for testing
        A = np.random.randint(10, size=(m, k))
        B = np.random.randint(10, size=(k, n))

        def multiply(A, B, U, V, W):
            # Multiply two matrices A and B using index matrices U, V and W
            a = A.ravel()
            b = B.ravel()

            tmp = (U.T @ a) * (V.T @ b)
            c = W @ tmp
            C = c.reshape(n, m).T

            return C

        # Multiply matrices
        C = multiply(A, B, U, V, W)

        # Check that result is correct, taking potential mod 2 into account
        if mod is None:
            assert np.allclose(C, A @ B)
        else:
            assert np.allclose(C % mod, (A @ B) % mod)

        def make_code(variables, factors):
            # Generate code like "(a11 + a21 - a22)"
            parts = []

            for variable, factor in zip(variables, factors):
                # Simplify +1 and -1 factors
                if factor == 1:
                    factor = " + "
                elif factor == -1:
                    factor = " - "
                elif factor < 0:
                    factor = f" {factor} * "
                elif factor > 0:
                    factor = f" + {factor} * "
                else:
                    continue

                parts.append(factor + variable)

            code = "".join(parts).lstrip(" +")

            if len(parts) > 1:
                code = "(" + code + ")"

            return code

        def make_variables(var, m, n):
            # Generate variables like a11, a12, a21, a22
            # or maybe a_1_1, a_1_2, a_2_1, a_2_2.
            # For larger matrices, we need a separator to avoid
            # confusing e.g. a_1_11 with a_11_1.
            separator = "_" if max(m, n, k) > 9 else ""
            return [f"{var}{separator}{i + 1}{separator}{j + 1}"
                for i in range(m) for j in range(n)]

        A_variables = make_variables("a", m, k)
        B_variables = make_variables("b", k, n)
        C_variables = make_variables("c", m, n)
        h_variables = [f"h{i + 1}" for i in range(U.shape[1])]

        lines = [
            ", ".join(A_variables) + " = A.ravel()",
            ", ".join(B_variables) + " = B.ravel()",
        ]

        # Generate code for computation of temporary vector
        for h, u, v in zip(h_variables, U.T, V.T):
            sa = make_code(A_variables, u)
            sb = make_code(B_variables, v)

            lines.append(f"{h} = {sa} * {sb}")

        # Generate code for computation
        for c, w in zip(C_variables, W):
            lines.append(f"{c} = " + make_code(h_variables, w).strip("()"))

        lines.append("C = np.array([" + ", ".join(C_variables) +
            f"]).reshape({n}, {m}).T")

        code = "\n".join(lines)

        print(code)

        # Verify that code generates the correct result
        exec(code)

        if mod is None:
            assert np.allclose(C, A @ B)
        else:
            assert np.allclose(C % mod, (A @ B) % mod)

For example, the generated code for general 2-by-2 times 2-by-2 matrix multiplication is

a11, a12, a21, a22 = A.ravel()
b11, b12, b21, b22 = B.ravel()
h1 = (a21 - a22) * b12
h2 = (a11 + a21 - a22) * (b12 + b21 + b22)
h3 = (a11 - a12 + a21 - a22) * (b21 + b22)
h4 = a12 * b21
h5 = (a11 + a21) * (b11 + b12 + b21 + b22)
h6 = a11 * b11
h7 = a22 * (b12 + b22)
c11 = h4 + h6
c12 = - h2 + h5 - h6 - h7
c21 = - h1 + h2 - h3 - h4
c22 = h1 + h7
C = np.array([c11, c12, c21, c22]).reshape(2, 2).T

For 4-by-4 times 4-by-4 matrix multiplication in $\mathbb {Z} _{2}$, i.e. when doing mod 2 math (missed that on first reading), the code is:

a11, a12, a13, a14, a21, a22, a23, a24, a31, a32, a33, a34, a41, a42, a43, a44 = A.ravel()
b11, b12, b13, b14, b21, b22, b23, b24, b31, b32, b33, b34, b41, b42, b43, b44 = B.ravel()
h1 = a13 * b31
h2 = (a13 + a22 + a23) * (b21 + b24 + b34)
h3 = (a13 + a21 + a23) * (b11 + b13 + b33)
h4 = (a13 + a23) * (b11 + b13 + b21 + b24 + b31 + b33 + b34)
h5 = a11 * b11
h6 = (a11 + a31) * (b11 + b12 + b14 + b21 + b24 + b31 + b32)
h7 = (a11 + a31 + a33) * (b12 + b31 + b32)
h8 = (a11 + a12 + a13 + a22 + a23 + a31 + a32) * (b21 + b24)
h9 = (a12 + a41 + a42) * (b11 + b13 + b23)
h10 = (a12 + a42 + a43) * (b22 + b31 + b32)
h11 = (a12 + a42) * (b11 + b13 + b21 + b22 + b23 + b31 + b32)
h12 = (a11 + a12 + a13 + a21 + a23 + a41 + a42) * (b11 + b13)
h13 = (a11 + a12 + a13 + a31 + a33 + a42 + a43) * (b31 + b32)
h14 = a41 * (b12 + b13 + b23 + b41 + b42)
h15 = (a14 + a41 + a44) * (b12 + b41 + b42)
h16 = (a14 + a44) * (b12 + b33 + b41 + b42 + b43)
h17 = (a11 + a31 + a32) * (b14 + b21 + b24)
h18 = (a14 + a32 + a34 + a41 + a44) * (b41 + b42)
h19 = (a14 + a32 + a34) * (b22 + b41 + b42)
h20 = (a14 + a34) * (b22 + b34 + b41 + b42 + b44)
h21 = a22 * (b23 + b24 + b34 + b41 + b43)
h22 = (a14 + a22 + a24) * (b23 + b41 + b43)
h23 = (a14 + a43 + a44) * (b33 + b41 + b43)
h24 = (a14 + a21 + a23 + a43 + a44) * b33
h25 = (a14 + a22 + a34 + a43) * (b22 + b34 + b41 + b43)
h26 = a33 * (b12 + b32 + b34 + b41 + b44)
h27 = (a14 + a24) * (b14 + b23 + b41 + b43 + b44)
h28 = (a14 + a21 + a24) * (b14 + b41 + b44)
h29 = (a14 + a32 + a34 + a42 + a43) * b22
h30 = (a14 + a22 + a24 + a43 + a44) * (b41 + b43)
h31 = a14 * b41
h32 = (a14 + a33 + a34) * (b34 + b41 + b44)
h33 = (a21 + a31 + a41) * (b12 + b13 + b14)
h34 = (a14 + a22 + a24 + a41 + a42) * b23
h35 = (a24 + a34 + a44) * (b42 + b43 + b44)
h36 = (a14 + a22 + a23 + a33 + a34) * b34
h37 = (a23 + a33 + a43) * (b32 + b33 + b34)
h38 = (a22 + a32 + a42) * (b22 + b23 + b24)
h39 = a12 * b21
h40 = (a14 + a21 + a24 + a33 + a34) * (b41 + b44)
h41 = a43 * (b22 + b32 + b33 + b41 + b43)
h42 = a21 * (b13 + b14 + b33 + b41 + b44)
h43 = (a14 + a21 + a24 + a31 + a32) * b14
h44 = (a14 + a24 + a32 + a41) * (b14 + b23 + b41 + b42)
h45 = a32 * (b14 + b22 + b24 + b41 + b42)
h46 = (a14 + a21 + a33 + a44) * (b12 + b33 + b41 + b44)
h47 = (a14 + a31 + a33 + a41 + a44) * b12
c11 = h1 + h5 + h31 + h39
c12 = h1 + h2 + h3 + h4 + h21 + h22 + h27 + h28 + h31 + h42
c13 = h5 + h6 + h7 + h17 + h19 + h20 + h26 + h31 + h32 + h45
c14 = h9 + h10 + h11 + h14 + h15 + h16 + h23 + h31 + h39 + h41
c21 = h1 + h7 + h10 + h13 + h15 + h18 + h19 + h29 + h31 + h47
c22 = h16 + h20 + h23 + h24 + h25 + h26 + h30 + h32 + h35 + h36 + h37 + h40 + h41 + h46
c23 = h15 + h18 + h19 + h20 + h26 + h31 + h32 + h47
c24 = h15 + h16 + h18 + h19 + h23 + h29 + h31 + h41
c31 = h3 + h5 + h9 + h12 + h22 + h23 + h24 + h30 + h31 + h34
c32 = h22 + h23 + h24 + h27 + h28 + h30 + h31 + h42
c33 = h14 + h15 + h16 + h18 + h27 + h28 + h33 + h35 + h40 + h42 + h43 + h44 + h46 + h47
c34 = h14 + h15 + h16 + h22 + h23 + h30 + h31 + h34
c41 = h2 + h8 + h17 + h28 + h31 + h32 + h36 + h39 + h40 + h43
c42 = h21 + h22 + h27 + h28 + h31 + h32 + h36 + h40
c43 = h19 + h20 + h28 + h31 + h32 + h40 + h43 + h45
c44 = h18 + h19 + h20 + h21 + h22 + h25 + h27 + h29 + h30 + h34 + h35 + h38 + h44 + h45
C = np.array([c11, c12, c13, c14, c21, c22, c23, c24, c31, c32, c33, c34, c41, c42, c43, c44]).reshape(4, 4).T

For the Matlab users: Note that .ravel() in Python flattens row-wise, not column-wise. The same goes for reshape.

Fault Tolerance aspect

Dear Authors,
First of all thanks for the great work and congrats on breakthrough!

In the paper the Fault Tolerance aspect of multiplication algorithms are left out, could you please shed some light on it. I believe this has serious implications.

Clock Speed Set to 1530MHz when V100 has a max boost clock of 1380 MHz

In your GPU benchmark, you set the persistence mode to ON and then lock the GPU clocks to 1530,1530 as follows:

process = subprocess.Popen(
      'sudo nvidia-smi --lock-gpu-clocks=1530,1530'.split(' '),
      stdout=subprocess.PIPE)

How does that work with V100 GPU with a base clock of 1245 MHz and boost clock of 1380 MHz? To my understanding V100S has a higher clock, but GCP offers V100, not V100S.

How to apply it on GPU

  That's an amazing work!
  We are university students from China.Our group projects are foucused on how to apply these algorithms on GPU to get it more efficient.But after learnt,we found some algorithms could just be used in Z2,and did not like srtassen has universality.We want to know how did you applied it to GPU to have considerable improvements.Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.