llvm-project/mlir/lib/Tools/PDLL/Parser/Parser.cpp

1706 lines
60 KiB
C++

//===- Parser.cpp ---------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "Lexer.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Types.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
#include <string>
using namespace mlir;
using namespace mlir::pdll;
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
namespace {
class Parser {
public:
Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
: ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
curToken(lexer.lexToken()), curDeclScope(nullptr),
valueTy(ast::ValueType::get(ctx)),
valueRangeTy(ast::ValueRangeType::get(ctx)),
typeTy(ast::TypeType::get(ctx)),
typeRangeTy(ast::TypeRangeType::get(ctx)) {}
/// Try to parse a new module. Returns nullptr in the case of failure.
FailureOr<ast::Module *> parseModule();
private:
/// The current context of the parser. It allows for the parser to know a bit
/// about the construct it is nested within during parsing. This is used
/// specifically to provide additional verification during parsing, e.g. to
/// prevent using rewrites within a match context, matcher constraints within
/// a rewrite section, etc.
enum class ParserContext {
/// The parser is in the global context.
Global,
/// The parser is currently within the matcher portion of a Pattern, which
/// is allows a terminal operation rewrite statement but no other rewrite
/// transformations.
PatternMatch,
/// The parser is currently within a Rewrite, which disallows calls to
/// constraints, requires operation expressions to have names, etc.
Rewrite,
};
//===--------------------------------------------------------------------===//
// Parsing
//===--------------------------------------------------------------------===//
/// Push a new decl scope onto the lexer.
ast::DeclScope *pushDeclScope() {
ast::DeclScope *newScope =
new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
return (curDeclScope = newScope);
}
void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
/// Pop the last decl scope from the lexer.
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
/// Parse the body of an AST module.
LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
/// Try to convert the given expression to `type`. Returns failure and emits
/// an error if a conversion is not viable. On failure, `noteAttachFn` is
/// invoked to attach notes to the emitted error diagnostic. On success,
/// `expr` is updated to the expression used to convert to `type`.
LogicalResult convertExpressionTo(
ast::Expr *&expr, ast::Type type,
function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
/// Given an operation expression, convert it to a Value or ValueRange
/// typed expression.
ast::Expr *convertOpToValue(const ast::Expr *opExpr);
//===--------------------------------------------------------------------===//
// Directives
LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
//===--------------------------------------------------------------------===//
// Decls
/// This structure contains the set of pattern metadata that may be parsed.
struct ParsedPatternMetadata {
Optional<uint16_t> benefit;
bool hasBoundedRecursion = false;
};
FailureOr<ast::Decl *> parseTopLevelDecl();
FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
FailureOr<ast::Decl *> parsePatternDecl();
LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
/// Check to see if a decl has already been defined with the given name, if
/// one has emit and error and return failure. Returns success otherwise.
LogicalResult checkDefineNamedDecl(const ast::Name &name);
/// Try to define a variable decl with the given components, returns the
/// variable on success.
FailureOr<ast::VariableDecl *>
defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
ast::Expr *initExpr,
ArrayRef<ast::ConstraintRef> constraints);
FailureOr<ast::VariableDecl *>
defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
ArrayRef<ast::ConstraintRef> constraints);
/// Parse the constraint reference list for a variable decl.
LogicalResult parseVariableDeclConstraintList(
SmallVectorImpl<ast::ConstraintRef> &constraints);
/// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
FailureOr<ast::Expr *> parseTypeConstraintExpr();
/// Try to parse a single reference to a constraint. `typeConstraint` is the
/// location of a previously parsed type constraint for the entity that will
/// be constrained by the parsed constraint. `existingConstraints` are any
/// existing constraints that have already been parsed for the same entity
/// that will be constrained by this constraint.
FailureOr<ast::ConstraintRef>
parseConstraint(Optional<llvm::SMRange> &typeConstraint,
ArrayRef<ast::ConstraintRef> existingConstraints);
//===--------------------------------------------------------------------===//
// Exprs
FailureOr<ast::Expr *> parseExpr();
/// Identifier expressions.
FailureOr<ast::Expr *> parseAttributeExpr();
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, llvm::SMRange loc);
FailureOr<ast::Expr *> parseIdentifierExpr();
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
FailureOr<ast::Expr *> parseOperationExpr();
FailureOr<ast::Expr *> parseTupleExpr();
FailureOr<ast::Expr *> parseTypeExpr();
FailureOr<ast::Expr *> parseUnderscoreExpr();
//===--------------------------------------------------------------------===//
// Stmts
FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
FailureOr<ast::CompoundStmt *> parseCompoundStmt();
FailureOr<ast::EraseStmt *> parseEraseStmt();
FailureOr<ast::LetStmt *> parseLetStmt();
FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
FailureOr<ast::RewriteStmt *> parseRewriteStmt();
//===--------------------------------------------------------------------===//
// Creation+Analysis
//===--------------------------------------------------------------------===//
//===--------------------------------------------------------------------===//
// Decls
/// Try to create a pattern decl with the given components, returning the
/// Pattern on success.
FailureOr<ast::PatternDecl *>
createPatternDecl(llvm::SMRange loc, const ast::Name *name,
const ParsedPatternMetadata &metadata,
ast::CompoundStmt *body);
/// Try to create a variable decl with the given components, returning the
/// Variable on success.
FailureOr<ast::VariableDecl *>
createVariableDecl(StringRef name, llvm::SMRange loc, ast::Expr *initializer,
ArrayRef<ast::ConstraintRef> constraints);
/// Validate the constraints used to constraint a variable decl.
/// `inferredType` is the type of the variable inferred by the constraints
/// within the list, and is updated to the most refined type as determined by
/// the constraints. Returns success if the constraint list is valid, failure
/// otherwise.
LogicalResult
validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
ast::Type &inferredType);
/// Validate a single reference to a constraint. `inferredType` contains the
/// currently inferred variabled type and is refined within the type defined
/// by the constraint. Returns success if the constraint is valid, failure
/// otherwise.
LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
ast::Type &inferredType);
LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
//===--------------------------------------------------------------------===//
// Exprs
FailureOr<ast::DeclRefExpr *> createDeclRefExpr(llvm::SMRange loc,
ast::Decl *decl);
FailureOr<ast::DeclRefExpr *>
createInlineVariableExpr(ast::Type type, StringRef name, llvm::SMRange loc,
ArrayRef<ast::ConstraintRef> constraints);
FailureOr<ast::MemberAccessExpr *>
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
llvm::SMRange loc);
/// Validate the member access `name` into the given parent expression. On
/// success, this also returns the type of the member accessed.
FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
StringRef name, llvm::SMRange loc);
FailureOr<ast::OperationExpr *>
createOperationExpr(llvm::SMRange loc, const ast::OpNameDecl *name,
MutableArrayRef<ast::Expr *> operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
MutableArrayRef<ast::Expr *> results);
LogicalResult
validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> operands);
LogicalResult validateOperationResults(llvm::SMRange loc,
Optional<StringRef> name,
MutableArrayRef<ast::Expr *> results);
LogicalResult
validateOperationOperandsOrResults(llvm::SMRange loc,
Optional<StringRef> name,
MutableArrayRef<ast::Expr *> values,
ast::Type singleTy, ast::Type rangeTy);
FailureOr<ast::TupleExpr *> createTupleExpr(llvm::SMRange loc,
ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames);
//===--------------------------------------------------------------------===//
// Stmts
FailureOr<ast::EraseStmt *> createEraseStmt(llvm::SMRange loc,
ast::Expr *rootOp);
FailureOr<ast::ReplaceStmt *>
createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
MutableArrayRef<ast::Expr *> replValues);
FailureOr<ast::RewriteStmt *>
createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
ast::CompoundStmt *rewriteBody);
//===--------------------------------------------------------------------===//
// Lexer Utilities
//===--------------------------------------------------------------------===//
/// If the current token has the specified kind, consume it and return true.
/// If not, return false.
bool consumeIf(Token::Kind kind) {
if (curToken.isNot(kind))
return false;
consumeToken(kind);
return true;
}
/// Advance the current lexer onto the next token.
void consumeToken() {
assert(curToken.isNot(Token::eof, Token::error) &&
"shouldn't advance past EOF or errors");
curToken = lexer.lexToken();
}
/// Advance the current lexer onto the next token, asserting what the expected
/// current token is. This is preferred to the above method because it leads
/// to more self-documenting code with better checking.
void consumeToken(Token::Kind kind) {
assert(curToken.is(kind) && "consumed an unexpected token");
consumeToken();
}
/// Reset the lexer to the location at the given position.
void resetToken(llvm::SMRange tokLoc) {
lexer.resetPointer(tokLoc.Start.getPointer());
curToken = lexer.lexToken();
}
/// Consume the specified token if present and return success. On failure,
/// output a diagnostic and return failure.
LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
if (curToken.getKind() != kind)
return emitError(curToken.getLoc(), msg);
consumeToken();
return success();
}
LogicalResult emitError(llvm::SMRange loc, const Twine &msg) {
lexer.emitError(loc, msg);
return failure();
}
LogicalResult emitError(const Twine &msg) {
return emitError(curToken.getLoc(), msg);
}
LogicalResult emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
llvm::SMRange noteLoc, const Twine &note) {
lexer.emitErrorAndNote(loc, msg, noteLoc, note);
return failure();
}
//===--------------------------------------------------------------------===//
// Fields
//===--------------------------------------------------------------------===//
/// The owning AST context.
ast::Context &ctx;
/// The lexer of this parser.
Lexer lexer;
/// The current token within the lexer.
Token curToken;
/// The most recently defined decl scope.
ast::DeclScope *curDeclScope;
llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
/// The current context of the parser.
ParserContext parserContext = ParserContext::Global;
/// Cached types to simplify verification and expression creation.
ast::Type valueTy, valueRangeTy;
ast::Type typeTy, typeRangeTy;
};
} // namespace
FailureOr<ast::Module *> Parser::parseModule() {
llvm::SMLoc moduleLoc = curToken.getStartLoc();
pushDeclScope();
// Parse the top-level decls of the module.
SmallVector<ast::Decl *> decls;
if (failed(parseModuleBody(decls)))
return popDeclScope(), failure();
popDeclScope();
return ast::Module::create(ctx, moduleLoc, decls);
}
LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
while (curToken.isNot(Token::eof)) {
if (curToken.is(Token::directive)) {
if (failed(parseDirective(decls)))
return failure();
continue;
}
FailureOr<ast::Decl *> decl = parseTopLevelDecl();
if (failed(decl))
return failure();
decls.push_back(*decl);
}
return success();
}
ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
valueRangeTy);
}
LogicalResult Parser::convertExpressionTo(
ast::Expr *&expr, ast::Type type,
function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
ast::Type exprType = expr->getType();
if (exprType == type)
return success();
auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
expr->getLoc(), llvm::formatv("unable to convert expression of type "
"`{0}` to the expected type of "
"`{1}`",
exprType, type));
if (noteAttachFn)
noteAttachFn(*diag);
return diag;
};
if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
// Two operation types are compatible if they have the same name, or if the
// expected type is more general.
if (auto opType = type.dyn_cast<ast::OperationType>()) {
if (opType.getName())
return emitConvertError();
return success();
}
// An operation can always convert to a ValueRange.
if (type == valueRangeTy) {
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
valueRangeTy);
return success();
}
// Allow conversion to a single value by constraining the result range.
if (type == valueTy) {
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
valueTy);
return success();
}
return emitConvertError();
}
// FIXME: Decide how to allow/support converting a single result to multiple,
// and multiple to a single result. For now, we just allow Single->Range,
// but this isn't something really supported in the PDL dialect. We should
// figure out some way to support both.
if ((exprType == valueTy || exprType == valueRangeTy) &&
(type == valueTy || type == valueRangeTy))
return success();
if ((exprType == typeTy || exprType == typeRangeTy) &&
(type == typeTy || type == typeRangeTy))
return success();
// Handle tuple types.
if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
auto tupleType = type.dyn_cast<ast::TupleType>();
if (!tupleType || tupleType.size() != exprTupleType.size())
return emitConvertError();
// Build a new tuple expression using each of the elements of the current
// tuple.
SmallVector<ast::Expr *> newExprs;
for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
newExprs.push_back(ast::MemberAccessExpr::create(
ctx, expr->getLoc(), expr, llvm::to_string(i),
exprTupleType.getElementTypes()[i]));
auto diagFn = [&](ast::Diagnostic &diag) {
diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
i, exprTupleType));
if (noteAttachFn)
noteAttachFn(diag);
};
if (failed(convertExpressionTo(newExprs.back(),
tupleType.getElementTypes()[i], diagFn)))
return failure();
}
expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
tupleType.getElementNames());
return success();
}
return emitConvertError();
}
//===----------------------------------------------------------------------===//
// Directives
LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
StringRef directive = curToken.getSpelling();
if (directive == "#include")
return parseInclude(decls);
return emitError("unknown directive `" + directive + "`");
}
LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::directive);
// Parse the file being included.
if (!curToken.isString())
return emitError(loc,
"expected string file name after `include` directive");
llvm::SMRange fileLoc = curToken.getLoc();
std::string filenameStr = curToken.getStringValue();
StringRef filename = filenameStr;
consumeToken();
// Check the type of include. If ending with `.pdll`, this is another pdl file
// to be parsed along with the current module.
if (filename.endswith(".pdll")) {
if (failed(lexer.pushInclude(filename)))
return emitError(fileLoc,
"unable to open include file `" + filename + "`");
// If we added the include successfully, parse it into the current module.
// Make sure to save the current token so that we can restore it when we
// finish parsing the nested file.
Token oldToken = curToken;
curToken = lexer.lexToken();
LogicalResult result = parseModuleBody(decls);
curToken = oldToken;
return result;
}
return emitError(fileLoc, "expected include filename to end with `.pdll`");
}
//===----------------------------------------------------------------------===//
// Decls
FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
FailureOr<ast::Decl *> decl;
switch (curToken.getKind()) {
case Token::kw_Pattern:
decl = parsePatternDecl();
break;
default:
return emitError("expected top-level declaration, such as a `Pattern`");
}
if (failed(decl))
return failure();
// If the decl has a name, add it to the current scope.
if (const ast::Name *name = (*decl)->getName()) {
if (failed(checkDefineNamedDecl(*name)))
return failure();
curDeclScope->add(*decl);
}
return decl;
}
FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
std::string attrNameStr;
if (curToken.isString())
attrNameStr = curToken.getStringValue();
else if (curToken.is(Token::identifier) || curToken.isKeyword())
attrNameStr = curToken.getSpelling().str();
else
return emitError("expected identifier or string attribute name");
const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
consumeToken();
// Check for a value of the attribute.
ast::Expr *attrValue = nullptr;
if (consumeIf(Token::equal)) {
FailureOr<ast::Expr *> attrExpr = parseExpr();
if (failed(attrExpr))
return failure();
attrValue = *attrExpr;
} else {
// If there isn't a concrete value, create an expression representing a
// UnitAttr.
attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
}
return ast::NamedAttributeDecl::create(ctx, name, attrValue);
}
FailureOr<ast::Decl *> Parser::parsePatternDecl() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::kw_Pattern);
llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
ParserContext::PatternMatch);
// Check for an optional identifier for the pattern name.
const ast::Name *name = nullptr;
if (curToken.is(Token::identifier)) {
name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
consumeToken(Token::identifier);
}
// Parse any pattern metadata.
ParsedPatternMetadata metadata;
if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
return failure();
// Parse the pattern body.
ast::CompoundStmt *body;
if (curToken.isNot(Token::l_brace))
return emitError("expected `{` to start pattern body");
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
if (failed(bodyResult))
return failure();
body = *bodyResult;
// Verify the body of the pattern.
auto bodyIt = body->begin(), bodyE = body->end();
for (; bodyIt != bodyE; ++bodyIt) {
// Break when we've found the rewrite statement.
if (isa<ast::OpRewriteStmt>(*bodyIt))
break;
}
if (bodyIt == bodyE) {
return emitError(loc,
"expected Pattern body to terminate with an operation "
"rewrite statement, such as `erase`");
}
if (std::next(bodyIt) != bodyE) {
return emitError((*std::next(bodyIt))->getLoc(),
"Pattern body was terminated by an operation "
"rewrite statement, but found trailing statements");
}
return createPatternDecl(loc, name, metadata, body);
}
LogicalResult
Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
Optional<llvm::SMRange> benefitLoc;
Optional<llvm::SMRange> hasBoundedRecursionLoc;
do {
if (curToken.isNot(Token::identifier))
return emitError("expected pattern metadata identifier");
StringRef metadataStr = curToken.getSpelling();
llvm::SMRange metadataLoc = curToken.getLoc();
consumeToken(Token::identifier);
// Parse the benefit metadata: benefit(<integer-value>)
if (metadataStr == "benefit") {
if (benefitLoc) {
return emitErrorAndNote(metadataLoc,
"pattern benefit has already been specified",
*benefitLoc, "see previous definition here");
}
if (failed(parseToken(Token::l_paren,
"expected `(` before pattern benefit")))
return failure();
uint16_t benefitValue = 0;
if (curToken.isNot(Token::integer))
return emitError("expected integral pattern benefit");
if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
return emitError(
"expected pattern benefit to fit within a 16-bit integer");
consumeToken(Token::integer);
metadata.benefit = benefitValue;
benefitLoc = metadataLoc;
if (failed(
parseToken(Token::r_paren, "expected `)` after pattern benefit")))
return failure();
continue;
}
// Parse the bounded recursion metadata: recursion
if (metadataStr == "recursion") {
if (hasBoundedRecursionLoc) {
return emitErrorAndNote(
metadataLoc,
"pattern recursion metadata has already been specified",
*hasBoundedRecursionLoc, "see previous definition here");
}
metadata.hasBoundedRecursion = true;
hasBoundedRecursionLoc = metadataLoc;
continue;
}
return emitError(metadataLoc, "unknown pattern metadata");
} while (consumeIf(Token::comma));
return success();
}
FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
consumeToken(Token::less);
FailureOr<ast::Expr *> typeExpr = parseExpr();
if (failed(typeExpr) ||
failed(parseToken(Token::greater,
"expected `>` after variable type constraint")))
return failure();
return typeExpr;
}
LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
assert(curDeclScope && "defining decl outside of a decl scope");
if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
return emitErrorAndNote(
name.getLoc(), "`" + name.getName() + "` has already been defined",
lastDecl->getName()->getLoc(), "see previous definition here");
}
return success();
}
FailureOr<ast::VariableDecl *>
Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
ast::Type type, ast::Expr *initExpr,
ArrayRef<ast::ConstraintRef> constraints) {
assert(curDeclScope && "defining variable outside of decl scope");
const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
// If the name of the variable indicates a special variable, we don't add it
// to the scope. This variable is local to the definition point.
if (name.empty() || name == "_") {
return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
constraints);
}
if (failed(checkDefineNamedDecl(nameDecl)))
return failure();
auto *varDecl =
ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
curDeclScope->add(varDecl);
return varDecl;
}
FailureOr<ast::VariableDecl *>
Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
ast::Type type,
ArrayRef<ast::ConstraintRef> constraints) {
return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
constraints);
}
LogicalResult Parser::parseVariableDeclConstraintList(
SmallVectorImpl<ast::ConstraintRef> &constraints) {
Optional<llvm::SMRange> typeConstraint;
auto parseSingleConstraint = [&] {
FailureOr<ast::ConstraintRef> constraint =
parseConstraint(typeConstraint, constraints);
if (failed(constraint))
return failure();
constraints.push_back(*constraint);
return success();
};
// Check to see if this is a single constraint, or a list.
if (!consumeIf(Token::l_square))
return parseSingleConstraint();
do {
if (failed(parseSingleConstraint()))
return failure();
} while (consumeIf(Token::comma));
return parseToken(Token::r_square, "expected `]` after constraint list");
}
FailureOr<ast::ConstraintRef>
Parser::parseConstraint(Optional<llvm::SMRange> &typeConstraint,
ArrayRef<ast::ConstraintRef> existingConstraints) {
auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
if (typeConstraint)
return emitErrorAndNote(
curToken.getLoc(),
"the type of this variable has already been constrained",
*typeConstraint, "see previous constraint location here");
FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
if (failed(constraintExpr))
return failure();
typeExpr = *constraintExpr;
typeConstraint = typeExpr->getLoc();
return success();
};
llvm::SMRange loc = curToken.getLoc();
switch (curToken.getKind()) {
case Token::kw_Attr: {
consumeToken(Token::kw_Attr);
// Check for a type constraint.
ast::Expr *typeExpr = nullptr;
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
return failure();
return ast::ConstraintRef(
ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
}
case Token::kw_Op: {
consumeToken(Token::kw_Op);
// Parse an optional operation name. If the name isn't provided, this refers
// to "any" operation.
FailureOr<ast::OpNameDecl *> opName =
parseWrappedOperationName(/*allowEmptyName=*/true);
if (failed(opName))
return failure();
return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
loc);
}
case Token::kw_Type:
consumeToken(Token::kw_Type);
return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
case Token::kw_TypeRange:
consumeToken(Token::kw_TypeRange);
return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
loc);
case Token::kw_Value: {
consumeToken(Token::kw_Value);
// Check for a type constraint.
ast::Expr *typeExpr = nullptr;
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
return failure();
return ast::ConstraintRef(
ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
}
case Token::kw_ValueRange: {
consumeToken(Token::kw_ValueRange);
// Check for a type constraint.
ast::Expr *typeExpr = nullptr;
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
return failure();
return ast::ConstraintRef(
ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
}
case Token::identifier: {
StringRef constraintName = curToken.getSpelling();
consumeToken(Token::identifier);
// Lookup the referenced constraint.
ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
if (!cstDecl) {
return emitError(loc, "unknown reference to constraint `" +
constraintName + "`");
}
// Handle a reference to a proper constraint.
if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
return ast::ConstraintRef(cst, loc);
return emitErrorAndNote(
loc, "invalid reference to non-constraint", cstDecl->getLoc(),
"see the definition of `" + constraintName + "` here");
}
default:
break;
}
return emitError(loc, "expected identifier constraint");
}
//===----------------------------------------------------------------------===//
// Exprs
FailureOr<ast::Expr *> Parser::parseExpr() {
if (curToken.is(Token::underscore))
return parseUnderscoreExpr();
// Parse the LHS expression.
FailureOr<ast::Expr *> lhsExpr;
switch (curToken.getKind()) {
case Token::kw_attr:
lhsExpr = parseAttributeExpr();
break;
case Token::identifier:
lhsExpr = parseIdentifierExpr();
break;
case Token::kw_op:
lhsExpr = parseOperationExpr();
break;
case Token::kw_type:
lhsExpr = parseTypeExpr();
break;
case Token::l_paren:
lhsExpr = parseTupleExpr();
break;
default:
return emitError("expected expression");
}
if (failed(lhsExpr))
return failure();
// Check for an operator expression.
while (true) {
switch (curToken.getKind()) {
case Token::dot:
lhsExpr = parseMemberAccessExpr(*lhsExpr);
break;
default:
return lhsExpr;
}
if (failed(lhsExpr))
return failure();
}
}
FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::kw_attr);
// If we aren't followed by a `<`, the `attr` keyword is treated as a normal
// identifier.
if (!consumeIf(Token::less)) {
resetToken(loc);
return parseIdentifierExpr();
}
if (!curToken.isString())
return emitError("expected string literal containing MLIR attribute");
std::string attrExpr = curToken.getStringValue();
consumeToken();
if (failed(
parseToken(Token::greater, "expected `>` after attribute literal")))
return failure();
return ast::AttributeExpr::create(ctx, loc, attrExpr);
}
FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name,
llvm::SMRange loc) {
ast::Decl *decl = curDeclScope->lookup(name);
if (!decl)
return emitError(loc, "undefined reference to `" + name + "`");
return createDeclRefExpr(loc, decl);
}
FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
StringRef name = curToken.getSpelling();
llvm::SMRange nameLoc = curToken.getLoc();
consumeToken();
// Check to see if this is a decl ref expression that defines a variable
// inline.
if (consumeIf(Token::colon)) {
SmallVector<ast::ConstraintRef> constraints;
if (failed(parseVariableDeclConstraintList(constraints)))
return failure();
ast::Type type;
if (failed(validateVariableConstraints(constraints, type)))
return failure();
return createInlineVariableExpr(type, name, nameLoc, constraints);
}
return parseDeclRefExpr(name, nameLoc);
}
FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::dot);
// Parse the member name.
Token memberNameTok = curToken;
if (memberNameTok.isNot(Token::identifier, Token::integer) &&
!memberNameTok.isKeyword())
return emitError(loc, "expected identifier or numeric member name");
StringRef memberName = memberNameTok.getSpelling();
consumeToken();
return createMemberAccessExpr(parentExpr, memberName, loc);
}
FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
llvm::SMRange loc = curToken.getLoc();
// Handle the case of an no operation name.
if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
if (allowEmptyName)
return ast::OpNameDecl::create(ctx, llvm::SMRange());
return emitError("expected dialect namespace");
}
StringRef name = curToken.getSpelling();
consumeToken();
// Otherwise, this is a literal operation name.
if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
return failure();
if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
return emitError("expected operation name after dialect namespace");
name = StringRef(name.data(), name.size() + 1);
do {
name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
loc.End = curToken.getEndLoc();
consumeToken();
} while (curToken.isAny(Token::identifier, Token::dot) ||
curToken.isKeyword());
return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
}
FailureOr<ast::OpNameDecl *>
Parser::parseWrappedOperationName(bool allowEmptyName) {
if (!consumeIf(Token::less))
return ast::OpNameDecl::create(ctx, llvm::SMRange());
FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
if (failed(opNameDecl))
return failure();
if (failed(parseToken(Token::greater, "expected `>` after operation name")))
return failure();
return opNameDecl;
}
FailureOr<ast::Expr *> Parser::parseOperationExpr() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::kw_op);
// If it isn't followed by a `<`, the `op` keyword is treated as a normal
// identifier.
if (curToken.isNot(Token::less)) {
resetToken(loc);
return parseIdentifierExpr();
}
// Parse the operation name. The name may be elided, in which case the
// operation refers to "any" operation(i.e. a difference between `MyOp` and
// `Operation*`). Operation names within a rewrite context must be named.
bool allowEmptyName = parserContext != ParserContext::Rewrite;
FailureOr<ast::OpNameDecl *> opNameDecl =
parseWrappedOperationName(allowEmptyName);
if (failed(opNameDecl))
return failure();
// Check for the optional list of operands.
SmallVector<ast::Expr *> operands;
if (consumeIf(Token::l_paren)) {
do {
FailureOr<ast::Expr *> operand = parseExpr();
if (failed(operand))
return failure();
operands.push_back(*operand);
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_paren,
"expected `)` after operation operand list")))
return failure();
}
// Check for the optional list of attributes.
SmallVector<ast::NamedAttributeDecl *> attributes;
if (consumeIf(Token::l_brace)) {
do {
FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
if (failed(decl))
return failure();
attributes.emplace_back(*decl);
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_brace,
"expected `}` after operation attribute list")))
return failure();
}
// Check for the optional list of result types.
SmallVector<ast::Expr *> resultTypes;
if (consumeIf(Token::arrow)) {
if (failed(parseToken(Token::l_paren,
"expected `(` before operation result type list")))
return failure();
do {
FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
if (failed(resultTypeExpr))
return failure();
resultTypes.push_back(*resultTypeExpr);
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_paren,
"expected `)` after operation result type list")))
return failure();
}
return createOperationExpr(loc, *opNameDecl, operands, attributes,
resultTypes);
}
FailureOr<ast::Expr *> Parser::parseTupleExpr() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::l_paren);
DenseMap<StringRef, llvm::SMRange> usedNames;
SmallVector<StringRef> elementNames;
SmallVector<ast::Expr *> elements;
if (curToken.isNot(Token::r_paren)) {
do {
// Check for the optional element name assignment before the value.
StringRef elementName;
if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
Token elementNameTok = curToken;
consumeToken();
// The element name is only present if followed by an `=`.
if (consumeIf(Token::equal)) {
elementName = elementNameTok.getSpelling();
// Check to see if this name is already used.
auto elementNameIt =
usedNames.try_emplace(elementName, elementNameTok.getLoc());
if (!elementNameIt.second) {
return emitErrorAndNote(
elementNameTok.getLoc(),
llvm::formatv("duplicate tuple element label `{0}`",
elementName),
elementNameIt.first->getSecond(),
"see previous label use here");
}
} else {
// Otherwise, we treat this as part of an expression so reset the
// lexer.
resetToken(elementNameTok.getLoc());
}
}
elementNames.push_back(elementName);
// Parse the tuple element value.
FailureOr<ast::Expr *> element = parseExpr();
if (failed(element))
return failure();
elements.push_back(*element);
} while (consumeIf(Token::comma));
}
loc.End = curToken.getEndLoc();
if (failed(
parseToken(Token::r_paren, "expected `)` after tuple element list")))
return failure();
return createTupleExpr(loc, elements, elementNames);
}
FailureOr<ast::Expr *> Parser::parseTypeExpr() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::kw_type);
// If we aren't followed by a `<`, the `type` keyword is treated as a normal
// identifier.
if (!consumeIf(Token::less)) {
resetToken(loc);
return parseIdentifierExpr();
}
if (!curToken.isString())
return emitError("expected string literal containing MLIR type");
std::string attrExpr = curToken.getStringValue();
consumeToken();
if (failed(parseToken(Token::greater, "expected `>` after type literal")))
return failure();
return ast::TypeExpr::create(ctx, loc, attrExpr);
}
FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
StringRef name = curToken.getSpelling();
llvm::SMRange nameLoc = curToken.getLoc();
consumeToken(Token::underscore);
// Underscore expressions require a constraint list.
if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
return failure();
// Parse the constraints for the expression.
SmallVector<ast::ConstraintRef> constraints;
if (failed(parseVariableDeclConstraintList(constraints)))
return failure();
ast::Type type;
if (failed(validateVariableConstraints(constraints, type)))
return failure();
return createInlineVariableExpr(type, name, nameLoc, constraints);
}
//===----------------------------------------------------------------------===//
// Stmts
FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
FailureOr<ast::Stmt *> stmt;
switch (curToken.getKind()) {
case Token::kw_erase:
stmt = parseEraseStmt();
break;
case Token::kw_let:
stmt = parseLetStmt();
break;
case Token::kw_replace:
stmt = parseReplaceStmt();
break;
case Token::kw_rewrite:
stmt = parseRewriteStmt();
break;
default:
stmt = parseExpr();
break;
}
if (failed(stmt) ||
(expectTerminalSemicolon &&
failed(parseToken(Token::semicolon, "expected `;` after statement"))))
return failure();
return stmt;
}
FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
llvm::SMLoc startLoc = curToken.getStartLoc();
consumeToken(Token::l_brace);
// Push a new block scope and parse any nested statements.
pushDeclScope();
SmallVector<ast::Stmt *> statements;
while (curToken.isNot(Token::r_brace)) {
FailureOr<ast::Stmt *> statement = parseStmt();
if (failed(statement))
return popDeclScope(), failure();
statements.push_back(*statement);
}
popDeclScope();
// Consume the end brace.
llvm::SMRange location(startLoc, curToken.getEndLoc());
consumeToken(Token::r_brace);
return ast::CompoundStmt::create(ctx, location, statements);
}
FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::kw_erase);
// Parse the root operation expression.
FailureOr<ast::Expr *> rootOp = parseExpr();
if (failed(rootOp))
return failure();
return createEraseStmt(loc, *rootOp);
}
FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::kw_let);
// Parse the name of the new variable.
llvm::SMRange varLoc = curToken.getLoc();
if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
// `_` is a reserved variable name.
if (curToken.is(Token::underscore)) {
return emitError(varLoc,
"`_` may only be used to define \"inline\" variables");
}
return emitError(varLoc,
"expected identifier after `let` to name a new variable");
}
StringRef varName = curToken.getSpelling();
consumeToken();
// Parse the optional set of constraints.
SmallVector<ast::ConstraintRef> constraints;
if (consumeIf(Token::colon) &&
failed(parseVariableDeclConstraintList(constraints)))
return failure();
// Parse the optional initializer expression.
ast::Expr *initializer = nullptr;
if (consumeIf(Token::equal)) {
FailureOr<ast::Expr *> initOrFailure = parseExpr();
if (failed(initOrFailure))
return failure();
initializer = *initOrFailure;
// Check that the constraints are compatible with having an initializer,
// e.g. type constraints cannot be used with initializers.
for (ast::ConstraintRef constraint : constraints) {
LogicalResult result =
TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
.Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
ast::ValueRangeConstraintDecl>([&](const auto *cst) {
if (auto *typeConstraintExpr = cst->getTypeExpr()) {
return this->emitError(
constraint.referenceLoc,
"type constraints are not permitted on variables with "
"initializers");
}
return success();
})
.Default(success());
if (failed(result))
return failure();
}
}
FailureOr<ast::VariableDecl *> varDecl =
createVariableDecl(varName, varLoc, initializer, constraints);
if (failed(varDecl))
return failure();
return ast::LetStmt::create(ctx, loc, *varDecl);
}
FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::kw_replace);
// Parse the root operation expression.
FailureOr<ast::Expr *> rootOp = parseExpr();
if (failed(rootOp))
return failure();
if (failed(
parseToken(Token::kw_with, "expected `with` after root operation")))
return failure();
// The replacement portion of this statement is within a rewrite context.
llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
ParserContext::Rewrite);
// Parse the replacement values.
SmallVector<ast::Expr *> replValues;
if (consumeIf(Token::l_paren)) {
if (consumeIf(Token::r_paren)) {
return emitError(
loc, "expected at least one replacement value, consider using "
"`erase` if no replacement values are desired");
}
do {
FailureOr<ast::Expr *> replExpr = parseExpr();
if (failed(replExpr))
return failure();
replValues.emplace_back(*replExpr);
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_paren,
"expected `)` after replacement values")))
return failure();
} else {
FailureOr<ast::Expr *> replExpr = parseExpr();
if (failed(replExpr))
return failure();
replValues.emplace_back(*replExpr);
}
return createReplaceStmt(loc, *rootOp, replValues);
}
FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::kw_rewrite);
// Parse the root operation.
FailureOr<ast::Expr *> rootOp = parseExpr();
if (failed(rootOp))
return failure();
if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
return failure();
if (curToken.isNot(Token::l_brace))
return emitError("expected `{` to start rewrite body");
// The rewrite body of this statement is within a rewrite context.
llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
ParserContext::Rewrite);
FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
if (failed(rewriteBody))
return failure();
return createRewriteStmt(loc, *rootOp, *rewriteBody);
}
//===----------------------------------------------------------------------===//
// Creation+Analysis
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Decls
FailureOr<ast::PatternDecl *>
Parser::createPatternDecl(llvm::SMRange loc, const ast::Name *name,
const ParsedPatternMetadata &metadata,
ast::CompoundStmt *body) {
return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
metadata.hasBoundedRecursion, body);
}
FailureOr<ast::VariableDecl *>
Parser::createVariableDecl(StringRef name, llvm::SMRange loc,
ast::Expr *initializer,
ArrayRef<ast::ConstraintRef> constraints) {
// The type of the variable, which is expected to be inferred by either a
// constraint or an initializer expression.
ast::Type type;
if (failed(validateVariableConstraints(constraints, type)))
return failure();
if (initializer) {
// Update the variable type based on the initializer, or try to convert the
// initializer to the existing type.
if (!type)
type = initializer->getType();
else if (ast::Type mergedType = type.refineWith(initializer->getType()))
type = mergedType;
else if (failed(convertExpressionTo(initializer, type)))
return failure();
// Otherwise, if there is no initializer check that the type has already
// been resolved from the constraint list.
} else if (!type) {
return emitErrorAndNote(
loc, "unable to infer type for variable `" + name + "`", loc,
"the type of a variable must be inferable from the constraint "
"list or the initializer");
}
// Try to define a variable with the given name.
FailureOr<ast::VariableDecl *> varDecl =
defineVariableDecl(name, loc, type, initializer, constraints);
if (failed(varDecl))
return failure();
return *varDecl;
}
LogicalResult
Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
ast::Type &inferredType) {
for (const ast::ConstraintRef &ref : constraints)
if (failed(validateVariableConstraint(ref, inferredType)))
return failure();
return success();
}
LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
ast::Type &inferredType) {
ast::Type constraintType;
if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
if (failed(validateTypeConstraintExpr(typeExpr)))
return failure();
}
constraintType = ast::AttributeType::get(ctx);
} else if (const auto *cst =
dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
constraintType = ast::OperationType::get(ctx, cst->getName());
} else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
constraintType = typeTy;
} else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
constraintType = typeRangeTy;
} else if (const auto *cst =
dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
if (failed(validateTypeConstraintExpr(typeExpr)))
return failure();
}
constraintType = valueTy;
} else if (const auto *cst =
dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
if (failed(validateTypeRangeConstraintExpr(typeExpr)))
return failure();
}
constraintType = valueRangeTy;
} else {
llvm_unreachable("unknown constraint type");
}
// Check that the constraint type is compatible with the current inferred
// type.
if (!inferredType) {
inferredType = constraintType;
} else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
inferredType = mergedTy;
} else {
return emitError(ref.referenceLoc,
llvm::formatv("constraint type `{0}` is incompatible "
"with the previously inferred type `{1}`",
constraintType, inferredType));
}
return success();
}
LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
ast::Type typeExprType = typeExpr->getType();
if (typeExprType != typeTy) {
return emitError(typeExpr->getLoc(),
"expected expression of `Type` in type constraint");
}
return success();
}
LogicalResult
Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
ast::Type typeExprType = typeExpr->getType();
if (typeExprType != typeRangeTy) {
return emitError(typeExpr->getLoc(),
"expected expression of `TypeRange` in type constraint");
}
return success();
}
//===----------------------------------------------------------------------===//
// Exprs
FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(llvm::SMRange loc,
ast::Decl *decl) {
// Check the type of decl being referenced.
ast::Type declType;
if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
declType = varDecl->getType();
else
return emitError(loc, "invalid reference to `" +
decl->getName()->getName() + "`");
return ast::DeclRefExpr::create(ctx, loc, decl, declType);
}
FailureOr<ast::DeclRefExpr *>
Parser::createInlineVariableExpr(ast::Type type, StringRef name,
llvm::SMRange loc,
ArrayRef<ast::ConstraintRef> constraints) {
FailureOr<ast::VariableDecl *> decl =
defineVariableDecl(name, loc, type, constraints);
if (failed(decl))
return failure();
return ast::DeclRefExpr::create(ctx, loc, *decl, type);
}
FailureOr<ast::MemberAccessExpr *>
Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
llvm::SMRange loc) {
// Validate the member name for the given parent expression.
FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
if (failed(memberType))
return failure();
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
}
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
StringRef name,
llvm::SMRange loc) {
ast::Type parentType = parentExpr->getType();
if (parentType.isa<ast::OperationType>()) {
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
return valueRangeTy;
} else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
// Handle indexed results.
unsigned index = 0;
if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
index < tupleType.size()) {
return tupleType.getElementTypes()[index];
}
// Handle named results.
auto elementNames = tupleType.getElementNames();
const auto *it = llvm::find(elementNames, name);
if (it != elementNames.end())
return tupleType.getElementTypes()[it - elementNames.begin()];
}
return emitError(
loc,
llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
name, parentType));
}
FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
llvm::SMRange loc, const ast::OpNameDecl *name,
MutableArrayRef<ast::Expr *> operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
MutableArrayRef<ast::Expr *> results) {
Optional<StringRef> opNameRef = name->getName();
// Verify the inputs operands.
if (failed(validateOperationOperands(loc, opNameRef, operands)))
return failure();
// Verify the attribute list.
for (ast::NamedAttributeDecl *attr : attributes) {
// Check for an attribute type, or a type awaiting resolution.
ast::Type attrType = attr->getValue()->getType();
if (!attrType.isa<ast::AttributeType>()) {
return emitError(
attr->getValue()->getLoc(),
llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
}
}
// Verify the result types.
if (failed(validateOperationResults(loc, opNameRef, results)))
return failure();
return ast::OperationExpr::create(ctx, loc, name, operands, results,
attributes);
}
LogicalResult
Parser::validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> operands) {
return validateOperationOperandsOrResults(loc, name, operands, valueTy,
valueRangeTy);
}
LogicalResult
Parser::validateOperationResults(llvm::SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> results) {
return validateOperationOperandsOrResults(loc, name, results, typeTy,
typeRangeTy);
}
LogicalResult Parser::validateOperationOperandsOrResults(
llvm::SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> values, ast::Type singleTy,
ast::Type rangeTy) {
// All operation types accept a single range parameter.
if (values.size() == 1) {
if (failed(convertExpressionTo(values[0], rangeTy)))
return failure();
return success();
}
// Otherwise, accept the value groups as they have been defined and just
// ensure they are one of the expected types.
for (ast::Expr *&valueExpr : values) {
ast::Type valueExprType = valueExpr->getType();
// Check if this is one of the expected types.
if (valueExprType == rangeTy || valueExprType == singleTy)
continue;
// If the operand is an Operation, allow converting to a Value or
// ValueRange. This situations arises quite often with nested operation
// expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
if (singleTy == valueTy) {
if (valueExprType.isa<ast::OperationType>()) {
valueExpr = convertOpToValue(valueExpr);
continue;
}
}
return emitError(
valueExpr->getLoc(),
llvm::formatv(
"expected `{0}` or `{1}` convertible expression, but got `{2}`",
singleTy, rangeTy, valueExprType));
}
return success();
}
FailureOr<ast::TupleExpr *>
Parser::createTupleExpr(llvm::SMRange loc, ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames) {
for (const ast::Expr *element : elements) {
ast::Type eleTy = element->getType();
if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) {
return emitError(
element->getLoc(),
llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
}
}
return ast::TupleExpr::create(ctx, loc, elements, elementNames);
}
//===----------------------------------------------------------------------===//
// Stmts
FailureOr<ast::EraseStmt *> Parser::createEraseStmt(llvm::SMRange loc,
ast::Expr *rootOp) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>())
return emitError(rootOp->getLoc(), "expected `Op` expression");
return ast::EraseStmt::create(ctx, loc, rootOp);
}
FailureOr<ast::ReplaceStmt *>
Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
MutableArrayRef<ast::Expr *> replValues) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>()) {
return emitError(
rootOp->getLoc(),
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
}
// If there are multiple replacement values, we implicitly convert any Op
// expressions to the value form.
bool shouldConvertOpToValues = replValues.size() > 1;
for (ast::Expr *&replExpr : replValues) {
ast::Type replType = replExpr->getType();
// Check that replExpr is an Operation, Value, or ValueRange.
if (replType.isa<ast::OperationType>()) {
if (shouldConvertOpToValues)
replExpr = convertOpToValue(replExpr);
continue;
}
if (replType != valueTy && replType != valueRangeTy) {
return emitError(replExpr->getLoc(),
llvm::formatv("expected `Op`, `Value` or `ValueRange` "
"expression, but got `{0}`",
replType));
}
}
return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
}
FailureOr<ast::RewriteStmt *>
Parser::createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
ast::CompoundStmt *rewriteBody) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>()) {
return emitError(
rootOp->getLoc(),
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
}
return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
}
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
llvm::SourceMgr &sourceMgr) {
Parser parser(ctx, sourceMgr);
return parser.parseModule();
}