forked from OSchip/llvm-project
[mlir][linalg] add conv_2d_nhwc_fhwc named op
This operation should be supported as a named op because when the operands are viewed as having canonical layouts with decreasing strides, then the "reduction" dimensions of the filter (h, w, and c) are contiguous relative to each output channel. When lowered to a matrix multiplication, this layout is the simplest to deal with, and thus future transforms/vectorizations of `conv2d` may find using this named op convenient. Differential Revision: https://reviews.llvm.org/D126995
This commit is contained in:
parent
66bd14697b
commit
cca662b849
|
@ -1311,6 +1311,104 @@ structured_op: !LinalgStructuredOpConfig
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: K
|
scalar_arg: K
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
|
metadata: !LinalgOpMetadata
|
||||||
|
name: conv_2d_nhwc_fhwc
|
||||||
|
cpp_class_name: Conv2DNhwcFhwcOp
|
||||||
|
doc: |-
|
||||||
|
Performs 2-D convolution.
|
||||||
|
|
||||||
|
Layout:
|
||||||
|
* Input: NHWC.
|
||||||
|
* Kernel: FHWC.
|
||||||
|
|
||||||
|
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||||
|
them to the same data type as the accumulator/output.
|
||||||
|
implements:
|
||||||
|
- LinalgConvolutionOpInterface
|
||||||
|
structured_op: !LinalgStructuredOpConfig
|
||||||
|
args:
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: I
|
||||||
|
kind: input_tensor
|
||||||
|
type_var: T1
|
||||||
|
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
|
||||||
|
s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: K
|
||||||
|
kind: input_tensor
|
||||||
|
type_var: T2
|
||||||
|
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10, s3,
|
||||||
|
s7, s9)>
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: O
|
||||||
|
kind: output_tensor
|
||||||
|
type_var: U
|
||||||
|
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
|
||||||
|
s1, s5, s10)>
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: strides
|
||||||
|
kind: index_attr
|
||||||
|
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
|
||||||
|
(s2, s6)>
|
||||||
|
default_indices:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: dilations
|
||||||
|
kind: index_attr
|
||||||
|
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
|
||||||
|
(s4, s8)>
|
||||||
|
default_indices:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
indexing_maps: !LinalgIndexingMapsConfig
|
||||||
|
static_indexing_maps:
|
||||||
|
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||||
|
s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)>
|
||||||
|
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||||
|
s9, s10] -> (d3, d4, d5, d6)>
|
||||||
|
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||||
|
s9, s10] -> (d0, d1, d2, d3)>
|
||||||
|
iterator_types:
|
||||||
|
- parallel
|
||||||
|
- parallel
|
||||||
|
- parallel
|
||||||
|
- parallel
|
||||||
|
- reduction
|
||||||
|
- reduction
|
||||||
|
- reduction
|
||||||
|
assignments:
|
||||||
|
- !ScalarAssign
|
||||||
|
arg: O
|
||||||
|
value: !ScalarExpression
|
||||||
|
scalar_fn:
|
||||||
|
kind: binary
|
||||||
|
fn_name: add
|
||||||
|
operands:
|
||||||
|
- !ScalarExpression
|
||||||
|
scalar_arg: O
|
||||||
|
- !ScalarExpression
|
||||||
|
scalar_fn:
|
||||||
|
kind: binary
|
||||||
|
fn_name: mul
|
||||||
|
operands:
|
||||||
|
- !ScalarExpression
|
||||||
|
scalar_fn:
|
||||||
|
kind: type
|
||||||
|
fn_name: cast_signed
|
||||||
|
type_var: U
|
||||||
|
operands:
|
||||||
|
- !ScalarExpression
|
||||||
|
scalar_arg: I
|
||||||
|
- !ScalarExpression
|
||||||
|
scalar_fn:
|
||||||
|
kind: type
|
||||||
|
fn_name: cast_signed
|
||||||
|
type_var: U
|
||||||
|
operands:
|
||||||
|
- !ScalarExpression
|
||||||
|
scalar_arg: K
|
||||||
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: conv_2d_nhwc_hwcf_q
|
name: conv_2d_nhwc_hwcf_q
|
||||||
cpp_class_name: Conv2DNhwcHwcfQOp
|
cpp_class_name: Conv2DNhwcHwcfQOp
|
||||||
|
|
|
@ -178,6 +178,38 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @conv_2d_nhwc_fhwc
|
||||||
|
func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||||
|
// CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc
|
||||||
|
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||||
|
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||||
|
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||||
|
// CHECK-SAME: outs(%{{.+}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
|
||||||
|
strides = dense<1> : tensor<2xi64>}
|
||||||
|
ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||||
|
outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
return %0 : tensor<?x?x?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @conv_2d_nhwc_fhwc_static
|
||||||
|
func.func @conv_2d_nhwc_fhwc_static(%input: tensor<?x128x128x32xf32>, %filter: tensor<64x3x3x32xf32>, %init: tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32> {
|
||||||
|
// CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc
|
||||||
|
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||||
|
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||||
|
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x128x128x32xf32>, tensor<64x3x3x32xf32>)
|
||||||
|
// CHECK-SAME: outs(%{{.+}} : tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32>
|
||||||
|
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
|
||||||
|
strides = dense<1> : tensor<2xi64>}
|
||||||
|
ins (%input, %filter: tensor<?x128x128x32xf32>, tensor<64x3x3x32xf32>)
|
||||||
|
outs (%init: tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32>
|
||||||
|
return %0 : tensor<?x126x126x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @conv_2d_nhwc_hwcf
|
// CHECK-LABEL: func @conv_2d_nhwc_hwcf
|
||||||
func.func @conv_2d_nhwc_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
|
func.func @conv_2d_nhwc_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
|
||||||
// CHECK: linalg.conv_2d_nhwc_hwcf
|
// CHECK: linalg.conv_2d_nhwc_hwcf
|
||||||
|
|
Loading…
Reference in New Issue