From 9128a4aa87b42e1737f0be6f4fc2b44b4559b28e Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 28 Jul 2018 09:36:25 -0700 Subject: [PATCH] Finish parser/printer support for AffineMapOp, implement operand iterators on VariadicOperands, tidy up some code in the asmprinter, fill out more verification logic in for LoadOp. PiperOrigin-RevId: 206443020 --- mlir/include/mlir/IR/OpDefinition.h | 22 +++++ mlir/include/mlir/IR/OpImplementation.h | 14 +++- mlir/include/mlir/IR/StandardOps.h | 66 ++++++++------- mlir/lib/IR/AsmPrinter.cpp | 24 +++--- mlir/lib/IR/StandardOps.cpp | 104 ++++++++++++++++++------ mlir/lib/Parser/Parser.cpp | 13 ++- mlir/test/IR/core-ops.mlir | 23 ++++-- 7 files changed, 191 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 95d95ddaadff..3ec2c5e9596d 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -318,6 +318,28 @@ public: void setOperand(unsigned i, SSAValue *value) { this->getOperation()->setOperand(i, value); } + + // Support non-const operand iteration. + using operand_iterator = Operation::operand_iterator; + operand_iterator operand_begin() { + return this->getOperation()->operand_begin(); + } + operand_iterator operand_end() { return this->getOperation()->operand_end(); } + llvm::iterator_range getOperands() { + return this->getOperands(); + } + + // Support const operand iteration. + using const_operand_iterator = Operation::const_operand_iterator; + const_operand_iterator operand_begin() const { + return this->getOperation()->operand_begin(); + } + const_operand_iterator operand_end() const { + return this->getOperation()->operand_end(); + } + llvm::iterator_range getOperands() const { + return this->getOperands(); + } }; /// This class provides return value APIs for ops that are known to have a diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 1ef4846bfcab..fc154220e23c 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -49,7 +49,12 @@ public: /// Print a comma separated list of operands. template void printOperands(const ContainerType &container) { - auto it = container.begin(), end = container.end(); + printOperands(container.begin(), container.end()); + } + + /// Print a comma separated list of operands. + template + void printOperands(IteratorType it, IteratorType end) { if (it == end) return; printOperand(*it); @@ -198,9 +203,16 @@ public: /// These are the supported delimeters around operand lists, used by /// parseOperandList. enum Delimeter { + /// Zero or more operands with no delimeters. NoDelimeter, + /// Parens surrounding zero or more operands. ParenDelimeter, + /// Square brackets surrounding zero or more operands. SquareDelimeter, + /// Parens supporting zero or more operands, or nothing. + OptionalParenDelimeter, + /// Square brackets supporting zero or more ops, or nothing. + OptionalSquareDelimeter, }; /// Parse zero or more SSA comma-separated operand references with a specified diff --git a/mlir/include/mlir/IR/StandardOps.h b/mlir/include/mlir/IR/StandardOps.h index 73494e4c550f..2a4f9c6e775f 100644 --- a/mlir/include/mlir/IR/StandardOps.h +++ b/mlir/include/mlir/IR/StandardOps.h @@ -51,6 +51,42 @@ private: explicit AddFOp(const Operation *state) : Base(state) {} }; +/// The "affine_apply" operation applies an affine map to a list of operands, +/// yielding a list of results. The operand and result list sizes must be the +/// same. All operands and results are of type 'AffineInt'. This operation +/// requires a single affine map attribute named "map". +/// For example: +/// +/// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } : +/// (affineint) -> (affineint) +/// +/// equivalently: +/// +/// #map42 = (d0)->(d0+1) +/// %y = affine_apply #map42(%x) +/// +class AffineApplyOp + : public OpImpl::Base { +public: + // Returns the affine map to be applied by this operation. + AffineMap *getAffineMap() const { + return getAttrOfType("map")->getValue(); + } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static StringRef getOperationName() { return "affine_apply"; } + + // Hooks to customize behavior of this op. + static OpAsmParserResult parse(OpAsmParser *parser); + void print(OpAsmPrinter *p) const; + const char *verify() const; + +private: + friend class Operation; + explicit AffineApplyOp(const Operation *state) : Base(state) {} +}; + /// The "constant" operation requires a single attribute named "value". /// It returns its value as an SSA value. For example: /// @@ -152,36 +188,6 @@ private: explicit LoadOp(const Operation *state) : Base(state) {} }; -/// The "affine_apply" operation applies an affine map to a list of operands, -/// yielding a list of results. The operand and result list sizes must be the -/// same. All operands and results are of type 'AffineInt'. This operation -/// requires a single affine map attribute named "map". -/// For example: -/// -/// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } : -/// (affineint) -> (affineint) -/// -class AffineApplyOp - : public OpImpl::Base { -public: - // Returns the affine map to be applied by this operation. - AffineMap *getAffineMap() const { - return getAttrOfType("map")->getValue(); - } - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static StringRef getOperationName() { return "affine_apply"; } - - // Hooks to customize behavior of this op. - const char *verify() const; - void print(OpAsmPrinter *p) const; - -private: - friend class Operation; - explicit AffineApplyOp(const Operation *state) : Base(state) {} -}; - /// Install the standard operations in the specified operation set. void registerStandardOperations(OperationSet &opSet); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 28149fc73004..a3ea465e7c17 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -93,29 +93,25 @@ private: // TODO Support visiting other types/instructions when implemented. void ModuleState::visitType(const Type *type) { - if (type->getKind() == Type::Kind::Function) { + if (auto *funcType = dyn_cast(type)) { // Visit input and result types for functions. - auto *funcType = cast(type); - for (auto *input : funcType->getInputs()) { + for (auto *input : funcType->getInputs()) visitType(input); - } - for (auto *result : funcType->getResults()) { + for (auto *result : funcType->getResults()) visitType(result); - } - } else if (type->getKind() == Type::Kind::MemRef) { + } else if (auto *memref = dyn_cast(type)) { // Visit affine maps in memref type. - auto *memref = cast(type); - for (AffineMap *map : memref->getAffineMaps()) { + for (auto *map : memref->getAffineMaps()) { recordAffineMapReference(map); } } } void ModuleState::visitAttribute(const Attribute *attr) { - if (isa(attr)) { - recordAffineMapReference(cast(attr)->getValue()); - } else if (isa(attr)) { - for (auto elt : cast(attr)->getValue()) { + if (auto *mapAttr = dyn_cast(attr)) { + recordAffineMapReference(mapAttr->getValue()); + } else if (auto *array = dyn_cast(attr)) { + for (auto elt : array->getValue()) { visitAttribute(elt); } } @@ -535,7 +531,7 @@ public: ModulePrinter::printAttribute(attr); } void printAffineMap(const AffineMap *map) { - return ModulePrinter::printAffineMap(map); + return ModulePrinter::printAffineMapReference(map); } void printAffineExpr(const AffineExpr *expr) { return ModulePrinter::printAffineExpr(expr); diff --git a/mlir/lib/IR/StandardOps.cpp b/mlir/lib/IR/StandardOps.cpp index 104ff2efe5dc..55c0c252899b 100644 --- a/mlir/lib/IR/StandardOps.cpp +++ b/mlir/lib/IR/StandardOps.cpp @@ -52,6 +52,75 @@ const char *AddFOp::verify() const { return nullptr; } +OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) { + SmallVector opInfos; + SmallVector operands; + + auto &builder = parser->getBuilder(); + auto *affineIntTy = builder.getAffineIntType(); + + AffineMapAttr *mapAttr; + if (parser->parseAttribute(mapAttr) || + parser->parseOperandList(opInfos, -1, + OpAsmParser::Delimeter::ParenDelimeter)) + return {}; + unsigned numDims = opInfos.size(); + + if (parser->parseOperandList( + opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) || + parser->resolveOperands(opInfos, affineIntTy, operands)) + return {}; + + auto *map = mapAttr->getValue(); + if (map->getNumDims() != numDims || + numDims + map->getNumSymbols() != opInfos.size()) { + parser->emitError(parser->getNameLoc(), + "dimension or symbol index mismatch"); + return {}; + } + + SmallVector resultTypes(map->getNumResults(), affineIntTy); + return OpAsmParserResult( + operands, resultTypes, + NamedAttribute(builder.getIdentifier("map"), mapAttr)); +} + +void AffineApplyOp::print(OpAsmPrinter *p) const { + auto *map = getAffineMap(); + *p << "affine_apply " << *map; + + auto opit = operand_begin(); + *p << '('; + p->printOperands(opit, opit + map->getNumDims()); + *p << ')'; + + if (map->getNumSymbols()) { + *p << '['; + p->printOperands(opit + map->getNumDims(), operand_end()); + *p << ']'; + } +} + +const char *AffineApplyOp::verify() const { + // Check that affine map attribute was specified. + auto *affineMapAttr = getAttrOfType("map"); + if (!affineMapAttr) + return "requires an affine map."; + + // Check input and output dimensions match. + auto *map = affineMapAttr->getValue(); + + // Verify that operand count matches affine map dimension and symbol count. + if (getNumOperands() != map->getNumDims() + map->getNumSymbols()) + return "operand count and affine map dimension and symbol count must match"; + + // Verify that result count matches affine map result count. + if (getNumResults() != map->getNumResults()) + return "result count and affine map result count must match"; + + return nullptr; +} + /// The constant op requires an attribute, and furthermore requires that it /// matches the return type. const char *ConstantOp::verify() const { @@ -151,37 +220,26 @@ OpAsmParserResult LoadOp::parse(OpAsmParser *parser) { } const char *LoadOp::verify() const { - // TODO: Check load - return nullptr; -} + if (getNumOperands() == 0) + return "expected a memref to load from"; -void AffineApplyOp::print(OpAsmPrinter *p) const { - // TODO: Print operands etc. - *p << "affine_apply map: " << *getAffineMap(); -} + auto *memRefType = dyn_cast(getMemRef()->getType()); + if (!memRefType) + return "first operand must be a memref"; -const char *AffineApplyOp::verify() const { - // Check that affine map attribute was specified - auto affineMapAttr = getAttrOfType("map"); - if (!affineMapAttr) - return "requires an affine map."; + for (auto *idx : getIndices()) + if (!idx->getType()->isAffineInt()) + return "index to load must have 'affineint' type"; - // Check input and output dimensions match. - auto *map = affineMapAttr->getValue(); - - // Verify that operand count matches affine map dimension and symbol count. - if (getNumOperands() != map->getNumDims() + map->getNumSymbols()) - return "operand count and affine map dimension and symbol count must match"; - - // Verify that result count matches affine map result count. - if (getNumResults() != map->getNumResults()) - return "result count and affine map result count must match"; + // TODO: Verify we have the right number of indices. + // TODO: in MLFunction verify that the indices are parameters, IV's, or the + // result of an affine_apply. return nullptr; } /// Install the standard operations in the specified operation set. void mlir::registerStandardOperations(OperationSet &opSet) { - opSet.addOperations( + opSet.addOperations( /*prefix=*/""); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index b2c9bb9710e9..654db988dfc5 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1674,10 +1674,18 @@ public: switch (delimeter) { case Delimeter::NoDelimeter: break; + case Delimeter::OptionalParenDelimeter: + if (parser.getToken().isNot(Token::l_paren)) + return false; + LLVM_FALLTHROUGH; case Delimeter::ParenDelimeter: if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) return true; break; + case Delimeter::OptionalSquareDelimeter: + if (parser.getToken().isNot(Token::l_square)) + return false; + LLVM_FALLTHROUGH; case Delimeter::SquareDelimeter: if (parser.parseToken(Token::l_square, "expected '[' in operand list")) return true; @@ -1694,14 +1702,17 @@ public: } while (parser.consumeIf(Token::comma)); } - // Handle delimeters. + // Handle delimeters. If we reach here, the optional delimiters were + // present, so we need to parse their closing one. switch (delimeter) { case Delimeter::NoDelimeter: break; + case Delimeter::OptionalParenDelimeter: case Delimeter::ParenDelimeter: if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) return true; break; + case Delimeter::OptionalSquareDelimeter: case Delimeter::SquareDelimeter: if (parser.parseToken(Token::r_square, "expected ']' in operand list")) return true; diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 37d6ec13f760..c5ecb8e071ca 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -1,8 +1,12 @@ // RUN: %S/../../mlir-opt %s -o - | FileCheck %s -// CHECK: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), (d1 + 2)) +// CHECK: #map0 = (d0) -> ((d0 + 1)) + +// CHECK: #map1 = (d0, d1) -> ((d0 + 1), (d1 + 2)) #map5 = (d0, d1) -> (d0 + 1, d1 + 2) -#id2 = (i,j)->(i,j) + +// CHECK: #map2 = (d0, d1) [s0, s1] -> ((d0 + s1), (d1 + s0)) +// CHECK: #map3 = () [s0] -> ((s0 + 1)) // CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) { cfgfunc @cfgfunc_with_ops(f32) { @@ -47,13 +51,20 @@ bb0: %i = "constant"() {value: 0} : () -> affineint %j = "constant"() {value: 1} : () -> affineint - // CHECK: affine_apply map: (d0) -> ((d0 + 1)) - %x = "affine_apply" (%i) { map: (d0) -> (d0 + 1) } : + // CHECK: affine_apply #map0(%0) + %a = "affine_apply" (%i) { map: (d0) -> (d0 + 1) } : (affineint) -> (affineint) - // CHECK: affine_apply map: (d0, d1) -> ((d0 + 1), (d1 + 2)) - %y = "affine_apply" (%i, %j) { map: #map5 } : + // CHECK: affine_apply #map1(%0, %1) + %b = "affine_apply" (%i, %j) { map: #map5 } : (affineint, affineint) -> (affineint, affineint) + + // CHECK: affine_apply #map2(%0, %1)[%1, %0] + %c = affine_apply (i,j)[m,n] -> (i+n, j+m)(%i, %j)[%j, %i] + + // CHECK: affine_apply #map3()[%0] + %d = affine_apply ()[x] -> (x+1)()[%i] + return }