316 lines
12 KiB
C++
316 lines
12 KiB
C++
//===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
|
|
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/FunctionImplementation.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::pdl_interp;
|
|
|
|
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PDLInterp Dialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void PDLInterpDialect::initialize() {
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
|
|
>();
|
|
}
|
|
|
|
template <typename OpT>
|
|
static LogicalResult verifySwitchOp(OpT op) {
|
|
// Verify that the number of case destinations matches the number of case
|
|
// values.
|
|
size_t numDests = op.getCases().size();
|
|
size_t numValues = op.getCaseValues().size();
|
|
if (numDests != numValues) {
|
|
return op.emitOpError(
|
|
"expected number of cases to match the number of case "
|
|
"values, got ")
|
|
<< numDests << " but expected " << numValues;
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::CreateOperationOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult CreateOperationOp::verify() {
|
|
if (!getInferredResultTypes())
|
|
return success();
|
|
if (!getInputResultTypes().empty()) {
|
|
return emitOpError("with inferred results cannot also have "
|
|
"explicit result types");
|
|
}
|
|
OperationName opName(getName(), getContext());
|
|
if (!opName.hasInterface<InferTypeOpInterface>()) {
|
|
return emitOpError()
|
|
<< "has inferred results, but the created operation '" << opName
|
|
<< "' does not support result type inference (or is not "
|
|
"registered)";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseCreateOperationOpAttributes(
|
|
OpAsmParser &p,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
|
|
ArrayAttr &attrNamesAttr) {
|
|
Builder &builder = p.getBuilder();
|
|
SmallVector<Attribute, 4> attrNames;
|
|
if (succeeded(p.parseOptionalLBrace())) {
|
|
auto parseOperands = [&]() {
|
|
StringAttr nameAttr;
|
|
OpAsmParser::UnresolvedOperand operand;
|
|
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
|
|
p.parseOperand(operand))
|
|
return failure();
|
|
attrNames.push_back(nameAttr);
|
|
attrOperands.push_back(operand);
|
|
return success();
|
|
};
|
|
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
|
|
return failure();
|
|
}
|
|
attrNamesAttr = builder.getArrayAttr(attrNames);
|
|
return success();
|
|
}
|
|
|
|
static void printCreateOperationOpAttributes(OpAsmPrinter &p,
|
|
CreateOperationOp op,
|
|
OperandRange attrArgs,
|
|
ArrayAttr attrNames) {
|
|
if (attrNames.empty())
|
|
return;
|
|
p << " {";
|
|
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
|
|
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
|
|
p << '}';
|
|
}
|
|
|
|
static ParseResult parseCreateOperationOpResults(
|
|
OpAsmParser &p,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands,
|
|
SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
|
|
if (failed(p.parseOptionalArrow()))
|
|
return success();
|
|
|
|
// Handle the case of inferred results.
|
|
if (succeeded(p.parseOptionalLess())) {
|
|
if (p.parseKeyword("inferred") || p.parseGreater())
|
|
return failure();
|
|
inferredResultTypes = p.getBuilder().getUnitAttr();
|
|
return success();
|
|
}
|
|
|
|
// Otherwise, parse the explicit results.
|
|
return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
|
|
p.parseColonTypeList(resultTypes) || p.parseRParen());
|
|
}
|
|
|
|
static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
|
|
OperandRange resultOperands,
|
|
TypeRange resultTypes,
|
|
UnitAttr inferredResultTypes) {
|
|
// Handle the case of inferred results.
|
|
if (inferredResultTypes) {
|
|
p << " -> <inferred>";
|
|
return;
|
|
}
|
|
|
|
// Otherwise, handle the explicit results.
|
|
if (!resultTypes.empty())
|
|
p << " -> (" << resultOperands << " : " << resultTypes << ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::ForEachOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
|
Value range, Block *successor, bool initLoop) {
|
|
build(builder, state, range, successor);
|
|
if (initLoop) {
|
|
// Create the block and the loop variable.
|
|
// FIXME: Allow passing in a proper location for the loop variable.
|
|
auto rangeType = range.getType().cast<pdl::RangeType>();
|
|
state.regions.front()->emplaceBlock();
|
|
state.regions.front()->addArgument(rangeType.getElementType(),
|
|
state.location);
|
|
}
|
|
}
|
|
|
|
ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
// Parse the loop variable followed by type.
|
|
OpAsmParser::Argument loopVariable;
|
|
OpAsmParser::UnresolvedOperand operandInfo;
|
|
if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
|
|
parser.parseKeyword("in", " after loop variable") ||
|
|
// Parse the operand (value range).
|
|
parser.parseOperand(operandInfo))
|
|
return failure();
|
|
|
|
// Resolve the operand.
|
|
Type rangeType = pdl::RangeType::get(loopVariable.type);
|
|
if (parser.resolveOperand(operandInfo, rangeType, result.operands))
|
|
return failure();
|
|
|
|
// Parse the body region.
|
|
Region *body = result.addRegion();
|
|
Block *successor;
|
|
if (parser.parseRegion(*body, loopVariable) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
// Parse the successor.
|
|
parser.parseArrow() || parser.parseSuccessor(successor))
|
|
return failure();
|
|
|
|
result.addSuccessors(successor);
|
|
return success();
|
|
}
|
|
|
|
void ForEachOp::print(OpAsmPrinter &p) {
|
|
BlockArgument arg = getLoopVariable();
|
|
p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
|
|
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
p << " -> ";
|
|
p.printSuccessor(getSuccessor());
|
|
}
|
|
|
|
LogicalResult ForEachOp::verify() {
|
|
// Verify that the operation has exactly one argument.
|
|
if (getRegion().getNumArguments() != 1)
|
|
return emitOpError("requires exactly one argument");
|
|
|
|
// Verify that the loop variable and the operand (value range)
|
|
// have compatible types.
|
|
BlockArgument arg = getLoopVariable();
|
|
Type rangeType = pdl::RangeType::get(arg.getType());
|
|
if (rangeType != getValues().getType())
|
|
return emitOpError("operand must be a range of loop variable type");
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::FuncOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
|
|
FunctionType type, ArrayRef<NamedAttribute> attrs) {
|
|
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
|
}
|
|
|
|
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
auto buildFuncType =
|
|
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
|
function_interface_impl::VariadicFlag,
|
|
std::string &) { return builder.getFunctionType(argTypes, results); };
|
|
|
|
return function_interface_impl::parseFunctionOp(
|
|
parser, result, /*allowVariadic=*/false, buildFuncType);
|
|
}
|
|
|
|
void FuncOp::print(OpAsmPrinter &p) {
|
|
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::GetValueTypeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Given the result type of a `GetValueTypeOp`, return the expected input type.
|
|
static Type getGetValueTypeOpValueType(Type type) {
|
|
Type valueTy = pdl::ValueType::get(type.getContext());
|
|
return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl::CreateRangeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
|
|
Type &resultType) {
|
|
// If arguments were provided, infer the result type from the argument list.
|
|
if (!argumentTypes.empty()) {
|
|
resultType =
|
|
pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
|
|
return success();
|
|
}
|
|
// Otherwise, parse the type as a trailing type.
|
|
return p.parseColonType(resultType);
|
|
}
|
|
|
|
static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
|
|
TypeRange argumentTypes, Type resultType) {
|
|
if (argumentTypes.empty())
|
|
p << ": " << resultType;
|
|
}
|
|
|
|
LogicalResult CreateRangeOp::verify() {
|
|
Type elementType = getType().getElementType();
|
|
for (Type operandType : getOperandTypes()) {
|
|
Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
|
|
if (operandElementType != elementType) {
|
|
return emitOpError("expected operand to have element type ")
|
|
<< elementType << ", but got " << operandElementType;
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::SwitchAttributeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::SwitchOperandCountOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::SwitchOperationNameOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::SwitchResultCountOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::SwitchTypeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// pdl_interp::SwitchTypesOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen Auto-Generated Op and Interface Definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
|