Implement support for branch instruction operands.

PiperOrigin-RevId: 205666777
This commit is contained in:
Chris Lattner 2018-07-23 08:42:19 -07:00 committed by jpienaar
parent 3de07e5c53
commit 21ede32ff5
12 changed files with 217 additions and 26 deletions

View File

@ -59,9 +59,12 @@ public:
reverse_args_iterator args_rend() const { return getArguments().rend(); }
bool args_empty() const { return arguments.empty(); }
/// Add one value to the operand list.
BBArgument *addArgument(Type *type);
llvm::iterator_range<BBArgListType::iterator>
addArguments(ArrayRef<Type *> types);
/// Add one argument to the argument list for each type specified in the list.
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type *> types);
unsigned getNumArguments() const { return arguments.size(); }
BBArgument *getArgument(unsigned i) { return arguments[i]; }

View File

@ -186,7 +186,7 @@ public:
return op;
}
// Creates for statement. When step is not specified, it is set to 1.
// Creates for statement. When step is not specified, it is set to 1.
ForStmt *createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
AffineConstantExpr *step = nullptr);

View File

@ -260,7 +260,31 @@ public:
return dest;
}
// TODO: need to take operands to specify BB arguments
unsigned getNumOperands() const { return operands.size(); }
// TODO: Add a getOperands() custom sequence that provides a value projection
// of the operand list.
CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
const CFGValue *getOperand(unsigned idx) const {
return getInstOperand(idx).get();
}
ArrayRef<InstOperand> getInstOperands() const { return operands; }
MutableArrayRef<InstOperand> getInstOperands() { return operands; }
InstOperand &getInstOperand(unsigned idx) { return operands[idx]; }
const InstOperand &getInstOperand(unsigned idx) const {
return operands[idx];
}
/// Add one value to the operand list.
void addOperand(CFGValue *value);
/// Add a list of values to the operand list.
void addOperands(ArrayRef<CFGValue *> values);
/// Erase a specific argument from the arg list.
// TODO: void eraseArgument(int Index);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Instruction *inst) {
@ -272,6 +296,7 @@ private:
: TerminatorInst(Kind::Branch), dest(dest) {}
BasicBlock *dest;
std::vector<InstOperand> operands;
};

View File

@ -37,6 +37,8 @@ public:
// FIXME: wrong representation and API.
// TODO(someone): This should switch to llvm::iplist<Function>.
// TODO(someone): we also need a symbol table for function names +
// autorenaming like LLVM does.
std::vector<Function*> functionList;
/// Perform (potentially expensive) checks of invariants, used to detect

View File

@ -61,6 +61,17 @@ public:
/// of the SSA machinery.
SSAOperand *getNextOperandUsingThisValue() { return nextUse; }
/// We support a move constructor so SSAOperands can be in vectors, but this
/// shouldn't be used by general clients.
SSAOperand(SSAOperand &&other) {
other.removeFromCurrent();
value = other.value;
other.value = nullptr;
nextUse = nullptr;
back = nullptr;
insertIntoCurrent();
}
private:
/// The value used as this operand. This can be null when in a
/// "dropAllUses" state.
@ -116,6 +127,11 @@ public:
/// Return which operand this is in the operand list of the User.
// TODO: unsigned getOperandNumber() const;
/// We support a move constructor so SSAOperands can be in vectors, but this
/// shouldn't be used by general clients.
SSAOperandImpl(SSAOperandImpl &&other)
: SSAOperand(std::move(other)), owner(other.owner) {}
private:
/// The owner of this operand.
SSAOwnerTy *const owner;

View File

