Finish parser/printer support for AffineMapOp, implement operand iterators on
VariadicOperands, tidy up some code in the asmprinter, fill out more verification logic in for LoadOp. PiperOrigin-RevId: 206443020
This commit is contained in:
parent
c77f39f55c
commit
9128a4aa87
|
@ -318,6 +318,28 @@ public:
|
|||
void setOperand(unsigned i, SSAValue *value) {
|
||||
this->getOperation()->setOperand(i, value);
|
||||
}
|
||||
|
||||
// Support non-const operand iteration.
|
||||
using operand_iterator = Operation::operand_iterator;
|
||||
operand_iterator operand_begin() {
|
||||
return this->getOperation()->operand_begin();
|
||||
}
|
||||
operand_iterator operand_end() { return this->getOperation()->operand_end(); }
|
||||
llvm::iterator_range<operand_iterator> getOperands() {
|
||||
return this->getOperands();
|
||||
}
|
||||
|
||||
// Support const operand iteration.
|
||||
using const_operand_iterator = Operation::const_operand_iterator;
|
||||
const_operand_iterator operand_begin() const {
|
||||
return this->getOperation()->operand_begin();
|
||||
}
|
||||
const_operand_iterator operand_end() const {
|
||||
return this->getOperation()->operand_end();
|
||||
}
|
||||
llvm::iterator_range<const_operand_iterator> getOperands() const {
|
||||
return this->getOperands();
|
||||
}
|
||||
};
|
||||
|
||||
/// This class provides return value APIs for ops that are known to have a
|
||||
|
|
|
@ -49,7 +49,12 @@ public:
|
|||
/// Print a comma separated list of operands.
|
||||
template <typename ContainerType>
|
||||
void printOperands(const ContainerType &container) {
|
||||
auto it = container.begin(), end = container.end();
|
||||
printOperands(container.begin(), container.end());
|
||||
}
|
||||
|
||||
/// Print a comma separated list of operands.
|
||||
template <typename IteratorType>
|
||||
void printOperands(IteratorType it, IteratorType end) {
|
||||
if (it == end)
|
||||
return;
|
||||
printOperand(*it);
|
||||
|
@ -198,9 +203,16 @@ public:
|
|||
/// These are the supported delimeters around operand lists, used by
|
||||
/// parseOperandList.
|
||||
enum Delimeter {
|
||||
/// Zero or more operands with no delimeters.
|
||||
NoDelimeter,
|
||||
/// Parens surrounding zero or more operands.
|
||||
ParenDelimeter,
|
||||
/// Square brackets surrounding zero or more operands.
|
||||
SquareDelimeter,
|
||||
/// Parens supporting zero or more operands, or nothing.
|
||||
OptionalParenDelimeter,
|
||||
/// Square brackets supporting zero or more ops, or nothing.
|
||||
OptionalSquareDelimeter,
|
||||
};
|
||||
|
||||
/// Parse zero or more SSA comma-separated operand references with a specified
|
||||
|
|
|
@ -51,6 +51,42 @@ private:
|
|||
explicit AddFOp(const Operation *state) : Base(state) {}
|
||||
};
|
||||
|
||||
/// The "affine_apply" operation applies an affine map to a list of operands,
|
||||
/// yielding a list of results. The operand and result list sizes must be the
|
||||
/// same. All operands and results are of type 'AffineInt'. This operation
|
||||
/// requires a single affine map attribute named "map".
|
||||
/// For example:
|
||||
///
|
||||
/// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
|
||||
/// (affineint) -> (affineint)
|
||||
///
|
||||
/// equivalently:
|
||||
///
|
||||
/// #map42 = (d0)->(d0+1)
|
||||
/// %y = affine_apply #map42(%x)
|
||||
///
|
||||
class AffineApplyOp
|
||||
: public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands,
|
||||
OpImpl::VariadicResults> {
|
||||
public:
|
||||
// Returns the affine map to be applied by this operation.
|
||||
AffineMap *getAffineMap() const {
|
||||
return getAttrOfType<AffineMapAttr>("map")->getValue();
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static StringRef getOperationName() { return "affine_apply"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static OpAsmParserResult parse(OpAsmParser *parser);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
const char *verify() const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
explicit AffineApplyOp(const Operation *state) : Base(state) {}
|
||||
};
|
||||
|
||||
/// The "constant" operation requires a single attribute named "value".
|
||||
/// It returns its value as an SSA value. For example:
|
||||
///
|
||||
|
@ -152,36 +188,6 @@ private:
|
|||
explicit LoadOp(const Operation *state) : Base(state) {}
|
||||
};
|
||||
|
||||
/// The "affine_apply" operation applies an affine map to a list of operands,
|
||||
/// yielding a list of results. The operand and result list sizes must be the
|
||||
/// same. All operands and results are of type 'AffineInt'. This operation
|
||||
/// requires a single affine map attribute named "map".
|
||||
/// For example:
|
||||
///
|
||||
/// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
|
||||
/// (affineint) -> (affineint)
|
||||
///
|
||||
class AffineApplyOp
|
||||
: public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands,
|
||||
OpImpl::VariadicResults> {
|
||||
public:
|
||||
// Returns the affine map to be applied by this operation.
|
||||
AffineMap *getAffineMap() const {
|
||||
return getAttrOfType<AffineMapAttr>("map")->getValue();
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static StringRef getOperationName() { return "affine_apply"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
const char *verify() const;
|
||||
void print(OpAsmPrinter *p) const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
explicit AffineApplyOp(const Operation *state) : Base(state) {}
|
||||
};
|
||||
|
||||
/// Install the standard operations in the specified operation set.
|
||||
void registerStandardOperations(OperationSet &opSet);
|
||||
|
||||
|
|
|
@ -93,29 +93,25 @@ private:
|
|||
|
||||
// TODO Support visiting other types/instructions when implemented.
|
||||
void ModuleState::visitType(const Type *type) {
|
||||
if (type->getKind() == Type::Kind::Function) {
|
||||
if (auto *funcType = dyn_cast<FunctionType>(type)) {
|
||||
// Visit input and result types for functions.
|
||||
auto *funcType = cast<FunctionType>(type);
|
||||
for (auto *input : funcType->getInputs()) {
|
||||
for (auto *input : funcType->getInputs())
|
||||
visitType(input);
|
||||
}
|
||||
for (auto *result : funcType->getResults()) {
|
||||
for (auto *result : funcType->getResults())
|
||||
visitType(result);
|
||||
}
|
||||
} else if (type->getKind() == Type::Kind::MemRef) {
|
||||
} else if (auto *memref = dyn_cast<MemRefType>(type)) {
|
||||
// Visit affine maps in memref type.
|
||||
auto *memref = cast<MemRefType>(type);
|
||||
for (AffineMap *map : memref->getAffineMaps()) {
|
||||
for (auto *map : memref->getAffineMaps()) {
|
||||
recordAffineMapReference(map);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ModuleState::visitAttribute(const Attribute *attr) {
|
||||
if (isa<AffineMapAttr>(attr)) {
|
||||
recordAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
|
||||
} else if (isa<ArrayAttr>(attr)) {
|
||||
for (auto elt : cast<ArrayAttr>(attr)->getValue()) {
|
||||
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) {
|
||||
recordAffineMapReference(mapAttr->getValue());
|
||||
} else if (auto *array = dyn_cast<ArrayAttr>(attr)) {
|
||||
for (auto elt : array->getValue()) {
|
||||
visitAttribute(elt);
|
||||
}
|
||||
}
|
||||
|
@ -535,7 +531,7 @@ public:
|
|||
ModulePrinter::printAttribute(attr);
|
||||
}
|
||||
void printAffineMap(const AffineMap *map) {
|
||||
return ModulePrinter::printAffineMap(map);
|
||||
return ModulePrinter::printAffineMapReference(map);
|
||||
}
|
||||
void printAffineExpr(const AffineExpr *expr) {
|
||||
return ModulePrinter::printAffineExpr(expr);
|
||||
|
|
|
@ -52,6 +52,75 @@ const char *AddFOp::verify() const {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> opInfos;
|
||||
SmallVector<SSAValue *, 4> operands;
|
||||
|
||||
auto &builder = parser->getBuilder();
|
||||
auto *affineIntTy = builder.getAffineIntType();
|
||||
|
||||
AffineMapAttr *mapAttr;
|
||||
if (parser->parseAttribute(mapAttr) ||
|
||||
parser->parseOperandList(opInfos, -1,
|
||||
OpAsmParser::Delimeter::ParenDelimeter))
|
||||
return {};
|
||||
unsigned numDims = opInfos.size();
|
||||
|
||||
if (parser->parseOperandList(
|
||||
opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
|
||||
parser->resolveOperands(opInfos, affineIntTy, operands))
|
||||
return {};
|
||||
|
||||
auto *map = mapAttr->getValue();
|
||||
if (map->getNumDims() != numDims ||
|
||||
numDims + map->getNumSymbols() != opInfos.size()) {
|
||||
parser->emitError(parser->getNameLoc(),
|
||||
"dimension or symbol index mismatch");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
|
||||
return OpAsmParserResult(
|
||||
operands, resultTypes,
|
||||
NamedAttribute(builder.getIdentifier("map"), mapAttr));
|
||||
}
|
||||
|
||||
void AffineApplyOp::print(OpAsmPrinter *p) const {
|
||||
auto *map = getAffineMap();
|
||||
*p << "affine_apply " << *map;
|
||||
|
||||
auto opit = operand_begin();
|
||||
*p << '(';
|
||||
p->printOperands(opit, opit + map->getNumDims());
|
||||
*p << ')';
|
||||
|
||||
if (map->getNumSymbols()) {
|
||||
*p << '[';
|
||||
p->printOperands(opit + map->getNumDims(), operand_end());
|
||||
*p << ']';
|
||||
}
|
||||
}
|
||||
|
||||
const char *AffineApplyOp::verify() const {
|
||||
// Check that affine map attribute was specified.
|
||||
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
|
||||
if (!affineMapAttr)
|
||||
return "requires an affine map.";
|
||||
|
||||
// Check input and output dimensions match.
|
||||
auto *map = affineMapAttr->getValue();
|
||||
|
||||
// Verify that operand count matches affine map dimension and symbol count.
|
||||
if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
|
||||
return "operand count and affine map dimension and symbol count must match";
|
||||
|
||||
// Verify that result count matches affine map result count.
|
||||
if (getNumResults() != map->getNumResults())
|
||||
return "result count and affine map result count must match";
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// The constant op requires an attribute, and furthermore requires that it
|
||||
/// matches the return type.
|
||||
const char *ConstantOp::verify() const {
|
||||
|
@ -151,37 +220,26 @@ OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
|
|||
}
|
||||
|
||||
const char *LoadOp::verify() const {
|
||||
// TODO: Check load
|
||||
return nullptr;
|
||||
}
|
||||
if (getNumOperands() == 0)
|
||||
return "expected a memref to load from";
|
||||
|
||||
void AffineApplyOp::print(OpAsmPrinter *p) const {
|
||||
// TODO: Print operands etc.
|
||||
*p << "affine_apply map: " << *getAffineMap();
|
||||
}
|
||||
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
||||
if (!memRefType)
|
||||
return "first operand must be a memref";
|
||||
|
||||
const char *AffineApplyOp::verify() const {
|
||||
// Check that affine map attribute was specified
|
||||
auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
|
||||
if (!affineMapAttr)
|
||||
return "requires an affine map.";
|
||||
for (auto *idx : getIndices())
|
||||
if (!idx->getType()->isAffineInt())
|
||||
return "index to load must have 'affineint' type";
|
||||
|
||||
// Check input and output dimensions match.
|
||||
auto *map = affineMapAttr->getValue();
|
||||
|
||||
// Verify that operand count matches affine map dimension and symbol count.
|
||||
if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
|
||||
return "operand count and affine map dimension and symbol count must match";
|
||||
|
||||
// Verify that result count matches affine map result count.
|
||||
if (getNumResults() != map->getNumResults())
|
||||
return "result count and affine map result count must match";
|
||||
// TODO: Verify we have the right number of indices.
|
||||
|
||||
// TODO: in MLFunction verify that the indices are parameters, IV's, or the
|
||||
// result of an affine_apply.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Install the standard operations in the specified operation set.
|
||||
void mlir::registerStandardOperations(OperationSet &opSet) {
|
||||
opSet.addOperations<AddFOp, ConstantOp, DimOp, LoadOp, AffineApplyOp>(
|
||||
opSet.addOperations<AddFOp, AffineApplyOp, ConstantOp, DimOp, LoadOp>(
|
||||
/*prefix=*/"");
|
||||
}
|
||||
|
|
|
@ -1674,10 +1674,18 @@ public:
|
|||
switch (delimeter) {
|
||||
case Delimeter::NoDelimeter:
|
||||
break;
|
||||
case Delimeter::OptionalParenDelimeter:
|
||||
if (parser.getToken().isNot(Token::l_paren))
|
||||
return false;
|
||||
LLVM_FALLTHROUGH;
|
||||
case Delimeter::ParenDelimeter:
|
||||
if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
|
||||
return true;
|
||||
break;
|
||||
case Delimeter::OptionalSquareDelimeter:
|
||||
if (parser.getToken().isNot(Token::l_square))
|
||||
return false;
|
||||
LLVM_FALLTHROUGH;
|
||||
case Delimeter::SquareDelimeter:
|
||||
if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
|
||||
return true;
|
||||
|
@ -1694,14 +1702,17 @@ public:
|
|||
} while (parser.consumeIf(Token::comma));
|
||||
}
|
||||
|
||||
// Handle delimeters.
|
||||
// Handle delimeters. If we reach here, the optional delimiters were
|
||||
// present, so we need to parse their closing one.
|
||||
switch (delimeter) {
|
||||
case Delimeter::NoDelimeter:
|
||||
break;
|
||||
case Delimeter::OptionalParenDelimeter:
|
||||
case Delimeter::ParenDelimeter:
|
||||
if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
|
||||
return true;
|
||||
break;
|
||||
case Delimeter::OptionalSquareDelimeter:
|
||||
case Delimeter::SquareDelimeter:
|
||||
if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
|
||||
return true;
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
// RUN: %S/../../mlir-opt %s -o - | FileCheck %s
|
||||
|
||||
// CHECK: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), (d1 + 2))
|
||||
// CHECK: #map0 = (d0) -> ((d0 + 1))
|
||||
|
||||
// CHECK: #map1 = (d0, d1) -> ((d0 + 1), (d1 + 2))
|
||||
#map5 = (d0, d1) -> (d0 + 1, d1 + 2)
|
||||
#id2 = (i,j)->(i,j)
|
||||
|
||||
// CHECK: #map2 = (d0, d1) [s0, s1] -> ((d0 + s1), (d1 + s0))
|
||||
// CHECK: #map3 = () [s0] -> ((s0 + 1))
|
||||
|
||||
// CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) {
|
||||
cfgfunc @cfgfunc_with_ops(f32) {
|
||||
|
@ -47,13 +51,20 @@ bb0:
|
|||
%i = "constant"() {value: 0} : () -> affineint
|
||||
%j = "constant"() {value: 1} : () -> affineint
|
||||
|
||||
// CHECK: affine_apply map: (d0) -> ((d0 + 1))
|
||||
%x = "affine_apply" (%i) { map: (d0) -> (d0 + 1) } :
|
||||
// CHECK: affine_apply #map0(%0)
|
||||
%a = "affine_apply" (%i) { map: (d0) -> (d0 + 1) } :
|
||||
(affineint) -> (affineint)
|
||||
|
||||
// CHECK: affine_apply map: (d0, d1) -> ((d0 + 1), (d1 + 2))
|
||||
%y = "affine_apply" (%i, %j) { map: #map5 } :
|
||||
// CHECK: affine_apply #map1(%0, %1)
|
||||
%b = "affine_apply" (%i, %j) { map: #map5 } :
|
||||
(affineint, affineint) -> (affineint, affineint)
|
||||
|
||||
// CHECK: affine_apply #map2(%0, %1)[%1, %0]
|
||||
%c = affine_apply (i,j)[m,n] -> (i+n, j+m)(%i, %j)[%j, %i]
|
||||
|
||||
// CHECK: affine_apply #map3()[%0]
|
||||
%d = affine_apply ()[x] -> (x+1)()[%i]
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue