Implement getFunction() helpers on the various value types, and use it to

implement some simple checks in the Verifier.

PiperOrigin-RevId: 211729987
This commit is contained in:
Chris Lattner 2018-09-05 17:45:19 -07:00 committed by jpienaar
parent f884e8da82
commit 2366c58a79
5 changed files with 87 additions and 3 deletions

View File

@ -27,6 +27,7 @@
namespace mlir { namespace mlir {
class BasicBlock; class BasicBlock;
class CFGValue; class CFGValue;
class CFGFunction;
class Instruction; class Instruction;
/// This enum contains all of the SSA value kinds that are valid in a CFG /// This enum contains all of the SSA value kinds that are valid in a CFG
@ -56,6 +57,14 @@ public:
} }
} }
/// Return the function that this CFGValue is defined in.
CFGFunction *getFunction();
/// Return the function that this CFGValue is defined in.
const CFGFunction *getFunction() const {
return const_cast<CFGValue *>(this)->getFunction();
}
protected: protected:
CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {} CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
}; };
@ -67,6 +76,9 @@ public:
return value->getKind() == SSAValueKind::BBArgument; return value->getKind() == SSAValueKind::BBArgument;
} }
/// Return the function that this argument is defined in.
CFGFunction *getFunction() const;
BasicBlock *getOwner() { return owner; } BasicBlock *getOwner() { return owner; }
const BasicBlock *getOwner() const { return owner; } const BasicBlock *getOwner() const { return owner; }

View File

@ -63,6 +63,14 @@ public:
} }
} }
/// Return the function that this MLValue is defined in.
MLFunction *getFunction();
/// Return the function that this MLValue is defined in.
const MLFunction *getFunction() const {
return const_cast<MLValue *>(this)->getFunction();
}
protected: protected:
MLValue(MLValueKind kind, Type *type) : SSAValueImpl(kind, type) {} MLValue(MLValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
}; };
@ -77,6 +85,11 @@ public:
MLFunction *getOwner() { return owner; } MLFunction *getOwner() { return owner; }
const MLFunction *getOwner() const { return owner; } const MLFunction *getOwner() const { return owner; }
/// Return the function that this MLFuncArgument is defined in.
const MLFunction *getFunction() const { return getOwner(); }
MLFunction *getFunction() { return getOwner(); }
private: private:
friend class MLFunction; // For access to private constructor. friend class MLFunction; // For access to private constructor.
MLFuncArgument(Type *type, MLFunction *owner) MLFuncArgument(Type *type, MLFunction *owner)

View File

@ -28,6 +28,7 @@
#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/PointerIntPair.h"
namespace mlir { namespace mlir {
class Function;
class OperationInst; class OperationInst;
class OperationStmt; class OperationStmt;
class Operation; class Operation;
@ -59,6 +60,14 @@ public:
IRObjectWithUseList::replaceAllUsesWith(newValue); IRObjectWithUseList::replaceAllUsesWith(newValue);
} }
/// Return the function that this SSAValue is defined in.
Function *getFunction();
/// Return the function that this SSAValue is defined in.
const Function *getFunction() const {
return const_cast<SSAValue *>(this)->getFunction();
}
/// If this value is the result of an OperationInst, return the instruction /// If this value is the result of an OperationInst, return the instruction
/// that defines it. /// that defines it.
OperationInst *getDefiningInst(); OperationInst *getDefiningInst();

View File

@ -16,7 +16,9 @@
// ============================================================================= // =============================================================================
#include "mlir/IR/SSAValue.h" #include "mlir/IR/SSAValue.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/Instructions.h" #include "mlir/IR/Instructions.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/StandardOps.h" #include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h" #include "mlir/IR/Statements.h"
using namespace mlir; using namespace mlir;
@ -45,10 +47,51 @@ Operation *SSAValue::getDefiningOperation() {
return nullptr; return nullptr;
} }
/// Return the function that this SSAValue is defined in.
Function *SSAValue::getFunction() {
switch (getKind()) {
case SSAValueKind::BBArgument:
return cast<BBArgument>(this)->getFunction();
case SSAValueKind::InstResult:
return getDefiningInst()->getFunction();
case SSAValueKind::MLFuncArgument:
return cast<MLFuncArgument>(this)->getFunction();
case SSAValueKind::StmtResult:
return getDefiningStmt()->findFunction();
case SSAValueKind::ForStmt:
return cast<ForStmt>(this)->findFunction();
}
}
//===----------------------------------------------------------------------===//
// CFGValue implementation.
//===----------------------------------------------------------------------===//
/// Return the function that this CFGValue is defined in.
CFGFunction *CFGValue::getFunction() {
return cast<CFGFunction>(static_cast<SSAValue *>(this)->getFunction());
}
//===----------------------------------------------------------------------===//
// BBArgument implementation.
//===----------------------------------------------------------------------===//
/// Return the function that this argument is defined in.
CFGFunction *BBArgument::getFunction() const {
if (auto *owner = getOwner())
return owner->getFunction();
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MLValue implementation. // MLValue implementation.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Return the function that this MLValue is defined in.
MLFunction *MLValue::getFunction() {
return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction());
}
// MLValue can be used a a dimension id if it is valid as a symbol, or // MLValue can be used a a dimension id if it is valid as a symbol, or
// it is an induction variable, or it is a result of affine apply operation // it is an induction variable, or it is a result of affine apply operation
// with dimension id arguments. // with dimension id arguments.
@ -62,8 +105,8 @@ bool MLValue::isValidDim() const {
return op->isValidDim(); return op->isValidDim();
return false; return false;
} }
// This value is either a function argument or an induction variable. Both are // This value is either a function argument or an induction variable. Both
// ok. // are ok.
return true; return true;
} }

View File

@ -137,7 +137,14 @@ bool Verifier::verifyOperation(const Operation &op) {
if (op.getOperationFunction() != &fn) if (op.getOperationFunction() != &fn)
return opFailure("operation in the wrong function", op); return opFailure("operation in the wrong function", op);
// TODO: Check that operands are non-nil and structurally ok. // Check that operands are non-nil and structurally ok.
for (const auto *operand : op.getOperands()) {
if (!operand)
return opFailure("null operand found", op);
if (operand->getFunction() != &fn)
return opFailure("reference to operand defined in another function", op);
}
// Verify all attributes are ok. We need to check Function attributes, since // Verify all attributes are ok. We need to check Function attributes, since
// they are actually mutable (the function they refer to can be deleted), and // they are actually mutable (the function they refer to can be deleted), and