[mlir][scf] Add scf-to-cf lowering for `scf.index_switch`
This patch adds lowering from `scf.index_switch` to `cf.switch. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D136883
This commit is contained in:
parent
144e38f5e5
commit
91effec852
|
@ -290,6 +290,14 @@ struct DoWhileLowering : public OpRewritePattern<WhileOp> {
|
|||
LogicalResult matchAndRewrite(WhileOp whileOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
|
||||
struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IndexSwitchOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
|
||||
|
@ -615,10 +623,68 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
// Split the block at the op.
|
||||
Block *condBlock = rewriter.getInsertionBlock();
|
||||
Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
|
||||
|
||||
// Create the arguments on the continue block with which to replace the
|
||||
// results of the op.
|
||||
SmallVector<Value> results;
|
||||
results.reserve(op.getNumResults());
|
||||
for (Type resultType : op.getResultTypes())
|
||||
results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
|
||||
|
||||
// Handle the regions.
|
||||
auto convertRegion = [&](Region ®ion) -> FailureOr<Block *> {
|
||||
Block *block = ®ion.front();
|
||||
|
||||
// Convert the yield terminator to a branch to the continue block.
|
||||
auto yield = cast<scf::YieldOp>(block->getTerminator());
|
||||
rewriter.setInsertionPoint(yield);
|
||||
rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
|
||||
yield.getOperands());
|
||||
|
||||
// Inline the region.
|
||||
rewriter.inlineRegionBefore(region, continueBlock);
|
||||
return block;
|
||||
};
|
||||
|
||||
// Convert the case regions.
|
||||
SmallVector<Block *> caseSuccessors;
|
||||
SmallVector<int32_t> caseValues;
|
||||
caseSuccessors.reserve(op.getCases().size());
|
||||
caseValues.reserve(op.getCases().size());
|
||||
for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
|
||||
FailureOr<Block *> block = convertRegion(region);
|
||||
if (failed(block))
|
||||
return failure();
|
||||
caseSuccessors.push_back(*block);
|
||||
caseValues.push_back(value);
|
||||
}
|
||||
|
||||
// Convert the default region.
|
||||
FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
|
||||
if (failed(defaultBlock))
|
||||
return failure();
|
||||
|
||||
// Create the switch.
|
||||
rewriter.setInsertionPointToEnd(condBlock);
|
||||
SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
|
||||
rewriter.create<cf::SwitchOp>(
|
||||
op.getLoc(), op.getArg(), *defaultBlock, ValueRange(),
|
||||
rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
|
||||
rewriter.replaceOp(op, continueBlock->getArguments());
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::populateSCFToControlFlowConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
|
||||
ExecuteRegionLowering>(patterns.getContext());
|
||||
ExecuteRegionLowering, IndexSwitchLowering>(
|
||||
patterns.getContext());
|
||||
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
|
||||
}
|
||||
|
||||
|
|
|
@ -473,7 +473,7 @@ func.func @while_values(%arg0: i32, %arg1: f32) {
|
|||
scf.condition(%0) %2, %3 : i64, f64
|
||||
} do {
|
||||
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
|
||||
^bb0(%arg2: i64, %arg3: f64):
|
||||
^bb0(%arg2: i64, %arg3: f64):
|
||||
// CHECK: cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
|
||||
scf.yield %c0_i32, %cst : i32, f32
|
||||
}
|
||||
|
@ -620,3 +620,30 @@ func.func @func_execute_region_elim_multi_yield() {
|
|||
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
||||
// CHECK: "test.bar"(%[[z]])
|
||||
// CHECK: return
|
||||
|
||||
// SWITCH-LABEL: @index_switch
|
||||
func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
|
||||
// SWITCH: cf.switch %arg0 : index
|
||||
// SWITCH-NEXT: default: ^bb3
|
||||
// SWITCH-NEXT: 0: ^bb1
|
||||
// SWITCH-NEXT: 1: ^bb2
|
||||
%0 = scf.index_switch %i -> i32
|
||||
// SWITCH: ^bb1:
|
||||
case 0 {
|
||||
// SWITCH-NEXT: llvm.br ^bb4(%arg1
|
||||
scf.yield %a : i32
|
||||
}
|
||||
// SWITCH: ^bb2:
|
||||
case 1 {
|
||||
// SWITCH-NEXT: llvm.br ^bb4(%arg2
|
||||
scf.yield %b : i32
|
||||
}
|
||||
// SWITCH: ^bb3:
|
||||
default {
|
||||
// SWITCH-NEXT: llvm.br ^bb4(%arg3
|
||||
scf.yield %c : i32
|
||||
}
|
||||
// SWITCH: ^bb4(%[[V:.*]]: i32
|
||||
// SWITCH-NEXT: return %[[V]]
|
||||
return %0 : i32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue