[mlir:PDL] Add support for creating ranges in rewrites

This commit adds support for building a concatenated range from
a given set of elements, either single element or other ranges, within a
rewrite. We could conceptually extend this to support constraining
input ranges, but the logic there is quite a bit more complex so it is
left for later work when a need arises.

Differential Revision: https://reviews.llvm.org/D133719
This commit is contained in:
River Riddle 2022-09-09 16:31:24 -07:00
parent 8c66344ee9
commit ce57789d8e
12 changed files with 363 additions and 28 deletions

View File

@ -436,6 +436,48 @@ def PDL_PatternOp : PDL_Op<"pattern", [
let hasRegionVerifier = 1;
}
//===----------------------------------------------------------------------===//
// pdl::RangeOp
//===----------------------------------------------------------------------===//
def PDL_RangeOp : PDL_Op<"range", [Pure, HasParent<"pdl::RewriteOp">]> {
let summary = "Construct a range of pdl entities";
let description = [{
`pdl.range` operations construct a range from a given set of PDL entities,
which all share the same underlying element type. For example, a
`!pdl.range<value>` may be constructed from a list of `!pdl.value`
or `!pdl.range<value>` entities.
Example:
```mlir
// Construct a range of values.
%valueRange = pdl.range %inputValue, %inputRange : !pdl.value, !pdl.range<value>
// Construct a range of types.
%typeRange = pdl.range %inputType, %inputRange : !pdl.type, !pdl.range<type>
// Construct an empty range of types.
%valueRange = pdl.range : !pdl.range<type>
```
TODO: Range construction is currently limited to rewrites, but it could
be extended to constraints under certain circustances; i.e., if we can
determine how to extract the underlying elements. If we can't, e.g. if
there are multiple sub ranges used for construction, we won't be able
to determine their sizes during constraint time.
}];
let arguments = (ins Variadic<PDL_AnyType>:$arguments);
let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result);
let assemblyFormat = [{
($arguments^ `:` type($arguments))?
custom<RangeType>(ref(type($arguments)), type($result))
attr-dict
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// pdl::ReplaceOp
//===----------------------------------------------------------------------===//

View File

@ -28,6 +28,11 @@ public:
static bool classof(Type type);
};
/// If the given type is a range, return its element type, otherwise return
/// the type itself.
Type getRangeElementTypeOrSelf(Type type);
} // namespace pdl
} // namespace mlir

View File

