[LSR] Check if terminating value is safe to expand before transformation

According to report by @JojoR, the assertion error was hit hence we need
to have this check before the actual transformation.

Reviewed By: Meinersbur, #loopoptwg

Differential Revision: https://reviews.llvm.org/D136415
This commit is contained in:
eopXD 2022-10-28 02:07:17 -07:00
parent 8d0b2f09a2
commit c0ef83e3b9
2 changed files with 134 additions and 50 deletions

View File

@ -6614,7 +6614,7 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE,
return nullptr;
}
static Optional<std::pair<PHINode *, PHINode *>>
static Optional<std::pair<PHINode *, std::pair<PHINode *, const SCEV *>>>
canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
const LoopInfo &LI) {
if (!L->isInnermost()) {
@ -6699,16 +6699,37 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
// For `IsToHelpFold`, other IV that is an affine AddRec will be sufficient to
// replace the terminating condition
auto IsToHelpFold = [&](PHINode &PN) -> bool {
auto IsToHelpFold = [&](PHINode &PN) -> std::pair<bool, const SCEV *> {
const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
const SCEV *BECount = SE.getBackedgeTakenCount(L);
const SCEV *TermValueS = SE.getAddExpr(
AddRec->getOperand(0),
SE.getTruncateOrZeroExtend(
SE.getMulExpr(
AddRec->getOperand(1),
SE.getTruncateOrZeroExtend(
SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
AddRec->getOperand(1)->getType())),
AddRec->getOperand(0)->getType()));
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
if (!Expander.isSafeToExpand(TermValueS)) {
LLVM_DEBUG(
dbgs() << "Is not safe to expand terminating value for phi node" << PN
<< "\n");
return {false, nullptr};
}
// TODO: Right now we limit the phi node to help the folding be of a start
// value of getelementptr. We can extend to any kinds of IV as long as it is
// an affine AddRec. Add a switch to cover more types of instructions here
// and down in the actual transformation.
return isa<GetElementPtrInst>(PN.getIncomingValueForBlock(LoopPreheader));
return {isa<GetElementPtrInst>(PN.getIncomingValueForBlock(LoopPreheader)),
TermValueS};
};
PHINode *ToFold = nullptr;
PHINode *ToHelpFold = nullptr;
const SCEV *TermValueS = nullptr;
for (PHINode &PN : L->getHeader()->phis()) {
if (!SE.isSCEVable(PN.getType())) {
@ -6729,8 +6750,10 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
if (IsToFold(PN))
ToFold = &PN;
else if (IsToHelpFold(PN))
else if (auto P = IsToHelpFold(PN); P.first) {
ToHelpFold = &PN;
TermValueS = P.second;
}
}
LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
@ -6746,7 +6769,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
if (!ToFold || !ToHelpFold)
return None;
return {{ToFold, ToHelpFold}};
return {{ToFold, {ToHelpFold, TermValueS}}};
}
static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
@ -6810,11 +6833,14 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
if (AllowTerminatingConditionFoldingAfterLSR) {
auto CanFoldTerminatingCondition = canFoldTermCondOfLoop(L, SE, DT, LI);
if (CanFoldTerminatingCondition) {
Changed = true;
NumTermFold++;
BasicBlock *LoopPreheader = L->getLoopPreheader();
BasicBlock *LoopLatch = L->getLoopLatch();
PHINode *ToFold = CanFoldTerminatingCondition->first;
PHINode *ToHelpFold = CanFoldTerminatingCondition->second;
PHINode *ToHelpFold = CanFoldTerminatingCondition->second.first;
(void)ToFold;
LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
@ -6834,56 +6860,35 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
GetElementPtrInst *StartValueGEP = cast<GetElementPtrInst>(StartValue);
Type *PtrTy = StartValueGEP->getPointerOperand()->getType();
const SCEV *BECount = SE.getBackedgeTakenCount(L);
const SCEVAddRecExpr *AddRec =
cast<SCEVAddRecExpr>(SE.getSCEV(ToHelpFold));
const SCEV *TermValueS = CanFoldTerminatingCondition->second.second;
assert(
Expander.isSafeToExpand(TermValueS) &&
"Terminating value was checked safe in canFoldTerminatingCondition");
// TermValue = Start + Stride * (BackedgeCount + 1)
const SCEV *TermValueS = SE.getAddExpr(
AddRec->getOperand(0),
SE.getTruncateOrZeroExtend(
SE.getMulExpr(
AddRec->getOperand(1),
SE.getTruncateOrZeroExtend(
SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
AddRec->getOperand(1)->getType())),
AddRec->getOperand(0)->getType()));
Value *TermValue = Expander.expandCodeFor(TermValueS, PtrTy,
LoopPreheader->getTerminator());
// NOTE: If this is triggered, we should add this into predicate
if (!Expander.isSafeToExpand(TermValueS)) {
LLVMContext &Ctx = L->getHeader()->getContext();
Ctx.emitError(
"Terminating value is not safe to expand, need to add it to "
"predicate");
} else { // Now we replace the condition with ToHelpFold and remove ToFold
Changed = true;
NumTermFold++;
LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
<< *StartValue << "\n"
<< "Terminating value of new term-cond phi-node:\n"
<< *TermValue << "\n");
Value *TermValue = Expander.expandCodeFor(
TermValueS, PtrTy, LoopPreheader->getTerminator());
// Create new terminating condition at loop latch
BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
Value *NewTermCond = LatchBuilder.CreateICmp(
OldTermCond->getPredicate(), LoopValue, TermValue,
"lsr_fold_term_cond.replaced_term_cond");
LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
<< *StartValue << "\n"
<< "Terminating value of new term-cond phi-node:\n"
<< *TermValue << "\n");
LLVM_DEBUG(dbgs() << "Old term-cond:\n"
<< *OldTermCond << "\n"
<< "New term-cond:\b" << *NewTermCond << "\n");
// Create new terminating condition at loop latch
BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
Value *NewTermCond = LatchBuilder.CreateICmp(
OldTermCond->getPredicate(), LoopValue, TermValue,
"lsr_fold_term_cond.replaced_term_cond");
BI->setCondition(NewTermCond);
LLVM_DEBUG(dbgs() << "Old term-cond:\n"
<< *OldTermCond << "\n"
<< "New term-cond:\b" << *NewTermCond << "\n");
BI->setCondition(NewTermCond);
OldTermCond->eraseFromParent();
DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
}
OldTermCond->eraseFromParent();
DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
ExpCleaner.markResultUsed();
}

View File

@ -158,3 +158,82 @@ for.body: ; preds = %for.body, %entry
for.end: ; preds = %for.body
ret void
}
; The test case is reduced from FFmpeg/libavfilter/ebur128.c
; Testing check if terminating value is safe to expand
%struct.FFEBUR128State = type { i32, ptr, i64, i64 }
@histogram_energy_boundaries = global [1001 x double] zeroinitializer, align 8
define void @ebur128_calc_gating_block(ptr %st, ptr %optional_output) {
; CHECK: Is not safe to expand terminating value for phi node %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ]
entry:
%0 = load i32, ptr %st, align 8
%conv = zext i32 %0 to i64
%cmp28.not = icmp eq i32 %0, 0
br i1 %cmp28.not, label %for.end13, label %for.cond2.preheader.lr.ph
for.cond2.preheader.lr.ph: ; preds = %entry
%audio_data_index = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 3
%1 = load i64, ptr %audio_data_index, align 8
%div = udiv i64 %1, %conv
%cmp525.not = icmp ult i64 %1, %conv
%audio_data = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 1
%umax = tail call i64 @llvm.umax.i64(i64 %div, i64 1)
br label %for.cond2.preheader
for.cond2.preheader: ; preds = %for.cond2.preheader.lr.ph, %for.inc11
%channel_sum.030 = phi double [ 0.000000e+00, %for.cond2.preheader.lr.ph ], [ %channel_sum.1.lcssa, %for.inc11 ]
%c.029 = phi i64 [ 0, %for.cond2.preheader.lr.ph ], [ %inc12, %for.inc11 ]
br i1 %cmp525.not, label %for.inc11, label %for.body7.lr.ph
for.body7.lr.ph: ; preds = %for.cond2.preheader
%2 = load ptr, ptr %audio_data, align 8
br label %for.body7
for.body7: ; preds = %for.body7.lr.ph, %for.body7
%channel_sum.127 = phi double [ %channel_sum.030, %for.body7.lr.ph ], [ %add10, %for.body7 ]
%i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ]
%mul = mul i64 %i.026, %conv
%add = add i64 %mul, %c.029
%arrayidx = getelementptr inbounds double, ptr %2, i64 %add
%3 = load double, ptr %arrayidx, align 8
%add10 = fadd double %channel_sum.127, %3
%inc = add nuw i64 %i.026, 1
%exitcond.not = icmp eq i64 %inc, %umax
br i1 %exitcond.not, label %for.inc11, label %for.body7
for.inc11: ; preds = %for.body7, %for.cond2.preheader
%channel_sum.1.lcssa = phi double [ %channel_sum.030, %for.cond2.preheader ], [ %add10, %for.body7 ]
%inc12 = add nuw nsw i64 %c.029, 1
%exitcond32.not = icmp eq i64 %inc12, %conv
br i1 %exitcond32.not, label %for.end13, label %for.cond2.preheader
for.end13: ; preds = %for.inc11, %entry
%channel_sum.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %channel_sum.1.lcssa, %for.inc11 ]
%add14 = fadd double %channel_sum.0.lcssa, 0.000000e+00
store double %add14, ptr %optional_output, align 8
ret void
}
declare i64 @llvm.umax.i64(i64, i64)
%struct.PAKT_INFO = type { i32, i32, i32, [0 x i32] }
define i64 @alac_seek(ptr %0) {
; CHECK: Is not safe to expand terminating value for phi node %indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ]
entry:
%div = udiv i64 1, 0
br label %for.body.i
for.body.i: ; preds = %for.body.i, %entry
%indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ]
%arrayidx.i = getelementptr %struct.PAKT_INFO, ptr %0, i64 0, i32 3, i64 %indvars.iv.i
%1 = load i32, ptr %arrayidx.i, align 4
%indvars.iv.next.i = add i64 %indvars.iv.i, 1
%exitcond.not.i = icmp eq i64 %indvars.iv.i, %div
br i1 %exitcond.not.i, label %alac_pakt_block_offset.exit, label %for.body.i
alac_pakt_block_offset.exit: ; preds = %for.body.i
ret i64 0
}