forked from OSchip/llvm-project
[mlir][linalg] add conv_2d_nhwc_fhwc named op
This operation should be supported as a named op because when the operands are viewed as having canonical layouts with decreasing strides, then the "reduction" dimensions of the filter (h, w, and c) are contiguous relative to each output channel. When lowered to a matrix multiplication, this layout is the simplest to deal with, and thus future transforms/vectorizations of `conv2d` may find using this named op convenient. Differential Revision: https://reviews.llvm.org/D126995
This commit is contained in:
parent
66bd14697b
commit
cca662b849
|
@ -1311,6 +1311,104 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_arg: K
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: conv_2d_nhwc_fhwc
|
||||
cpp_class_name: Conv2DNhwcFhwcOp
|
||||
doc: |-
|
||||
Performs 2-D convolution.
|
||||
|
||||
Layout:
|
||||
* Input: NHWC.
|
||||
* Kernel: FHWC.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
implements:
|
||||
- LinalgConvolutionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: I
|
||||
kind: input_tensor
|
||||
type_var: T1
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
|
||||
s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: K
|
||||
kind: input_tensor
|
||||
type_var: T2
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10, s3,
|
||||
s7, s9)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
|
||||
s1, s5, s10)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: strides
|
||||
kind: index_attr
|
||||
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
|
||||
(s2, s6)>
|
||||
default_indices:
|
||||
- 1
|
||||
- 1
|
||||
- !LinalgOperandDefConfig
|
||||
name: dilations
|
||||
kind: index_attr
|
||||
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
|
||||
(s4, s8)>
|
||||
default_indices:
|
||||
- 1
|
||||
- 1
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10] -> (d3, d4, d5, d6)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10] -> (d0, d1, d2, d3)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
- reduction
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: I
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: K
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: conv_2d_nhwc_hwcf_q
|
||||
cpp_class_name: Conv2DNhwcHwcfQOp
|
||||
|
|
|
@ -178,6 +178,38 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @conv_2d_nhwc_fhwc
|
||||
func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
// CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
|
||||
strides = dense<1> : tensor<2xi64>}
|
||||
ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @conv_2d_nhwc_fhwc_static
|
||||
func.func @conv_2d_nhwc_fhwc_static(%input: tensor<?x128x128x32xf32>, %filter: tensor<64x3x3x32xf32>, %init: tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32> {
|
||||
// CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x128x128x32xf32>, tensor<64x3x3x32xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32>
|
||||
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
|
||||
strides = dense<1> : tensor<2xi64>}
|
||||
ins (%input, %filter: tensor<?x128x128x32xf32>, tensor<64x3x3x32xf32>)
|
||||
outs (%init: tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32>
|
||||
return %0 : tensor<?x126x126x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @conv_2d_nhwc_hwcf
|
||||
func.func @conv_2d_nhwc_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
|
||||
// CHECK: linalg.conv_2d_nhwc_hwcf
|
||||
|
|
Loading…
Reference in New Issue