[mlir][NFC] Replace some nested if with logical and.

This patch replaces some nested if statement with logical and to reduce the nesting depth.

Differential Revision: https://reviews.llvm.org/D126050
This commit is contained in:
jacquesguan 2022-05-20 08:48:52 +00:00
parent c11051a400
commit 10c9ecce9f
1 changed files with 27 additions and 32 deletions

View File

@ -1296,20 +1296,16 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
if (matchPattern(getRhs(), m_Zero())) {
if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
// extsi(%x : i1 -> iN) != 0 -> %x
if (getPredicate() == arith::CmpIPredicate::ne) {
return extOp.getOperand();
}
}
// extsi(%x : i1 -> iN) != 0 -> %x
if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
// extui(%x : i1 -> iN) != 0 -> %x
if (getPredicate() == arith::CmpIPredicate::ne) {
return extOp.getOperand();
}
}
// extui(%x : i1 -> iN) != 0 -> %x
if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
}
@ -1733,24 +1729,24 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
return failure();
// select %x, c1, %c0 => extui %arg
if (matchPattern(op.getTrueValue(), m_One()))
if (matchPattern(op.getFalseValue(), m_Zero())) {
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
op.getCondition());
return success();
}
if (matchPattern(op.getTrueValue(), m_One()) &&
matchPattern(op.getFalseValue(), m_Zero())) {
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
op.getCondition());
return success();
}
// select %x, c0, %c1 => extui (xor %arg, true)
if (matchPattern(op.getTrueValue(), m_Zero()))
if (matchPattern(op.getFalseValue(), m_One())) {
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
op, op.getType(),
rewriter.create<arith::XOrIOp>(
op.getLoc(), op.getCondition(),
rewriter.create<arith::ConstantIntOp>(
op.getLoc(), 1, op.getCondition().getType())));
return success();
}
if (matchPattern(op.getTrueValue(), m_Zero()) &&
matchPattern(op.getFalseValue(), m_One())) {
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
op, op.getType(),
rewriter.create<arith::XOrIOp>(
op.getLoc(), op.getCondition(),
rewriter.create<arith::ConstantIntOp>(
op.getLoc(), 1, op.getCondition().getType())));
return success();
}
return failure();
}
@ -1778,10 +1774,9 @@ OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
return falseVal;
// select %x, true, false => %x
if (getType().isInteger(1))
if (matchPattern(getTrueValue(), m_One()))
if (matchPattern(getFalseValue(), m_Zero()))
return condition;
if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
matchPattern(getFalseValue(), m_Zero()))
return condition;
if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
auto pred = cmp.getPredicate();