[SCEV] Add missing cache queries

Calculating SCEVs can be cumbersome, and may take very long time (even
hours, for very long expressions). To prevent recalculating expressions
over and over again, we cache them.
This change add cache queries to key positions, to prevent recalculation
of the expressions.

Fix PR43571.

Differential Revision: https://reviews.llvm.org/D70097
This commit is contained in:
Ehud Katz 2020-03-13 15:32:43 +02:00
parent a0c15ed460
commit fcc2238b8b
2 changed files with 39 additions and 12 deletions

View File

@ -1899,7 +1899,7 @@ private:
/// otherwise. The second component is the `FoldingSetNodeID` that was /// otherwise. The second component is the `FoldingSetNodeID` that was
/// constructed to look up the SCEV and the third component is the insertion /// constructed to look up the SCEV and the third component is the insertion
/// point. /// point.
std::tuple<const SCEV *, FoldingSetNodeID, void *> std::tuple<SCEV *, FoldingSetNodeID, void *>
findExistingSCEVInCache(int SCEVType, ArrayRef<const SCEV *> Ops); findExistingSCEVInCache(int SCEVType, ArrayRef<const SCEV *> Ops);
FoldingSet<SCEV> UniqueSCEVs; FoldingSet<SCEV> UniqueSCEVs;

View File

@ -2452,6 +2452,11 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
if (Depth > MaxArithDepth || hasHugeExpression(Ops)) if (Depth > MaxArithDepth || hasHugeExpression(Ops))
return getOrCreateAddExpr(Ops, Flags); return getOrCreateAddExpr(Ops, Flags);
if (SCEV *S = std::get<0>(findExistingSCEVInCache(scAddExpr, Ops))) {
static_cast<SCEVAddExpr *>(S)->setNoWrapFlags(Flags);
return S;
}
// Okay, check to see if the same value occurs in the operand list more than // Okay, check to see if the same value occurs in the operand list more than
// once. If so, merge them together into an multiply expression. Since we // once. If so, merge them together into an multiply expression. Since we
// sorted the list, these values are required to be adjacent. // sorted the list, these values are required to be adjacent.
@ -2931,6 +2936,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
if (Depth > MaxArithDepth || hasHugeExpression(Ops)) if (Depth > MaxArithDepth || hasHugeExpression(Ops))
return getOrCreateMulExpr(Ops, Flags); return getOrCreateMulExpr(Ops, Flags);
if (SCEV *S = std::get<0>(findExistingSCEVInCache(scMulExpr, Ops))) {
static_cast<SCEVMulExpr *>(S)->setNoWrapFlags(Flags);
return S;
}
// If there are any constants, fold them together. // If there are any constants, fold them together.
unsigned Idx = 0; unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
@ -3193,6 +3203,14 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
getEffectiveSCEVType(RHS->getType()) && getEffectiveSCEVType(RHS->getType()) &&
"SCEVUDivExpr operand types don't match!"); "SCEVUDivExpr operand types don't match!");
FoldingSetNodeID ID;
ID.AddInteger(scUDivExpr);
ID.AddPointer(LHS);
ID.AddPointer(RHS);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
if (RHSC->getValue()->isOne()) if (RHSC->getValue()->isOne())
return LHS; // X udiv 1 --> x return LHS; // X udiv 1 --> x
@ -3239,9 +3257,24 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
AR->getLoop(), SCEV::FlagAnyWrap)) { AR->getLoop(), SCEV::FlagAnyWrap)) {
const APInt &StartInt = StartC->getAPInt(); const APInt &StartInt = StartC->getAPInt();
const APInt &StartRem = StartInt.urem(StepInt); const APInt &StartRem = StartInt.urem(StepInt);
if (StartRem != 0) if (StartRem != 0) {
LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, const SCEV *NewLHS =
AR->getLoop(), SCEV::FlagNW); getAddRecExpr(getConstant(StartInt - StartRem), Step,
AR->getLoop(), SCEV::FlagNW);
if (LHS != NewLHS) {
LHS = NewLHS;
// Reset the ID to include the new LHS, and check if it is
// already cached.
ID.clear();
ID.AddInteger(scUDivExpr);
ID.AddPointer(LHS);
ID.AddPointer(RHS);
IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
}
}
} }
} }
// (A*B)/C --> A*(B/C) if safe and B/C can be folded. // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
@ -3306,12 +3339,6 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
} }
} }
FoldingSetNodeID ID;
ID.AddInteger(scUDivExpr);
ID.AddPointer(LHS);
ID.AddPointer(RHS);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
LHS, RHS); LHS, RHS);
UniqueSCEVs.InsertNode(S, IP); UniqueSCEVs.InsertNode(S, IP);
@ -3537,7 +3564,7 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP,
return getAddExpr(BaseExpr, TotalOffset, Wrap); return getAddExpr(BaseExpr, TotalOffset, Wrap);
} }
std::tuple<const SCEV *, FoldingSetNodeID, void *> std::tuple<SCEV *, FoldingSetNodeID, void *>
ScalarEvolution::findExistingSCEVInCache(int SCEVType, ScalarEvolution::findExistingSCEVInCache(int SCEVType,
ArrayRef<const SCEV *> Ops) { ArrayRef<const SCEV *> Ops) {
FoldingSetNodeID ID; FoldingSetNodeID ID;
@ -3545,7 +3572,7 @@ ScalarEvolution::findExistingSCEVInCache(int SCEVType,
ID.AddInteger(SCEVType); ID.AddInteger(SCEVType);
for (unsigned i = 0, e = Ops.size(); i != e; ++i) for (unsigned i = 0, e = Ops.size(); i != e; ++i)
ID.AddPointer(Ops[i]); ID.AddPointer(Ops[i]);
return std::tuple<const SCEV *, FoldingSetNodeID, void *>( return std::tuple<SCEV *, FoldingSetNodeID, void *>(
UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP); UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP);
} }