forked from OSchip/llvm-project
1706 lines
60 KiB
C++
1706 lines
60 KiB
C++
//===- Parser.cpp ---------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Tools/PDLL/Parser/Parser.h"
|
|
#include "Lexer.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Tools/PDLL/AST/Context.h"
|
|
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
|
|
#include "mlir/Tools/PDLL/AST/Nodes.h"
|
|
#include "mlir/Tools/PDLL/AST/Types.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/SaveAndRestore.h"
|
|
#include "llvm/Support/ScopedPrinter.h"
|
|
#include <string>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::pdll;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class Parser {
|
|
public:
|
|
Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
|
|
: ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
|
|
curToken(lexer.lexToken()), curDeclScope(nullptr),
|
|
valueTy(ast::ValueType::get(ctx)),
|
|
valueRangeTy(ast::ValueRangeType::get(ctx)),
|
|
typeTy(ast::TypeType::get(ctx)),
|
|
typeRangeTy(ast::TypeRangeType::get(ctx)) {}
|
|
|
|
/// Try to parse a new module. Returns nullptr in the case of failure.
|
|
FailureOr<ast::Module *> parseModule();
|
|
|
|
private:
|
|
/// The current context of the parser. It allows for the parser to know a bit
|
|
/// about the construct it is nested within during parsing. This is used
|
|
/// specifically to provide additional verification during parsing, e.g. to
|
|
/// prevent using rewrites within a match context, matcher constraints within
|
|
/// a rewrite section, etc.
|
|
enum class ParserContext {
|
|
/// The parser is in the global context.
|
|
Global,
|
|
/// The parser is currently within the matcher portion of a Pattern, which
|
|
/// is allows a terminal operation rewrite statement but no other rewrite
|
|
/// transformations.
|
|
PatternMatch,
|
|
/// The parser is currently within a Rewrite, which disallows calls to
|
|
/// constraints, requires operation expressions to have names, etc.
|
|
Rewrite,
|
|
};
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Parsing
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
/// Push a new decl scope onto the lexer.
|
|
ast::DeclScope *pushDeclScope() {
|
|
ast::DeclScope *newScope =
|
|
new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
|
|
return (curDeclScope = newScope);
|
|
}
|
|
void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
|
|
|
|
/// Pop the last decl scope from the lexer.
|
|
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
|
|
|
|
/// Parse the body of an AST module.
|
|
LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
|
|
|
|
/// Try to convert the given expression to `type`. Returns failure and emits
|
|
/// an error if a conversion is not viable. On failure, `noteAttachFn` is
|
|
/// invoked to attach notes to the emitted error diagnostic. On success,
|
|
/// `expr` is updated to the expression used to convert to `type`.
|
|
LogicalResult convertExpressionTo(
|
|
ast::Expr *&expr, ast::Type type,
|
|
function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
|
|
|
|
/// Given an operation expression, convert it to a Value or ValueRange
|
|
/// typed expression.
|
|
ast::Expr *convertOpToValue(const ast::Expr *opExpr);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Directives
|
|
|
|
LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
|
|
LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Decls
|
|
|
|
/// This structure contains the set of pattern metadata that may be parsed.
|
|
struct ParsedPatternMetadata {
|
|
Optional<uint16_t> benefit;
|
|
bool hasBoundedRecursion = false;
|
|
};
|
|
|
|
FailureOr<ast::Decl *> parseTopLevelDecl();
|
|
FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
|
|
FailureOr<ast::Decl *> parsePatternDecl();
|
|
LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
|
|
|
|
/// Check to see if a decl has already been defined with the given name, if
|
|
/// one has emit and error and return failure. Returns success otherwise.
|
|
LogicalResult checkDefineNamedDecl(const ast::Name &name);
|
|
|
|
/// Try to define a variable decl with the given components, returns the
|
|
/// variable on success.
|
|
FailureOr<ast::VariableDecl *>
|
|
defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
|
|
ast::Expr *initExpr,
|
|
ArrayRef<ast::ConstraintRef> constraints);
|
|
FailureOr<ast::VariableDecl *>
|
|
defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
|
|
ArrayRef<ast::ConstraintRef> constraints);
|
|
|
|
/// Parse the constraint reference list for a variable decl.
|
|
LogicalResult parseVariableDeclConstraintList(
|
|
SmallVectorImpl<ast::ConstraintRef> &constraints);
|
|
|
|
/// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
|
|
FailureOr<ast::Expr *> parseTypeConstraintExpr();
|
|
|
|
/// Try to parse a single reference to a constraint. `typeConstraint` is the
|
|
/// location of a previously parsed type constraint for the entity that will
|
|
/// be constrained by the parsed constraint. `existingConstraints` are any
|
|
/// existing constraints that have already been parsed for the same entity
|
|
/// that will be constrained by this constraint.
|
|
FailureOr<ast::ConstraintRef>
|
|
parseConstraint(Optional<llvm::SMRange> &typeConstraint,
|
|
ArrayRef<ast::ConstraintRef> existingConstraints);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Exprs
|
|
|
|
FailureOr<ast::Expr *> parseExpr();
|
|
|
|
/// Identifier expressions.
|
|
FailureOr<ast::Expr *> parseAttributeExpr();
|
|
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, llvm::SMRange loc);
|
|
FailureOr<ast::Expr *> parseIdentifierExpr();
|
|
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
|
|
FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
|
|
FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
|
|
FailureOr<ast::Expr *> parseOperationExpr();
|
|
FailureOr<ast::Expr *> parseTupleExpr();
|
|
FailureOr<ast::Expr *> parseTypeExpr();
|
|
FailureOr<ast::Expr *> parseUnderscoreExpr();
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Stmts
|
|
|
|
FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
|
|
FailureOr<ast::CompoundStmt *> parseCompoundStmt();
|
|
FailureOr<ast::EraseStmt *> parseEraseStmt();
|
|
FailureOr<ast::LetStmt *> parseLetStmt();
|
|
FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
|
|
FailureOr<ast::RewriteStmt *> parseRewriteStmt();
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Creation+Analysis
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Decls
|
|
|
|
/// Try to create a pattern decl with the given components, returning the
|
|
/// Pattern on success.
|
|
FailureOr<ast::PatternDecl *>
|
|
createPatternDecl(llvm::SMRange loc, const ast::Name *name,
|
|
const ParsedPatternMetadata &metadata,
|
|
ast::CompoundStmt *body);
|
|
|
|
/// Try to create a variable decl with the given components, returning the
|
|
/// Variable on success.
|
|
FailureOr<ast::VariableDecl *>
|
|
createVariableDecl(StringRef name, llvm::SMRange loc, ast::Expr *initializer,
|
|
ArrayRef<ast::ConstraintRef> constraints);
|
|
|
|
/// Validate the constraints used to constraint a variable decl.
|
|
/// `inferredType` is the type of the variable inferred by the constraints
|
|
/// within the list, and is updated to the most refined type as determined by
|
|
/// the constraints. Returns success if the constraint list is valid, failure
|
|
/// otherwise.
|
|
LogicalResult
|
|
validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
|
|
ast::Type &inferredType);
|
|
/// Validate a single reference to a constraint. `inferredType` contains the
|
|
/// currently inferred variabled type and is refined within the type defined
|
|
/// by the constraint. Returns success if the constraint is valid, failure
|
|
/// otherwise.
|
|
LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
|
|
ast::Type &inferredType);
|
|
LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
|
|
LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Exprs
|
|
|
|
FailureOr<ast::DeclRefExpr *> createDeclRefExpr(llvm::SMRange loc,
|
|
ast::Decl *decl);
|
|
FailureOr<ast::DeclRefExpr *>
|
|
createInlineVariableExpr(ast::Type type, StringRef name, llvm::SMRange loc,
|
|
ArrayRef<ast::ConstraintRef> constraints);
|
|
FailureOr<ast::MemberAccessExpr *>
|
|
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
|
|
llvm::SMRange loc);
|
|
|
|
/// Validate the member access `name` into the given parent expression. On
|
|
/// success, this also returns the type of the member accessed.
|
|
FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
|
|
StringRef name, llvm::SMRange loc);
|
|
FailureOr<ast::OperationExpr *>
|
|
createOperationExpr(llvm::SMRange loc, const ast::OpNameDecl *name,
|
|
MutableArrayRef<ast::Expr *> operands,
|
|
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
|
|
MutableArrayRef<ast::Expr *> results);
|
|
LogicalResult
|
|
validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
|
|
MutableArrayRef<ast::Expr *> operands);
|
|
LogicalResult validateOperationResults(llvm::SMRange loc,
|
|
Optional<StringRef> name,
|
|
MutableArrayRef<ast::Expr *> results);
|
|
LogicalResult
|
|
validateOperationOperandsOrResults(llvm::SMRange loc,
|
|
Optional<StringRef> name,
|
|
MutableArrayRef<ast::Expr *> values,
|
|
ast::Type singleTy, ast::Type rangeTy);
|
|
FailureOr<ast::TupleExpr *> createTupleExpr(llvm::SMRange loc,
|
|
ArrayRef<ast::Expr *> elements,
|
|
ArrayRef<StringRef> elementNames);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Stmts
|
|
|
|
FailureOr<ast::EraseStmt *> createEraseStmt(llvm::SMRange loc,
|
|
ast::Expr *rootOp);
|
|
FailureOr<ast::ReplaceStmt *>
|
|
createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
|
|
MutableArrayRef<ast::Expr *> replValues);
|
|
FailureOr<ast::RewriteStmt *>
|
|
createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
|
|
ast::CompoundStmt *rewriteBody);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Lexer Utilities
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
/// If the current token has the specified kind, consume it and return true.
|
|
/// If not, return false.
|
|
bool consumeIf(Token::Kind kind) {
|
|
if (curToken.isNot(kind))
|
|
return false;
|
|
consumeToken(kind);
|
|
return true;
|
|
}
|
|
|
|
/// Advance the current lexer onto the next token.
|
|
void consumeToken() {
|
|
assert(curToken.isNot(Token::eof, Token::error) &&
|
|
"shouldn't advance past EOF or errors");
|
|
curToken = lexer.lexToken();
|
|
}
|
|
|
|
/// Advance the current lexer onto the next token, asserting what the expected
|
|
/// current token is. This is preferred to the above method because it leads
|
|
/// to more self-documenting code with better checking.
|
|
void consumeToken(Token::Kind kind) {
|
|
assert(curToken.is(kind) && "consumed an unexpected token");
|
|
consumeToken();
|
|
}
|
|
|
|
/// Reset the lexer to the location at the given position.
|
|
void resetToken(llvm::SMRange tokLoc) {
|
|
lexer.resetPointer(tokLoc.Start.getPointer());
|
|
curToken = lexer.lexToken();
|
|
}
|
|
|
|
/// Consume the specified token if present and return success. On failure,
|
|
/// output a diagnostic and return failure.
|
|
LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
|
|
if (curToken.getKind() != kind)
|
|
return emitError(curToken.getLoc(), msg);
|
|
consumeToken();
|
|
return success();
|
|
}
|
|
LogicalResult emitError(llvm::SMRange loc, const Twine &msg) {
|
|
lexer.emitError(loc, msg);
|
|
return failure();
|
|
}
|
|
LogicalResult emitError(const Twine &msg) {
|
|
return emitError(curToken.getLoc(), msg);
|
|
}
|
|
LogicalResult emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
|
|
llvm::SMRange noteLoc, const Twine ¬e) {
|
|
lexer.emitErrorAndNote(loc, msg, noteLoc, note);
|
|
return failure();
|
|
}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Fields
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
/// The owning AST context.
|
|
ast::Context &ctx;
|
|
|
|
/// The lexer of this parser.
|
|
Lexer lexer;
|
|
|
|
/// The current token within the lexer.
|
|
Token curToken;
|
|
|
|
/// The most recently defined decl scope.
|
|
ast::DeclScope *curDeclScope;
|
|
llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
|
|
|
|
/// The current context of the parser.
|
|
ParserContext parserContext = ParserContext::Global;
|
|
|
|
/// Cached types to simplify verification and expression creation.
|
|
ast::Type valueTy, valueRangeTy;
|
|
ast::Type typeTy, typeRangeTy;
|
|
};
|
|
} // namespace
|
|
|
|
FailureOr<ast::Module *> Parser::parseModule() {
|
|
llvm::SMLoc moduleLoc = curToken.getStartLoc();
|
|
pushDeclScope();
|
|
|
|
// Parse the top-level decls of the module.
|
|
SmallVector<ast::Decl *> decls;
|
|
if (failed(parseModuleBody(decls)))
|
|
return popDeclScope(), failure();
|
|
|
|
popDeclScope();
|
|
return ast::Module::create(ctx, moduleLoc, decls);
|
|
}
|
|
|
|
LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
|
|
while (curToken.isNot(Token::eof)) {
|
|
if (curToken.is(Token::directive)) {
|
|
if (failed(parseDirective(decls)))
|
|
return failure();
|
|
continue;
|
|
}
|
|
|
|
FailureOr<ast::Decl *> decl = parseTopLevelDecl();
|
|
if (failed(decl))
|
|
return failure();
|
|
decls.push_back(*decl);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
|
|
return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
|
|
valueRangeTy);
|
|
}
|
|
|
|
LogicalResult Parser::convertExpressionTo(
|
|
ast::Expr *&expr, ast::Type type,
|
|
function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
|
|
ast::Type exprType = expr->getType();
|
|
if (exprType == type)
|
|
return success();
|
|
|
|
auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
|
|
ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
|
|
expr->getLoc(), llvm::formatv("unable to convert expression of type "
|
|
"`{0}` to the expected type of "
|
|
"`{1}`",
|
|
exprType, type));
|
|
if (noteAttachFn)
|
|
noteAttachFn(*diag);
|
|
return diag;
|
|
};
|
|
|
|
if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
|
|
// Two operation types are compatible if they have the same name, or if the
|
|
// expected type is more general.
|
|
if (auto opType = type.dyn_cast<ast::OperationType>()) {
|
|
if (opType.getName())
|
|
return emitConvertError();
|
|
return success();
|
|
}
|
|
|
|
// An operation can always convert to a ValueRange.
|
|
if (type == valueRangeTy) {
|
|
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
|
|
valueRangeTy);
|
|
return success();
|
|
}
|
|
|
|
// Allow conversion to a single value by constraining the result range.
|
|
if (type == valueTy) {
|
|
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
|
|
valueTy);
|
|
return success();
|
|
}
|
|
return emitConvertError();
|
|
}
|
|
|
|
// FIXME: Decide how to allow/support converting a single result to multiple,
|
|
// and multiple to a single result. For now, we just allow Single->Range,
|
|
// but this isn't something really supported in the PDL dialect. We should
|
|
// figure out some way to support both.
|
|
if ((exprType == valueTy || exprType == valueRangeTy) &&
|
|
(type == valueTy || type == valueRangeTy))
|
|
return success();
|
|
if ((exprType == typeTy || exprType == typeRangeTy) &&
|
|
(type == typeTy || type == typeRangeTy))
|
|
return success();
|
|
|
|
// Handle tuple types.
|
|
if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
|
|
auto tupleType = type.dyn_cast<ast::TupleType>();
|
|
if (!tupleType || tupleType.size() != exprTupleType.size())
|
|
return emitConvertError();
|
|
|
|
// Build a new tuple expression using each of the elements of the current
|
|
// tuple.
|
|
SmallVector<ast::Expr *> newExprs;
|
|
for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
|
|
newExprs.push_back(ast::MemberAccessExpr::create(
|
|
ctx, expr->getLoc(), expr, llvm::to_string(i),
|
|
exprTupleType.getElementTypes()[i]));
|
|
|
|
auto diagFn = [&](ast::Diagnostic &diag) {
|
|
diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
|
|
i, exprTupleType));
|
|
if (noteAttachFn)
|
|
noteAttachFn(diag);
|
|
};
|
|
if (failed(convertExpressionTo(newExprs.back(),
|
|
tupleType.getElementTypes()[i], diagFn)))
|
|
return failure();
|
|
}
|
|
expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
|
|
tupleType.getElementNames());
|
|
return success();
|
|
}
|
|
|
|
return emitConvertError();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Directives
|
|
|
|
LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
|
|
StringRef directive = curToken.getSpelling();
|
|
if (directive == "#include")
|
|
return parseInclude(decls);
|
|
|
|
return emitError("unknown directive `" + directive + "`");
|
|
}
|
|
|
|
LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::directive);
|
|
|
|
// Parse the file being included.
|
|
if (!curToken.isString())
|
|
return emitError(loc,
|
|
"expected string file name after `include` directive");
|
|
llvm::SMRange fileLoc = curToken.getLoc();
|
|
std::string filenameStr = curToken.getStringValue();
|
|
StringRef filename = filenameStr;
|
|
consumeToken();
|
|
|
|
// Check the type of include. If ending with `.pdll`, this is another pdl file
|
|
// to be parsed along with the current module.
|
|
if (filename.endswith(".pdll")) {
|
|
if (failed(lexer.pushInclude(filename)))
|
|
return emitError(fileLoc,
|
|
"unable to open include file `" + filename + "`");
|
|
|
|
// If we added the include successfully, parse it into the current module.
|
|
// Make sure to save the current token so that we can restore it when we
|
|
// finish parsing the nested file.
|
|
Token oldToken = curToken;
|
|
curToken = lexer.lexToken();
|
|
LogicalResult result = parseModuleBody(decls);
|
|
curToken = oldToken;
|
|
return result;
|
|
}
|
|
|
|
return emitError(fileLoc, "expected include filename to end with `.pdll`");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Decls
|
|
|
|
FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
|
|
FailureOr<ast::Decl *> decl;
|
|
switch (curToken.getKind()) {
|
|
case Token::kw_Pattern:
|
|
decl = parsePatternDecl();
|
|
break;
|
|
default:
|
|
return emitError("expected top-level declaration, such as a `Pattern`");
|
|
}
|
|
if (failed(decl))
|
|
return failure();
|
|
|
|
// If the decl has a name, add it to the current scope.
|
|
if (const ast::Name *name = (*decl)->getName()) {
|
|
if (failed(checkDefineNamedDecl(*name)))
|
|
return failure();
|
|
curDeclScope->add(*decl);
|
|
}
|
|
return decl;
|
|
}
|
|
|
|
FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
|
|
std::string attrNameStr;
|
|
if (curToken.isString())
|
|
attrNameStr = curToken.getStringValue();
|
|
else if (curToken.is(Token::identifier) || curToken.isKeyword())
|
|
attrNameStr = curToken.getSpelling().str();
|
|
else
|
|
return emitError("expected identifier or string attribute name");
|
|
const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
|
|
consumeToken();
|
|
|
|
// Check for a value of the attribute.
|
|
ast::Expr *attrValue = nullptr;
|
|
if (consumeIf(Token::equal)) {
|
|
FailureOr<ast::Expr *> attrExpr = parseExpr();
|
|
if (failed(attrExpr))
|
|
return failure();
|
|
attrValue = *attrExpr;
|
|
} else {
|
|
// If there isn't a concrete value, create an expression representing a
|
|
// UnitAttr.
|
|
attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
|
|
}
|
|
|
|
return ast::NamedAttributeDecl::create(ctx, name, attrValue);
|
|
}
|
|
|
|
FailureOr<ast::Decl *> Parser::parsePatternDecl() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_Pattern);
|
|
llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
|
|
ParserContext::PatternMatch);
|
|
|
|
// Check for an optional identifier for the pattern name.
|
|
const ast::Name *name = nullptr;
|
|
if (curToken.is(Token::identifier)) {
|
|
name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
|
|
consumeToken(Token::identifier);
|
|
}
|
|
|
|
// Parse any pattern metadata.
|
|
ParsedPatternMetadata metadata;
|
|
if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
|
|
return failure();
|
|
|
|
// Parse the pattern body.
|
|
ast::CompoundStmt *body;
|
|
|
|
if (curToken.isNot(Token::l_brace))
|
|
return emitError("expected `{` to start pattern body");
|
|
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
|
|
if (failed(bodyResult))
|
|
return failure();
|
|
body = *bodyResult;
|
|
|
|
// Verify the body of the pattern.
|
|
auto bodyIt = body->begin(), bodyE = body->end();
|
|
for (; bodyIt != bodyE; ++bodyIt) {
|
|
// Break when we've found the rewrite statement.
|
|
if (isa<ast::OpRewriteStmt>(*bodyIt))
|
|
break;
|
|
}
|
|
if (bodyIt == bodyE) {
|
|
return emitError(loc,
|
|
"expected Pattern body to terminate with an operation "
|
|
"rewrite statement, such as `erase`");
|
|
}
|
|
if (std::next(bodyIt) != bodyE) {
|
|
return emitError((*std::next(bodyIt))->getLoc(),
|
|
"Pattern body was terminated by an operation "
|
|
"rewrite statement, but found trailing statements");
|
|
}
|
|
|
|
return createPatternDecl(loc, name, metadata, body);
|
|
}
|
|
|
|
LogicalResult
|
|
Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
|
|
Optional<llvm::SMRange> benefitLoc;
|
|
Optional<llvm::SMRange> hasBoundedRecursionLoc;
|
|
|
|
do {
|
|
if (curToken.isNot(Token::identifier))
|
|
return emitError("expected pattern metadata identifier");
|
|
StringRef metadataStr = curToken.getSpelling();
|
|
llvm::SMRange metadataLoc = curToken.getLoc();
|
|
consumeToken(Token::identifier);
|
|
|
|
// Parse the benefit metadata: benefit(<integer-value>)
|
|
if (metadataStr == "benefit") {
|
|
if (benefitLoc) {
|
|
return emitErrorAndNote(metadataLoc,
|
|
"pattern benefit has already been specified",
|
|
*benefitLoc, "see previous definition here");
|
|
}
|
|
if (failed(parseToken(Token::l_paren,
|
|
"expected `(` before pattern benefit")))
|
|
return failure();
|
|
|
|
uint16_t benefitValue = 0;
|
|
if (curToken.isNot(Token::integer))
|
|
return emitError("expected integral pattern benefit");
|
|
if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
|
|
return emitError(
|
|
"expected pattern benefit to fit within a 16-bit integer");
|
|
consumeToken(Token::integer);
|
|
|
|
metadata.benefit = benefitValue;
|
|
benefitLoc = metadataLoc;
|
|
|
|
if (failed(
|
|
parseToken(Token::r_paren, "expected `)` after pattern benefit")))
|
|
return failure();
|
|
continue;
|
|
}
|
|
|
|
// Parse the bounded recursion metadata: recursion
|
|
if (metadataStr == "recursion") {
|
|
if (hasBoundedRecursionLoc) {
|
|
return emitErrorAndNote(
|
|
metadataLoc,
|
|
"pattern recursion metadata has already been specified",
|
|
*hasBoundedRecursionLoc, "see previous definition here");
|
|
}
|
|
metadata.hasBoundedRecursion = true;
|
|
hasBoundedRecursionLoc = metadataLoc;
|
|
continue;
|
|
}
|
|
|
|
return emitError(metadataLoc, "unknown pattern metadata");
|
|
} while (consumeIf(Token::comma));
|
|
|
|
return success();
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
|
|
consumeToken(Token::less);
|
|
|
|
FailureOr<ast::Expr *> typeExpr = parseExpr();
|
|
if (failed(typeExpr) ||
|
|
failed(parseToken(Token::greater,
|
|
"expected `>` after variable type constraint")))
|
|
return failure();
|
|
return typeExpr;
|
|
}
|
|
|
|
LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
|
|
assert(curDeclScope && "defining decl outside of a decl scope");
|
|
if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
|
|
return emitErrorAndNote(
|
|
name.getLoc(), "`" + name.getName() + "` has already been defined",
|
|
lastDecl->getName()->getLoc(), "see previous definition here");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
FailureOr<ast::VariableDecl *>
|
|
Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
|
|
ast::Type type, ast::Expr *initExpr,
|
|
ArrayRef<ast::ConstraintRef> constraints) {
|
|
assert(curDeclScope && "defining variable outside of decl scope");
|
|
const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
|
|
|
|
// If the name of the variable indicates a special variable, we don't add it
|
|
// to the scope. This variable is local to the definition point.
|
|
if (name.empty() || name == "_") {
|
|
return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
|
|
constraints);
|
|
}
|
|
if (failed(checkDefineNamedDecl(nameDecl)))
|
|
return failure();
|
|
|
|
auto *varDecl =
|
|
ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
|
|
curDeclScope->add(varDecl);
|
|
return varDecl;
|
|
}
|
|
|
|
FailureOr<ast::VariableDecl *>
|
|
Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
|
|
ast::Type type,
|
|
ArrayRef<ast::ConstraintRef> constraints) {
|
|
return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
|
|
constraints);
|
|
}
|
|
|
|
LogicalResult Parser::parseVariableDeclConstraintList(
|
|
SmallVectorImpl<ast::ConstraintRef> &constraints) {
|
|
Optional<llvm::SMRange> typeConstraint;
|
|
auto parseSingleConstraint = [&] {
|
|
FailureOr<ast::ConstraintRef> constraint =
|
|
parseConstraint(typeConstraint, constraints);
|
|
if (failed(constraint))
|
|
return failure();
|
|
constraints.push_back(*constraint);
|
|
return success();
|
|
};
|
|
|
|
// Check to see if this is a single constraint, or a list.
|
|
if (!consumeIf(Token::l_square))
|
|
return parseSingleConstraint();
|
|
|
|
do {
|
|
if (failed(parseSingleConstraint()))
|
|
return failure();
|
|
} while (consumeIf(Token::comma));
|
|
return parseToken(Token::r_square, "expected `]` after constraint list");
|
|
}
|
|
|
|
FailureOr<ast::ConstraintRef>
|
|
Parser::parseConstraint(Optional<llvm::SMRange> &typeConstraint,
|
|
ArrayRef<ast::ConstraintRef> existingConstraints) {
|
|
auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
|
|
if (typeConstraint)
|
|
return emitErrorAndNote(
|
|
curToken.getLoc(),
|
|
"the type of this variable has already been constrained",
|
|
*typeConstraint, "see previous constraint location here");
|
|
FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
|
|
if (failed(constraintExpr))
|
|
return failure();
|
|
typeExpr = *constraintExpr;
|
|
typeConstraint = typeExpr->getLoc();
|
|
return success();
|
|
};
|
|
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
switch (curToken.getKind()) {
|
|
case Token::kw_Attr: {
|
|
consumeToken(Token::kw_Attr);
|
|
|
|
// Check for a type constraint.
|
|
ast::Expr *typeExpr = nullptr;
|
|
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
|
|
return failure();
|
|
return ast::ConstraintRef(
|
|
ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
|
|
}
|
|
case Token::kw_Op: {
|
|
consumeToken(Token::kw_Op);
|
|
|
|
// Parse an optional operation name. If the name isn't provided, this refers
|
|
// to "any" operation.
|
|
FailureOr<ast::OpNameDecl *> opName =
|
|
parseWrappedOperationName(/*allowEmptyName=*/true);
|
|
if (failed(opName))
|
|
return failure();
|
|
|
|
return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
|
|
loc);
|
|
}
|
|
case Token::kw_Type:
|
|
consumeToken(Token::kw_Type);
|
|
return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
|
|
case Token::kw_TypeRange:
|
|
consumeToken(Token::kw_TypeRange);
|
|
return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
|
|
loc);
|
|
case Token::kw_Value: {
|
|
consumeToken(Token::kw_Value);
|
|
|
|
// Check for a type constraint.
|
|
ast::Expr *typeExpr = nullptr;
|
|
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
|
|
return failure();
|
|
|
|
return ast::ConstraintRef(
|
|
ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
|
|
}
|
|
case Token::kw_ValueRange: {
|
|
consumeToken(Token::kw_ValueRange);
|
|
|
|
// Check for a type constraint.
|
|
ast::Expr *typeExpr = nullptr;
|
|
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
|
|
return failure();
|
|
|
|
return ast::ConstraintRef(
|
|
ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
|
|
}
|
|
case Token::identifier: {
|
|
StringRef constraintName = curToken.getSpelling();
|
|
consumeToken(Token::identifier);
|
|
|
|
// Lookup the referenced constraint.
|
|
ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
|
|
if (!cstDecl) {
|
|
return emitError(loc, "unknown reference to constraint `" +
|
|
constraintName + "`");
|
|
}
|
|
|
|
// Handle a reference to a proper constraint.
|
|
if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
|
|
return ast::ConstraintRef(cst, loc);
|
|
|
|
return emitErrorAndNote(
|
|
loc, "invalid reference to non-constraint", cstDecl->getLoc(),
|
|
"see the definition of `" + constraintName + "` here");
|
|
}
|
|
default:
|
|
break;
|
|
}
|
|
return emitError(loc, "expected identifier constraint");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Exprs
|
|
|
|
FailureOr<ast::Expr *> Parser::parseExpr() {
|
|
if (curToken.is(Token::underscore))
|
|
return parseUnderscoreExpr();
|
|
|
|
// Parse the LHS expression.
|
|
FailureOr<ast::Expr *> lhsExpr;
|
|
switch (curToken.getKind()) {
|
|
case Token::kw_attr:
|
|
lhsExpr = parseAttributeExpr();
|
|
break;
|
|
case Token::identifier:
|
|
lhsExpr = parseIdentifierExpr();
|
|
break;
|
|
case Token::kw_op:
|
|
lhsExpr = parseOperationExpr();
|
|
break;
|
|
case Token::kw_type:
|
|
lhsExpr = parseTypeExpr();
|
|
break;
|
|
case Token::l_paren:
|
|
lhsExpr = parseTupleExpr();
|
|
break;
|
|
default:
|
|
return emitError("expected expression");
|
|
}
|
|
if (failed(lhsExpr))
|
|
return failure();
|
|
|
|
// Check for an operator expression.
|
|
while (true) {
|
|
switch (curToken.getKind()) {
|
|
case Token::dot:
|
|
lhsExpr = parseMemberAccessExpr(*lhsExpr);
|
|
break;
|
|
default:
|
|
return lhsExpr;
|
|
}
|
|
if (failed(lhsExpr))
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_attr);
|
|
|
|
// If we aren't followed by a `<`, the `attr` keyword is treated as a normal
|
|
// identifier.
|
|
if (!consumeIf(Token::less)) {
|
|
resetToken(loc);
|
|
return parseIdentifierExpr();
|
|
}
|
|
|
|
if (!curToken.isString())
|
|
return emitError("expected string literal containing MLIR attribute");
|
|
std::string attrExpr = curToken.getStringValue();
|
|
consumeToken();
|
|
|
|
if (failed(
|
|
parseToken(Token::greater, "expected `>` after attribute literal")))
|
|
return failure();
|
|
return ast::AttributeExpr::create(ctx, loc, attrExpr);
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name,
|
|
llvm::SMRange loc) {
|
|
ast::Decl *decl = curDeclScope->lookup(name);
|
|
if (!decl)
|
|
return emitError(loc, "undefined reference to `" + name + "`");
|
|
|
|
return createDeclRefExpr(loc, decl);
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
|
|
StringRef name = curToken.getSpelling();
|
|
llvm::SMRange nameLoc = curToken.getLoc();
|
|
consumeToken();
|
|
|
|
// Check to see if this is a decl ref expression that defines a variable
|
|
// inline.
|
|
if (consumeIf(Token::colon)) {
|
|
SmallVector<ast::ConstraintRef> constraints;
|
|
if (failed(parseVariableDeclConstraintList(constraints)))
|
|
return failure();
|
|
ast::Type type;
|
|
if (failed(validateVariableConstraints(constraints, type)))
|
|
return failure();
|
|
return createInlineVariableExpr(type, name, nameLoc, constraints);
|
|
}
|
|
|
|
return parseDeclRefExpr(name, nameLoc);
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::dot);
|
|
|
|
// Parse the member name.
|
|
Token memberNameTok = curToken;
|
|
if (memberNameTok.isNot(Token::identifier, Token::integer) &&
|
|
!memberNameTok.isKeyword())
|
|
return emitError(loc, "expected identifier or numeric member name");
|
|
StringRef memberName = memberNameTok.getSpelling();
|
|
consumeToken();
|
|
|
|
return createMemberAccessExpr(parentExpr, memberName, loc);
|
|
}
|
|
|
|
FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
|
|
// Handle the case of an no operation name.
|
|
if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
|
|
if (allowEmptyName)
|
|
return ast::OpNameDecl::create(ctx, llvm::SMRange());
|
|
return emitError("expected dialect namespace");
|
|
}
|
|
StringRef name = curToken.getSpelling();
|
|
consumeToken();
|
|
|
|
// Otherwise, this is a literal operation name.
|
|
if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
|
|
return failure();
|
|
|
|
if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
|
|
return emitError("expected operation name after dialect namespace");
|
|
|
|
name = StringRef(name.data(), name.size() + 1);
|
|
do {
|
|
name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
|
|
loc.End = curToken.getEndLoc();
|
|
consumeToken();
|
|
} while (curToken.isAny(Token::identifier, Token::dot) ||
|
|
curToken.isKeyword());
|
|
return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
|
|
}
|
|
|
|
FailureOr<ast::OpNameDecl *>
|
|
Parser::parseWrappedOperationName(bool allowEmptyName) {
|
|
if (!consumeIf(Token::less))
|
|
return ast::OpNameDecl::create(ctx, llvm::SMRange());
|
|
|
|
FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
|
|
if (failed(opNameDecl))
|
|
return failure();
|
|
|
|
if (failed(parseToken(Token::greater, "expected `>` after operation name")))
|
|
return failure();
|
|
return opNameDecl;
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseOperationExpr() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_op);
|
|
|
|
// If it isn't followed by a `<`, the `op` keyword is treated as a normal
|
|
// identifier.
|
|
if (curToken.isNot(Token::less)) {
|
|
resetToken(loc);
|
|
return parseIdentifierExpr();
|
|
}
|
|
|
|
// Parse the operation name. The name may be elided, in which case the
|
|
// operation refers to "any" operation(i.e. a difference between `MyOp` and
|
|
// `Operation*`). Operation names within a rewrite context must be named.
|
|
bool allowEmptyName = parserContext != ParserContext::Rewrite;
|
|
FailureOr<ast::OpNameDecl *> opNameDecl =
|
|
parseWrappedOperationName(allowEmptyName);
|
|
if (failed(opNameDecl))
|
|
return failure();
|
|
|
|
// Check for the optional list of operands.
|
|
SmallVector<ast::Expr *> operands;
|
|
if (consumeIf(Token::l_paren)) {
|
|
do {
|
|
FailureOr<ast::Expr *> operand = parseExpr();
|
|
if (failed(operand))
|
|
return failure();
|
|
operands.push_back(*operand);
|
|
} while (consumeIf(Token::comma));
|
|
|
|
if (failed(parseToken(Token::r_paren,
|
|
"expected `)` after operation operand list")))
|
|
return failure();
|
|
}
|
|
|
|
// Check for the optional list of attributes.
|
|
SmallVector<ast::NamedAttributeDecl *> attributes;
|
|
if (consumeIf(Token::l_brace)) {
|
|
do {
|
|
FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
|
|
if (failed(decl))
|
|
return failure();
|
|
attributes.emplace_back(*decl);
|
|
} while (consumeIf(Token::comma));
|
|
|
|
if (failed(parseToken(Token::r_brace,
|
|
"expected `}` after operation attribute list")))
|
|
return failure();
|
|
}
|
|
|
|
// Check for the optional list of result types.
|
|
SmallVector<ast::Expr *> resultTypes;
|
|
if (consumeIf(Token::arrow)) {
|
|
if (failed(parseToken(Token::l_paren,
|
|
"expected `(` before operation result type list")))
|
|
return failure();
|
|
|
|
do {
|
|
FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
|
|
if (failed(resultTypeExpr))
|
|
return failure();
|
|
resultTypes.push_back(*resultTypeExpr);
|
|
} while (consumeIf(Token::comma));
|
|
|
|
if (failed(parseToken(Token::r_paren,
|
|
"expected `)` after operation result type list")))
|
|
return failure();
|
|
}
|
|
|
|
return createOperationExpr(loc, *opNameDecl, operands, attributes,
|
|
resultTypes);
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseTupleExpr() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::l_paren);
|
|
|
|
DenseMap<StringRef, llvm::SMRange> usedNames;
|
|
SmallVector<StringRef> elementNames;
|
|
SmallVector<ast::Expr *> elements;
|
|
if (curToken.isNot(Token::r_paren)) {
|
|
do {
|
|
// Check for the optional element name assignment before the value.
|
|
StringRef elementName;
|
|
if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
|
|
Token elementNameTok = curToken;
|
|
consumeToken();
|
|
|
|
// The element name is only present if followed by an `=`.
|
|
if (consumeIf(Token::equal)) {
|
|
elementName = elementNameTok.getSpelling();
|
|
|
|
// Check to see if this name is already used.
|
|
auto elementNameIt =
|
|
usedNames.try_emplace(elementName, elementNameTok.getLoc());
|
|
if (!elementNameIt.second) {
|
|
return emitErrorAndNote(
|
|
elementNameTok.getLoc(),
|
|
llvm::formatv("duplicate tuple element label `{0}`",
|
|
elementName),
|
|
elementNameIt.first->getSecond(),
|
|
"see previous label use here");
|
|
}
|
|
} else {
|
|
// Otherwise, we treat this as part of an expression so reset the
|
|
// lexer.
|
|
resetToken(elementNameTok.getLoc());
|
|
}
|
|
}
|
|
elementNames.push_back(elementName);
|
|
|
|
// Parse the tuple element value.
|
|
FailureOr<ast::Expr *> element = parseExpr();
|
|
if (failed(element))
|
|
return failure();
|
|
elements.push_back(*element);
|
|
} while (consumeIf(Token::comma));
|
|
}
|
|
loc.End = curToken.getEndLoc();
|
|
if (failed(
|
|
parseToken(Token::r_paren, "expected `)` after tuple element list")))
|
|
return failure();
|
|
return createTupleExpr(loc, elements, elementNames);
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseTypeExpr() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_type);
|
|
|
|
// If we aren't followed by a `<`, the `type` keyword is treated as a normal
|
|
// identifier.
|
|
if (!consumeIf(Token::less)) {
|
|
resetToken(loc);
|
|
return parseIdentifierExpr();
|
|
}
|
|
|
|
if (!curToken.isString())
|
|
return emitError("expected string literal containing MLIR type");
|
|
std::string attrExpr = curToken.getStringValue();
|
|
consumeToken();
|
|
|
|
if (failed(parseToken(Token::greater, "expected `>` after type literal")))
|
|
return failure();
|
|
return ast::TypeExpr::create(ctx, loc, attrExpr);
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
|
|
StringRef name = curToken.getSpelling();
|
|
llvm::SMRange nameLoc = curToken.getLoc();
|
|
consumeToken(Token::underscore);
|
|
|
|
// Underscore expressions require a constraint list.
|
|
if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
|
|
return failure();
|
|
|
|
// Parse the constraints for the expression.
|
|
SmallVector<ast::ConstraintRef> constraints;
|
|
if (failed(parseVariableDeclConstraintList(constraints)))
|
|
return failure();
|
|
|
|
ast::Type type;
|
|
if (failed(validateVariableConstraints(constraints, type)))
|
|
return failure();
|
|
return createInlineVariableExpr(type, name, nameLoc, constraints);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Stmts
|
|
|
|
FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
|
|
FailureOr<ast::Stmt *> stmt;
|
|
switch (curToken.getKind()) {
|
|
case Token::kw_erase:
|
|
stmt = parseEraseStmt();
|
|
break;
|
|
case Token::kw_let:
|
|
stmt = parseLetStmt();
|
|
break;
|
|
case Token::kw_replace:
|
|
stmt = parseReplaceStmt();
|
|
break;
|
|
case Token::kw_rewrite:
|
|
stmt = parseRewriteStmt();
|
|
break;
|
|
default:
|
|
stmt = parseExpr();
|
|
break;
|
|
}
|
|
if (failed(stmt) ||
|
|
(expectTerminalSemicolon &&
|
|
failed(parseToken(Token::semicolon, "expected `;` after statement"))))
|
|
return failure();
|
|
return stmt;
|
|
}
|
|
|
|
FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
|
|
llvm::SMLoc startLoc = curToken.getStartLoc();
|
|
consumeToken(Token::l_brace);
|
|
|
|
// Push a new block scope and parse any nested statements.
|
|
pushDeclScope();
|
|
SmallVector<ast::Stmt *> statements;
|
|
while (curToken.isNot(Token::r_brace)) {
|
|
FailureOr<ast::Stmt *> statement = parseStmt();
|
|
if (failed(statement))
|
|
return popDeclScope(), failure();
|
|
statements.push_back(*statement);
|
|
}
|
|
popDeclScope();
|
|
|
|
// Consume the end brace.
|
|
llvm::SMRange location(startLoc, curToken.getEndLoc());
|
|
consumeToken(Token::r_brace);
|
|
|
|
return ast::CompoundStmt::create(ctx, location, statements);
|
|
}
|
|
|
|
FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_erase);
|
|
|
|
// Parse the root operation expression.
|
|
FailureOr<ast::Expr *> rootOp = parseExpr();
|
|
if (failed(rootOp))
|
|
return failure();
|
|
|
|
return createEraseStmt(loc, *rootOp);
|
|
}
|
|
|
|
FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_let);
|
|
|
|
// Parse the name of the new variable.
|
|
llvm::SMRange varLoc = curToken.getLoc();
|
|
if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
|
|
// `_` is a reserved variable name.
|
|
if (curToken.is(Token::underscore)) {
|
|
return emitError(varLoc,
|
|
"`_` may only be used to define \"inline\" variables");
|
|
}
|
|
return emitError(varLoc,
|
|
"expected identifier after `let` to name a new variable");
|
|
}
|
|
StringRef varName = curToken.getSpelling();
|
|
consumeToken();
|
|
|
|
// Parse the optional set of constraints.
|
|
SmallVector<ast::ConstraintRef> constraints;
|
|
if (consumeIf(Token::colon) &&
|
|
failed(parseVariableDeclConstraintList(constraints)))
|
|
return failure();
|
|
|
|
// Parse the optional initializer expression.
|
|
ast::Expr *initializer = nullptr;
|
|
if (consumeIf(Token::equal)) {
|
|
FailureOr<ast::Expr *> initOrFailure = parseExpr();
|
|
if (failed(initOrFailure))
|
|
return failure();
|
|
initializer = *initOrFailure;
|
|
|
|
// Check that the constraints are compatible with having an initializer,
|
|
// e.g. type constraints cannot be used with initializers.
|
|
for (ast::ConstraintRef constraint : constraints) {
|
|
LogicalResult result =
|
|
TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
|
|
.Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
|
|
ast::ValueRangeConstraintDecl>([&](const auto *cst) {
|
|
if (auto *typeConstraintExpr = cst->getTypeExpr()) {
|
|
return this->emitError(
|
|
constraint.referenceLoc,
|
|
"type constraints are not permitted on variables with "
|
|
"initializers");
|
|
}
|
|
return success();
|
|
})
|
|
.Default(success());
|
|
if (failed(result))
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
FailureOr<ast::VariableDecl *> varDecl =
|
|
createVariableDecl(varName, varLoc, initializer, constraints);
|
|
if (failed(varDecl))
|
|
return failure();
|
|
return ast::LetStmt::create(ctx, loc, *varDecl);
|
|
}
|
|
|
|
FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_replace);
|
|
|
|
// Parse the root operation expression.
|
|
FailureOr<ast::Expr *> rootOp = parseExpr();
|
|
if (failed(rootOp))
|
|
return failure();
|
|
|
|
if (failed(
|
|
parseToken(Token::kw_with, "expected `with` after root operation")))
|
|
return failure();
|
|
|
|
// The replacement portion of this statement is within a rewrite context.
|
|
llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
|
|
ParserContext::Rewrite);
|
|
|
|
// Parse the replacement values.
|
|
SmallVector<ast::Expr *> replValues;
|
|
if (consumeIf(Token::l_paren)) {
|
|
if (consumeIf(Token::r_paren)) {
|
|
return emitError(
|
|
loc, "expected at least one replacement value, consider using "
|
|
"`erase` if no replacement values are desired");
|
|
}
|
|
|
|
do {
|
|
FailureOr<ast::Expr *> replExpr = parseExpr();
|
|
if (failed(replExpr))
|
|
return failure();
|
|
replValues.emplace_back(*replExpr);
|
|
} while (consumeIf(Token::comma));
|
|
|
|
if (failed(parseToken(Token::r_paren,
|
|
"expected `)` after replacement values")))
|
|
return failure();
|
|
} else {
|
|
FailureOr<ast::Expr *> replExpr = parseExpr();
|
|
if (failed(replExpr))
|
|
return failure();
|
|
replValues.emplace_back(*replExpr);
|
|
}
|
|
|
|
return createReplaceStmt(loc, *rootOp, replValues);
|
|
}
|
|
|
|
FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_rewrite);
|
|
|
|
// Parse the root operation.
|
|
FailureOr<ast::Expr *> rootOp = parseExpr();
|
|
if (failed(rootOp))
|
|
return failure();
|
|
|
|
if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
|
|
return failure();
|
|
|
|
if (curToken.isNot(Token::l_brace))
|
|
return emitError("expected `{` to start rewrite body");
|
|
|
|
// The rewrite body of this statement is within a rewrite context.
|
|
llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
|
|
ParserContext::Rewrite);
|
|
|
|
FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
|
|
if (failed(rewriteBody))
|
|
return failure();
|
|
|
|
return createRewriteStmt(loc, *rootOp, *rewriteBody);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Creation+Analysis
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Decls
|
|
|
|
FailureOr<ast::PatternDecl *>
|
|
Parser::createPatternDecl(llvm::SMRange loc, const ast::Name *name,
|
|
const ParsedPatternMetadata &metadata,
|
|
ast::CompoundStmt *body) {
|
|
return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
|
|
metadata.hasBoundedRecursion, body);
|
|
}
|
|
|
|
FailureOr<ast::VariableDecl *>
|
|
Parser::createVariableDecl(StringRef name, llvm::SMRange loc,
|
|
ast::Expr *initializer,
|
|
ArrayRef<ast::ConstraintRef> constraints) {
|
|
// The type of the variable, which is expected to be inferred by either a
|
|
// constraint or an initializer expression.
|
|
ast::Type type;
|
|
if (failed(validateVariableConstraints(constraints, type)))
|
|
return failure();
|
|
|
|
if (initializer) {
|
|
// Update the variable type based on the initializer, or try to convert the
|
|
// initializer to the existing type.
|
|
if (!type)
|
|
type = initializer->getType();
|
|
else if (ast::Type mergedType = type.refineWith(initializer->getType()))
|
|
type = mergedType;
|
|
else if (failed(convertExpressionTo(initializer, type)))
|
|
return failure();
|
|
|
|
// Otherwise, if there is no initializer check that the type has already
|
|
// been resolved from the constraint list.
|
|
} else if (!type) {
|
|
return emitErrorAndNote(
|
|
loc, "unable to infer type for variable `" + name + "`", loc,
|
|
"the type of a variable must be inferable from the constraint "
|
|
"list or the initializer");
|
|
}
|
|
|
|
// Try to define a variable with the given name.
|
|
FailureOr<ast::VariableDecl *> varDecl =
|
|
defineVariableDecl(name, loc, type, initializer, constraints);
|
|
if (failed(varDecl))
|
|
return failure();
|
|
|
|
return *varDecl;
|
|
}
|
|
|
|
LogicalResult
|
|
Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
|
|
ast::Type &inferredType) {
|
|
for (const ast::ConstraintRef &ref : constraints)
|
|
if (failed(validateVariableConstraint(ref, inferredType)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
|
|
ast::Type &inferredType) {
|
|
ast::Type constraintType;
|
|
if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
|
|
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
|
|
if (failed(validateTypeConstraintExpr(typeExpr)))
|
|
return failure();
|
|
}
|
|
constraintType = ast::AttributeType::get(ctx);
|
|
} else if (const auto *cst =
|
|
dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
|
|
constraintType = ast::OperationType::get(ctx, cst->getName());
|
|
} else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
|
|
constraintType = typeTy;
|
|
} else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
|
|
constraintType = typeRangeTy;
|
|
} else if (const auto *cst =
|
|
dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
|
|
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
|
|
if (failed(validateTypeConstraintExpr(typeExpr)))
|
|
return failure();
|
|
}
|
|
constraintType = valueTy;
|
|
} else if (const auto *cst =
|
|
dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
|
|
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
|
|
if (failed(validateTypeRangeConstraintExpr(typeExpr)))
|
|
return failure();
|
|
}
|
|
constraintType = valueRangeTy;
|
|
} else {
|
|
llvm_unreachable("unknown constraint type");
|
|
}
|
|
|
|
// Check that the constraint type is compatible with the current inferred
|
|
// type.
|
|
if (!inferredType) {
|
|
inferredType = constraintType;
|
|
} else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
|
|
inferredType = mergedTy;
|
|
} else {
|
|
return emitError(ref.referenceLoc,
|
|
llvm::formatv("constraint type `{0}` is incompatible "
|
|
"with the previously inferred type `{1}`",
|
|
constraintType, inferredType));
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
|
|
ast::Type typeExprType = typeExpr->getType();
|
|
if (typeExprType != typeTy) {
|
|
return emitError(typeExpr->getLoc(),
|
|
"expected expression of `Type` in type constraint");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
|
|
ast::Type typeExprType = typeExpr->getType();
|
|
if (typeExprType != typeRangeTy) {
|
|
return emitError(typeExpr->getLoc(),
|
|
"expected expression of `TypeRange` in type constraint");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Exprs
|
|
|
|
FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(llvm::SMRange loc,
|
|
ast::Decl *decl) {
|
|
// Check the type of decl being referenced.
|
|
ast::Type declType;
|
|
if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
|
|
declType = varDecl->getType();
|
|
else
|
|
return emitError(loc, "invalid reference to `" +
|
|
decl->getName()->getName() + "`");
|
|
|
|
return ast::DeclRefExpr::create(ctx, loc, decl, declType);
|
|
}
|
|
|
|
FailureOr<ast::DeclRefExpr *>
|
|
Parser::createInlineVariableExpr(ast::Type type, StringRef name,
|
|
llvm::SMRange loc,
|
|
ArrayRef<ast::ConstraintRef> constraints) {
|
|
FailureOr<ast::VariableDecl *> decl =
|
|
defineVariableDecl(name, loc, type, constraints);
|
|
if (failed(decl))
|
|
return failure();
|
|
return ast::DeclRefExpr::create(ctx, loc, *decl, type);
|
|
}
|
|
|
|
FailureOr<ast::MemberAccessExpr *>
|
|
Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
|
|
llvm::SMRange loc) {
|
|
// Validate the member name for the given parent expression.
|
|
FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
|
|
if (failed(memberType))
|
|
return failure();
|
|
|
|
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
|
|
}
|
|
|
|
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
|
|
StringRef name,
|
|
llvm::SMRange loc) {
|
|
ast::Type parentType = parentExpr->getType();
|
|
if (parentType.isa<ast::OperationType>()) {
|
|
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
|
|
return valueRangeTy;
|
|
} else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
|
|
// Handle indexed results.
|
|
unsigned index = 0;
|
|
if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
|
|
index < tupleType.size()) {
|
|
return tupleType.getElementTypes()[index];
|
|
}
|
|
|
|
// Handle named results.
|
|
auto elementNames = tupleType.getElementNames();
|
|
const auto *it = llvm::find(elementNames, name);
|
|
if (it != elementNames.end())
|
|
return tupleType.getElementTypes()[it - elementNames.begin()];
|
|
}
|
|
return emitError(
|
|
loc,
|
|
llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
|
|
name, parentType));
|
|
}
|
|
|
|
FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
|
|
llvm::SMRange loc, const ast::OpNameDecl *name,
|
|
MutableArrayRef<ast::Expr *> operands,
|
|
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
|
|
MutableArrayRef<ast::Expr *> results) {
|
|
Optional<StringRef> opNameRef = name->getName();
|
|
|
|
// Verify the inputs operands.
|
|
if (failed(validateOperationOperands(loc, opNameRef, operands)))
|
|
return failure();
|
|
|
|
// Verify the attribute list.
|
|
for (ast::NamedAttributeDecl *attr : attributes) {
|
|
// Check for an attribute type, or a type awaiting resolution.
|
|
ast::Type attrType = attr->getValue()->getType();
|
|
if (!attrType.isa<ast::AttributeType>()) {
|
|
return emitError(
|
|
attr->getValue()->getLoc(),
|
|
llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
|
|
}
|
|
}
|
|
|
|
// Verify the result types.
|
|
if (failed(validateOperationResults(loc, opNameRef, results)))
|
|
return failure();
|
|
|
|
return ast::OperationExpr::create(ctx, loc, name, operands, results,
|
|
attributes);
|
|
}
|
|
|
|
LogicalResult
|
|
Parser::validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
|
|
MutableArrayRef<ast::Expr *> operands) {
|
|
return validateOperationOperandsOrResults(loc, name, operands, valueTy,
|
|
valueRangeTy);
|
|
}
|
|
|
|
LogicalResult
|
|
Parser::validateOperationResults(llvm::SMRange loc, Optional<StringRef> name,
|
|
MutableArrayRef<ast::Expr *> results) {
|
|
return validateOperationOperandsOrResults(loc, name, results, typeTy,
|
|
typeRangeTy);
|
|
}
|
|
|
|
LogicalResult Parser::validateOperationOperandsOrResults(
|
|
llvm::SMRange loc, Optional<StringRef> name,
|
|
MutableArrayRef<ast::Expr *> values, ast::Type singleTy,
|
|
ast::Type rangeTy) {
|
|
// All operation types accept a single range parameter.
|
|
if (values.size() == 1) {
|
|
if (failed(convertExpressionTo(values[0], rangeTy)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
// Otherwise, accept the value groups as they have been defined and just
|
|
// ensure they are one of the expected types.
|
|
for (ast::Expr *&valueExpr : values) {
|
|
ast::Type valueExprType = valueExpr->getType();
|
|
|
|
// Check if this is one of the expected types.
|
|
if (valueExprType == rangeTy || valueExprType == singleTy)
|
|
continue;
|
|
|
|
// If the operand is an Operation, allow converting to a Value or
|
|
// ValueRange. This situations arises quite often with nested operation
|
|
// expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
|
|
if (singleTy == valueTy) {
|
|
if (valueExprType.isa<ast::OperationType>()) {
|
|
valueExpr = convertOpToValue(valueExpr);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
return emitError(
|
|
valueExpr->getLoc(),
|
|
llvm::formatv(
|
|
"expected `{0}` or `{1}` convertible expression, but got `{2}`",
|
|
singleTy, rangeTy, valueExprType));
|
|
}
|
|
return success();
|
|
}
|
|
|
|
FailureOr<ast::TupleExpr *>
|
|
Parser::createTupleExpr(llvm::SMRange loc, ArrayRef<ast::Expr *> elements,
|
|
ArrayRef<StringRef> elementNames) {
|
|
for (const ast::Expr *element : elements) {
|
|
ast::Type eleTy = element->getType();
|
|
if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) {
|
|
return emitError(
|
|
element->getLoc(),
|
|
llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
|
|
}
|
|
}
|
|
return ast::TupleExpr::create(ctx, loc, elements, elementNames);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Stmts
|
|
|
|
FailureOr<ast::EraseStmt *> Parser::createEraseStmt(llvm::SMRange loc,
|
|
ast::Expr *rootOp) {
|
|
// Check that root is an Operation.
|
|
ast::Type rootType = rootOp->getType();
|
|
if (!rootType.isa<ast::OperationType>())
|
|
return emitError(rootOp->getLoc(), "expected `Op` expression");
|
|
|
|
return ast::EraseStmt::create(ctx, loc, rootOp);
|
|
}
|
|
|
|
FailureOr<ast::ReplaceStmt *>
|
|
Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
|
|
MutableArrayRef<ast::Expr *> replValues) {
|
|
// Check that root is an Operation.
|
|
ast::Type rootType = rootOp->getType();
|
|
if (!rootType.isa<ast::OperationType>()) {
|
|
return emitError(
|
|
rootOp->getLoc(),
|
|
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
|
|
}
|
|
|
|
// If there are multiple replacement values, we implicitly convert any Op
|
|
// expressions to the value form.
|
|
bool shouldConvertOpToValues = replValues.size() > 1;
|
|
for (ast::Expr *&replExpr : replValues) {
|
|
ast::Type replType = replExpr->getType();
|
|
|
|
// Check that replExpr is an Operation, Value, or ValueRange.
|
|
if (replType.isa<ast::OperationType>()) {
|
|
if (shouldConvertOpToValues)
|
|
replExpr = convertOpToValue(replExpr);
|
|
continue;
|
|
}
|
|
|
|
if (replType != valueTy && replType != valueRangeTy) {
|
|
return emitError(replExpr->getLoc(),
|
|
llvm::formatv("expected `Op`, `Value` or `ValueRange` "
|
|
"expression, but got `{0}`",
|
|
replType));
|
|
}
|
|
}
|
|
|
|
return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
|
|
}
|
|
|
|
FailureOr<ast::RewriteStmt *>
|
|
Parser::createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
|
|
ast::CompoundStmt *rewriteBody) {
|
|
// Check that root is an Operation.
|
|
ast::Type rootType = rootOp->getType();
|
|
if (!rootType.isa<ast::OperationType>()) {
|
|
return emitError(
|
|
rootOp->getLoc(),
|
|
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
|
|
}
|
|
|
|
return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
|
|
llvm::SourceMgr &sourceMgr) {
|
|
Parser parser(ctx, sourceMgr);
|
|
return parser.parseModule();
|
|
}
|