Giter Site home page Giter Site logo

Comments (30)

bjacob avatar bjacob commented on June 8, 2024 4

It's ok, I think I have the patch ready soon.

from shark.

bjacob avatar bjacob commented on June 8, 2024 2

llvm/llvm-project#83180 is merged, so you'll get it in the next integrate or can cherry-pick it locally until then to verify it fixed your issue.

from shark.

Shukla-Gaurav avatar Shukla-Gaurav commented on June 8, 2024 1

I think this is not really a codegen issue. This is really a bf16 issue . Are we comfortable closing this, or do we need to do more here.

I think we should not close this unless we can conclude on handling bf16 in cpu. I mean how to verify the model is producing the correct outputs through onnx pipeline.

from shark.

MaheshRavishankar avatar MaheshRavishankar commented on June 8, 2024

Can you give me the run module command that shows the error... That will help repro the error for me. For example the %1 above is strange (correct still and shouldnt affect correctness) but will help triage.

from shark.

Shukla-Gaurav avatar Shukla-Gaurav commented on June 8, 2024

@MaheshRavishankar

  1. Following are the commands:
~/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu  conv2d.linalg.mlir > conv2d.bf16.vmfb 2>iree-compile.log
~/iree-build/tools/iree-run-module --module=conv2d.bf16.vmfb --input="2x8x12x16xbf16=@inference_input.0.bin.txt"  
 > inference.log

