[mlir] Fix dumping invalid ops
This patch fixes the crash when printing some ops (like affine.for and scf.for) when they are dumped in invalid state, e.g. during pattern application. Now the AsmState constructor verifies the operation first and switches to generic operation printing when the verification fails. Also operations are now printed in generic form when emitting diagnostics and the severity level is Error. Reviewed By: rriddle, mehdi_amini Differential Revision: https://reviews.llvm.org/D117834
This commit is contained in:
parent
54d6b5b67f
commit
27df7158fe
|
@ -107,6 +107,18 @@ op->emitError() << "Compose an interesting error: " << fooAttr << ", " << fooTyp
|
|||
"Compose an interesting error: @foo, i32, (0, 1, 2)"
|
||||
```
|
||||
|
||||
Operations attached to a diagnostic will be printed in generic form if the
|
||||
severity level is `Error`, otherwise custom operation printers will be used.
|
||||
```c++
|
||||
// `anotherOp` will be printed in generic form,
|
||||
// e.g. %3 = "arith.addf"(%arg4, %2) : (f32, f32) -> f32
|
||||
op->emitError() << anotherOp;
|
||||
|
||||
// `anotherOp` will be printed using the custom printer,
|
||||
// e.g. %3 = arith.addf %arg4, %2 : f32
|
||||
op->emitRemark() << anotherOp;
|
||||
```
|
||||
|
||||
### Attaching notes
|
||||
|
||||
Unlike many other compiler frameworks, notes in MLIR cannot be emitted directly.
|
||||
|
|
|
@ -601,6 +601,15 @@ Note that the second phase will be run after the operations in the region are
|
|||
verified. Verifiers further down the order can rely on certain invariants being
|
||||
verified by a previous verifier and do not need to re-verify them.
|
||||
|
||||
#### Emitting diagnostics in custom verifiers
|
||||
|
||||
Custom verifiers should avoid printing operations using custom operation
|
||||
printers, because they require the printed operation (and sometimes its parent
|
||||
operation) to be verified first. In particular, when emitting diagnostics,
|
||||
custom verifiers should use the `Error` severity level, which prints operations
|
||||
in generic form by default, and avoid using lower severity levels (`Note`,
|
||||
`Remark`, `Warning`).
|
||||
|
||||
### Declarative Assembly Format
|
||||
|
||||
The custom assembly form of the operation may be specified in a declarative
|
||||
|
|
|
@ -726,6 +726,9 @@ public:
|
|||
/// Always print operations in the generic form.
|
||||
OpPrintingFlags &printGenericOpForm();
|
||||
|
||||
/// Do not verify the operation when using custom operation printers.
|
||||
OpPrintingFlags &assumeVerified();
|
||||
|
||||
/// Use local scope when printing the operation. This allows for using the
|
||||
/// printer in a more localized and thread-safe setting, but may not
|
||||
/// necessarily be identical to what the IR will look like when dumping
|
||||
|
@ -747,6 +750,9 @@ public:
|
|||
/// Return if operations should be printed in the generic form.
|
||||
bool shouldPrintGenericOpForm() const;
|
||||
|
||||
/// Return if operation verification should be skipped.
|
||||
bool shouldAssumeVerified() const;
|
||||
|
||||
/// Return if the printer should use local scope when dumping the IR.
|
||||
bool shouldUseLocalScope() const;
|
||||
|
||||
|
@ -762,6 +768,9 @@ private:
|
|||
/// Print operations in the generic form.
|
||||
bool printGenericOpFormFlag : 1;
|
||||
|
||||
/// Skip operation verification.
|
||||
bool assumeVerifiedFlag : 1;
|
||||
|
||||
/// Print operations with numberings local to the current operation.
|
||||
bool printLocalScope : 1;
|
||||
};
|
||||
|
|
|
@ -24,6 +24,7 @@ class Block;
|
|||
class BlockArgument;
|
||||
class Operation;
|
||||
class OpOperand;
|
||||
class OpPrintingFlags;
|
||||
class OpResult;
|
||||
class Region;
|
||||
class Value;
|
||||
|
@ -215,6 +216,7 @@ public:
|
|||
// Utilities
|
||||
|
||||
void print(raw_ostream &os);
|
||||
void print(raw_ostream &os, const OpPrintingFlags &flags);
|
||||
void print(raw_ostream &os, AsmState &state);
|
||||
void dump();
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/SubElementInterfaces.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
@ -40,6 +41,7 @@
|
|||
#include "llvm/Support/Endian.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/SaveAndRestore.h"
|
||||
#include "llvm/Support/Threading.h"
|
||||
|
||||
#include <tuple>
|
||||
|
||||
|
@ -141,6 +143,11 @@ struct AsmPrinterOptions {
|
|||
"mlir-print-op-generic", llvm::cl::init(false),
|
||||
llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
|
||||
|
||||
llvm::cl::opt<bool> assumeVerifiedOpt{
|
||||
"mlir-print-assume-verified", llvm::cl::init(false),
|
||||
llvm::cl::desc("Skip op verification when using custom printers"),
|
||||
llvm::cl::Hidden};
|
||||
|
||||
llvm::cl::opt<bool> printLocalScopeOpt{
|
||||
"mlir-print-local-scope", llvm::cl::init(false),
|
||||
llvm::cl::desc("Print with local scope and inline information (eliding "
|
||||
|
@ -160,7 +167,8 @@ void mlir::registerAsmPrinterCLOptions() {
|
|||
/// Initialize the printing flags with default supplied by the cl::opts above.
|
||||
OpPrintingFlags::OpPrintingFlags()
|
||||
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
|
||||
printGenericOpFormFlag(false), printLocalScope(false) {
|
||||
printGenericOpFormFlag(false), assumeVerifiedFlag(false),
|
||||
printLocalScope(false) {
|
||||
// Initialize based upon command line options, if they are available.
|
||||
if (!clOptions.isConstructed())
|
||||
return;
|
||||
|
@ -169,6 +177,7 @@ OpPrintingFlags::OpPrintingFlags()
|
|||
printDebugInfoFlag = clOptions->printDebugInfoOpt;
|
||||
printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
|
||||
printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
|
||||
assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
|
||||
printLocalScope = clOptions->printLocalScopeOpt;
|
||||
}
|
||||
|
||||
|
@ -196,6 +205,12 @@ OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Do not verify the operation when using custom operation printers.
|
||||
OpPrintingFlags &OpPrintingFlags::assumeVerified() {
|
||||
assumeVerifiedFlag = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Use local scope when printing the operation. This allows for using the
|
||||
/// printer in a more localized and thread-safe setting, but may not necessarily
|
||||
/// be identical of what the IR will look like when dumping the full module.
|
||||
|
@ -231,6 +246,11 @@ bool OpPrintingFlags::shouldPrintGenericOpForm() const {
|
|||
return printGenericOpFormFlag;
|
||||
}
|
||||
|
||||
/// Return if operation verification should be skipped.
|
||||
bool OpPrintingFlags::shouldAssumeVerified() const {
|
||||
return assumeVerifiedFlag;
|
||||
}
|
||||
|
||||
/// Return if the printer should use local scope when dumping the IR.
|
||||
bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
|
||||
|
||||
|
@ -1245,9 +1265,31 @@ private:
|
|||
} // namespace detail
|
||||
} // namespace mlir
|
||||
|
||||
/// Verifies the operation and switches to generic op printing if verification
|
||||
/// fails. We need to do this because custom print functions may fail for
|
||||
/// invalid ops.
|
||||
static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
|
||||
OpPrintingFlags printerFlags) {
|
||||
if (printerFlags.shouldPrintGenericOpForm() ||
|
||||
printerFlags.shouldAssumeVerified())
|
||||
return printerFlags;
|
||||
|
||||
// Ignore errors emitted by the verifier. We check the thread id to avoid
|
||||
// consuming other threads' errors.
|
||||
auto parentThreadId = llvm::get_threadid();
|
||||
ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &) {
|
||||
return success(parentThreadId == llvm::get_threadid());
|
||||
});
|
||||
if (failed(verify(op)))
|
||||
printerFlags.printGenericOpForm();
|
||||
|
||||
return printerFlags;
|
||||
}
|
||||
|
||||
AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
|
||||
LocationMap *locationMap)
|
||||
: impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
|
||||
: impl(std::make_unique<AsmStateImpl>(
|
||||
op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
|
||||
AsmState::~AsmState() = default;
|
||||
|
||||
const OpPrintingFlags &AsmState::getPrinterFlags() const {
|
||||
|
@ -2853,14 +2895,15 @@ void IntegerSet::print(raw_ostream &os) const {
|
|||
AsmPrinter::Impl(os).printIntegerSet(*this);
|
||||
}
|
||||
|
||||
void Value::print(raw_ostream &os) {
|
||||
void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }
|
||||
void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
|
||||
if (!impl) {
|
||||
os << "<<NULL VALUE>>";
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto *op = getDefiningOp())
|
||||
return op->print(os);
|
||||
return op->print(os, flags);
|
||||
// TODO: Improve BlockArgument print'ing.
|
||||
BlockArgument arg = this->cast<BlockArgument>();
|
||||
os << "<block argument> of type '" << arg.getType()
|
||||
|
|
|
@ -121,6 +121,17 @@ Diagnostic &Diagnostic::operator<<(OperationName val) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Adjusts operation printing flags used in diagnostics for the given severity
|
||||
/// level.
|
||||
static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
|
||||
DiagnosticSeverity severity) {
|
||||
flags.useLocalScope();
|
||||
flags.elideLargeElementsAttrs();
|
||||
if (severity == DiagnosticSeverity::Error)
|
||||
flags.printGenericOpForm();
|
||||
return flags;
|
||||
}
|
||||
|
||||
/// Stream in an Operation.
|
||||
Diagnostic &Diagnostic::operator<<(Operation &val) {
|
||||
return appendOp(val, OpPrintingFlags());
|
||||
|
@ -128,8 +139,7 @@ Diagnostic &Diagnostic::operator<<(Operation &val) {
|
|||
Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
val.print(os,
|
||||
OpPrintingFlags(flags).useLocalScope().elideLargeElementsAttrs());
|
||||
val.print(os, adjustPrintingFlags(flags, severity));
|
||||
return *this << os.str();
|
||||
}
|
||||
|
||||
|
@ -137,7 +147,7 @@ Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
|
|||
Diagnostic &Diagnostic::operator<<(Value val) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
val.print(os);
|
||||
val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
|
||||
return *this << os.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -1097,6 +1097,8 @@ LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) {
|
|||
// Check that any value that is used by an operation is defined in the
|
||||
// same region as either an operation result.
|
||||
auto *operandRegion = operand.getParentRegion();
|
||||
if (!operandRegion)
|
||||
return op.emitError("operation's operand is unlinked");
|
||||
if (!region.isAncestor(operandRegion)) {
|
||||
return op.emitOpError("using value defined outside the region")
|
||||
.attachNote(isolatedOp->getLoc())
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
// # RUN: mlir-opt -test-print-invalid %s | FileCheck %s
|
||||
// # RUN: mlir-opt -test-print-invalid %s --mlir-print-assume-verified | FileCheck %s --check-prefix=ASSUME-VERIFIED
|
||||
|
||||
// The pass creates some ops and prints them to stdout, the input is just an
|
||||
// empty module.
|
||||
module {}
|
||||
|
||||
// The operation is invalid because the body does not have a terminator, print
|
||||
// the generic form.
|
||||
// CHECK: Invalid operation:
|
||||
// CHECK-NEXT: "builtin.func"() ({
|
||||
// CHECK-NEXT: ^bb0:
|
||||
// CHECK-NEXT: })
|
||||
// CHECK-SAME: sym_name = "test"
|
||||
|
||||
// The operation is valid because the body has a terminator, print the custom
|
||||
// form.
|
||||
// CHECK: Valid operation:
|
||||
// CHECK-NEXT: func @test() {
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// With --mlir-print-assume-verified the custom form is printed in both cases.
|
||||
// This works in this particular case, but may crash in general.
|
||||
|
||||
// ASSUME-VERIFIED: Invalid operation:
|
||||
// ASSUME-VERIFIED-NEXT: func @test() {
|
||||
// ASSUME-VERIFIED-NEXT: }
|
||||
|
||||
// ASSUME-VERIFIED: Valid operation:
|
||||
// ASSUME-VERIFIED-NEXT: func @test() {
|
||||
// ASSUME-VERIFIED-NEXT: return
|
||||
// ASSUME-VERIFIED-NEXT: }
|
|
@ -9,6 +9,7 @@ add_mlir_library(MLIRTestIR
|
|||
TestOpaqueLoc.cpp
|
||||
TestOperationEquals.cpp
|
||||
TestPrintDefUse.cpp
|
||||
TestPrintInvalid.cpp
|
||||
TestPrintNesting.cpp
|
||||
TestSideEffects.cpp
|
||||
TestSlicing.cpp
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
//===- TestPrintInvalid.cpp - Test printing invalid ops -------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This pass creates and prints to the standard output an invalid operation and
|
||||
// a valid operation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestPrintInvalidPass
|
||||
: public PassWrapper<TestPrintInvalidPass, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-print-invalid"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test printing invalid ops.";
|
||||
}
|
||||
void getDependentDialects(DialectRegistry ®istry) const {
|
||||
registry.insert<func::FuncDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Location loc = getOperation().getLoc();
|
||||
OpBuilder builder(getOperation().body());
|
||||
auto funcOp = builder.create<FuncOp>(
|
||||
loc, "test", FunctionType::get(getOperation().getContext(), {}, {}));
|
||||
funcOp.addEntryBlock();
|
||||
// The created function is invalid because there is no return op.
|
||||
llvm::outs() << "Invalid operation:\n" << funcOp << "\n";
|
||||
builder.setInsertionPointToEnd(&funcOp.getBody().front());
|
||||
builder.create<func::ReturnOp>(loc);
|
||||
// Now this function is valid.
|
||||
llvm::outs() << "Valid operation:\n" << funcOp << "\n";
|
||||
funcOp.erase();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
void registerTestPrintInvalidPass() {
|
||||
PassRegistration<TestPrintInvalidPass>{};
|
||||
}
|
||||
} // namespace mlir
|
|
@ -45,6 +45,7 @@ void registerTestLoopPermutationPass();
|
|||
void registerTestMatchers();
|
||||
void registerTestOperationEqualPass();
|
||||
void registerTestPrintDefUsePass();
|
||||
void registerTestPrintInvalidPass();
|
||||
void registerTestPrintNestingPass();
|
||||
void registerTestReducer();
|
||||
void registerTestSpirvEntryPointABIPass();
|
||||
|
@ -132,6 +133,7 @@ void registerTestPasses() {
|
|||
registerTestMatchers();
|
||||
registerTestOperationEqualPass();
|
||||
registerTestPrintDefUsePass();
|
||||
registerTestPrintInvalidPass();
|
||||
registerTestPrintNestingPass();
|
||||
registerTestReducer();
|
||||
registerTestSpirvEntryPointABIPass();
|
||||
|
|
Loading…
Reference in New Issue