Giter Site home page Giter Site logo

Comments (2)

Valentine233 avatar Valentine233 commented on May 20, 2024

With the following optimizations:

  1. Add VecConvert for uint8 to bf16.
    Improvement of graph_0_cpp_fused__to_copy_index_1: 96.359ms -> 35.693ms.
cpp_fused__to_copy_index_1 = async_compile.cpp_pybinding(['const int64_t*', 'const signed char*', 'const signed char*', 'const signed char*', 'bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include <ATen/record_function.h>
#include "/tmp/torchinductor_liaoxuan/nd/cndd7co72iqjtof53ikp4l7yibmqrbjkni3cu6xj5p7hywloe5yg.h"
extern "C" void kernel(const int64_t* in_ptr0,
                       const signed char* in_ptr1,
                       const signed char* in_ptr2,
                       const signed char* in_ptr3,
                       bfloat16* out_ptr0,
                       bfloat16* out_ptr1,
                       bfloat16* out_ptr2)
{
    RECORD_FUNCTION("graph_0_cpp_fused__to_copy_index_1", c10::ArrayRef<c10::IValue>({}));
    #pragma omp parallel num_threads(112)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(58720256L); x1+=static_cast<long>(16L))
                {
                    auto tmp0 = in_ptr0[static_cast<long>(x0)];
                    auto tmp1 = decltype(tmp0)(tmp0 + 8);
                    auto tmp2 = tmp0 < 0;
                    auto tmp3 = tmp2 ? tmp1 : tmp0;
                    TORCH_CHECK((0 <= tmp3) & (tmp3 < 8L), "index out of bounds: 0 <= tmp3 < 8L")
                    auto tmp4 = at::vec::Vectorized<signed char>::loadu(in_ptr1 + static_cast<long>(x1 + (58720256L*tmp3)), 16);
                    auto tmp5 = at::vec::convert<bfloat16>(tmp4);
                    auto tmp6 = at::vec::Vectorized<signed char>::loadu(in_ptr2 + static_cast<long>(x1 + (58720256L*tmp3)), 16);
                    auto tmp7 = at::vec::convert<bfloat16>(tmp6);
                    auto tmp8 = at::vec::Vectorized<signed char>::loadu(in_ptr3 + static_cast<long>(x1 + (58720256L*tmp3)), 16);
                    auto tmp9 = at::vec::convert<bfloat16>(tmp8);
                    tmp5.store(out_ptr0 + static_cast<long>(x1 + (58720256L*x0)), 16);
                    tmp7.store(out_ptr1 + static_cast<long>(x1 + (58720256L*x0)), 16);
                    tmp9.store(out_ptr2 + static_cast<long>(x1 + (58720256L*x0)), 16);
                }
            }
        }
    }
}
''')
  1. Extend the condition of is_mkldnn_optimized_format in Mkldnn Matmul: accept stride[0]=0.
    Input example: size=[1, 4096, 2], stride=[0, 1, 4096].
    Improvement of aten::bmm: 60.056ms -> 3.738ms.

Overall profiling

-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
       graph_0_cpp_fused__to_copy_index_1        89.42%      35.693ms        89.42%      35.693ms      35.693ms             1
                                aten::bmm         9.35%       3.731ms         9.36%       3.738ms     934.462us             4
                    Torch-Compiled Region         0.70%     277.698us        99.94%      39.890ms      39.890ms             1
       graph_0_cpp_fused_index_mul_silu_2         0.24%      96.978us         0.24%      96.978us      96.978us             1
                aten::_weight_int8pack_mm         0.09%      35.892us         0.10%      41.160us      41.160us             1
                 TorchDynamo Cache Lookup         0.06%      24.675us         0.06%      24.675us      24.675us             1
                               aten::topk         0.05%      20.133us         0.05%      20.133us      20.133us             1
            inductor::_reinterpret_tensor         0.03%      12.107us         0.03%      12.107us       1.009us            12
    graph_0_cpp_fused_div_index_mul_sum_3         0.02%       9.435us         0.02%       9.435us       9.435us             1
                         aten::as_strided         0.01%       5.363us         0.01%       5.363us       2.682us             2
                              aten::empty         0.01%       5.268us         0.01%       5.268us       5.268us             1
                       aten::resolve_conj         0.00%       1.367us         0.00%       1.367us       0.171us             8
             graph_0_cpp_fused__softmax_0         0.00%       1.224us         0.00%       1.224us       1.224us             1
-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 39.915ms

from pytorch.

Valentine233 avatar Valentine233 commented on May 20, 2024

With bmm fallback, weight is converted from int8 to bf16 and Onednn uses bf16 weight. With bmm decomposition, type conversion and bmm are fused in one cpp kernel. Bmm fallback leads to the regression because the case is memory bound with batch size 1.
Synced with Jiong, it is better to decompose bmm for memory bound case in lowering.

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.