[mlir][Arith] Fix a use-after-free after rewriting ops to unsigned

Just short-circuit when a change was made, the erased value is invalid
after that. Found by asan.

This pass looks like it could use rewrite patterns instead which don't
have this issue, but let's fix the asan build first.
This commit is contained in:
Benjamin Kramer 2022-06-15 10:27:19 +02:00
parent 687e56614f
commit 0886ea902b
1 changed files with 13 additions and 10 deletions

View File

@ -90,7 +90,7 @@ static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) {
} }
template <typename T, typename U> template <typename T, typename U>
static void rewriteOp(Operation *op, OpBuilder &b) { static bool rewriteOp(Operation *op, OpBuilder &b) {
if (isa<T>(op)) { if (isa<T>(op)) {
OpBuilder::InsertionGuard guard(b); OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op); b.setInsertionPoint(op);
@ -98,28 +98,31 @@ static void rewriteOp(Operation *op, OpBuilder &b) {
op->getOperands(), op->getAttrs()); op->getOperands(), op->getAttrs());
op->replaceAllUsesWith(newOp->getResults()); op->replaceAllUsesWith(newOp->getResults());
op->erase(); op->erase();
return true;
} }
return false;
} }
static void rewriteCmpI(Operation *op, OpBuilder &b) { static bool rewriteCmpI(Operation *op, OpBuilder &b) {
if (auto cmpOp = dyn_cast<CmpIOp>(op)) { if (auto cmpOp = dyn_cast<CmpIOp>(op)) {
cmpOp.setPredicateAttr(CmpIPredicateAttr::get( cmpOp.setPredicateAttr(CmpIPredicateAttr::get(
b.getContext(), toUnsignedPred(cmpOp.getPredicate()))); b.getContext(), toUnsignedPred(cmpOp.getPredicate())));
return true;
} }
return false;
} }
static void rewrite(Operation *root, const OpList &toReplace) { static void rewrite(Operation *root, const OpList &toReplace) {
OpBuilder b(root->getContext()); OpBuilder b(root->getContext());
b.setInsertionPoint(root); b.setInsertionPoint(root);
for (Operation *op : toReplace) { for (Operation *op : toReplace) {
rewriteOp<DivSIOp, DivUIOp>(op, b); rewriteOp<DivSIOp, DivUIOp>(op, b) ||
rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b); rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b) ||
rewriteOp<FloorDivSIOp, DivUIOp>(op, b); rewriteOp<FloorDivSIOp, DivUIOp>(op, b) ||
rewriteOp<RemSIOp, RemUIOp>(op, b); rewriteOp<RemSIOp, RemUIOp>(op, b) ||
rewriteOp<MinSIOp, MinUIOp>(op, b); rewriteOp<MinSIOp, MinUIOp>(op, b) ||
rewriteOp<MaxSIOp, MaxUIOp>(op, b); rewriteOp<MaxSIOp, MaxUIOp>(op, b) ||
rewriteOp<ExtSIOp, ExtUIOp>(op, b); rewriteOp<ExtSIOp, ExtUIOp>(op, b) || rewriteCmpI(op, b);
rewriteCmpI(op, b);
} }
} }