I am attaching the inference_input.0.bin, which is all 1's.
inference_input.0.bin.txt

  1. The iree-run-module is running fine but the result of iree-run-module mismatches with the result of conv2d pytorch module.
    (The input is all 1's in both the cases).
import torch
import torch.nn as nn

class op_conv2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(8, 10, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
        )
    def forward(self, x):
        return self.layers(x)

model = op_conv2d()
model_bf16 = model.to(torch.bfloat16)
test_input_bf16 = torch.ones(2, 8, 12, 16).to(torch.bfloat16)
test_output_bf16 = model_bf16(test_input_bf16)
print("Input:", test_input_bf16)
print("Output:", test_output_bf16)
  1. You can also use the repo https://github.com/nod-ai/SHARK-TestSuite to test/cross-check conv2d results with your fix:
git clone https://github.com/nod-ai/SHARK-TestSuite
activate your iree_venv or torch_mlir_venv
cd e2eshark
python ./run.py --runupto inference --mode onnx -c ~/torch-mlir/build -i ~/iree-build --tests pytorch/operators/conv2d/ --hfhome ~/hf_home/ --verbose -d bf16 -r test-conv2d-bf16 

from shark.

Shukla-Gaurav avatar Shukla-Gaurav commented on June 8, 2024

conv2d.bf16.linalg.mlir.txt
conv2d.fp32.linalg.mlir.txt
iree-compile-conv2d-bf16.log
iree-compile-conv2d-fp32.log

from shark.

Shukla-Gaurav avatar Shukla-Gaurav commented on June 8, 2024

Running conv2d with different precisions, keeping all the constants(weight and bias) same.
conv2d.bf16.compile.log
conv2d.fp32.compile.log

from shark.

MaheshRavishankar avatar MaheshRavishankar commented on June 8, 2024

I think this is not really a codegen issue. This is really a bf16 issue . Are we comfortable closing this, or do we need to do more here.

from shark.

Shukla-Gaurav avatar Shukla-Gaurav commented on June 8, 2024
  1. Linear Module for reference output:
    (weight, bias and input has been fixed to simplify comparison in IRs and outputs.
    Also all these values fits in bf16, so x.to(torch.bfloat16) won't change values).
import torch
import torch.nn as nn
class op_linear(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_layer = nn.Linear(3, 4)
        self.linear_layer.weight = nn.Parameter(
            torch.tensor([[-0.4199,  0.4180, -0.0293],
                        [ 0.4297, -0.4434,  0.0162],
                        [-0.2061, -0.4004,  0.2773],
                        [-0.5469,  0.0449,  0.3242]], 
                        dtype=torch.float32), 
            requires_grad=False)
        self.linear_layer.bias = nn.Parameter(
            torch.tensor([ 0.2236,  0.3184, -0.1709, -0.4883], 
                        dtype=torch.float32),
            requires_grad=False)
        self.layers = nn.Sequential(self.linear_layer)

    def forward(self, x):
        return self.layers(x)

model = op_linear()
# fp32 computation.
test_input = torch.tensor(
        [[-0.4062, -0.6953,  1.8516],
        [ 0.4961,  0.9609, -0.7500],
        [ 0.4766, -1.4531,  0.6172],
        [-0.4785,  1.0859, -0.9922],
        [ 2.1094, -0.0107,  0.3496],
        [-0.6562, -0.0116,  1.7812],
        [ 0.0114, -0.1279,  1.7266],
        [-0.1289,  0.6250,  1.3516]],
        dtype=torch.float32)
output = model(test_input)

# bf16 computation.
test_input.to(torch.bfloat16)
model.to(torch.bfloat16)
bf16_golden_output = model(test_input)
  1. Linalg IR:
    linear.bf16.linalg.mlir.txt
    linear.fp32.linalg.mlir.txt

  2. Applying --iree-global-opt-enable-demote-contraction-inputs-to-bf16 in fp32 IR and --iree-llvmcpu-enable-ukernels=all in bf16 IR to compare output from different paths.

/home/gaurav/MLIRepos/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu --iree-global-opt-enable-demote-contraction-inputs-to-bf16  --iree-input-type=tm_tensor --mlir-print-ir-after-all --mlir-disable-threading linear.fp32.linalg.mlir > linear.contraction.vmfb 2>contraction-flag-linear-fp32-iree_compile.log
/home/gaurav/MLIRepos/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-enable-ukernels=all --mlir-print-ir-after-all --mlir-disable-threading  --iree-input-type=tm_tensor linear.bf16.linalg.mlir > linear.ukernel.vmfb 2>ukernel-flag-linear-bf16-iree_compile.log

contraction-flag-linear-fp32-iree_compile.log
ukernel-flag-linear-bf16-iree_compile.log

/home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=linear.contraction.vmfb --input="8x3xf32=@inference_input.0.bin.txt"
/home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=linear.ukernel.vmfb --input="8x3xbf16=@inference_input_bf16.0.bin.txt"

inference_input.0.bin.txt
inference_input_bf16.0.bin.txt

  1. Comparing different outputs
f32_golden_output=
(=model(input).to(bf16))
tensor([[ 0.0493,  0.4824,  0.7031,  0.3027],
        [ 0.4395,  0.0933, -0.8672, -0.9609],
        [-0.6016,  1.1797,  0.4844, -0.6133],
        [ 0.9062, -0.3848, -0.7812, -0.5000],
        [-0.6758,  1.2344, -0.5039, -1.5312],
        [ 0.4414,  0.0703,  0.4629,  0.4473],
        [ 0.1147,  0.4082,  0.3574,  0.0596],
        [ 0.5000,  0.0078, -0.0198,  0.0483]], dtype=torch.bfloat16)

bf16_golden_output=
(=model.to(bf16); input.to(bf16); model(input))
tensor([[ 0.0493,  0.4824,  0.7031,  0.3027],
        [ 0.4395,  0.0933, -0.8672, -0.9609],
        [-0.6016,  1.1797,  0.4844, -0.6133],
        [ 0.9062, -0.3848, -0.7812, -0.5000],
        [-0.6758,  1.2344, -0.5039, -1.5312],
        [ 0.4414,  0.0703,  0.4629,  0.4473],
        [ 0.1147,  0.4082,  0.3574,  0.0596],
        [ 0.5000,  0.0078, -0.0198,  0.0486]], dtype=torch.bfloat16)

contraction_flag_output= 
(this is quite close to golden output but we trace f32 model(f32 linalg IR) and only demote certain ops to bf16.)
tensor([[ 0.0493,  0.4824,  0.7031,  0.3027],
        [ 0.4395,  0.0933, -0.8672, -0.9609],
        [-0.6016,  1.1797,  0.4844, -0.6133],
        [ 0.9062, -0.3848, -0.7812, -0.5000],
        [-0.6758,  1.2344, -0.5039, -1.5312],
        [ 0.4414,  0.0703,  0.4629,  0.4473],
        [ 0.1147,  0.4082,  0.3574,  0.0596],
        [ 0.5000,  0.0079, -0.0198,  0.0486]], dtype=torch.bfloat16)

ukernel_flag_output= (this is same as iree bf16 inference output without using this flag)
(the output mismatches with golden output.)
tensor([[ 0.0498,  0.4824,  0.7031,  0.3047],
        [ 0.4395,  0.0938, -0.8672, -0.9570],
        [-0.6055,  1.1797,  0.4863, -0.6133],
        [ 0.9062, -0.3848, -0.7812, -0.7500],
        [-0.6797,  1.2344, -0.5039, -1.5234],
        [ 0.4434,  0.0723,  0.4629,  0.4492],
        [ 0.1147,  0.4082,  0.3574,  0.0586],
        [ 0.5000,  0.0078, -0.0195,  0.0469]], dtype=torch.bfloat16)

print(torch.allclose(bf16_golden_output, contraction_flag_output, atol=1e-04, rtol=1e-03)) // True
print(torch.allclose(bf16_golden_output, ukernel_flag_output, atol=1e-04, rtol=1e-03)) //False

Not sure how pytorch handles the bf16 computation but it's close to =>(f32 computation then result demoted to bf16)
(NOTE: all the inputs are small enough to fit in bf16 type).
For the same inputs, iree bf16 inference result mismatches with the above mentioned tolerances(the ukernel_flag_output).
(Although the contraction_flag_output is close but we want to start with a bf16 model not f32 model)

This problem requires attention at how iree handles bf16 loads in cpu backends? Thoughts?
@stellaraccident @kumardeepakamd @MaheshRavishankar @bjacob

I can create an onnx linear module with the same inputs and run it on onnx runtime to have one more reference output if it helps. Thanks!

from shark.

MaheshRavishankar avatar MaheshRavishankar commented on June 8, 2024

Really I dont know what the compiler itself can decide here.. this always going to mismatch cause the reference is doing different things (and different references will do different things). What IREE is doing is basically "do what it is told to do". If the linalg op says the input type is bf16 and output type is bf16, it is actually doing the accumulation in f32. IMO that is actually being too smart. It should be doing the accumulation in bf16 as well (cause thats what it was told to do). So only AI I can think of is to actually make IREE less smart.

from shark.

stellaraccident avatar stellaraccident commented on June 8, 2024

Usually when I've seen these kinds of issues resolved before, it requires a much more careful drill down vs a high level this vs that. There is no abstract answer to these things at that precision: a rounding/truncation mode difference or 1ULP difference at any stage is enough to result in a 1% error for a datatype like this. If trying to get complete correspondence, then none of that can be ignored.

You'll have to dig deeper, and likely if you are still looking at results as textual floating point, you'll miss the difference.

from shark.

stellaraccident avatar stellaraccident commented on June 8, 2024

Usually when I've seen these kinds of issues resolved before, it requires a much more careful drill down vs a high level this vs that. There is no abstract answer to these things at that precision: a rounding/truncation mode difference or 1ULP difference at any stage is enough to result in a 1% error for a datatype like this. If trying to get complete correspondence, then none of that can be ignored.

You'll have to dig deeper, and likely if you are still looking at results as textual floating point, you'll miss the difference.

One advantage to having an ONNX reference is that it is much easier to hack on at that level than PyTorch (i.e. you can build it, set a breakpoint or print specific values in a kernel, etc).

from shark.

Shukla-Gaurav avatar Shukla-Gaurav commented on June 8, 2024

It seems the mismatch is due to different rounding mechanisms used by pytorch and IREE. I ran few simple add/mul tests, and it's mostly 1-bit difference in the outputs. The pytorch simply truncates the last 16 bits after computing the result while IREE seems to be rounding it.
Following example can illustrate the above behavior:

  1. Multiplication (1.3667e+30 * 5) ~ 6.8335e+30
    Attaching linalg IR and iree-compile log.
    mul.linalg.mlir.txt
    mul-bf16-iree_compile.log
    /home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=mul.vmfb --input="1xbf16=[5]"
pytorch-cpu output: 6.8136e+30
IREE-cpu output: 6.8532e+30 

Their binary representation differs exactly in the last bit.
If we do f32 multiplication and remove the last 16-bits, it will give exact same output as pytorch.
I used following c++ code to check 16-bits truncation result:

#include <iostream>
struct bfloat16{
   unsigned short int data;
   public:
   bfloat16(){
      data = 0;
   }
   //cast to float
   operator float(){
      unsigned int proc = data<<16;
      return *reinterpret_cast<float*>(&proc);
   }
   //cast to bfloat16
   bfloat16& operator =(float float_val){
      data = (*reinterpret_cast<unsigned int *>(&float_val))>>16;
      return *this;
   }
};

//an example that enumerates all the possible values between 1.0f and 300.0f
using namespace std;

int main(){
   bfloat16 x;
   x = 6.8335e+30;
   // for(x=1.0f;x<300.0f;x.data++){
   cout<<x.data<<" "<<x<<endl;
   // }
   return 0;
}

from shark.

bjacob avatar bjacob commented on June 8, 2024

Great analysis, thanks! Indeed, the f32 value 6.8335e+30 has the binary encoding 0x72ac8075.
Truncating this to bf16 means replacing this by a value that is a multiple of 0x10000, so the two conceivable candidates are 0x72ac000 or 0x72ad000. These are respectively 6.81362e+30 and 6.85324e+30. However, as the next digit after 0x72ac... is a 8 (the 8 in 0x72ac8075), the nearest value really is the latter, 0x72ad000 = 6.85324e+30. And this isn't even a tie (it would be a tie if the value were 0x72ac8000), so there are no questions of tie-breaks here ("tie to nearest-even"). So there is no question that the nearest value is 0x72ad000 = 6.85324e+30. So, IREE is being correct here, and PyTorch is incorrect. What PyTorch does here has a name, "rounding towards zero", it can be sometimes useful for some really specialized uses, but doesn't make sense as the default way to round all values in a workload.

from shark.

bjacob avatar bjacob commented on June 8, 2024

I don't know what code path is used in the code that you ran, but I checked this PyTorch f32 -> bf16 rounding helper,
https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16.h#L76

And it does return the correct result,

int main() {
    float f = 6.8335e+30;
    printf("round_to_nearest_even(%g) = 0x%x\n", f, round_to_nearest_even(f));
}

prints

round_to_nearest_even(6.8335e+30) = 0x72ad

And not 0x72ac as in your above PyTorch result. So, it seems like different code paths within PyTorch don't agree with each other, and at least this one agrees with us.

from shark.

bjacob avatar bjacob commented on June 8, 2024

I used following c++ code to check 16-bits truncation result:

   //cast to bfloat16
   bfloat16& operator =(float float_val){
      data = (*reinterpret_cast<unsigned int *>(&float_val))>>16;
      return *this;
   }

This implements the same incorrect rounding-towards-zero as we discussed above. Just dropping the bottom 16 bits like this fails to account for the possibiliy that their value might be >= 0x8000 requiring rounding upwards to the next representable value.

(Side note: this also has undefined behavior in C++, as an unsigned int object coexists with a float object at the same memory location. The only way to implement a bitcast like this in C++ prior to C++20 is to copy data with something like memcpy, or go down to one of the few POD types that support aliasing, such as char, unsigned char or std::byte. Or if using C++20, use the new std::bitcast for that).

from shark.

Shukla-Gaurav avatar Shukla-Gaurav commented on June 8, 2024

@bjacob Thanks for the explanation! For multiplication example, iree output is more closer as pytorch is simply truncating the result. Can we add a functionality to explicitly mention the rounding mechanism we want in IREE, as we need pytorch results as our reference for the model outputs?

And I did following for pytorch bf16 multiplication:

>>> x = torch.tensor([1.3667e30], dtype=torch.bfloat16)
>>> y = torch.tensor([5], dtype=torch.bfloat16)
>>> x*y
tensor([6.8136e+30], dtype=torch.bfloat16)

from shark.

Shukla-Gaurav avatar Shukla-Gaurav commented on June 8, 2024

I also got a weird example, @bjacob
2. Addition (-0.0112 - 0.4882) ~ -0.4994
Attaching linalg IR and iree-compile log.
add.linalg.mlir.txt
add-bf16-iree_compile.log
/home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=add.vmfb --input="1xbf16=[-0.4882]"

pytorch-cpu output: -0.5
IREE-cpu output: -0.75 

from shark.

bjacob avatar bjacob commented on June 8, 2024

Wow, funny bug that you found here! It appears to be a parsing bug, in how iree-run-module parses the --input flag. Indeed, it produces expected results when the specified array element has no more than two digits after the decimal point, and it reproduces whenever it has 3 or more digits after the decimal point.

~/iree-build tools/iree-run-module --module=/tmp/add.vmfb --input="1xbf16=[-0.48]"
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.492188
~/iree-build tools/iree-run-module --module=/tmp/add.vmfb --input="1xbf16=[-0.49]"
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.5
~/iree-build tools/iree-run-module --module=/tmp/add.vmfb --input="1xbf16=[-0.488]"
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.75

FYI @benvanik

from shark.

bjacob avatar bjacob commented on June 8, 2024

The parsing itself is correct, though - iree_hal_parse_element_unsafe does parse the correct value and its caller iree_hal_parse_buffer_elements does store it in the destination buffer.

And yet, something is producing incorrect results only when the --input parameter specified more than two decimals...

from shark.

bjacob avatar bjacob commented on June 8, 2024

The bug reproduces whenever the specified --input element rounds to -0.488281 as a bfloat16 (encoding 0xbefa). It does not reproduce whenever it rounds to the previous bfloat16 value -0.486328 (encoding 0xbef9). In both cases, our f32 <-> bfloat16 conversion helpers produce correct results, and the parsing is correct too as noted above. Just for some reason, the bfloat16 value -0.488281 (encoding 0xbefa) runs into some arithmetic bug elsewhere.

from shark.

bjacob avatar bjacob commented on June 8, 2024

And the other operand, which is hardcoded as a constant in the above testcase, also matters. Here is a testcase taking both operands as arguments:

#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (0)>
#map2 = affine_map<(d0) -> (d0)>
module {
  func.func @main_graph(%arg0: tensor<1xbf16>, %arg1: tensor<1xbf16>) -> tensor<1xbf16> {
    %0 = tensor.empty() : tensor<1xbf16>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xbf16>, tensor<1xbf16>) outs(%0 : tensor<1xbf16>) {
    ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
      %2 = arith.addf %in, %in_0 : bf16
      linalg.yield %2 : bf16
    } -> tensor<1xbf16>
    return %1 : tensor<1xbf16>
  }
}

With that, I find that for this to reproduce, that other operand needs to be bfloat16 0xbc31 or greater-negative (decimal value -0.010832).

from shark.

bjacob avatar bjacob commented on June 8, 2024

This actually minimizes down to a testcase that performs no bfloat16 arithmetic and only a f32->bfloat16 truncf:

#map = affine_map<(d0) -> (d0)>
module {
  func.func @main_graph(%arg0: tensor<1xf32>) -> tensor<1xbf16> {
    %0 = tensor.empty() : tensor<1xbf16>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs(%0 : tensor<1xbf16>) {
    ^bb0(%in0: f32, %out: bf16):
      %3 = arith.truncf %in0 : f32 to bf16
      linalg.yield %3 : bf16
    } -> tensor<1xbf16>
    return %1 : tensor<1xbf16>
  }
}
~/iree-build tools/iree-run-module --module=/tmp/repro2.vmfb --input="1xf32=[-0.499081]" --device=local-task
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.75

from shark.

bjacob avatar bjacob commented on June 8, 2024

@rsuderman this might be for you :-)

What --mlir-print-ir-after-all shows for the testcase in the previous comment:

// -----// IR Dump After CSE (cse) //----- //
module {
  func.func @main_graph_dispatch_0_generic() {
    %c0 = arith.constant 0 : index
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<f32>
    memref.assume_alignment %0, 64 : memref<f32>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<i16>
    memref.assume_alignment %1, 64 : memref<i16>
    %2 = memref.load %0[] : memref<f32>
    %3 = arith.truncf %2 : f32 to bf16
    %4 = arith.bitcast %3 : bf16 to i16
    memref.store %4, %1[] : memref<i16>
    return
  }
}

// -----// IR Dump After ConvertToLLVM (iree-convert-to-llvm) //----- //
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-unknown-eabi-elf"} {
  llvm.func @main_graph_dispatch_0_generic(%arg0: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}, %arg1: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}, %arg2: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}) -> i32 {
    %0 = llvm.mlir.constant(0 : i32) : i32
    %1 = llvm.mlir.constant(16 : i32) : i32
    %2 = llvm.mlir.constant(32768 : i32) : i32
    %3 = llvm.mlir.constant(2130706432 : i32) : i32
    %4 = llvm.mlir.constant(2139095040 : i32) : i32
    %5 = llvm.mlir.constant(8388607 : i32) : i32
    %6 = llvm.mlir.constant(31 : i32) : i32
    %7 = llvm.mlir.constant(23 : i32) : i32
    %8 = llvm.mlir.constant(63 : index) : i64
    %9 = llvm.mlir.constant(0 : index) : i64
    %10 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
    %11 = llvm.extractvalue %10[10] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)> 
    %12 = llvm.load %11 : !llvm.ptr -> !llvm.ptr
    %13 = llvm.ptrtoint %12 : !llvm.ptr to i64
    %14 = llvm.and %13, %8  : i64
    %15 = llvm.icmp "eq" %14, %9 : i64
    "llvm.intr.assume"(%15) : (i1) -> ()
    %16 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
    %17 = llvm.extractvalue %16[10] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)> 
    %18 = llvm.getelementptr %17[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
    %19 = llvm.load %18 : !llvm.ptr -> !llvm.ptr
    %20 = llvm.ptrtoint %19 : !llvm.ptr to i64
    %21 = llvm.and %20, %8  : i64
    %22 = llvm.icmp "eq" %21, %9 : i64
    "llvm.intr.assume"(%22) : (i1) -> ()
    %23 = llvm.load %12 : !llvm.ptr -> f32
    %24 = llvm.bitcast %23 : f32 to i32
    %25 = llvm.lshr %24, %6  : i32
    %26 = llvm.sub %2, %25  : i32
    %27 = llvm.and %24, %5  : i32
    %28 = llvm.add %27, %26  : i32
    %29 = llvm.lshr %28, %7  : i32
    %30 = llvm.lshr %28, %29  : i32
    %31 = llvm.and %24, %4  : i32
    %32 = llvm.add %31, %28  : i32
    %33 = llvm.and %32, %4  : i32
    %34 = llvm.icmp "uge" %31, %3 : i32
    %35 = llvm.select %34, %31, %33 : i1, i32
    %36 = llvm.trunc %29 : i32 to i1
    %37 = llvm.and %34, %36  : i1
    %38 = llvm.select %37, %27, %30 : i1, i32
    %39 = llvm.shl %25, %6  : i32
    %40 = llvm.or %39, %35  : i32
    %41 = llvm.or %40, %38  : i32
    %42 = llvm.lshr %41, %1  : i32
    %43 = llvm.trunc %42 : i32 to i16
    llvm.store %43, %19 : i16, !llvm.ptr
    llvm.return %0 : i32
  }
}

