diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 1f75241ace44..03535ff07d7d 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -473,10 +473,10 @@ public: } ArrayRef getTrueInstOperands() const { - return {&operands[0], &operands[0] + getNumTrueOperands()}; + return const_cast(this)->getTrueInstOperands(); } MutableArrayRef getTrueInstOperands() { - return {&operands[0], &operands[0] + getNumTrueOperands()}; + return {operands.data(), operands.data() + getNumTrueOperands()}; } InstOperand &getTrueInstOperand(unsigned idx) { return operands[idx]; } @@ -526,12 +526,11 @@ public: } ArrayRef getFalseInstOperands() const { - return {&operands[0] + getNumTrueOperands(), - &operands[0] + getNumOperands()}; + return const_cast(this)->getFalseInstOperands(); } MutableArrayRef getFalseInstOperands() { - return {&operands[0] + getNumTrueOperands(), - &operands[0] + getNumOperands()}; + return {operands.data() + getNumTrueOperands(), + operands.data() + getNumOperands()}; } InstOperand &getFalseInstOperand(unsigned idx) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 21dd4c69c99a..28a97ea9b850 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1147,6 +1147,10 @@ void CFGFunctionPrinter::print(const BasicBlock *block) { } void CFGFunctionPrinter::print(const Instruction *inst) { + if (!inst) { + os << "<>\n"; + return; + } switch (inst->getKind()) { case Instruction::Kind::Operation: return print(cast(inst)); diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp index 6c79b8cdd74e..6e93bf5a2a9e 100644 --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -178,19 +178,12 @@ struct CFGFuncVerifier : public Verifier { bool verify(); bool verifyBlock(const BasicBlock &block); bool verifyTerminator(const TerminatorInst &term); + + bool verifyBBArguments(ArrayRef operands, + const BasicBlock *destBB, const TerminatorInst &term); bool verifyReturn(const ReturnInst &inst); bool verifyBranch(const BranchInst &inst); bool verifyCondBranch(const CondBranchInst &inst); - - // Given a list of "operands" and "arguments" that are the same length, verify - // that the types of operands pointwise match argument types. The iterator - // types must expose the "getType()" function when dereferenced twice; that - // is, the iterator's value_type must be equivalent to SSAValue*. - template - bool verifyOperandsMatchArguments(OperandIteratorTy opBegin, - OperandIteratorTy opEnd, - ArgumentIteratorTy argBegin, - const Instruction &instContext); }; } // end anonymous namespace @@ -251,8 +244,20 @@ bool CFGFuncVerifier::verifyTerminator(const TerminatorInst &term) { if (term.getFunction() != &fn) return failure("terminator in the wrong function", term); - // TODO: Check that operands are structurally ok. - // TODO: Check that successors are in the right function. + // Check that operands are non-nil and structurally ok. + for (const auto *operand : term.getOperands()) { + if (!operand) + return failure("null operand found", term); + + if (operand->getFunction() != &fn) + return failure("reference to operand defined in another function", term); + } + + // Check that successors are in the right function. + for (auto *succ : term.getBlock()->getSuccessors()) { + if (succ->getFunction() != &fn) + return failure("reference to block defined in another function", term); + } if (auto *ret = dyn_cast(&term)) return verifyReturn(*ret); @@ -266,6 +271,24 @@ bool CFGFuncVerifier::verifyTerminator(const TerminatorInst &term) { return false; } +/// Check a set of basic block arguments against the expected list in in the +/// destination basic block. +bool CFGFuncVerifier::verifyBBArguments(ArrayRef operands, + const BasicBlock *destBB, + const TerminatorInst &term) { + if (operands.size() != destBB->getNumArguments()) + return failure("branch has " + Twine(operands.size()) + + " operands, but target block has " + + Twine(destBB->getNumArguments()), + term); + + for (unsigned i = 0, e = operands.size(); i != e; ++i) + if (operands[i].get()->getType() != destBB->getArgument(i)->getType()) + return failure("type mismatch in bb argument #" + Twine(i), term); + + return false; +} + bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) { // Verify that the return operands match the results of the function. auto results = fn.getType()->getResults(); @@ -287,63 +310,20 @@ bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) { bool CFGFuncVerifier::verifyBranch(const BranchInst &inst) { // Verify that the number of operands lines up with the number of BB arguments // in the successor. - auto dest = inst.getDest(); - if (inst.getNumOperands() != dest->getNumArguments()) - return failure("branch has " + Twine(inst.getNumOperands()) + - " operands, but target block has " + - Twine(dest->getNumArguments()), - inst); + if (verifyBBArguments(inst.getInstOperands(), inst.getDest(), inst)) + return true; - for (unsigned i = 0, e = inst.getNumOperands(); i != e; ++i) - if (inst.getOperand(i)->getType() != dest->getArgument(i)->getType()) - return failure("type of branch operand " + Twine(i) + - " doesn't match target bb argument type", - inst); - - return false; -} - -template -bool CFGFuncVerifier::verifyOperandsMatchArguments( - OperandIteratorTy opBegin, OperandIteratorTy opEnd, - ArgumentIteratorTy argBegin, const Instruction &instContext) { - OperandIteratorTy opIt = opBegin; - ArgumentIteratorTy argIt = argBegin; - for (; opIt != opEnd; ++opIt, ++argIt) { - if ((*opIt)->getType() != (*argIt)->getType()) - return failure("type of operand " + Twine(std::distance(opBegin, opIt)) + - " doesn't match argument type", - instContext); - } return false; } bool CFGFuncVerifier::verifyCondBranch(const CondBranchInst &inst) { // Verify that the number of operands lines up with the number of BB arguments // in the true successor. - auto trueDest = inst.getTrueDest(); - if (inst.getNumTrueOperands() != trueDest->getNumArguments()) - return failure("branch has " + Twine(inst.getNumTrueOperands()) + - " true operands, but true target block has " + - Twine(trueDest->getNumArguments()), - inst); - - if (verifyOperandsMatchArguments(inst.true_operand_begin(), - inst.true_operand_end(), - trueDest->args_begin(), inst)) + if (verifyBBArguments(inst.getTrueInstOperands(), inst.getTrueDest(), inst)) return true; // And the false successor. - auto falseDest = inst.getFalseDest(); - if (inst.getNumFalseOperands() != falseDest->getNumArguments()) - return failure("branch has " + Twine(inst.getNumFalseOperands()) + - " false operands, but false target block has " + - Twine(falseDest->getNumArguments()), - inst); - - if (verifyOperandsMatchArguments(inst.false_operand_begin(), - inst.false_operand_end(), - falseDest->args_begin(), inst)) + if (verifyBBArguments(inst.getFalseInstOperands(), inst.getFalseDest(), inst)) return true; if (inst.getCondition()->getType() != Type::getInteger(1, fn.getContext())) diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 4d18d5763e42..fe77602d2986 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -187,7 +187,18 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) { pass->runOnModule(module.get()); delete pass; - module->verify(); + + // Verify that the result of the pass is still valid. + std::string errorResult; + module->verify(&errorResult); + + // We don't have location information for general verifier errors, so emit + // the error with an unknown location. + if (!errorResult.empty()) { + context->emitDiagnostic(UnknownLoc::get(context), errorResult, + MLIRContext::DiagnosticKind::Error); + return OptFailure; + } } // Print the output.