@ -706,9 +706,23 @@ void CFGFunctionPrinter::print(const Instruction *inst) {
void CFGFunctionPrinter::print(const OperationInst *inst) {
printOperation(inst);
}
void CFGFunctionPrinter::print(const BranchInst *inst) {
os << " br bb" << getBBID(inst->getDest());
if (inst->getNumOperands() != 0) {
os << '(';
// TODO: Use getOperands() when we have it.
interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
printValueID(operand.get());
});
os << ") : ";
interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
ModulePrinter::print(operand.get()->getType());
});
}
}
void CFGFunctionPrinter::print(const ReturnInst *inst) {
os << " return";

View File

@ -35,6 +35,31 @@ void BasicBlock::eraseFromFunction() {
getFunction()->getBlocks().erase(this);
}
//===----------------------------------------------------------------------===//
// Argument list management.
//===----------------------------------------------------------------------===//
BBArgument *BasicBlock::addArgument(Type *type) {
auto *arg = new BBArgument(type, this);
arguments.push_back(arg);
return arg;
}
/// Add one argument to the argument list for each type specified in the list.
auto BasicBlock::addArguments(ArrayRef<Type *> types)
-> llvm::iterator_range<args_iterator> {
arguments.reserve(arguments.size() + types.size());
auto initialSize = arguments.size();
for (auto *type : types) {
addArgument(type);
}
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
}
//===----------------------------------------------------------------------===//
// Terminator management
//===----------------------------------------------------------------------===//
void BasicBlock::setTerminator(TerminatorInst *inst) {
// If we already had a terminator, abandon it.
if (terminator)
@ -46,6 +71,10 @@ void BasicBlock::setTerminator(TerminatorInst *inst) {
inst->block = this;
}
//===----------------------------------------------------------------------===//
// ilist_traits for BasicBlock
//===----------------------------------------------------------------------===//
mlir::CFGFunction *
llvm::ilist_traits<::mlir::BasicBlock>::getContainingFunction() {
size_t Offset(
@ -86,17 +115,3 @@ transferNodesFromList(ilist_traits<BasicBlock> &otherList,
for (; first != last; ++first)
first->function = curParent;
}
BBArgument *BasicBlock::addArgument(Type *type) {
arguments.push_back(new BBArgument(type, this));
return arguments.back();
}
llvm::iterator_range<BasicBlock::BBArgListType::iterator>
BasicBlock::addArguments(ArrayRef<Type *> types) {
auto initial_size = arguments.size();
for (auto *type : types) {
addArgument(type);
}
return {arguments.data() + initial_size, arguments.data() + arguments.size()};
}

View File

@ -207,3 +207,15 @@ ReturnInst::~ReturnInst() {
for (auto &operand : getInstOperands())
operand.~InstOperand();
}
/// Add one value to the operand list.
void BranchInst::addOperand(CFGValue *value) {
operands.emplace_back(InstOperand(this, value));
}
/// Add a list of values to the operand list.
void BranchInst::addOperands(ArrayRef<CFGValue *> values) {
operands.reserve(operands.size() + values.size());
for (auto *value : values)
addOperand(value);
}

View File

@ -100,6 +100,7 @@ public:
bool verifyOperation(const OperationInst &inst);
bool verifyTerminator(const TerminatorInst &term);
bool verifyReturn(const ReturnInst &inst);
bool verifyBranch(const BranchInst &inst);
};
} // end anonymous namespace
@ -163,6 +164,9 @@ bool CFGFuncVerifier::verifyTerminator(const TerminatorInst &term) {
if (auto *ret = dyn_cast<ReturnInst>(&term))
return verifyReturn(*ret);
if (auto *br = dyn_cast<BranchInst>(&term))
return verifyBranch(*br);
return false;
}
@ -175,6 +179,31 @@ bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
Twine(results.size()),
inst);
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (inst.getOperand(i)->getType() != results[i])
return failure("type of return operand " + Twine(i) +
" doesn't match result function result type",
inst);
return false;
}
bool CFGFuncVerifier::verifyBranch(const BranchInst &inst) {
// Verify that the number of operands lines up with the number of BB arguments
// in the successor.
auto dest = inst.getDest();
if (inst.getNumOperands() != dest->getNumArguments())
return failure("branch has " + Twine(inst.getNumOperands()) +
" operands, but target block has " +
Twine(dest->getNumArguments()),
inst);
for (unsigned i = 0, e = inst.getNumOperands(); i != e; ++i)
if (inst.getOperand(i)->getType() != dest->getArgument(i)->getType())
return failure("type of branch operand " + Twine(i) +
" doesn't match target bb argument type",
inst);
return false;
}

View File