In the first part of the above log, our arith.truncf op is still as in the original source. In the second part, IR Dump After ConvertToLLVM (iree-convert-to-llvm) , it has been lowered to llvm ops implementing the truncation. The testcase shows that this lowering produces incorrect results when the value being truncated, here -0.499081, is just below an exponent-threshold and its rounding to the nearest bfloat16 value makes it cross an exponent-threshold (-0.499081 becomes -0.5, bumping the exponent).

from shark.

bjacob avatar bjacob commented on June 8, 2024

@rsuderman , here is what the equivalent f32->bf16 truncation code does in the runtime (actually it is generic in bit-widths, but it in particular does f32->bf16) specifically to fix-up in this specific case:

https://github.com/openxla/iree/blob/01c4c57/runtime/src/iree/base/internal/math.h#L389-L390

from shark.

bjacob avatar bjacob commented on June 8, 2024

@rsuderman , here is the much more concise and optimized way that the PyTorch runtime does it (I think that part was written by Marat and carried over from XNNPACK or some predecessor of it):
https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L76

In the above-linked runtime code, I didn't bother to implement this magic trick because I wanted genericity and didn't need to chase performance. But in the compiler lowering, it would make sense to do the concise and efficient thing.

The link to IREE math.h in the previous comment has a comment explaining the magic trick here.

        // Note: software implementations that try to be fast tend to get this
        // conditional increment of exp and zeroing of mantissa for free by
        // simplying incrementing the whole uint32 encoding of the float value,
        // so that the mantissa overflows into the exponent bits.

