Finish parser/printer support for AffineMapOp, implement operand iterators on

VariadicOperands, tidy up some code in the asmprinter, fill out more
verification logic in for LoadOp.

PiperOrigin-RevId: 206443020
This commit is contained in:
Chris Lattner 2018-07-28 09:36:25 -07:00 committed by jpienaar
parent c77f39f55c
commit 9128a4aa87
7 changed files with 191 additions and 75 deletions

View File

@ -318,6 +318,28 @@ public:
void setOperand(unsigned i, SSAValue *value) {
this->getOperation()->setOperand(i, value);
}
// Support non-const operand iteration.
using operand_iterator = Operation::operand_iterator;
operand_iterator operand_begin() {
return this->getOperation()->operand_begin();
}
operand_iterator operand_end() { return this->getOperation()->operand_end(); }
llvm::iterator_range<operand_iterator> getOperands() {
return this->getOperands();
}
// Support const operand iteration.
using const_operand_iterator = Operation::const_operand_iterator;
const_operand_iterator operand_begin() const {
return this->getOperation()->operand_begin();
}
const_operand_iterator operand_end() const {
return this->getOperation()->operand_end();
}
llvm::iterator_range<const_operand_iterator> getOperands() const {
return this->getOperands();
}
};
/// This class provides return value APIs for ops that are known to have a

View File

@ -49,7 +49,12 @@ public:
/// Print a comma separated list of operands.
template <typename ContainerType>
void printOperands(const ContainerType &container) {
auto it = container.begin(), end = container.end();
printOperands(container.begin(), container.end());
}
/// Print a comma separated list of operands.
template <typename IteratorType>
void printOperands(IteratorType it, IteratorType end) {
if (it == end)
return;
printOperand(*it);
@ -198,9 +203,16 @@ public:
/// These are the supported delimeters around operand lists, used by
/// parseOperandList.
enum Delimeter {
/// Zero or more operands with no delimeters.
NoDelimeter,
/// Parens surrounding zero or more operands.
ParenDelimeter,
/// Square brackets surrounding zero or more operands.
SquareDelimeter,
/// Parens supporting zero or more operands, or nothing.
OptionalParenDelimeter,
/// Square brackets supporting zero or more ops, or nothing.
OptionalSquareDelimeter,
};
/// Parse zero or more SSA comma-separated operand references with a specified

View File

@ -51,6 +51,42 @@ private:
explicit AddFOp(const Operation *state) : Base(state) {}
};
/// The "affine_apply" operation applies an affine map to a list of operands,
/// yielding a list of results. The operand and result list sizes must be the
/// same. All operands and results are of type 'AffineInt'. This operation
/// requires a single affine map attribute named "map".
/// For example:
///
/// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
/// (affineint) -> (affineint)
///
/// equivalently:
///
/// #map42 = (d0)->(d0+1)
/// %y = affine_apply #map42(%x)
///
class AffineApplyOp
: public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands,
OpImpl::VariadicResults> {
public:
// Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map")->getValue();
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "affine_apply"; }
// Hooks to customize behavior of this op.
static OpAsmParserResult parse(OpAsmParser *parser);
void print(OpAsmPrinter *p) const;
const char *verify() const;
private:
friend class Operation;
explicit AffineApplyOp(const Operation *state) : Base(state) {}
};
/// The "constant" operation requires a single attribute named "value".
/// It returns its value as an SSA value. For example:
///
@ -152,36 +188,6 @@ private:
explicit LoadOp(const Operation *state) : Base(state) {}
};
/// The "affine_apply" operation applies an affine map to a list of operands,
/// yielding a list of results. The operand and result list sizes must be the
/// same. All operands and results are of type 'AffineInt'. This operation
/// requires a single affine map attribute named "map".
/// For example:
///
/// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
/// (affineint) -> (affineint)
///
class AffineApplyOp
: public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands,
OpImpl::VariadicResults> {
public:
// Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map")->getValue();
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "affine_apply"; }
// Hooks to customize behavior of this op.
const char *verify() const;
void print(OpAsmPrinter *p) const;
private:
friend class Operation;
explicit AffineApplyOp(const Operation *state) : Base(state) {}
};
/// Install the standard operations in the specified operation set.
void registerStandardOperations(OperationSet &opSet);

