[SCEV] Make SCEVUnionPredicate externally immutable [NFC]

This is the last major stepping stone before being able to allocate the node via the folding set allocator.  That will in turn allow more general SCEV predicate expression trees.
This commit is contained in:
Philip Reames 2022-02-09 13:15:17 -08:00
parent a7b5e5b413
commit d334fec140
3 changed files with 29 additions and 22 deletions

View File

@ -425,16 +425,16 @@ private:
/// Maps SCEVs to predicates for quick look-ups.
PredicateMap SCEVToPreds;
/// Adds a predicate to this union.
void add(const SCEVPredicate *N);
public:
SCEVUnionPredicate();
SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds);
const SmallVectorImpl<const SCEVPredicate *> &getPredicates() const {
return Preds;
}
/// Adds a predicate to this union.
void add(const SCEVPredicate *N);
/// Returns a reference to a vector containing all predicates which apply to
/// \p Expr.
ArrayRef<const SCEVPredicate *> getPredicatesForExpr(const SCEV *Expr);
@ -2254,7 +2254,7 @@ private:
/// The SCEVPredicate that forms our context. We will rewrite all
/// expressions assuming that this predicate true.
SCEVUnionPredicate Preds;
std::unique_ptr<SCEVUnionPredicate> Preds;
/// Marks the version of the SCEV predicate used. When rewriting a SCEV
/// expression we mark it with the version of the predicate. We use this to

View File

@ -5489,8 +5489,8 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
return true;
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) &&
!Preds.implies(SE.getEqualPredicate(Expr2, Expr1)))
if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
return false;
return true;
};
@ -12818,9 +12818,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
if (!isa<SCEVCouldNotCompute>(PBT)) {
OS << "Predicated backedge-taken count is " << *PBT << "\n";
OS << " Predicates:\n";
SCEVUnionPredicate Dedup;
for (auto *P : Preds)
Dedup.add(P);
SCEVUnionPredicate Dedup(Preds);
Dedup.print(OS, 4);
} else {
OS << "Unpredictable predicated backedge-taken count. ";
@ -13807,8 +13805,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
}
/// Union predicates don't get cached so create a dummy set ID for it.
SCEVUnionPredicate::SCEVUnionPredicate()
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
for (auto *P : Preds)
add(P);
}
bool SCEVUnionPredicate::isAlwaysTrue() const {
return all_of(Preds,
@ -13864,7 +13865,10 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) {
PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
Loop &L)
: SE(SE), L(L) {}
: SE(SE), L(L) {
SmallVector<const SCEVPredicate*, 4> Empty;
Preds = std::make_unique<SCEVUnionPredicate>(Empty);
}
void ScalarEvolution::registerUser(const SCEV *User,
ArrayRef<const SCEV *> Ops) {
@ -13889,7 +13893,7 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
if (Entry.second)
Expr = Entry.second;
const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds);
const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
Entry = {Generation, NewSCEV};
return NewSCEV;
@ -13906,14 +13910,18 @@ const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
}
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
if (Preds.implies(&Pred))
if (Preds->implies(&Pred))
return;
Preds.add(&Pred);
auto &OldPreds = Preds->getPredicates();
SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
NewPreds.push_back(&Pred);
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
updateGeneration();
}
const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
return Preds;
return *Preds;
}
void PredicatedScalarEvolution::updateGeneration() {
@ -13921,7 +13929,7 @@ void PredicatedScalarEvolution::updateGeneration() {
if (++Generation == 0) {
for (auto &II : RewriteMap) {
const SCEV *Rewritten = II.second.second;
II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)};
II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
}
}
}
@ -13975,8 +13983,9 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
PredicatedScalarEvolution::PredicatedScalarEvolution(
const PredicatedScalarEvolution &Init)
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds),
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
for (auto I : Init.FlagsMap)
FlagsMap.insert(I);
}

View File

@ -942,7 +942,6 @@ TEST_F(ScalarEvolutionsTest, SCEVAddRecFromPHIwithLargeConstants) {
// Make sure that SCEV doesn't blow up
ScalarEvolution SE = buildSE(*F);
SCEVUnionPredicate Preds;
const SCEV *Expr = SE.getSCEV(Phi);
EXPECT_NE(nullptr, Expr);
EXPECT_TRUE(isa<SCEVUnknown>(Expr));
@ -1000,7 +999,6 @@ TEST_F(ScalarEvolutionsTest, SCEVAddRecFromPHIwithLargeConstantAccum) {
// Make sure that SCEV doesn't blow up
ScalarEvolution SE = buildSE(*F);
SCEVUnionPredicate Preds;
const SCEV *Expr = SE.getSCEV(Phi);
EXPECT_NE(nullptr, Expr);
EXPECT_TRUE(isa<SCEVUnknown>(Expr));