Use for statement directly as an operand instead of having it pretend to be an induction variable.
PiperOrigin-RevId: 206759180
This commit is contained in:
parent
c8b0273f19
commit
43e2a13605
|
@ -51,7 +51,7 @@ public:
|
|||
|
||||
case SSAValueKind::FnArgument:
|
||||
case SSAValueKind::StmtResult:
|
||||
case SSAValueKind::InductionVar:
|
||||
case SSAValueKind::ForStmt:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ class ForStmt;
|
|||
enum class MLValueKind {
|
||||
FnArgument = (int)SSAValueKind::FnArgument,
|
||||
StmtResult = (int)SSAValueKind::StmtResult,
|
||||
InductionVar = (int)SSAValueKind::InductionVar,
|
||||
ForStmt = (int)SSAValueKind::ForStmt,
|
||||
};
|
||||
|
||||
/// The operand of ML function statement contains an MLValue.
|
||||
|
@ -49,7 +49,7 @@ public:
|
|||
switch (value->getKind()) {
|
||||
case SSAValueKind::FnArgument:
|
||||
case SSAValueKind::StmtResult:
|
||||
case SSAValueKind::InductionVar:
|
||||
case SSAValueKind::ForStmt:
|
||||
return true;
|
||||
|
||||
case SSAValueKind::BBArgument:
|
||||
|
@ -106,26 +106,6 @@ private:
|
|||
OperationStmt *const owner;
|
||||
};
|
||||
|
||||
/// This is a value defined by a loop induction variable.
|
||||
class InductionVar : public MLValue {
|
||||
public:
|
||||
InductionVar(Type *type, ForStmt *owner)
|
||||
: MLValue(MLValueKind::InductionVar, type), owner(owner) {}
|
||||
|
||||
static bool classof(const SSAValue *value) {
|
||||
return value->getKind() == SSAValueKind::InductionVar;
|
||||
}
|
||||
|
||||
ForStmt *getOwner() { return owner; }
|
||||
const ForStmt *getOwner() const { return owner; }
|
||||
|
||||
private:
|
||||
/// The owner of this operand.
|
||||
/// TODO: can encode this more efficiently to avoid the space hit of this
|
||||
/// through bitpacking shenanigans.
|
||||
ForStmt *const owner;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
|
|
@ -33,11 +33,11 @@ class OperationStmt;
|
|||
|
||||
/// This enumerates all of the SSA value kinds in the MLIR system.
|
||||
enum class SSAValueKind {
|
||||
BBArgument, // basic block argument
|
||||
InstResult, // instruction result
|
||||
FnArgument, // ML function argument
|
||||
StmtResult, // statement result
|
||||
InductionVar, // for statement induction variable
|
||||
BBArgument, // basic block argument
|
||||
InstResult, // instruction result
|
||||
FnArgument, // ML function argument
|
||||
StmtResult, // statement result
|
||||
ForStmt, // for statement induction variable
|
||||
};
|
||||
|
||||
/// This is the common base class for all values in the MLIR system,
|
||||
|
|
|
@ -187,7 +187,7 @@ private:
|
|||
};
|
||||
|
||||
/// For statement represents an affine loop nest.
|
||||
class ForStmt : public Statement, public StmtBlock, private MLValue {
|
||||
class ForStmt : public Statement, public StmtBlock, public MLValue {
|
||||
public:
|
||||
// TODO: lower and upper bounds should be affine maps with
|
||||
// dimension and symbol use lists.
|
||||
|
@ -211,16 +211,13 @@ public:
|
|||
return block->getStmtBlockKind() == StmtBlockKind::For;
|
||||
}
|
||||
|
||||
// For statement represents induction variable by inheriting
|
||||
// from MLValue. This design is hidden behind interfaces.
|
||||
// For statement represents implicitly represents induction variable by
|
||||
// inheriting from MLValue class. Whenever you need to refer to the loop
|
||||
// induction variable, just use the for statement itself.
|
||||
static bool classof(const SSAValue *value) {
|
||||
return value->getKind() == SSAValueKind::InductionVar;
|
||||
return value->getKind() == SSAValueKind::ForStmt;
|
||||
}
|
||||
|
||||
/// MLValue methods
|
||||
MLValue *getInductionVar() { return this; }
|
||||
const MLValue *getInductionVar() const { return this; }
|
||||
|
||||
private:
|
||||
AffineConstantExpr *lowerBound;
|
||||
AffineConstantExpr *upperBound;
|
||||
|
|
|
@ -552,8 +552,8 @@ protected:
|
|||
case SSAValueKind::FnArgument:
|
||||
id = nextFnArgumentID++;
|
||||
break;
|
||||
case SSAValueKind::InductionVar:
|
||||
id = nextInductionVarID++;
|
||||
case SSAValueKind::ForStmt:
|
||||
id = nextLoopID++;
|
||||
break;
|
||||
}
|
||||
valueIDs[value] = id;
|
||||
|
@ -599,7 +599,7 @@ private:
|
|||
/// This is the value ID for each SSA value in the current function.
|
||||
DenseMap<const SSAValue *, unsigned> valueIDs;
|
||||
unsigned nextValueID = 0;
|
||||
unsigned nextInductionVarID = 0;
|
||||
unsigned nextLoopID = 0;
|
||||
unsigned nextFnArgumentID = 0;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -900,9 +900,7 @@ void MLFunctionPrinter::numberValues() {
|
|||
if (stmt->getNumResults() != 0)
|
||||
printer->numberValueID(stmt->getResult(0));
|
||||
}
|
||||
void visitForStmt(ForStmt *stmt) {
|
||||
printer->numberValueID(stmt->getInductionVar());
|
||||
}
|
||||
void visitForStmt(ForStmt *stmt) { printer->numberValueID(stmt); }
|
||||
MLFunctionPrinter *printer;
|
||||
};
|
||||
|
||||
|
@ -948,7 +946,7 @@ void MLFunctionPrinter::print(const OperationStmt *stmt) {
|
|||
|
||||
void MLFunctionPrinter::print(const ForStmt *stmt) {
|
||||
os.indent(numSpaces) << "for ";
|
||||
printOperand(stmt->getInductionVar());
|
||||
printOperand(stmt);
|
||||
os << " = " << *stmt->getLowerBound();
|
||||
os << " to " << *stmt->getUpperBound();
|
||||
if (stmt->getStep()->getValue() != 1)
|
||||
|
|
|
@ -198,7 +198,7 @@ OperationStmt *SSAValue::getDefiningStmt() {
|
|||
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
|
||||
AffineConstantExpr *step, MLIRContext *context)
|
||||
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
|
||||
MLValue(MLValueKind::InductionVar, Type::getAffineInt(context)),
|
||||
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
|
||||
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -2117,7 +2117,7 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
ForStmt *forStmt = builder.createFor(lowerBound, upperBound, step);
|
||||
|
||||
// Create SSA value definition for the induction variable.
|
||||
addDefinition({inductionVariableName, 0, loc}, forStmt->getInductionVar());
|
||||
addDefinition({inductionVariableName, 0, loc}, forStmt);
|
||||
|
||||
// If parsing of the for statement body fails,
|
||||
// MLIR contains for statement with those nested statements that have been
|
||||
|
|
Loading…
Reference in New Issue