View File

@ -93,29 +93,25 @@ private:
// TODO Support visiting other types/instructions when implemented.
void ModuleState::visitType(const Type *type) {
if (type->getKind() == Type::Kind::Function) {
if (auto *funcType = dyn_cast<FunctionType>(type)) {
// Visit input and result types for functions.
auto *funcType = cast<FunctionType>(type);
for (auto *input : funcType->getInputs()) {
for (auto *input : funcType->getInputs())
visitType(input);
}
for (auto *result : funcType->getResults()) {
for (auto *result : funcType->getResults())
visitType(result);
}
} else if (type->getKind() == Type::Kind::MemRef) {
} else if (auto *memref = dyn_cast<MemRefType>(type)) {
// Visit affine maps in memref type.
auto *memref = cast<MemRefType>(type);
for (AffineMap *map : memref->getAffineMaps()) {
for (auto *map : memref->getAffineMaps()) {
recordAffineMapReference(map);
}
}
}
void ModuleState::visitAttribute(const Attribute *attr) {
if (isa<AffineMapAttr>(attr)) {
recordAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
} else if (isa<ArrayAttr>(attr)) {
for (auto elt : cast<ArrayAttr>(attr)->getValue()) {
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) {
recordAffineMapReference(mapAttr->getValue());
} else if (auto *array = dyn_cast<ArrayAttr>(attr)) {
for (auto elt : array->getValue()) {
visitAttribute(elt);
}
}
@ -535,7 +531,7 @@ public:
ModulePrinter::printAttribute(attr);
}
void printAffineMap(const AffineMap *map) {
return ModulePrinter::printAffineMap(map);
return ModulePrinter::printAffineMapReference(map);
}
void printAffineExpr(const AffineExpr *expr) {
return ModulePrinter::printAffineExpr(expr);

View File

@ -52,6 +52,75 @@ const char *AddFOp::verify() const {
return nullptr;
}
OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
SmallVector<OpAsmParser::OperandType, 2> opInfos;
SmallVector<SSAValue *, 4> operands;
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getAffineIntType();
AffineMapAttr *mapAttr;
if (parser->parseAttribute(mapAttr) ||
parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimeter::ParenDelimeter))
return {};
unsigned numDims = opInfos.size();
if (parser->parseOperandList(
opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
return {};
auto *map = mapAttr->getValue();
if (map->getNumDims() != numDims ||
numDims + map->getNumSymbols() != opInfos.size()) {
parser->emitError(parser->getNameLoc(),
"dimension or symbol index mismatch");
return {};
}
SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
return OpAsmParserResult(
operands, resultTypes,
NamedAttribute(builder.getIdentifier("map"), mapAttr));
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto *map = getAffineMap();
*p << "affine_apply " << *map;
auto opit = operand_begin();
*p << '(';
p->printOperands(opit, opit + map->getNumDims());
*p << ')';
if (map->getNumSymbols()) {
*p << '[';
p->printOperands(opit + map->getNumDims(), operand_end());
*p << ']';
}
}
const char *AffineApplyOp::verify() const {
// Check that affine map attribute was specified.
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
if (!affineMapAttr)
return "requires an affine map.";
// Check input and output dimensions match.
auto *map = affineMapAttr->getValue();
// Verify that operand count matches affine map dimension and symbol count.
if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
return "operand count and affine map dimension and symbol count must match";
// Verify that result count matches affine map result count.
if (getNumResults() != map->getNumResults())
return "result count and affine map result count must match";
return nullptr;
}
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
const char *ConstantOp::verify() const {
@ -151,37 +220,26 @@ OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
}
const char *LoadOp::verify() const {
// TODO: Check load
return nullptr;
}
if (getNumOperands() == 0)
return "expected a memref to load from";
void AffineApplyOp::print(OpAsmPrinter *p) const {
// TODO: Print operands etc.
*p << "affine_apply map: " << *getAffineMap();
}
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
if (!memRefType)
return "first operand must be a memref";
const char *AffineApplyOp::verify() const {
// Check that affine map attribute was specified
auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
if (!affineMapAttr)
return "requires an affine map.";
for (auto *idx : getIndices())
if (!idx->getType()->isAffineInt())
return "index to load must have 'affineint' type";
// Check input and output dimensions match.
auto *map = affineMapAttr->getValue();
// Verify that operand count matches affine map dimension and symbol count.
if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
return "operand count and affine map dimension and symbol count must match";
// Verify that result count matches affine map result count.
if (getNumResults() != map->getNumResults())
return "result count and affine map result count must match";
// TODO: Verify we have the right number of indices.
// TODO: in MLFunction verify that the indices are parameters, IV's, or the
// result of an affine_apply.
return nullptr;
}
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, ConstantOp, DimOp, LoadOp, AffineApplyOp>(
opSet.addOperations<AddFOp, AffineApplyOp, ConstantOp, DimOp, LoadOp>(
/*prefix=*/"");
}

View File

@ -1674,10 +1674,18 @@ public:
switch (delimeter) {
case Delimeter::NoDelimeter:
break;
case Delimeter::OptionalParenDelimeter:
if (parser.getToken().isNot(Token::l_paren))
return false;
LLVM_FALLTHROUGH;
case Delimeter::ParenDelimeter:
if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
return true;
break;
case Delimeter::OptionalSquareDelimeter:
if (parser.getToken().isNot(Token::l_square))
return false;
LLVM_FALLTHROUGH;
case Delimeter::SquareDelimeter:
if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
return true;
@ -1694,14 +1702,17 @@ public:
} while (parser.consumeIf(Token::comma));
}
// Handle delimeters.
// Handle delimeters. If we reach here, the optional delimiters were
// present, so we need to parse their closing one.
switch (delimeter) {
case Delimeter::NoDelimeter:
break;
case Delimeter::OptionalParenDelimeter:
case Delimeter::ParenDelimeter:
if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
return true;
break;
case Delimeter::OptionalSquareDelimeter:
case Delimeter::SquareDelimeter:
if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
return true;

View File

@ -1,8 +1,12 @@
// RUN: %S/../../mlir-opt %s -o - | FileCheck %s
// CHECK: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), (d1 + 2))
// CHECK: #map0 = (d0) -> ((d0 + 1))
// CHECK: #map1 = (d0, d1) -> ((d0 + 1), (d1 + 2))
#map5 = (d0, d1) -> (d0 + 1, d1 + 2)
#id2 = (i,j)->(i,j)
// CHECK: #map2 = (d0, d1) [s0, s1] -> ((d0 + s1), (d1 + s0))
// CHECK: #map3 = () [s0] -> ((s0 + 1))
// CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) {
cfgfunc @cfgfunc_with_ops(f32) {
@ -47,13 +51,20 @@ bb0:
%i = "constant"() {value: 0} : () -> affineint
%j = "constant"() {value: 1} : () -> affineint
// CHECK: affine_apply map: (d0) -> ((d0 + 1))
%x = "affine_apply" (%i) { map: (d0) -> (d0 + 1) } :
// CHECK: affine_apply #map0(%0)
%a = "affine_apply" (%i) { map: (d0) -> (d0 + 1) } :
(affineint) -> (affineint)
// CHECK: affine_apply map: (d0, d1) -> ((d0 + 1), (d1 + 2))
%y = "affine_apply" (%i, %j) { map: #map5 } :
// CHECK: affine_apply #map1(%0, %1)
%b = "affine_apply" (%i, %j) { map: #map5 } :
(affineint, affineint) -> (affineint, affineint)
// CHECK: affine_apply #map2(%0, %1)[%1, %0]
%c = affine_apply (i,j)[m,n] -> (i+n, j+m)(%i, %j)[%j, %i]
// CHECK: affine_apply #map3()[%0]
%d = affine_apply ()[x] -> (x+1)()[%i]
return
}