@ -162,6 +162,7 @@ public:
Type *parseMemRefType();
Type *parseFunctionType();
Type *parseType();
ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
ParseResult parseTypeList(SmallVectorImpl<Type*> &elements);
// Attribute parsing.
@ -516,12 +517,27 @@ Type *Parser::parseType() {
}
}
/// Parse a list of types without an enclosing parenthesis. The list must have
/// at least one member.
///
/// type-list-no-parens ::= type (`,` type)*
///
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
auto parseElt = [&]() -> ParseResult {
auto elt = parseType();
elements.push_back(elt);
return elt ? ParseSuccess : ParseFailure;
};
return parseCommaSeparatedList(parseElt);
}
/// Parse a "type list", which is a singular type, or a parenthesized list of
/// types.
///
/// type-list ::= type-list-parens | type
/// type-list-parens ::= `(` `)`
/// | `(` type (`,` type)* `)`
/// | `(` type-list-no-parens `)`
///
ParseResult Parser::parseTypeList(SmallVectorImpl<Type*> &elements) {
auto parseElt = [&]() -> ParseResult {
@ -1706,7 +1722,7 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
/// Parse the terminator instruction for a basic block.
///
/// terminator-stmt ::= `br` bb-id branch-use-list?
/// branch-use-list ::= `(` ssa-use-and-type-list? `)`
/// branch-use-list ::= `(` ssa-use-list `)` ':' type-list-no-parens
/// terminator-stmt ::=
/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
/// terminator-stmt ::= `return` ssa-use-and-type-list?
@ -1730,7 +1746,40 @@ TerminatorInst *CFGFunctionParser::parseTerminator() {
auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
if (!consumeIf(Token::bare_identifier))
return (emitError("expected basic block name"), nullptr);
return builder.createBranchInst(destBB);
auto branch = builder.createBranchInst(destBB);
// Parse the use list.
if (!consumeIf(Token::l_paren))
return branch;
SmallVector<SSAUseInfo, 4> valueIDs;
if (parseOptionalSSAUseList(valueIDs))
return nullptr;
if (!consumeIf(Token::r_paren))
return (emitError("expected ')' in branch argument list"), nullptr);
if (!consumeIf(Token::colon))
return (emitError("expected ':' in branch argument list"), nullptr);
auto typeLoc = getToken().getLoc();
SmallVector<Type *, 4> types;
if (parseTypeListNoParens(types))
return nullptr;
if (types.size() != valueIDs.size())
return (emitError(typeLoc, "expected " + Twine(valueIDs.size()) +
" types to match operand list"),
nullptr);
SmallVector<CFGValue *, 4> values;
values.reserve(valueIDs.size());
for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
if (auto *value = resolveSSAUse(valueIDs[i], types[i]))
values.push_back(cast<CFGValue>(value));
else
return nullptr;
}
branch->addOperands(values);
return branch;
}
// TODO: cond_br.
}

View File

@ -243,3 +243,16 @@ cfgfunc @bbargMismatch(i32, f32) { // expected-error {{first block of cfgfunc mu
bb42(%0: f32):
return
}
// -----
cfgfunc @br_mismatch() { // expected-error {{branch has 2 operands, but target block has 1}}
bb0: // CHECK: bb0:
// CHECK: %0 = "foo"() : () -> (i1, i17)
%0 = "foo"() : () -> (i1, i17)
br bb1(%0#1, %0#0) : i17, i1
bb1(%x: i17):
return
}

View File

@ -66,16 +66,16 @@ extfunc @memrefs23(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -
// CHECK: extfunc @functions((memref<1x?x4x?x?xaffineint, (d0, d1, d2, d3, d4) [s0] -> (d0, d1, d2, d3, d4), 0>, memref<i8, (d0) -> (d0), 0>) -> (), () -> ())
extfunc @functions((memref<1x?x4x?x?xaffineint, #map0, 0>, memref<i8, #map1, 0>) -> (), ()->())
// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) {
cfgfunc @simpleCFG(i32, f32) {
// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) -> i1 {
cfgfunc @simpleCFG(i32, f32) -> i1 {
// CHECK: bb0(%0: i32, %1: f32):
bb42 (%0: i32, %f: f32):
// CHECK: %2 = "foo"() : () -> i64
%1 = "foo"() : ()->i64
// CHECK: "bar"(%2) : (i64) -> (i1, i1, i1)
%2 = "bar"(%1) : (i64) -> (i1,i1,i1)
// CHECK: return
return
// CHECK: return %3#1
return %2#1 : i1
// CHECK: }
}
@ -208,4 +208,17 @@ bb2: // CHECK: bb2:
// CHECK: %2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
%2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
br bb1
}
}
// CHECK-LABEL: cfgfunc @bbargs() -> (i16, i8) {
cfgfunc @bbargs() -> (i16, i8) {
bb0: // CHECK: bb0:
// CHECK: %0 = "foo"() : () -> (i1, i17)
%0 = "foo"() : () -> (i1, i17)
br bb1(%0#1, %0#0) : i17, i1
bb1(%x: i17, %y: i1): // CHECK: bb1(%1: i17, %2: i1):
// CHECK: %3 = "baz"(%1, %2, %0#1) : (i17, i1, i17) -> (i16, i8)
%1 = "baz"(%x, %y, %0#1) : (i17, i1, i17) -> (i16, i8)
return %1#0 : i16, %1#1 : i8
}