Implement getFunction() helpers on the various value types, and use it to
implement some simple checks in the Verifier. PiperOrigin-RevId: 211729987
This commit is contained in:
parent
f884e8da82
commit
2366c58a79
|
@ -27,6 +27,7 @@
|
|||
namespace mlir {
|
||||
class BasicBlock;
|
||||
class CFGValue;
|
||||
class CFGFunction;
|
||||
class Instruction;
|
||||
|
||||
/// This enum contains all of the SSA value kinds that are valid in a CFG
|
||||
|
@ -56,6 +57,14 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/// Return the function that this CFGValue is defined in.
|
||||
CFGFunction *getFunction();
|
||||
|
||||
/// Return the function that this CFGValue is defined in.
|
||||
const CFGFunction *getFunction() const {
|
||||
return const_cast<CFGValue *>(this)->getFunction();
|
||||
}
|
||||
|
||||
protected:
|
||||
CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
|
||||
};
|
||||
|
@ -67,6 +76,9 @@ public:
|
|||
return value->getKind() == SSAValueKind::BBArgument;
|
||||
}
|
||||
|
||||
/// Return the function that this argument is defined in.
|
||||
CFGFunction *getFunction() const;
|
||||
|
||||
BasicBlock *getOwner() { return owner; }
|
||||
const BasicBlock *getOwner() const { return owner; }
|
||||
|
||||
|
|
|
@ -63,6 +63,14 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/// Return the function that this MLValue is defined in.
|
||||
MLFunction *getFunction();
|
||||
|
||||
/// Return the function that this MLValue is defined in.
|
||||
const MLFunction *getFunction() const {
|
||||
return const_cast<MLValue *>(this)->getFunction();
|
||||
}
|
||||
|
||||
protected:
|
||||
MLValue(MLValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
|
||||
};
|
||||
|
@ -77,6 +85,11 @@ public:
|
|||
MLFunction *getOwner() { return owner; }
|
||||
const MLFunction *getOwner() const { return owner; }
|
||||
|
||||
/// Return the function that this MLFuncArgument is defined in.
|
||||
const MLFunction *getFunction() const { return getOwner(); }
|
||||
|
||||
MLFunction *getFunction() { return getOwner(); }
|
||||
|
||||
private:
|
||||
friend class MLFunction; // For access to private constructor.
|
||||
MLFuncArgument(Type *type, MLFunction *owner)
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "llvm/ADT/PointerIntPair.h"
|
||||
|
||||
namespace mlir {
|
||||
class Function;
|
||||
class OperationInst;
|
||||
class OperationStmt;
|
||||
class Operation;
|
||||
|
@ -59,6 +60,14 @@ public:
|
|||
IRObjectWithUseList::replaceAllUsesWith(newValue);
|
||||
}
|
||||
|
||||
/// Return the function that this SSAValue is defined in.
|
||||
Function *getFunction();
|
||||
|
||||
/// Return the function that this SSAValue is defined in.
|
||||
const Function *getFunction() const {
|
||||
return const_cast<SSAValue *>(this)->getFunction();
|
||||
}
|
||||
|
||||
/// If this value is the result of an OperationInst, return the instruction
|
||||
/// that defines it.
|
||||
OperationInst *getDefiningInst();
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/CFGFunction.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/IR/MLFunction.h"
|
||||
#include "mlir/IR/StandardOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
using namespace mlir;
|
||||
|
@ -45,10 +47,51 @@ Operation *SSAValue::getDefiningOperation() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
/// Return the function that this SSAValue is defined in.
|
||||
Function *SSAValue::getFunction() {
|
||||
switch (getKind()) {
|
||||
case SSAValueKind::BBArgument:
|
||||
return cast<BBArgument>(this)->getFunction();
|
||||
case SSAValueKind::InstResult:
|
||||
return getDefiningInst()->getFunction();
|
||||
case SSAValueKind::MLFuncArgument:
|
||||
return cast<MLFuncArgument>(this)->getFunction();
|
||||
case SSAValueKind::StmtResult:
|
||||
return getDefiningStmt()->findFunction();
|
||||
case SSAValueKind::ForStmt:
|
||||
return cast<ForStmt>(this)->findFunction();
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CFGValue implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the function that this CFGValue is defined in.
|
||||
CFGFunction *CFGValue::getFunction() {
|
||||
return cast<CFGFunction>(static_cast<SSAValue *>(this)->getFunction());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BBArgument implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the function that this argument is defined in.
|
||||
CFGFunction *BBArgument::getFunction() const {
|
||||
if (auto *owner = getOwner())
|
||||
return owner->getFunction();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MLValue implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the function that this MLValue is defined in.
|
||||
MLFunction *MLValue::getFunction() {
|
||||
return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction());
|
||||
}
|
||||
|
||||
// MLValue can be used a a dimension id if it is valid as a symbol, or
|
||||
// it is an induction variable, or it is a result of affine apply operation
|
||||
// with dimension id arguments.
|
||||
|
@ -62,8 +105,8 @@ bool MLValue::isValidDim() const {
|
|||
return op->isValidDim();
|
||||
return false;
|
||||
}
|
||||
// This value is either a function argument or an induction variable. Both are
|
||||
// ok.
|
||||
// This value is either a function argument or an induction variable. Both
|
||||
// are ok.
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -137,7 +137,14 @@ bool Verifier::verifyOperation(const Operation &op) {
|
|||
if (op.getOperationFunction() != &fn)
|
||||
return opFailure("operation in the wrong function", op);
|
||||
|
||||
// TODO: Check that operands are non-nil and structurally ok.
|
||||
// Check that operands are non-nil and structurally ok.
|
||||
for (const auto *operand : op.getOperands()) {
|
||||
if (!operand)
|
||||
return opFailure("null operand found", op);
|
||||
|
||||
if (operand->getFunction() != &fn)
|
||||
return opFailure("reference to operand defined in another function", op);
|
||||
}
|
||||
|
||||
// Verify all attributes are ok. We need to check Function attributes, since
|
||||
// they are actually mutable (the function they refer to can be deleted), and
|
||||
|
|
Loading…
Reference in New Issue