From cca662b8495537b67f7b70090a33fc36154de81c Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Thu, 2 Jun 2022 20:06:24 -0600 Subject: [PATCH] [mlir][linalg] add conv_2d_nhwc_fhwc named op This operation should be supported as a named op because when the operands are viewed as having canonical layouts with decreasing strides, then the "reduction" dimensions of the filter (h, w, and c) are contiguous relative to each output channel. When lowered to a matrix multiplication, this layout is the simplest to deal with, and thus future transforms/vectorizations of `conv2d` may find using this named op convenient. Differential Revision: https://reviews.llvm.org/D126995 --- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 98 +++++++++++++++++++ mlir/test/Dialect/Linalg/named-ops.mlir | 32 ++++++ 2 files changed, 130 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 3cc8f8e32cc9..522bec689f12 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1311,6 +1311,104 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: K --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: conv_2d_nhwc_fhwc + cpp_class_name: Conv2DNhwcFhwcOp + doc: |- + Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, + s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10, s3, + s7, s9)> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, + s1, s5, s10)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s2, s6)> + default_indices: + - 1 + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s4, s8)> + default_indices: + - 1 + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> (d3, d4, d5, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> (d0, d1, d2, d3)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: K +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf_q cpp_class_name: Conv2DNhwcHwcfQOp diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 992da7e80ad0..b984a1baaa94 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -178,6 +178,38 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor, %filter: tensor, %filter: tensor, %init: tensor) -> tensor { + // CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc + // CHECK-SAME: dilations = dense<1> : tensor<2xi64> + // CHECK-SAME: strides = dense<1> : tensor<2xi64> + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor) + // CHECK-SAME: outs(%{{.+}} : tensor) -> tensor + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @conv_2d_nhwc_fhwc_static +func.func @conv_2d_nhwc_fhwc_static(%input: tensor, %filter: tensor<64x3x3x32xf32>, %init: tensor) -> tensor { + // CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc + // CHECK-SAME: dilations = dense<1> : tensor<2xi64> + // CHECK-SAME: strides = dense<1> : tensor<2xi64> + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor<64x3x3x32xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor) -> tensor + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<64x3x3x32xf32>) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @conv_2d_nhwc_hwcf func.func @conv_2d_nhwc_hwcf(%input: memref, %filter: memref, %output: memref) { // CHECK: linalg.conv_2d_nhwc_hwcf