[LSR] Check if terminating value is safe to expand before transformation
According to report by @JojoR, the assertion error was hit hence we need to have this check before the actual transformation. Reviewed By: Meinersbur, #loopoptwg Differential Revision: https://reviews.llvm.org/D136415
This commit is contained in:
parent
8d0b2f09a2
commit
c0ef83e3b9
|
@ -6614,7 +6614,7 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
static Optional<std::pair<PHINode *, PHINode *>>
|
||||
static Optional<std::pair<PHINode *, std::pair<PHINode *, const SCEV *>>>
|
||||
canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
|
||||
const LoopInfo &LI) {
|
||||
if (!L->isInnermost()) {
|
||||
|
@ -6699,16 +6699,37 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
|
|||
|
||||
// For `IsToHelpFold`, other IV that is an affine AddRec will be sufficient to
|
||||
// replace the terminating condition
|
||||
auto IsToHelpFold = [&](PHINode &PN) -> bool {
|
||||
auto IsToHelpFold = [&](PHINode &PN) -> std::pair<bool, const SCEV *> {
|
||||
const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
|
||||
const SCEV *BECount = SE.getBackedgeTakenCount(L);
|
||||
const SCEV *TermValueS = SE.getAddExpr(
|
||||
AddRec->getOperand(0),
|
||||
SE.getTruncateOrZeroExtend(
|
||||
SE.getMulExpr(
|
||||
AddRec->getOperand(1),
|
||||
SE.getTruncateOrZeroExtend(
|
||||
SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
|
||||
AddRec->getOperand(1)->getType())),
|
||||
AddRec->getOperand(0)->getType()));
|
||||
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
|
||||
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
|
||||
if (!Expander.isSafeToExpand(TermValueS)) {
|
||||
LLVM_DEBUG(
|
||||
dbgs() << "Is not safe to expand terminating value for phi node" << PN
|
||||
<< "\n");
|
||||
return {false, nullptr};
|
||||
}
|
||||
// TODO: Right now we limit the phi node to help the folding be of a start
|
||||
// value of getelementptr. We can extend to any kinds of IV as long as it is
|
||||
// an affine AddRec. Add a switch to cover more types of instructions here
|
||||
// and down in the actual transformation.
|
||||
return isa<GetElementPtrInst>(PN.getIncomingValueForBlock(LoopPreheader));
|
||||
return {isa<GetElementPtrInst>(PN.getIncomingValueForBlock(LoopPreheader)),
|
||||
TermValueS};
|
||||
};
|
||||
|
||||
PHINode *ToFold = nullptr;
|
||||
PHINode *ToHelpFold = nullptr;
|
||||
const SCEV *TermValueS = nullptr;
|
||||
|
||||
for (PHINode &PN : L->getHeader()->phis()) {
|
||||
if (!SE.isSCEVable(PN.getType())) {
|
||||
|
@ -6729,8 +6750,10 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
|
|||
|
||||
if (IsToFold(PN))
|
||||
ToFold = &PN;
|
||||
else if (IsToHelpFold(PN))
|
||||
else if (auto P = IsToHelpFold(PN); P.first) {
|
||||
ToHelpFold = &PN;
|
||||
TermValueS = P.second;
|
||||
}
|
||||
}
|
||||
|
||||
LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
|
||||
|
@ -6746,7 +6769,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
|
|||
|
||||
if (!ToFold || !ToHelpFold)
|
||||
return None;
|
||||
return {{ToFold, ToHelpFold}};
|
||||
return {{ToFold, {ToHelpFold, TermValueS}}};
|
||||
}
|
||||
|
||||
static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
|
||||
|
@ -6810,11 +6833,14 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
|
|||
if (AllowTerminatingConditionFoldingAfterLSR) {
|
||||
auto CanFoldTerminatingCondition = canFoldTermCondOfLoop(L, SE, DT, LI);
|
||||
if (CanFoldTerminatingCondition) {
|
||||
Changed = true;
|
||||
NumTermFold++;
|
||||
|
||||
BasicBlock *LoopPreheader = L->getLoopPreheader();
|
||||
BasicBlock *LoopLatch = L->getLoopLatch();
|
||||
|
||||
PHINode *ToFold = CanFoldTerminatingCondition->first;
|
||||
PHINode *ToHelpFold = CanFoldTerminatingCondition->second;
|
||||
PHINode *ToHelpFold = CanFoldTerminatingCondition->second.first;
|
||||
|
||||
(void)ToFold;
|
||||
LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
|
||||
|
@ -6834,56 +6860,35 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
|
|||
GetElementPtrInst *StartValueGEP = cast<GetElementPtrInst>(StartValue);
|
||||
Type *PtrTy = StartValueGEP->getPointerOperand()->getType();
|
||||
|
||||
const SCEV *BECount = SE.getBackedgeTakenCount(L);
|
||||
const SCEVAddRecExpr *AddRec =
|
||||
cast<SCEVAddRecExpr>(SE.getSCEV(ToHelpFold));
|
||||
const SCEV *TermValueS = CanFoldTerminatingCondition->second.second;
|
||||
assert(
|
||||
Expander.isSafeToExpand(TermValueS) &&
|
||||
"Terminating value was checked safe in canFoldTerminatingCondition");
|
||||
|
||||
// TermValue = Start + Stride * (BackedgeCount + 1)
|
||||
const SCEV *TermValueS = SE.getAddExpr(
|
||||
AddRec->getOperand(0),
|
||||
SE.getTruncateOrZeroExtend(
|
||||
SE.getMulExpr(
|
||||
AddRec->getOperand(1),
|
||||
SE.getTruncateOrZeroExtend(
|
||||
SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
|
||||
AddRec->getOperand(1)->getType())),
|
||||
AddRec->getOperand(0)->getType()));
|
||||
Value *TermValue = Expander.expandCodeFor(TermValueS, PtrTy,
|
||||
LoopPreheader->getTerminator());
|
||||
|
||||
// NOTE: If this is triggered, we should add this into predicate
|
||||
if (!Expander.isSafeToExpand(TermValueS)) {
|
||||
LLVMContext &Ctx = L->getHeader()->getContext();
|
||||
Ctx.emitError(
|
||||
"Terminating value is not safe to expand, need to add it to "
|
||||
"predicate");
|
||||
} else { // Now we replace the condition with ToHelpFold and remove ToFold
|
||||
Changed = true;
|
||||
NumTermFold++;
|
||||
LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
|
||||
<< *StartValue << "\n"
|
||||
<< "Terminating value of new term-cond phi-node:\n"
|
||||
<< *TermValue << "\n");
|
||||
|
||||
Value *TermValue = Expander.expandCodeFor(
|
||||
TermValueS, PtrTy, LoopPreheader->getTerminator());
|
||||
// Create new terminating condition at loop latch
|
||||
BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
|
||||
ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
|
||||
IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
|
||||
Value *NewTermCond = LatchBuilder.CreateICmp(
|
||||
OldTermCond->getPredicate(), LoopValue, TermValue,
|
||||
"lsr_fold_term_cond.replaced_term_cond");
|
||||
|
||||
LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
|
||||
<< *StartValue << "\n"
|
||||
<< "Terminating value of new term-cond phi-node:\n"
|
||||
<< *TermValue << "\n");
|
||||
LLVM_DEBUG(dbgs() << "Old term-cond:\n"
|
||||
<< *OldTermCond << "\n"
|
||||
<< "New term-cond:\b" << *NewTermCond << "\n");
|
||||
|
||||
// Create new terminating condition at loop latch
|
||||
BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
|
||||
ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
|
||||
IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
|
||||
Value *NewTermCond = LatchBuilder.CreateICmp(
|
||||
OldTermCond->getPredicate(), LoopValue, TermValue,
|
||||
"lsr_fold_term_cond.replaced_term_cond");
|
||||
BI->setCondition(NewTermCond);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "Old term-cond:\n"
|
||||
<< *OldTermCond << "\n"
|
||||
<< "New term-cond:\b" << *NewTermCond << "\n");
|
||||
|
||||
BI->setCondition(NewTermCond);
|
||||
|
||||
OldTermCond->eraseFromParent();
|
||||
DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
|
||||
}
|
||||
OldTermCond->eraseFromParent();
|
||||
DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
|
||||
|
||||
ExpCleaner.markResultUsed();
|
||||
}
|
||||
|
|
|
@ -158,3 +158,82 @@ for.body: ; preds = %for.body, %entry
|
|||
for.end: ; preds = %for.body
|
||||
ret void
|
||||
}
|
||||
|
||||
; The test case is reduced from FFmpeg/libavfilter/ebur128.c
|
||||
; Testing check if terminating value is safe to expand
|
||||
%struct.FFEBUR128State = type { i32, ptr, i64, i64 }
|
||||
|
||||
@histogram_energy_boundaries = global [1001 x double] zeroinitializer, align 8
|
||||
|
||||
define void @ebur128_calc_gating_block(ptr %st, ptr %optional_output) {
|
||||
; CHECK: Is not safe to expand terminating value for phi node %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ]
|
||||
entry:
|
||||
%0 = load i32, ptr %st, align 8
|
||||
%conv = zext i32 %0 to i64
|
||||
%cmp28.not = icmp eq i32 %0, 0
|
||||
br i1 %cmp28.not, label %for.end13, label %for.cond2.preheader.lr.ph
|
||||
|
||||
for.cond2.preheader.lr.ph: ; preds = %entry
|
||||
%audio_data_index = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 3
|
||||
%1 = load i64, ptr %audio_data_index, align 8
|
||||
%div = udiv i64 %1, %conv
|
||||
%cmp525.not = icmp ult i64 %1, %conv
|
||||
%audio_data = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 1
|
||||
%umax = tail call i64 @llvm.umax.i64(i64 %div, i64 1)
|
||||
br label %for.cond2.preheader
|
||||
|
||||
for.cond2.preheader: ; preds = %for.cond2.preheader.lr.ph, %for.inc11
|
||||
%channel_sum.030 = phi double [ 0.000000e+00, %for.cond2.preheader.lr.ph ], [ %channel_sum.1.lcssa, %for.inc11 ]
|
||||
%c.029 = phi i64 [ 0, %for.cond2.preheader.lr.ph ], [ %inc12, %for.inc11 ]
|
||||
br i1 %cmp525.not, label %for.inc11, label %for.body7.lr.ph
|
||||
|
||||
for.body7.lr.ph: ; preds = %for.cond2.preheader
|
||||
%2 = load ptr, ptr %audio_data, align 8
|
||||
br label %for.body7
|
||||
|
||||
for.body7: ; preds = %for.body7.lr.ph, %for.body7
|
||||
%channel_sum.127 = phi double [ %channel_sum.030, %for.body7.lr.ph ], [ %add10, %for.body7 ]
|
||||
%i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ]
|
||||
%mul = mul i64 %i.026, %conv
|
||||
%add = add i64 %mul, %c.029
|
||||
%arrayidx = getelementptr inbounds double, ptr %2, i64 %add
|
||||
%3 = load double, ptr %arrayidx, align 8
|
||||
%add10 = fadd double %channel_sum.127, %3
|
||||
%inc = add nuw i64 %i.026, 1
|
||||
%exitcond.not = icmp eq i64 %inc, %umax
|
||||
br i1 %exitcond.not, label %for.inc11, label %for.body7
|
||||
|
||||
for.inc11: ; preds = %for.body7, %for.cond2.preheader
|
||||
%channel_sum.1.lcssa = phi double [ %channel_sum.030, %for.cond2.preheader ], [ %add10, %for.body7 ]
|
||||
%inc12 = add nuw nsw i64 %c.029, 1
|
||||
%exitcond32.not = icmp eq i64 %inc12, %conv
|
||||
br i1 %exitcond32.not, label %for.end13, label %for.cond2.preheader
|
||||
|
||||
for.end13: ; preds = %for.inc11, %entry
|
||||
%channel_sum.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %channel_sum.1.lcssa, %for.inc11 ]
|
||||
%add14 = fadd double %channel_sum.0.lcssa, 0.000000e+00
|
||||
store double %add14, ptr %optional_output, align 8
|
||||
ret void
|
||||
}
|
||||
|
||||
declare i64 @llvm.umax.i64(i64, i64)
|
||||
|
||||
%struct.PAKT_INFO = type { i32, i32, i32, [0 x i32] }
|
||||
|
||||
define i64 @alac_seek(ptr %0) {
|
||||
; CHECK: Is not safe to expand terminating value for phi node %indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ]
|
||||
entry:
|
||||
%div = udiv i64 1, 0
|
||||
br label %for.body.i
|
||||
|
||||
for.body.i: ; preds = %for.body.i, %entry
|
||||
%indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ]
|
||||
%arrayidx.i = getelementptr %struct.PAKT_INFO, ptr %0, i64 0, i32 3, i64 %indvars.iv.i
|
||||
%1 = load i32, ptr %arrayidx.i, align 4
|
||||
%indvars.iv.next.i = add i64 %indvars.iv.i, 1
|
||||
%exitcond.not.i = icmp eq i64 %indvars.iv.i, %div
|
||||
br i1 %exitcond.not.i, label %alac_pakt_block_offset.exit, label %for.body.i
|
||||
|
||||
alac_pakt_block_offset.exit: ; preds = %for.body.i
|
||||
ret i64 0
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue