[mlir:PDL] Add support for creating ranges in rewrites
This commit adds support for building a concatenated range from a given set of elements, either single element or other ranges, within a rewrite. We could conceptually extend this to support constraining input ranges, but the logic there is quite a bit more complex so it is left for later work when a need arises. Differential Revision: https://reviews.llvm.org/D133719
This commit is contained in:
parent
8c66344ee9
commit
ce57789d8e
|
@ -436,6 +436,48 @@ def PDL_PatternOp : PDL_Op<"pattern", [
|
|||
let hasRegionVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl::RangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PDL_RangeOp : PDL_Op<"range", [Pure, HasParent<"pdl::RewriteOp">]> {
|
||||
let summary = "Construct a range of pdl entities";
|
||||
let description = [{
|
||||
`pdl.range` operations construct a range from a given set of PDL entities,
|
||||
which all share the same underlying element type. For example, a
|
||||
`!pdl.range<value>` may be constructed from a list of `!pdl.value`
|
||||
or `!pdl.range<value>` entities.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Construct a range of values.
|
||||
%valueRange = pdl.range %inputValue, %inputRange : !pdl.value, !pdl.range<value>
|
||||
|
||||
// Construct a range of types.
|
||||
%typeRange = pdl.range %inputType, %inputRange : !pdl.type, !pdl.range<type>
|
||||
|
||||
// Construct an empty range of types.
|
||||
%valueRange = pdl.range : !pdl.range<type>
|
||||
```
|
||||
|
||||
TODO: Range construction is currently limited to rewrites, but it could
|
||||
be extended to constraints under certain circustances; i.e., if we can
|
||||
determine how to extract the underlying elements. If we can't, e.g. if
|
||||
there are multiple sub ranges used for construction, we won't be able
|
||||
to determine their sizes during constraint time.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<PDL_AnyType>:$arguments);
|
||||
let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result);
|
||||
let assemblyFormat = [{
|
||||
($arguments^ `:` type($arguments))?
|
||||
custom<RangeType>(ref(type($arguments)), type($result))
|
||||
attr-dict
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl::ReplaceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -28,6 +28,11 @@ public:
|
|||
|
||||
static bool classof(Type type);
|
||||
};
|
||||
|
||||
/// If the given type is a range, return its element type, otherwise return
|
||||
/// the type itself.
|
||||
Type getRangeElementTypeOrSelf(Type type);
|
||||
|
||||
} // namespace pdl
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -992,6 +992,43 @@ def PDLInterp_IsNotNullOp
|
|||
let assemblyFormat = "$value `:` type($value) attr-dict `->` successors";
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl_interp::CreateRangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PDLInterp_CreateRangeOp : PDLInterp_Op<"create_range", [Pure]> {
|
||||
let summary = "Construct a range of PDL entities";
|
||||
let description = [{
|
||||
`pdl_interp.create_range` operations construct a range from a given set of PDL
|
||||
entities, which all share the same underlying element type. For example, a
|
||||
`!pdl.range<value>` may be constructed from a list of `!pdl.value`
|
||||
or `!pdl.range<value>` entities.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Construct a range of values.
|
||||
%valueRange = pdl_interp.create_range %inputValue, %inputRange : !pdl.value, !pdl.range<value>
|
||||
|
||||
// Construct a range of types.
|
||||
%typeRange = pdl_interp.create_range %inputType, %inputRange : !pdl.type, !pdl.range<type>
|
||||
|
||||
// Construct an empty range of types.
|
||||
%valueRange = pdl_interp.create_range : !pdl.range<type>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<PDL_AnyType>:$arguments);
|
||||
let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result);
|
||||
let assemblyFormat = [{
|
||||
($arguments^ `:` type($arguments))?
|
||||
custom<RangeType>(ref(type($arguments)), type($result))
|
||||
attr-dict
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl_interp::RecordMatchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -89,6 +89,9 @@ private:
|
|||
void generateRewriter(pdl::OperationOp operationOp,
|
||||
DenseMap<Value, Value> &rewriteValues,
|
||||
function_ref<Value(Value)> mapRewriteValue);
|
||||
void generateRewriter(pdl::RangeOp rangeOp,
|
||||
DenseMap<Value, Value> &rewriteValues,
|
||||
function_ref<Value(Value)> mapRewriteValue);
|
||||
void generateRewriter(pdl::ReplaceOp replaceOp,
|
||||
DenseMap<Value, Value> &rewriteValues,
|
||||
function_ref<Value(Value)> mapRewriteValue);
|
||||
|
@ -668,8 +671,8 @@ SymbolRefAttr PatternLowering::generateRewriter(
|
|||
for (Operation &rewriteOp : *rewriter.getBody()) {
|
||||
llvm::TypeSwitch<Operation *>(&rewriteOp)
|
||||
.Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
|
||||
pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
|
||||
pdl::TypeOp, pdl::TypesOp>([&](auto op) {
|
||||
pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
|
||||
pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
|
||||
this->generateRewriter(op, rewriteValues, mapRewriteValue);
|
||||
});
|
||||
}
|
||||
|
@ -775,6 +778,16 @@ void PatternLowering::generateRewriter(
|
|||
}
|
||||
}
|
||||
|
||||
void PatternLowering::generateRewriter(
|
||||
pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
|
||||
function_ref<Value(Value)> mapRewriteValue) {
|
||||
SmallVector<Value, 4> replOperands;
|
||||
for (Value operand : rangeOp.getArguments())
|
||||
replOperands.push_back(mapRewriteValue(operand));
|
||||
rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
|
||||
rangeOp.getLoc(), rangeOp.getType(), replOperands);
|
||||
}
|
||||
|
||||
void PatternLowering::generateRewriter(
|
||||
pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
|
||||
function_ref<Value(Value)> mapRewriteValue) {
|
||||
|
|
|
@ -397,6 +397,39 @@ StringRef PatternOp::getDefaultDialect() {
|
|||
return PDLDialect::getDialectNamespace();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl::RangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
|
||||
Type &resultType) {
|
||||
// If arguments were provided, infer the result type from the argument list.
|
||||
if (!argumentTypes.empty()) {
|
||||
resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0]));
|
||||
return success();
|
||||
}
|
||||
// Otherwise, parse the type as a trailing type.
|
||||
return p.parseColonType(resultType);
|
||||
}
|
||||
|
||||
static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes,
|
||||
Type resultType) {
|
||||
if (argumentTypes.empty())
|
||||
p << ": " << resultType;
|
||||
}
|
||||
|
||||
LogicalResult RangeOp::verify() {
|
||||
Type elementType = getType().getElementType();
|
||||
for (Type operandType : getOperandTypes()) {
|
||||
Type operandElementType = getRangeElementTypeOrSelf(operandType);
|
||||
if (operandElementType != elementType) {
|
||||
return emitOpError("expected operand to have element type ")
|
||||
<< elementType << ", but got " << operandElementType;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl::ReplaceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -59,6 +59,12 @@ bool PDLType::classof(Type type) {
|
|||
return llvm::isa<PDLDialect>(type.getDialect());
|
||||
}
|
||||
|
||||
Type pdl::getRangeElementTypeOrSelf(Type type) {
|
||||
if (auto rangeType = type.dyn_cast<RangeType>())
|
||||
return rangeType.getElementType();
|
||||
return type;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RangeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -237,6 +237,40 @@ static Type getGetValueTypeOpValueType(Type type) {
|
|||
return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl::CreateRangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
|
||||
Type &resultType) {
|
||||
// If arguments were provided, infer the result type from the argument list.
|
||||
if (!argumentTypes.empty()) {
|
||||
resultType =
|
||||
pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
|
||||
return success();
|
||||
}
|
||||
// Otherwise, parse the type as a trailing type.
|
||||
return p.parseColonType(resultType);
|
||||
}
|
||||
|
||||
static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
|
||||
TypeRange argumentTypes, Type resultType) {
|
||||
if (argumentTypes.empty())
|
||||
p << ": " << resultType;
|
||||
}
|
||||
|
||||
LogicalResult CreateRangeOp::verify() {
|
||||
Type elementType = getType().getElementType();
|
||||
for (Type operandType : getOperandTypes()) {
|
||||
Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
|
||||
if (operandElementType != elementType) {
|
||||
return emitOpError("expected operand to have element type ")
|
||||
<< elementType << ", but got " << operandElementType;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl_interp::SwitchAttributeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -99,10 +99,14 @@ enum OpCode : ByteCodeField {
|
|||
CheckTypes,
|
||||
/// Continue to the next iteration of a loop.
|
||||
Continue,
|
||||
/// Create a type range from a list of constant types.
|
||||
CreateConstantTypeRange,
|
||||
/// Create an operation.
|
||||
CreateOperation,
|
||||
/// Create a range of types.
|
||||
CreateTypes,
|
||||
/// Create a type range from a list of dynamic types.
|
||||
CreateDynamicTypeRange,
|
||||
/// Create a value range.
|
||||
CreateDynamicValueRange,
|
||||
/// Erase an operation.
|
||||
EraseOp,
|
||||
/// Extract the op from a range at the specified index.
|
||||
|
@ -265,6 +269,7 @@ private:
|
|||
void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
|
||||
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
|
||||
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
|
||||
void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
|
||||
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
|
||||
void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
|
||||
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
|
||||
|
@ -742,9 +747,9 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
|
|||
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
|
||||
pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
|
||||
pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
|
||||
pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
|
||||
pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
|
||||
pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
|
||||
pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
|
||||
pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
|
||||
pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
|
||||
pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
|
||||
pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
|
||||
pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
|
||||
|
@ -863,12 +868,24 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
|
|||
else
|
||||
writer.appendPDLValueList(op.getInputResultTypes());
|
||||
}
|
||||
void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
|
||||
// Append the correct opcode for the range type.
|
||||
TypeSwitch<Type>(op.getType().getElementType())
|
||||
.Case(
|
||||
[&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
|
||||
.Case([&](pdl::ValueType) {
|
||||
writer.append(OpCode::CreateDynamicValueRange);
|
||||
});
|
||||
|
||||
writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
|
||||
writer.appendPDLValueList(op->getOperands());
|
||||
}
|
||||
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
|
||||
// Simply repoint the memory index of the result to the constant.
|
||||
getMemIndex(op.getResult()) = getMemIndex(op.getValue());
|
||||
}
|
||||
void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
|
||||
writer.append(OpCode::CreateTypes, op.getResult(),
|
||||
writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
|
||||
getRangeStorageIndex(op.getResult()), op.getValue());
|
||||
}
|
||||
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
|
||||
|
@ -1103,9 +1120,11 @@ private:
|
|||
void executeCheckResultCount();
|
||||
void executeCheckTypes();
|
||||
void executeContinue();
|
||||
void executeCreateConstantTypeRange();
|
||||
void executeCreateOperation(PatternRewriter &rewriter,
|
||||
Location mainRewriteLoc);
|
||||
void executeCreateTypes();
|
||||
template <typename T>
|
||||
void executeDynamicCreateRange(StringRef type);
|
||||
void executeEraseOp(PatternRewriter &rewriter);
|
||||
template <typename T, typename Range, PDLValue::Kind kind>
|
||||
void executeExtract();
|
||||
|
@ -1172,8 +1191,18 @@ private:
|
|||
}
|
||||
|
||||
/// Read a list of values from the bytecode buffer. The values may be encoded
|
||||
/// as either Value or ValueRange elements.
|
||||
void readValueList(SmallVectorImpl<Value> &list) {
|
||||
/// either as a single element or a range of elements.
|
||||
void readList(SmallVectorImpl<Type> &list) {
|
||||
for (unsigned i = 0, e = read(); i != e; ++i) {
|
||||
if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
|
||||
list.push_back(read<Type>());
|
||||
} else {
|
||||
TypeRange *values = read<TypeRange *>();
|
||||
list.append(values->begin(), values->end());
|
||||
}
|
||||
}
|
||||
}
|
||||
void readList(SmallVectorImpl<Value> &list) {
|
||||
for (unsigned i = 0, e = read(); i != e; ++i) {
|
||||
if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
|
||||
list.push_back(read<Value>());
|
||||
|
@ -1292,6 +1321,39 @@ private:
|
|||
return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
|
||||
}
|
||||
|
||||
/// Assign the given range to the given memory index. This allocates a new
|
||||
/// range object if necessary.
|
||||
template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
|
||||
void assignRangeToMemory(RangeT &&range, unsigned memIndex,
|
||||
unsigned rangeIndex) {
|
||||
// Utility functor used to type-erase the assignment.
|
||||
auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
|
||||
// If the input range is empty, we don't need to allocate anything.
|
||||
if (range.empty()) {
|
||||
rangeMemory[rangeIndex] = {};
|
||||
} else {
|
||||
// Allocate a buffer for this type range.
|
||||
llvm::OwningArrayRef<T> storage(llvm::size(range));
|
||||
llvm::copy(range, storage.begin());
|
||||
|
||||
// Assign this to the range slot and use the range as the value for the
|
||||
// memory index.
|
||||
allocatedRangeMemory.emplace_back(std::move(storage));
|
||||
rangeMemory[rangeIndex] = allocatedRangeMemory.back();
|
||||
}
|
||||
memory[memIndex] = &rangeMemory[rangeIndex];
|
||||
};
|
||||
|
||||
// Dispatch based on the concrete range type.
|
||||
if constexpr (std::is_same_v<T, Type>) {
|
||||
return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
|
||||
} else if constexpr (std::is_same_v<T, Value>) {
|
||||
return assignRange(allocatedValueRangeMemory, valueRangeMemory);
|
||||
} else {
|
||||
llvm_unreachable("unhandled range type");
|
||||
}
|
||||
}
|
||||
|
||||
/// The underlying bytecode buffer.
|
||||
const ByteCodeField *curCodeIt;
|
||||
|
||||
|
@ -1514,23 +1576,15 @@ void ByteCodeExecutor::executeContinue() {
|
|||
popCodeIt();
|
||||
}
|
||||
|
||||
void ByteCodeExecutor::executeCreateTypes() {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
|
||||
void ByteCodeExecutor::executeCreateConstantTypeRange() {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
|
||||
unsigned memIndex = read();
|
||||
unsigned rangeIndex = read();
|
||||
ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
|
||||
|
||||
// Allocate a buffer for this type range.
|
||||
llvm::OwningArrayRef<Type> storage(typesAttr.size());
|
||||
llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
|
||||
allocatedTypeRangeMemory.emplace_back(std::move(storage));
|
||||
|
||||
// Assign this to the range slot and use the range as the value for the
|
||||
// memory index.
|
||||
typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
|
||||
memory[memIndex] = &typeRangeMemory[rangeIndex];
|
||||
assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
|
||||
rangeIndex);
|
||||
}
|
||||
|
||||
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
|
||||
|
@ -1539,7 +1593,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
|
|||
|
||||
unsigned memIndex = read();
|
||||
OperationState state(mainRewriteLoc, read<OperationName>());
|
||||
readValueList(state.operands);
|
||||
readList(state.operands);
|
||||
for (unsigned i = 0, e = read(); i != e; ++i) {
|
||||
StringAttr name = read<StringAttr>();
|
||||
if (Attribute attr = read<Attribute>())
|
||||
|
@ -1587,6 +1641,23 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
|
|||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
|
||||
unsigned memIndex = read();
|
||||
unsigned rangeIndex = read();
|
||||
SmallVector<T> values;
|
||||
readList(values);
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "\n * " << type << "s: ";
|
||||
llvm::interleaveComma(values, llvm::dbgs());
|
||||
llvm::dbgs() << "\n";
|
||||
});
|
||||
|
||||
assignRangeToMemory(values, memIndex, rangeIndex);
|
||||
}
|
||||
|
||||
void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
|
||||
Operation *op = read<Operation *>();
|
||||
|
@ -1949,7 +2020,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
|
|||
LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
|
||||
Operation *op = read<Operation *>();
|
||||
SmallVector<Value, 16> args;
|
||||
readValueList(args);
|
||||
readList(args);
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << " * Operation: " << *op << "\n"
|
||||
|
@ -2076,11 +2147,17 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
|
|||
case Continue:
|
||||
executeContinue();
|
||||
break;
|
||||
case CreateConstantTypeRange:
|
||||
executeCreateConstantTypeRange();
|
||||
break;
|
||||
case CreateOperation:
|
||||
executeCreateOperation(rewriter, *mainRewriteLoc);
|
||||
break;
|
||||
case CreateTypes:
|
||||
executeCreateTypes();
|
||||
case CreateDynamicTypeRange:
|
||||
executeDynamicCreateRange<Type>("Type");
|
||||
break;
|
||||
case CreateDynamicValueRange:
|
||||
executeDynamicCreateRange<Value>("Value");
|
||||
break;
|
||||
case EraseOp:
|
||||
executeEraseOp(rewriter);
|
||||
|
|
|
@ -243,3 +243,20 @@ module @unbound_rewrite_op {
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module @range_op
|
||||
module @range_op {
|
||||
// CHECK: module @rewriters
|
||||
// CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value)
|
||||
// CHECK: %[[RANGE1:.*]] = pdl_interp.create_range : !pdl.range<value>
|
||||
// CHECK: %[[RANGE2:.*]] = pdl_interp.create_range %[[OPERAND]], %[[RANGE1]] : !pdl.value, !pdl.range<value>
|
||||
// CHECK: pdl_interp.finalize
|
||||
pdl.pattern : benefit(1) {
|
||||
%operand = pdl.operand
|
||||
%root = operation "foo.op"(%operand : !pdl.value)
|
||||
rewrite %root {
|
||||
%emptyRange = pdl.range : !pdl.range<value>
|
||||
%range = pdl.range %operand, %emptyRange : !pdl.value, !pdl.range<value>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -237,6 +237,23 @@ pdl.pattern : benefit(1) {
|
|||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl::RangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
pdl.pattern : benefit(1) {
|
||||
%operand = pdl.operand
|
||||
%resultType = pdl.type
|
||||
%root = pdl.operation "baz.op"(%operand : !pdl.value) -> (%resultType : !pdl.type)
|
||||
|
||||
rewrite %root {
|
||||
// expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}}
|
||||
%range = pdl.range %operand, %resultType : !pdl.value, !pdl.type
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl::ResultsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl::CreateOperationOp
|
||||
// pdl_interp::CreateOperationOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
pdl_interp.func @rewriter() {
|
||||
|
@ -23,3 +23,15 @@ pdl_interp.func @rewriter() {
|
|||
} : (!pdl.type) -> (!pdl.operation)
|
||||
pdl_interp.finalize
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl_interp::CreateRangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
pdl_interp.func @rewriter(%value: !pdl.value, %type: !pdl.type) {
|
||||
// expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}}
|
||||
%range = pdl_interp.create_range %value, %type : !pdl.value, !pdl.type
|
||||
pdl_interp.finalize
|
||||
}
|
||||
|
|
|
@ -568,6 +568,48 @@ module @ir attributes { test.create_op_infer_results } {
|
|||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl_interp::CreateRangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
module @patterns {
|
||||
pdl_interp.func @matcher(%root : !pdl.operation) {
|
||||
pdl_interp.check_operand_count of %root is 2 -> ^pat1, ^end
|
||||
|
||||
^pat1:
|
||||
pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
|
||||
|
||||
^end:
|
||||
pdl_interp.finalize
|
||||
}
|
||||
|
||||
module @rewriters {
|
||||
pdl_interp.func @success(%root: !pdl.operation) {
|
||||
%rootOperand = pdl_interp.get_operand 0 of %root
|
||||
%rootOperands = pdl_interp.get_operands of %root : !pdl.range<value>
|
||||
%operandRange = pdl_interp.create_range %rootOperand, %rootOperands : !pdl.value, !pdl.range<value>
|
||||
|
||||
%operandType = pdl_interp.get_value_type of %rootOperand : !pdl.type
|
||||
%operandTypes = pdl_interp.get_value_type of %rootOperands : !pdl.range<type>
|
||||
%typeRange = pdl_interp.create_range %operandType, %operandTypes : !pdl.type, !pdl.range<type>
|
||||
|
||||
%op = pdl_interp.create_operation "test.success"(%operandRange : !pdl.range<value>) -> (%typeRange : !pdl.range<type>)
|
||||
pdl_interp.erase %root
|
||||
pdl_interp.finalize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test.create_range_1
|
||||
// CHECK: %[[INPUTS:.*]]:2 = "test.input"()
|
||||
// CHECK: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#0, %[[INPUTS]]#1) : (i32, i32, i32) -> (i32, i32, i32)
|
||||
module @ir attributes { test.create_range_1 } {
|
||||
%values:2 = "test.input"() : () -> (i32, i32)
|
||||
"test.op"(%values#0, %values#1) : (i32, i32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl_interp::CreateTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue