Use for statement directly as an operand instead of having it pretend to be an induction variable.

PiperOrigin-RevId: 206759180
This commit is contained in:
Tatiana Shpeisman 2018-07-31 07:40:14 -07:00 committed by jpienaar
parent c8b0273f19
commit 43e2a13605
7 changed files with 20 additions and 45 deletions

View File

@ -51,7 +51,7 @@ public:
case SSAValueKind::FnArgument:
case SSAValueKind::StmtResult:
case SSAValueKind::InductionVar:
case SSAValueKind::ForStmt:
return false;
}
}

View File

@ -36,7 +36,7 @@ class ForStmt;
enum class MLValueKind {
FnArgument = (int)SSAValueKind::FnArgument,
StmtResult = (int)SSAValueKind::StmtResult,
InductionVar = (int)SSAValueKind::InductionVar,
ForStmt = (int)SSAValueKind::ForStmt,
};
/// The operand of ML function statement contains an MLValue.
@ -49,7 +49,7 @@ public:
switch (value->getKind()) {
case SSAValueKind::FnArgument:
case SSAValueKind::StmtResult:
case SSAValueKind::InductionVar:
case SSAValueKind::ForStmt:
return true;
case SSAValueKind::BBArgument:
@ -106,26 +106,6 @@ private:
OperationStmt *const owner;
};
/// This is a value defined by a loop induction variable.
class InductionVar : public MLValue {
public:
InductionVar(Type *type, ForStmt *owner)
: MLValue(MLValueKind::InductionVar, type), owner(owner) {}
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::InductionVar;
}
ForStmt *getOwner() { return owner; }
const ForStmt *getOwner() const { return owner; }
private:
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
ForStmt *const owner;
};
} // namespace mlir
#endif

View File

@ -33,11 +33,11 @@ class OperationStmt;
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class SSAValueKind {
BBArgument, // basic block argument
InstResult, // instruction result
FnArgument, // ML function argument
StmtResult, // statement result
InductionVar, // for statement induction variable
BBArgument, // basic block argument
InstResult, // instruction result
FnArgument, // ML function argument
StmtResult, // statement result
ForStmt, // for statement induction variable
};
/// This is the common base class for all values in the MLIR system,

View File

@ -187,7 +187,7 @@ private:
};
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public StmtBlock, private MLValue {
class ForStmt : public Statement, public StmtBlock, public MLValue {
public:
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
@ -211,16 +211,13 @@ public:
return block->getStmtBlockKind() == StmtBlockKind::For;
}
// For statement represents induction variable by inheriting
// from MLValue. This design is hidden behind interfaces.
// For statement represents implicitly represents induction variable by
// inheriting from MLValue class. Whenever you need to refer to the loop
// induction variable, just use the for statement itself.
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::InductionVar;
return value->getKind() == SSAValueKind::ForStmt;
}
/// MLValue methods
MLValue *getInductionVar() { return this; }
const MLValue *getInductionVar() const { return this; }
private:
AffineConstantExpr *lowerBound;
AffineConstantExpr *upperBound;

View File

@ -552,8 +552,8 @@ protected:
case SSAValueKind::FnArgument:
id = nextFnArgumentID++;
break;
case SSAValueKind::InductionVar:
id = nextInductionVarID++;
case SSAValueKind::ForStmt:
id = nextLoopID++;
break;
}
valueIDs[value] = id;
@ -599,7 +599,7 @@ private:
/// This is the value ID for each SSA value in the current function.
DenseMap<const SSAValue *, unsigned> valueIDs;
unsigned nextValueID = 0;
unsigned nextInductionVarID = 0;
unsigned nextLoopID = 0;
unsigned nextFnArgumentID = 0;
};
} // end anonymous namespace
@ -900,9 +900,7 @@ void MLFunctionPrinter::numberValues() {
if (stmt->getNumResults() != 0)
printer->numberValueID(stmt->getResult(0));
}
void visitForStmt(ForStmt *stmt) {
printer->numberValueID(stmt->getInductionVar());
}
void visitForStmt(ForStmt *stmt) { printer->numberValueID(stmt); }
MLFunctionPrinter *printer;
};
@ -948,7 +946,7 @@ void MLFunctionPrinter::print(const OperationStmt *stmt) {
void MLFunctionPrinter::print(const ForStmt *stmt) {
os.indent(numSpaces) << "for ";
printOperand(stmt->getInductionVar());
printOperand(stmt);
os << " = " << *stmt->getLowerBound();
os << " to " << *stmt->getUpperBound();
if (stmt->getStep()->getValue() != 1)

View File

@ -198,7 +198,7 @@ OperationStmt *SSAValue::getDefiningStmt() {
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
AffineConstantExpr *step, MLIRContext *context)
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
MLValue(MLValueKind::InductionVar, Type::getAffineInt(context)),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
//===----------------------------------------------------------------------===//

View File

@ -2117,7 +2117,7 @@ ParseResult MLFunctionParser::parseForStmt() {
ForStmt *forStmt = builder.createFor(lowerBound, upperBound, step);
// Create SSA value definition for the induction variable.
addDefinition({inductionVariableName, 0, loc}, forStmt->getInductionVar());
addDefinition({inductionVariableName, 0, loc}, forStmt);
// If parsing of the for statement body fails,
// MLIR contains for statement with those nested statements that have been