from shark.

rsuderman avatar rsuderman commented on June 8, 2024

@rsuderman , here is the much more concise and optimized way that the PyTorch runtime does it (I think that part was written by Marat and carried over from XNNPACK or some predecessor of it): pytorch/pytorch@e1502c0/c10/util/BFloat16.h#L76

In the above-linked runtime code, I didn't bother to implement this magic trick because I wanted genericity and didn't need to chase performance. But in the compiler lowering, it would make sense to do the concise and efficient thing.

The link to IREE math.h in the previous comment has a comment explaining the magic trick here.

        // Note: software implementations that try to be fast tend to get this
        // conditional increment of exp and zeroing of mantissa for free by
        // simplying incrementing the whole uint32 encoding of the float value,
        // so that the mantissa overflows into the exponent bits.

Great :/, I was pretty sure I had managed to implement the rounding behavior correctly but I did not have an aggressive test case to evaluate with. I assume this means there is an error in our bf16 truncf implementation? I am not sure I have time in the near feauture to debug the exact bit errors, is it possible you could take a look?

from shark.

bjacob avatar bjacob commented on June 8, 2024

@Shukla-Gaurav , this seems to work. I'll fix up any unit test that fails and send that for review @rsuderman .
llvm/llvm-project#83180

from shark.

Shukla-Gaurav avatar Shukla-Gaurav commented on June 8, 2024

Thanks a lot @bjacob for actively working on this. Will try the patch with other test cases/models as well.

from shark.

kumardeepakamd avatar kumardeepakamd commented on June 8, 2024

from shark.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.