Comments (2)
With the following optimizations:
- Add
VecConvert
for uint8 to bf16.
Improvement ofgraph_0_cpp_fused__to_copy_index_1
: 96.359ms -> 35.693ms.
cpp_fused__to_copy_index_1 = async_compile.cpp_pybinding(['const int64_t*', 'const signed char*', 'const signed char*', 'const signed char*', 'bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include <ATen/record_function.h>
#include "/tmp/torchinductor_liaoxuan/nd/cndd7co72iqjtof53ikp4l7yibmqrbjkni3cu6xj5p7hywloe5yg.h"
extern "C" void kernel(const int64_t* in_ptr0,
const signed char* in_ptr1,
const signed char* in_ptr2,
const signed char* in_ptr3,
bfloat16* out_ptr0,
bfloat16* out_ptr1,
bfloat16* out_ptr2)
{
RECORD_FUNCTION("graph_0_cpp_fused__to_copy_index_1", c10::ArrayRef<c10::IValue>({}));
#pragma omp parallel num_threads(112)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(58720256L); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp1 = decltype(tmp0)(tmp0 + 8);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 8L), "index out of bounds: 0 <= tmp3 < 8L")
auto tmp4 = at::vec::Vectorized<signed char>::loadu(in_ptr1 + static_cast<long>(x1 + (58720256L*tmp3)), 16);
auto tmp5 = at::vec::convert<bfloat16>(tmp4);
auto tmp6 = at::vec::Vectorized<signed char>::loadu(in_ptr2 + static_cast<long>(x1 + (58720256L*tmp3)), 16);
auto tmp7 = at::vec::convert<bfloat16>(tmp6);
auto tmp8 = at::vec::Vectorized<signed char>::loadu(in_ptr3 + static_cast<long>(x1 + (58720256L*tmp3)), 16);
auto tmp9 = at::vec::convert<bfloat16>(tmp8);
tmp5.store(out_ptr0 + static_cast<long>(x1 + (58720256L*x0)), 16);
tmp7.store(out_ptr1 + static_cast<long>(x1 + (58720256L*x0)), 16);
tmp9.store(out_ptr2 + static_cast<long>(x1 + (58720256L*x0)), 16);
}
}
}
}
}
''')
- Extend the condition of
is_mkldnn_optimized_format
in Mkldnn Matmul: accept stride[0]=0.
Input example: size=[1, 4096, 2], stride=[0, 1, 4096].
Improvement ofaten::bmm
: 60.056ms -> 3.738ms.
Overall profiling
----------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
----------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
graph_0_cpp_fused__to_copy_index_1 89.42% 35.693ms 89.42% 35.693ms 35.693ms 1
aten::bmm 9.35% 3.731ms 9.36% 3.738ms 934.462us 4
Torch-Compiled Region 0.70% 277.698us 99.94% 39.890ms 39.890ms 1
graph_0_cpp_fused_index_mul_silu_2 0.24% 96.978us 0.24% 96.978us 96.978us 1
aten::_weight_int8pack_mm 0.09% 35.892us 0.10% 41.160us 41.160us 1
TorchDynamo Cache Lookup 0.06% 24.675us 0.06% 24.675us 24.675us 1
aten::topk 0.05% 20.133us 0.05% 20.133us 20.133us 1
inductor::_reinterpret_tensor 0.03% 12.107us 0.03% 12.107us 1.009us 12
graph_0_cpp_fused_div_index_mul_sum_3 0.02% 9.435us 0.02% 9.435us 9.435us 1
aten::as_strided 0.01% 5.363us 0.01% 5.363us 2.682us 2
aten::empty 0.01% 5.268us 0.01% 5.268us 5.268us 1
aten::resolve_conj 0.00% 1.367us 0.00% 1.367us 0.171us 8
graph_0_cpp_fused__softmax_0 0.00% 1.224us 0.00% 1.224us 1.224us 1
----------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 39.915ms
from pytorch.
With bmm fallback, weight is converted from int8 to bf16 and Onednn uses bf16 weight. With bmm decomposition, type conversion and bmm are fused in one cpp kernel. Bmm fallback leads to the regression because the case is memory bound with batch size 1.
Synced with Jiong, it is better to decompose bmm for memory bound case in lowering.
from pytorch.
Related Issues (20)
- DISABLED test_quantization_doc_ptdq (__main__.TestQuantizationDocs) HOT 1
- DISABLED test_quantization_doc_custom (__main__.TestQuantizationDocs) HOT 1
- DISABLED test_quantization_doc_ptsq (__main__.TestQuantizationDocs) HOT 1
- DISABLED test_quantization_doc_fx (__main__.TestQuantizationDocs) HOT 2
- DISABLED test_view_and_inplace_view (__main__.TestAOTAutograd) HOT 1
- [inductor][cpu]mobilenet_v2_quantized_qat float32 single thread static/dynamic shape CPP/default wrapper performance regression in 2024-04-28 nightly release HOT 1
- [BUG]Nan in gradients of scaled_dot_product_attention operation with mem_efficient backend
- Unnecessary warning when numpy not installed
- [RFC] Add Cpp Template for GEMM related ops via max-autotune for Inductor CPU
- MAX-Autotune Compilation Time Regression Due To Added MM Configs HOT 1
- cnm
- DISABLED [WORKFLOW_NAME] / [PLATFORM_NAME] / [JOB_NAME] HOT 1
- cnm
- [Dynamo] Support tracing through _get_current_dispatch_mode_stack HOT 3
- Have config/env option to disable all PT2 caching
- [dynamo] fix nn.Module @property that accesses closure cells
- KINETO_USE_DAEMON causing issues
- `torch.compile` and complex numbers HOT 3
- Support dynamo tracing weakref obj
- Migrate multiple/custom runner labels before deprecation
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.