[mlir][Vector] Vector transform skeleton

Differential Revision: https://reviews.llvm.org/D134722
This commit is contained in:
Nicolas Vasilache 2022-06-21 06:52:01 -07:00
parent 5577207d6d
commit c4ce8a40fa
9 changed files with 198 additions and 2 deletions

View File

@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Interfaces)
add_subdirectory(TransformOps)
add_subdirectory(Transforms)

View File

@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS VectorTransformOps.td)
mlir_tablegen(VectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(VectorTransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRVectorTransformOpsIncGen)
add_mlir_doc(VectorTransformOps VectorTransformOps Dialects/ -gen-op-doc)

View File

@ -0,0 +1,37 @@
//===- VectorTransformOps.h - Vector transform ops --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
#define MLIR_DIALECT_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
namespace mlir {
namespace vector {
class VectorOp;
} // namespace vector
} // namespace mlir
//===----------------------------------------------------------------------===//
// Vector Transform Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h.inc"
namespace mlir {
class DialectRegistry;
namespace vector {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H

View File

@ -0,0 +1,19 @@
//===- VectorTransformOps.td - Vector transform ops --------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef VECTOR_TRANSFORM_OPS
#define VECTOR_TRANSFORM_OPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
#endif // VECTOR_TRANSFORM_OPS

View File

@ -66,6 +66,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/Dialect.h"
@ -115,12 +116,13 @@ inline void registerAllDialects(DialectRegistry &registry) {
// clang-format on
// Register all dialect extensions.
affine::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
gpu::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
memref::registerTransformDialectExtension(registry);
scf::registerTransformDialectExtension(registry);
gpu::registerTransformDialectExtension(registry);
affine::registerTransformDialectExtension(registry);
vector::registerTransformDialectExtension(registry);
// Register all external models.
arith::registerBufferizableOpInterfaceExternalModels(registry);

View File

@ -1,4 +1,5 @@
add_subdirectory(IR)
add_subdirectory(Interfaces)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
add_subdirectory(Utils)

View File

@ -0,0 +1,19 @@
add_mlir_dialect_library(MLIRVectorTransformOps
VectorTransformOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/TransformOps
DEPENDS
MLIRVectorTransformOpsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRVectorDialect
MLIRVectorTransforms
MLIRParser
MLIRPDLDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRVectorDialect
)

View File

@ -0,0 +1,51 @@
//===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::transform;
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
/// Registers new ops and declares PDL as dependent dialect since the additional
/// ops are using PDL types for operands and results.
class VectorTransformDialectExtension
: public transform::TransformDialectExtension<
VectorTransformDialectExtension> {
public:
VectorTransformDialectExtension() {
declareDependentDialect<pdl::PDLDialect>();
declareDependentDialect<vector::VectorDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
void mlir::vector::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<VectorTransformDialectExtension>();
}

View File

@ -3349,6 +3349,32 @@ cc_library(
],
)
cc_library(
name = "VectorTransformOps",
srcs = [
"lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp",
],
hdrs = [
"include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h",
],
includes = ["include"],
deps = [
":AffineDialect",
":ArithDialect",
":AsmParser",
":IR",
":VectorDialect",
":VectorTransformOpsIncGen",
":VectorTransforms",
":PDLDialect",
":Parser",
":SideEffectInterfaces",
":TransformDialect",
":TransformUtils",
"//llvm:Support",
],
)
gentbl_cc_library(
name = "VectorPassIncGen",
strip_include_prefix = "include",
@ -6798,6 +6824,7 @@ cc_library(
":VectorToLLVM",
":VectorToSCF",
":VectorToSPIRV",
":VectorTransformOps",
":VectorTransforms",
":X86VectorDialect",
":X86VectorTransforms",
@ -8373,6 +8400,19 @@ td_library(
],
)
td_library(
name = "VectorTransformOpsTdFiles",
srcs = [
"include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td",
],
includes = ["include"],
deps = [
":PDLDialectTdFiles",
":SCFTdFiles",
":TransformDialectTdFiles",
],
)
gentbl_cc_library(
name = "MaskableOpInterfaceIncGen",
strip_include_prefix = "include",
@ -8461,6 +8501,26 @@ gentbl_cc_library(
deps = [":VectorOpsTdFiles"],
)
gentbl_cc_library(
name = "VectorTransformOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-decls"],
"include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h.inc",
),
(
["-gen-op-defs"],
"include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td",
deps = [
":VectorTransformOpsTdFiles",
],
)
cc_library(
name = "MaskableOpInterface",
srcs = ["lib/Dialect/Vector/Interfaces/MaskableOpInterface.cpp"],