@ -992,6 +992,43 @@ def PDLInterp_IsNotNullOp
let assemblyFormat = "$value `:` type($value) attr-dict `->` successors";
}
//===----------------------------------------------------------------------===//
// pdl_interp::CreateRangeOp
//===----------------------------------------------------------------------===//
def PDLInterp_CreateRangeOp : PDLInterp_Op<"create_range", [Pure]> {
let summary = "Construct a range of PDL entities";
let description = [{
`pdl_interp.create_range` operations construct a range from a given set of PDL
entities, which all share the same underlying element type. For example, a
`!pdl.range<value>` may be constructed from a list of `!pdl.value`
or `!pdl.range<value>` entities.
Example:
```mlir
// Construct a range of values.
%valueRange = pdl_interp.create_range %inputValue, %inputRange : !pdl.value, !pdl.range<value>
// Construct a range of types.
%typeRange = pdl_interp.create_range %inputType, %inputRange : !pdl.type, !pdl.range<type>
// Construct an empty range of types.
%valueRange = pdl_interp.create_range : !pdl.range<type>
```
}];
let arguments = (ins Variadic<PDL_AnyType>:$arguments);
let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result);
let assemblyFormat = [{
($arguments^ `:` type($arguments))?
custom<RangeType>(ref(type($arguments)), type($result))
attr-dict
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// pdl_interp::RecordMatchOp
//===----------------------------------------------------------------------===//

View File

@ -89,6 +89,9 @@ private:
void generateRewriter(pdl::OperationOp operationOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::RangeOp rangeOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::ReplaceOp replaceOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
@ -668,8 +671,8 @@ SymbolRefAttr PatternLowering::generateRewriter(
for (Operation &rewriteOp : *rewriter.getBody()) {
llvm::TypeSwitch<Operation *>(&rewriteOp)
.Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
pdl::TypeOp, pdl::TypesOp>([&](auto op) {
pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
this->generateRewriter(op, rewriteValues, mapRewriteValue);
});
}
@ -775,6 +778,16 @@ void PatternLowering::generateRewriter(
}
}
void PatternLowering::generateRewriter(
pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 4> replOperands;
for (Value operand : rangeOp.getArguments())
replOperands.push_back(mapRewriteValue(operand));
rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
rangeOp.getLoc(), rangeOp.getType(), replOperands);
}
void PatternLowering::generateRewriter(
pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {

View File

@ -397,6 +397,39 @@ StringRef PatternOp::getDefaultDialect() {
return PDLDialect::getDialectNamespace();
}
//===----------------------------------------------------------------------===//
// pdl::RangeOp
//===----------------------------------------------------------------------===//
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
Type &resultType) {
// If arguments were provided, infer the result type from the argument list.
if (!argumentTypes.empty()) {
resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0]));
return success();
}
// Otherwise, parse the type as a trailing type.
return p.parseColonType(resultType);
}
static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes,
Type resultType) {
if (argumentTypes.empty())
p << ": " << resultType;
}
LogicalResult RangeOp::verify() {
Type elementType = getType().getElementType();
for (Type operandType : getOperandTypes()) {
Type operandElementType = getRangeElementTypeOrSelf(operandType);
if (operandElementType != elementType) {
return emitOpError("expected operand to have element type ")
<< elementType << ", but got " << operandElementType;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// pdl::ReplaceOp
//===----------------------------------------------------------------------===//

View File

@ -59,6 +59,12 @@ bool PDLType::classof(Type type) {
return llvm::isa<PDLDialect>(type.getDialect());
}
Type pdl::getRangeElementTypeOrSelf(Type type) {
if (auto rangeType = type.dyn_cast<RangeType>())
return rangeType.getElementType();
return type;
}
//===----------------------------------------------------------------------===//
// RangeType
//===----------------------------------------------------------------------===//

View File

@ -237,6 +237,40 @@ static Type getGetValueTypeOpValueType(Type type) {
return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
}
//===----------------------------------------------------------------------===//
// pdl::CreateRangeOp
//===----------------------------------------------------------------------===//
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
Type &resultType) {
// If arguments were provided, infer the result type from the argument list.
if (!argumentTypes.empty()) {
resultType =
pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
return success();
}
// Otherwise, parse the type as a trailing type.
return p.parseColonType(resultType);
}
static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
TypeRange argumentTypes, Type resultType) {
if (argumentTypes.empty())
p << ": " << resultType;
}
LogicalResult CreateRangeOp::verify() {
Type elementType = getType().getElementType();
for (Type operandType : getOperandTypes()) {
Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
if (operandElementType != elementType) {
return emitOpError("expected operand to have element type ")
<< elementType << ", but got " << operandElementType;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// pdl_interp::SwitchAttributeOp
//===----------------------------------------------------------------------===//

View File

@ -99,10 +99,14 @@ enum OpCode : ByteCodeField {
CheckTypes,
/// Continue to the next iteration of a loop.
Continue,
/// Create a type range from a list of constant types.
CreateConstantTypeRange,
/// Create an operation.
CreateOperation,
/// Create a range of types.
CreateTypes,
/// Create a type range from a list of dynamic types.
CreateDynamicTypeRange,
/// Create a value range.
CreateDynamicValueRange,
/// Erase an operation.
EraseOp,
/// Extract the op from a range at the specified index.
@ -265,6 +269,7 @@ private:
void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
@ -742,9 +747,9 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
@ -863,12 +868,24 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
else
writer.appendPDLValueList(op.getInputResultTypes());
}
void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
// Append the correct opcode for the range type.
TypeSwitch<Type>(op.getType().getElementType())
.Case(
[&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
.Case([&](pdl::ValueType) {
writer.append(OpCode::CreateDynamicValueRange);
});
writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
writer.appendPDLValueList(op->getOperands());
}
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
getMemIndex(op.getResult()) = getMemIndex(op.getValue());
}
void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
writer.append(OpCode::CreateTypes, op.getResult(),
writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
getRangeStorageIndex(op.getResult()), op.getValue());
}
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
@ -1103,9 +1120,11 @@ private:
void executeCheckResultCount();
void executeCheckTypes();
void executeContinue();
void executeCreateConstantTypeRange();
void executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc);
void executeCreateTypes();
template <typename T>
void executeDynamicCreateRange(StringRef type);
void executeEraseOp(PatternRewriter &rewriter);
template <typename T, typename Range, PDLValue::Kind kind>
void executeExtract();
@ -1172,8 +1191,18 @@ private:
}
/// Read a list of values from the bytecode buffer. The values may be encoded
/// as either Value or ValueRange elements.
void readValueList(SmallVectorImpl<Value> &list) {
/// either as a single element or a range of elements.
void readList(SmallVectorImpl<Type> &list) {
for (unsigned i = 0, e = read(); i != e; ++i) {
if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
list.push_back(read<Type>());
} else {
TypeRange *values = read<TypeRange *>();
list.append(values->begin(), values->end());
}
}
}
void readList(SmallVectorImpl<Value> &list) {
for (unsigned i = 0, e = read(); i != e; ++i) {
if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
list.push_back(read<Value>());
@ -1292,6 +1321,39 @@ private:
return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
}
/// Assign the given range to the given memory index. This allocates a new
/// range object if necessary.
template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
void assignRangeToMemory(RangeT &&range, unsigned memIndex,
unsigned rangeIndex) {
// Utility functor used to type-erase the assignment.
auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
// If the input range is empty, we don't need to allocate anything.
if (range.empty()) {
rangeMemory[rangeIndex] = {};
} else {
// Allocate a buffer for this type range.
llvm::OwningArrayRef<T> storage(llvm::size(range));
llvm::copy(range, storage.begin());
// Assign this to the range slot and use the range as the value for the
// memory index.
allocatedRangeMemory.emplace_back(std::move(storage));
rangeMemory[rangeIndex] = allocatedRangeMemory.back();
}
memory[memIndex] = &rangeMemory[rangeIndex];
};
// Dispatch based on the concrete range type.
if constexpr (std::is_same_v<T, Type>) {
return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
} else if constexpr (std::is_same_v<T, Value>) {
return assignRange(allocatedValueRangeMemory, valueRangeMemory);
} else {
llvm_unreachable("unhandled range type");
}
}
/// The underlying bytecode buffer.
const ByteCodeField *curCodeIt;
@ -1514,23 +1576,15 @@ void ByteCodeExecutor::executeContinue() {
popCodeIt();
}
void ByteCodeExecutor::executeCreateTypes() {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
void ByteCodeExecutor::executeCreateConstantTypeRange() {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
unsigned memIndex = read();
unsigned rangeIndex = read();
ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
// Allocate a buffer for this type range.
llvm::OwningArrayRef<Type> storage(typesAttr.size());
llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
allocatedTypeRangeMemory.emplace_back(std::move(storage));
// Assign this to the range slot and use the range as the value for the
// memory index.
typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
memory[memIndex] = &typeRangeMemory[rangeIndex];
assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
rangeIndex);
}
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
@ -1539,7 +1593,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
unsigned memIndex = read();
OperationState state(mainRewriteLoc, read<OperationName>());
readValueList(state.operands);
readList(state.operands);
for (unsigned i = 0, e = read(); i != e; ++i) {
StringAttr name = read<StringAttr>();
if (Attribute attr = read<Attribute>())
@ -1587,6 +1641,23 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
});
}
template <typename T>
void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
unsigned memIndex = read();
unsigned rangeIndex = read();
SmallVector<T> values;
readList(values);
LLVM_DEBUG({
llvm::dbgs() << "\n * " << type << "s: ";
llvm::interleaveComma(values, llvm::dbgs());
llvm::dbgs() << "\n";
});
assignRangeToMemory(values, memIndex, rangeIndex);
}
void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
Operation *op = read<Operation *>();
@ -1949,7 +2020,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
Operation *op = read<Operation *>();
SmallVector<Value, 16> args;
readValueList(args);
readList(args);
LLVM_DEBUG({
llvm::dbgs() << " * Operation: " << *op << "\n"
@ -2076,11 +2147,17 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
case Continue:
executeContinue();
break;
case CreateConstantTypeRange:
executeCreateConstantTypeRange();
break;
case CreateOperation:
executeCreateOperation(rewriter, *mainRewriteLoc);
break;
case CreateTypes:
executeCreateTypes();
case CreateDynamicTypeRange:
executeDynamicCreateRange<Type>("Type");
break;
case CreateDynamicValueRange:
executeDynamicCreateRange<Value>("Value");
break;
case EraseOp:
executeEraseOp(rewriter);

View File

@ -243,3 +243,20 @@ module @unbound_rewrite_op {
}
// -----
// CHECK-LABEL: module @range_op
module @range_op {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value)
// CHECK: %[[RANGE1:.*]] = pdl_interp.create_range : !pdl.range<value>
// CHECK: %[[RANGE2:.*]] = pdl_interp.create_range %[[OPERAND]], %[[RANGE1]] : !pdl.value, !pdl.range<value>
// CHECK: pdl_interp.finalize
pdl.pattern : benefit(1) {
%operand = pdl.operand
%root = operation "foo.op"(%operand : !pdl.value)
rewrite %root {
%emptyRange = pdl.range : !pdl.range<value>
%range = pdl.range %operand, %emptyRange : !pdl.value, !pdl.range<value>
}
}
}

View File

@ -237,6 +237,23 @@ pdl.pattern : benefit(1) {
// -----
//===----------------------------------------------------------------------===//
// pdl::RangeOp
//===----------------------------------------------------------------------===//
pdl.pattern : benefit(1) {
%operand = pdl.operand
%resultType = pdl.type
%root = pdl.operation "baz.op"(%operand : !pdl.value) -> (%resultType : !pdl.type)
rewrite %root {
// expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}}
%range = pdl.range %operand, %resultType : !pdl.value, !pdl.type
}
}
// -----
//===----------------------------------------------------------------------===//
// pdl::ResultsOp
//===----------------------------------------------------------------------===//

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
//===----------------------------------------------------------------------===//
// pdl::CreateOperationOp
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//
pdl_interp.func @rewriter() {
@ -23,3 +23,15 @@ pdl_interp.func @rewriter() {
} : (!pdl.type) -> (!pdl.operation)
pdl_interp.finalize
}
// -----
//===----------------------------------------------------------------------===//
// pdl_interp::CreateRangeOp
//===----------------------------------------------------------------------===//
pdl_interp.func @rewriter(%value: !pdl.value, %type: !pdl.type) {
// expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}}
%range = pdl_interp.create_range %value, %type : !pdl.value, !pdl.type
pdl_interp.finalize
}

View File

@ -568,6 +568,48 @@ module @ir attributes { test.create_op_infer_results } {
// -----
//===----------------------------------------------------------------------===//
// pdl_interp::CreateRangeOp
//===----------------------------------------------------------------------===//
module @patterns {
pdl_interp.func @matcher(%root : !pdl.operation) {
pdl_interp.check_operand_count of %root is 2 -> ^pat1, ^end
^pat1:
pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
^end:
pdl_interp.finalize
}
module @rewriters {
pdl_interp.func @success(%root: !pdl.operation) {
%rootOperand = pdl_interp.get_operand 0 of %root
%rootOperands = pdl_interp.get_operands of %root : !pdl.range<value>
%operandRange = pdl_interp.create_range %rootOperand, %rootOperands : !pdl.value, !pdl.range<value>
%operandType = pdl_interp.get_value_type of %rootOperand : !pdl.type
%operandTypes = pdl_interp.get_value_type of %rootOperands : !pdl.range<type>
%typeRange = pdl_interp.create_range %operandType, %operandTypes : !pdl.type, !pdl.range<type>
%op = pdl_interp.create_operation "test.success"(%operandRange : !pdl.range<value>) -> (%typeRange : !pdl.range<type>)
pdl_interp.erase %root
pdl_interp.finalize
}
}
}
// CHECK-LABEL: test.create_range_1
// CHECK: %[[INPUTS:.*]]:2 = "test.input"()
// CHECK: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#0, %[[INPUTS]]#1) : (i32, i32, i32) -> (i32, i32, i32)
module @ir attributes { test.create_range_1 } {
%values:2 = "test.input"() : () -> (i32, i32)
"test.op"(%values#0, %values#1) : (i32, i32) -> ()
}
// -----
//===----------------------------------------------------------------------===//
// pdl_interp::CreateTypeOp
//===----------------------------------------------------------------------===//