[mlir] Introduce `replaceUsesOfWith` to `RewriterBase`
Finding uses of a value and replacing them with a new one is a common method. I have not seen an safe and easy shortcut that does that. This revision attempts to address that by intoroducing `replaceUsesOfWith` to `RewriterBase`. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D138110
This commit is contained in:
parent
6f48e68d39
commit
5ce68f4284
|
@ -502,6 +502,11 @@ public:
|
|||
finalizeRootUpdate(root);
|
||||
}
|
||||
|
||||
/// Find uses of `from` and replace it with `to`. It also marks every modified
|
||||
/// uses and notifies the rewriter that an in-place operation modification is
|
||||
/// about to happen.
|
||||
void replaceAllUsesWith(Value from, Value to);
|
||||
|
||||
/// Used to notify the rewriter that the IR failed to be rewritten because of
|
||||
/// a match failure, and provide a callback to populate a diagnostic with the
|
||||
/// reason why the failure occurred. This method allows for derived rewriters
|
||||
|
|
|
@ -246,15 +246,9 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
|
|||
sourceBlock.getOperations());
|
||||
|
||||
// Step 5. RAUW thread indices to thread ops.
|
||||
for (Value blockIdx : foreachThreadOp.getThreadIndices()) {
|
||||
Value val = bvm.lookup(blockIdx);
|
||||
SmallVector<OpOperand *> uses;
|
||||
for (OpOperand &use : blockIdx.getUses())
|
||||
uses.push_back(&use);
|
||||
for (OpOperand *operand : uses) {
|
||||
Operation *op = operand->getOwner();
|
||||
rewriter.updateRootInPlace(op, [&]() { operand->set(val); });
|
||||
}
|
||||
for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
|
||||
Value blockIdx = bvm.lookup(loopIndex);
|
||||
rewriter.replaceAllUsesWith(loopIndex, blockIdx);
|
||||
}
|
||||
|
||||
// Step 6. Erase old op.
|
||||
|
@ -492,15 +486,9 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
|
|||
sourceBlock.getOperations());
|
||||
|
||||
// Step 6. RAUW thread indices to thread ops.
|
||||
for (Value threadIdx : foreachThreadOp.getThreadIndices()) {
|
||||
Value val = bvm.lookup(threadIdx);
|
||||
SmallVector<OpOperand *> uses;
|
||||
for (OpOperand &use : threadIdx.getUses())
|
||||
uses.push_back(&use);
|
||||
for (OpOperand *operand : uses) {
|
||||
Operation *op = operand->getOwner();
|
||||
rewriter.updateRootInPlace(op, [&]() { operand->set(val); });
|
||||
}
|
||||
for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
|
||||
Value threadIdx = bvm.lookup(loopIndex);
|
||||
rewriter.replaceAllUsesWith(loopIndex, threadIdx);
|
||||
}
|
||||
|
||||
// Step 7. syncthreads.
|
||||
|
|
|
@ -309,6 +309,14 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
|
|||
source->erase();
|
||||
}
|
||||
|
||||
/// Find uses of `from` and replace it with `to`
|
||||
void RewriterBase::replaceAllUsesWith(Value from, Value to) {
|
||||
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
|
||||
Operation *op = operand.getOwner();
|
||||
updateRootInPlace(op, [&]() { operand.set(to); });
|
||||
}
|
||||
}
|
||||
|
||||
// Merge the operations of block 'source' before the operation 'op'. Source
|
||||
// block should not have existing predecessors or successors.
|
||||
void RewriterBase::mergeBlockBefore(Block *source, Operation *op,
|
||||
|
|
Loading…
Reference in New Issue