Meet ":0: error: loc("clip_by_norm_11/Select"): currently unsupported operand types: 'tensor<1x1xf32>' and 'tensor<?x?xf32>'" error when converts tf.Select to mhlo.select (on mlir::mhlo::createLegalizeTFPass()). Currently, TF doens't support this.
Current workload is blacklist SelectOp by "export TAO_OP_TYPE_CLUSTERING_BLACK_LIST='Select'".
The mlir ut is following:
#loc0 = loc(unknown)
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 891 : i32}} {
func @main(%arg0: tensor<?xi32> loc(unknown), %arg1: tensor<?xi32> loc(unknown), %arg2: tensor<?xi32> loc(unknown), %arg3: tensor<?xi32> loc(unknown), %arg4: tensor loc(unknown), %arg5: tensor loc(unknown), %arg6: tensor<1xi32> loc(unknown), %arg7: tensor<1xi32> loc(unknown), %arg8: tensor<1xi32> loc(unknown), %arg9: tensor<2xi32> loc(unknown), %arg10: tensor<2xi32> loc(unknown), %arg11: tensor<2xi32> loc(unknown), %arg12: tensor<2xi32> loc(unknown), %arg13: tensor<2xi32> loc(unknown), %arg14: tensor<2xi32> loc(unknown), %arg15: tensor<2xi32> loc(unknown), %arg16: tensor<2xi32> loc(unknown), %arg17: tensor<2xi32> loc(unknown), %arg18: tensor<2xi32> loc(unknown), %arg19: tensor<2xi32> loc(unknown), %arg20: tensor<2xi32> loc(unknown), %arg21: tensor<?xi32> loc(unknown), %arg22: tensor<?xi32> loc(unknown), %arg23: tensor<?xi32> loc(unknown), %arg24: tensor<?xi32> loc(unknown), %arg25: tensor<?x?xf32> loc(unknown), %arg26: tensor<?x?xf32> loc(unknown), %arg27: tensor<?x?xf32> loc(unknown), %arg28: tensor<?x?xf32> loc(unknown), %arg29: tensor<?x?xf32> loc(unknown), %arg30: tensor<?x?xf32> loc(unknown), %arg31: tensor<?x?xf32> loc(unknown), %arg32: tensor<?x?xf32> loc(unknown), %arg33: tensor<?x?xf32> loc(unknown), %arg34: tensor<?x?xf32> loc(unknown), %arg35: tensor loc(unknown), %arg36: tensor<?x?xf32> loc(unknown)) -> (tensor<?xi32>, tensor<?xi32>, tensor<?x?xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xi1>, tensor<1x1xf32>, tensor<1x1xi1>, tensor<1x1xf32>, tensor<1x1xi1>, tensor<1x1xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor, tensor<?x1xi1>) attributes {tf.entry_function = {control_outputs = "", disc.input_shape_10 = dense<0> : tensor<2xi32>, disc.input_shape_11 = dense<0> : tensor<2xi32>, disc.input_shape_12 = dense<0> : tensor<2xi32>, disc.input_shape_13 = dense<0> : tensor<2xi32>, disc.input_shape_14 = dense<0> : tensor<2xi32>, disc.input_shape_15 = dense<0> : tensor<2xi32>, disc.input_shape_16 = dense<0> : tensor<2xi32>, disc.input_shape_17 = dense<0> : tensor<2xi32>, disc.input_shape_18 = dense<0> : tensor<2xi32>, disc.input_shape_19 = dense<0> : tensor<2xi32>, disc.input_shape_20 = dense<0> : tensor<2xi32>, disc.input_shape_6 = dense<0> : tensor<1xi32>, disc.input_shape_7 = dense<0> : tensor<1xi32>, disc.input_shape_8 = dense<0> : tensor<1xi32>, disc.input_shape_9 = dense<0> : tensor<2xi32>, disc.input_value_0 = dense<> : tensor<0xi32>, disc.input_value_1 = dense<> : tensor<0xi32>, disc.input_value_2 = dense<> : tensor<0xi32>, disc.input_value_3 = dense<[0, 1]> : tensor<2xi32>, disc.input_value_4 = dense<0> : tensor, disc.input_value_5 = dense<-1> : tensor, input_placements = "gpu,gpu,gpu,gpu,gpu,gpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,cpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu", inputs = "gradients_graph_model_gnn_layer_3_mul_grad_broadcastgradientargs_1_arg,gradients_graph_model_gnn_layer_3_mul_1_grad_broadcastgradientargs_1_arg,gradients_graph_model_gnn_layer_3_mul_2_grad_broadcastgradientargs_1_arg,graph_model_gnn_layer_0_strided_slice_1_stack_tao_declustered_0_arg,graph_model_gnn_layer_0_concat_axis_tao_declustered_0_arg,graph_model_gnn_layer_0_expanddims_dim_tao_declustered_0_arg,gradients_graph_model_gnn_layer_0_embedding_lookup_grad_expanddims_0_arg,gradients_graph_model_gnn_layer_0_embedding_lookup_2_grad_expanddims_0_arg,gradients_graph_model_gnn_layer_0_embedding_lookup_4_grad_expanddims_0_arg,gradients_graph_model_gnn_layer_3_concat_1_grad_concatoffset_0_arg,gradients_graph_model_gnn_layer_3_concat_1_grad_shapen_0_arg,gradients_graph_model_gnn_layer_3_concat_1_grad_concatoffset_1_arg,gradients_graph_model_gnn_layer_3_concat_1_grad_shapen_1_arg,gradients_graph_model_gnn_layer_3_concat_1_grad_concatoffset_2_arg,gradients_graph_model_gnn_layer_3_concat_1_grad_shapen_2_arg,gradients_graph_model_gnn_layer_3_mul_grad_shape_1_0_arg,gradients_graph_model_gnn_layer_3_mul_1_grad_shape_1_0_arg,gradients_graph_model_gnn_layer_3_mul_2_grad_shape_1_0_arg,gradients_graph_model_gnn_layer_3_embedding_lookup_grad_concat_0_arg,gradients_graph_model_gnn_layer_3_embedding_lookup_2_grad_concat_0_arg,gradients_graph_model_gnn_layer_3_embedding_lookup_4_grad_concat_0_arg,graph_model_gnn_layer_0_strided_slice_7_tao_declustered_0_arg,graph_model_gnn_layer_0_strided_slice_4_tao_declustered_0_arg,graph_model_gnn_layer_0_strided_slice_10_tao_declustered_0_arg,graph_model_gnn_layer_0_concat_tao_declustered_0_arg,gradients_graph_model_gnn_layer_3_relu_grad_relugrad_0_arg,graph_model_gnn_layer_3_expanddims_0_arg,graph_model_gnn_layer_3_expanddims_1_0_arg,graph_model_gnn_layer_3_expanddims_2_0_arg,graph_model_gnn_layer_3_edge_0_weight_matmul_readvariableop_0_arg,graph_model_gnn_layer_3_embedding_lookup_0_arg,graph_model_gnn_layer_3_edge_1_weight_matmul_readvariableop_0_arg,graph_model_gnn_layer_3_embedding_lookup_2_0_arg,graph_model_gnn_layer_3_edge_2_weight_matmul_readvariableop_0_arg,graph_model_gnn_layer_3_embedding_lookup_4_0_arg,clip_by_norm_10_greater_y_0_arg,clip_by_norm_10_ones_like_0_arg", output_placements = "cpu,cpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu,gpu", outputs = "gradients/concat_7:0,gradients/graph_model/gnn_layer_0/UnsortedSegmentSum_grad/Maximum:0,gradients/concat:0,clip_by_norm_11/Select:0,clip_by_norm_12/Select:0,clip_by_norm_13/Select:0,clip_by_norm_11/Greater:0,clip_by_norm_11/Sum:0,clip_by_norm_12/Greater:0,clip_by_norm_12/Sum:0,clip_by_norm_13/Greater:0,clip_by_norm_13/Sum:0,gradients/graph_model/gnn_layer_3/Edge_0_Weight/MatMul_grad/MatMul_1:0,gradients/graph_model/gnn_layer_3/Edge_1_Weight/MatMul_grad/MatMul_1:0,gradients/graph_model/gnn_layer_3/Edge_2_Weight/MatMul_grad/MatMul_1:0,gradients/graph_model/gnn_layer_3/UnsortedSegmentSum_grad/ones_like/Const:0,gradients/graph_model/gnn_layer_0/UnsortedSegmentSum_grad/ExpandDims:0"}} {
%cst = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor loc(#loc0)
%cst_0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor loc(#loc0)
%cst_1 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> loc(#loc0)
%cst_2 = "tf.Const"() {value = dense : tensor} : () -> tensor loc(#loc1)
%0 = "tf.GreaterEqual"(%arg24, %cst_0) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?xi32>, tensor) -> tensor<?xi1> loc(#loc2)
%1 = "tf.ZerosLike"(%arg24) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?xi32>) -> tensor<?xi32> loc(#loc3)
%2 = "tf.Maximum"(%arg24, %1) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32> loc(#loc4)
%3 = "tf.GatherV2"(%arg25, %2, %cst_0) {_XlaAlreadyClustered = true, batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<?xi32>, tensor) -> tensor<?x?xf32> loc(#loc5)
%4 = "tf.Shape"(%3) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>) -> tensor<2xi32> loc(#loc6)
%5 = "tf.Fill"(%4, %cst_2) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2xi32>, tensor) -> tensor<?x?xi1> loc(#loc7)
%6 = "tf.ZerosLike"(%3) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc8)
%7 = "tf.ExpandDims"(%0, %cst) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?xi1>, tensor) -> tensor<?x1xi1> loc(#loc9)
%8 = "tf.LogicalAnd"(%7, %5) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x1xi1>, tensor<?x?xi1>) -> tensor<?x?xi1> loc(#loc10)
%9 = "tf.Select"(%8, %3, %6) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc11)
%10 = "tf.Slice"(%9, %arg9, %arg10) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32> loc(#loc12)
%11 = "tf.Slice"(%9, %arg11, %arg12) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32> loc(#loc13)
%12 = "tf.Slice"(%9, %arg13, %arg14) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32> loc(#loc14)
%13 = "tf.Reshape"(%arg23, %arg8) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32> loc(#loc15)
%14 = "tf.Reshape"(%arg22, %arg6) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32> loc(#loc16)
%15 = "tf.Reshape"(%arg21, %arg7) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32> loc(#loc17)
%16 = "tf.ConcatV2"(%14, %15, %13, %cst_0) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor) -> tensor<?xi32> loc(#loc18)
%17 = "tf.Mul"(%10, %arg26) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc19)
%18 = "tf.Reshape"(%17, %arg15) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> loc(#loc20)
%19 = "tf.MatMul"(%18, %arg29) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0", transpose_a = false, transpose_b = true} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc21)
%20 = "tf.MatMul"(%arg30, %18) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0", transpose_a = true, transpose_b = false} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc22)
%21 = "tf.Square"(%20) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc23)
%22 = "tf.Sum"(%21, %cst_1) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0", keep_dims = true} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<1x1xf32> loc(#loc24)
%23 = "tf.Greater"(%22, %arg35) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1x1xf32>, tensor) -> tensor<1x1xi1> loc(#loc25)
%24 = "tf.Select"(%23, %22, %arg36) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1x1xi1>, tensor<1x1xf32>, tensor<?x?xf32>) -> tensor<1x1xf32> loc(#loc26)
%25 = "tf.Reshape"(%19, %arg18) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> loc(#loc27)
%26 = "tf.Mul"(%11, %arg27) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc28)
%27 = "tf.Reshape"(%26, %arg16) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> loc(#loc29)
%28 = "tf.MatMul"(%27, %arg31) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0", transpose_a = false, transpose_b = true} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc30)
%29 = "tf.MatMul"(%arg32, %27) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0", transpose_a = true, transpose_b = false} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc31)
%30 = "tf.Square"(%29) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc32)
%31 = "tf.Sum"(%30, %cst_1) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0", keep_dims = true} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<1x1xf32> loc(#loc33)
%32 = "tf.Greater"(%31, %arg35) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1x1xf32>, tensor) -> tensor<1x1xi1> loc(#loc34)
%33 = "tf.Select"(%32, %31, %arg36) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1x1xi1>, tensor<1x1xf32>, tensor<?x?xf32>) -> tensor<1x1xf32> loc(#loc35)
%34 = "tf.Reshape"(%28, %arg19) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> loc(#loc36)
%35 = "tf.Mul"(%12, %arg28) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc37)
%36 = "tf.Reshape"(%35, %arg17) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> loc(#loc38)
%37 = "tf.MatMul"(%36, %arg33) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0", transpose_a = false, transpose_b = true} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc39)
%38 = "tf.MatMul"(%arg34, %36) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0", transpose_a = true, transpose_b = false} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc40)
%39 = "tf.Square"(%38) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>) -> tensor<?x?xf32> loc(#loc41)
%40 = "tf.Sum"(%39, %cst_1) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0", keep_dims = true} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<1x1xf32> loc(#loc42)
%41 = "tf.Greater"(%40, %arg35) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1x1xf32>, tensor) -> tensor<1x1xi1> loc(#loc43)
%42 = "tf.Select"(%41, %40, %arg36) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1x1xi1>, tensor<1x1xf32>, tensor<?x?xf32>) -> tensor<1x1xf32> loc(#loc44)
%43 = "tf.Reshape"(%37, %arg20) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> loc(#loc45)
%44 = "tf.ConcatV2"(%25, %34, %43, %cst_0) {_XlaAlreadyClustered = true, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor) -> tensor<?x?xf32> loc(#loc46)
return %16, %2, %44, %24, %33, %42, %23, %22, %32, %31, %41, %40, %20, %29, %38, %cst_2, %7 : tensor<?xi32>, tensor<?xi32>, tensor<?x?xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xi1>, tensor<1x1xf32>, tensor<1x1xi1>, tensor<1x1xf32>, tensor<1x1xi1>, tensor<1x1xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor, tensor<?x1xi1> loc(#loc0)
} loc(#loc0)
} loc(#loc0)
#loc1 = loc("gradients/graph_model/gnn_layer_3/UnsortedSegmentSum_grad/ones_like/Const")
#loc2 = loc("gradients/graph_model/gnn_layer_0/UnsortedSegmentSum_grad/GreaterEqual")
#loc3 = loc("gradients/graph_model/gnn_layer_0/UnsortedSegmentSum_grad/zeros_like")
#loc4 = loc("gradients/graph_model/gnn_layer_0/UnsortedSegmentSum_grad/Maximum")
#loc5 = loc("gradients/graph_model/gnn_layer_3/UnsortedSegmentSum_grad/GatherV2")
#loc6 = loc("gradients/graph_model/gnn_layer_3/UnsortedSegmentSum_grad/ones_like/Shape")
#loc7 = loc("gradients/graph_model/gnn_layer_3/UnsortedSegmentSum_grad/ones_like")
#loc8 = loc("gradients/graph_model/gnn_layer_3/UnsortedSegmentSum_grad/zeros_like_1")
#loc9 = loc("gradients/graph_model/gnn_layer_0/UnsortedSegmentSum_grad/ExpandDims")
#loc10 = loc("gradients/graph_model/gnn_layer_3/UnsortedSegmentSum_grad/and")
#loc11 = loc("gradients/graph_model/gnn_layer_3/UnsortedSegmentSum_grad/Select")
#loc12 = loc("gradients/graph_model/gnn_layer_3/concat_1_grad/Slice")
#loc13 = loc("gradients/graph_model/gnn_layer_3/concat_1_grad/Slice_1")
#loc14 = loc("gradients/graph_model/gnn_layer_3/concat_1_grad/Slice_2")
#loc15 = loc("gradients/graph_model/gnn_layer_0/embedding_lookup_4_grad/Reshape_1")
#loc16 = loc("gradients/graph_model/gnn_layer_0/embedding_lookup_grad/Reshape_1")
#loc17 = loc("gradients/graph_model/gnn_layer_0/embedding_lookup_2_grad/Reshape_1")
#loc18 = loc("gradients/concat_7")
#loc19 = loc("gradients/graph_model/gnn_layer_3/mul_grad/Mul_1")
#loc20 = loc("gradients/graph_model/gnn_layer_3/mul_grad/Reshape_1")
#loc21 = loc("gradients/graph_model/gnn_layer_3/Edge_0_Weight/MatMul_grad/MatMul")
#loc22 = loc("gradients/graph_model/gnn_layer_3/Edge_0_Weight/MatMul_grad/MatMul_1")
#loc23 = loc("clip_by_norm_11/ArithmeticOptimizer/ReplaceMulWithSquare_mul")
#loc24 = loc("clip_by_norm_11/Sum")
#loc25 = loc("clip_by_norm_11/Greater")
#loc26 = loc("clip_by_norm_11/Select")
#loc27 = loc("gradients/graph_model/gnn_layer_3/embedding_lookup_grad/Reshape")
#loc28 = loc("gradients/graph_model/gnn_layer_3/mul_1_grad/Mul_1")
#loc29 = loc("gradients/graph_model/gnn_layer_3/mul_1_grad/Reshape_1")
#loc30 = loc("gradients/graph_model/gnn_layer_3/Edge_1_Weight/MatMul_grad/MatMul")
#loc31 = loc("gradients/graph_model/gnn_layer_3/Edge_1_Weight/MatMul_grad/MatMul_1")
#loc32 = loc("clip_by_norm_12/ArithmeticOptimizer/ReplaceMulWithSquare_mul")
#loc33 = loc("clip_by_norm_12/Sum")
#loc34 = loc("clip_by_norm_12/Greater")
#loc35 = loc("clip_by_norm_12/Select")
#loc36 = loc("gradients/graph_model/gnn_layer_3/embedding_lookup_2_grad/Reshape")
#loc37 = loc("gradients/graph_model/gnn_layer_3/mul_2_grad/Mul_1")
#loc38 = loc("gradients/graph_model/gnn_layer_3/mul_2_grad/Reshape_1")
#loc39 = loc("gradients/graph_model/gnn_layer_3/Edge_2_Weight/MatMul_grad/MatMul")
#loc40 = loc("gradients/graph_model/gnn_layer_3/Edge_2_Weight/MatMul_grad/MatMul_1")
#loc41 = loc("clip_by_norm_13/ArithmeticOptimizer/ReplaceMulWithSquare_mul")
#loc42 = loc("clip_by_norm_13/Sum")
#loc43 = loc("clip_by_norm_13/Greater")
#loc44 = loc("clip_by_norm_13/Select")
#loc45 = loc("gradients/graph_model/gnn_layer_3/embedding_lookup_4_grad/Reshape")
#loc46 = loc("gradients/concat")
How to repro:
save the above string to select.mlir
tf-opt -xla-legalize-tf select.mlir
I cannot attach input proto. I would upload to dingding group if needed.
What is the solution(pls correct me if I'm wrong)?
- in ConvertSelectOp, return fail when then or else operand is dynamic?
- infer then or else operand's shape before createLegalizeTFPass or in ConvertSelectOp?