[RISCV] Lower i8/i16 bswap/bitreverse to grevi/greviw with Zbp.

Include known bits support so we know we don't need to zext the
output if the input was already zero extended.

Reviewed By: luismarques

Differential Revision: https://reviews.llvm.org/D103757
This commit is contained in:
Craig Topper 2021-06-07 10:21:53 -07:00
parent c880d5e583
commit f30f8b4f12
3 changed files with 64 additions and 34 deletions

View File

@ -250,12 +250,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.hasStdExtZbp()) {
// Custom lower bswap/bitreverse so we can convert them to GREVI to enable
// more combining.
setOperationAction(ISD::BITREVERSE, XLenVT, Custom);
setOperationAction(ISD::BSWAP, XLenVT, Custom);
setOperationAction(ISD::BITREVERSE, XLenVT, Custom);
setOperationAction(ISD::BSWAP, XLenVT, Custom);
setOperationAction(ISD::BITREVERSE, MVT::i8, Custom);
// BSWAP i8 doesn't exist.
setOperationAction(ISD::BITREVERSE, MVT::i16, Custom);
setOperationAction(ISD::BSWAP, MVT::i16, Custom);
if (Subtarget.is64Bit()) {
setOperationAction(ISD::BITREVERSE, MVT::i32, Custom);
setOperationAction(ISD::BSWAP, MVT::i32, Custom);
setOperationAction(ISD::BSWAP, MVT::i32, Custom);
}
} else {
// With Zbb we have an XLen rev8 instruction, but not GREVI. So we'll
@ -4861,16 +4865,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
}
case ISD::BSWAP:
case ISD::BITREVERSE: {
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
MVT VT = N->getSimpleValueType(0);
MVT XLenVT = Subtarget.getXLenVT();
assert((VT == MVT::i8 || VT == MVT::i16 ||
(VT == MVT::i32 && Subtarget.is64Bit())) &&
Subtarget.hasStdExtZbp() && "Unexpected custom legalisation");
SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64,
N->getOperand(0));
unsigned Imm = N->getOpcode() == ISD::BITREVERSE ? 31 : 24;
SDValue GREVIW = DAG.getNode(RISCVISD::GREVW, DL, MVT::i64, NewOp0,
DAG.getConstant(Imm, DL, MVT::i64));
SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, N->getOperand(0));
unsigned Imm = VT.getSizeInBits() - 1;
// If this is BSWAP rather than BITREVERSE, clear the lower 3 bits.
if (N->getOpcode() == ISD::BSWAP)
Imm &= ~0x7U;
unsigned Opc = Subtarget.is64Bit() ? RISCVISD::GREVW : RISCVISD::GREV;
SDValue GREVI =
DAG.getNode(Opc, DL, XLenVT, NewOp0, DAG.getConstant(Imm, DL, XLenVT));
// ReplaceNodeResults requires we maintain the same type for the return
// value.
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, GREVIW));
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, GREVI));
break;
}
case ISD::FSHL:
@ -6066,6 +6076,24 @@ bool RISCVTargetLowering::targetShrinkDemandedConstant(
return UseMask(NewMask);
}
static void computeGREV(APInt &Src, unsigned ShAmt) {
ShAmt &= Src.getBitWidth() - 1;
uint64_t x = Src.getZExtValue();
if (ShAmt & 1)
x = ((x & 0x5555555555555555LL) << 1) | ((x & 0xAAAAAAAAAAAAAAAALL) >> 1);
if (ShAmt & 2)
x = ((x & 0x3333333333333333LL) << 2) | ((x & 0xCCCCCCCCCCCCCCCCLL) >> 2);
if (ShAmt & 4)
x = ((x & 0x0F0F0F0F0F0F0F0FLL) << 4) | ((x & 0xF0F0F0F0F0F0F0F0LL) >> 4);
if (ShAmt & 8)
x = ((x & 0x00FF00FF00FF00FFLL) << 8) | ((x & 0xFF00FF00FF00FF00LL) >> 8);
if (ShAmt & 16)
x = ((x & 0x0000FFFF0000FFFFLL) << 16) | ((x & 0xFFFF0000FFFF0000LL) >> 16);
if (ShAmt & 32)
x = ((x & 0x00000000FFFFFFFFLL) << 32) | ((x & 0xFFFFFFFF00000000LL) >> 32);
Src = x;
}
void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
KnownBits &Known,
const APInt &DemandedElts,
@ -6128,6 +6156,20 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
Known.Zero.setBitsFrom(LowBits);
break;
}
case RISCVISD::GREV:
case RISCVISD::GREVW: {
if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
Known = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
if (Opc == RISCVISD::GREVW)
Known = Known.trunc(32);
unsigned ShAmt = C->getZExtValue();
computeGREV(Known.Zero, ShAmt);
computeGREV(Known.One, ShAmt);
if (Opc == RISCVISD::GREVW)
Known = Known.sext(BitWidth);
}
break;
}
case RISCVISD::READ_VLENB:
// We assume VLENB is at least 16 bytes.
Known.Zero.setLowBits(4);

View File

@ -2353,14 +2353,12 @@ define zeroext i16 @bswap_i16(i16 zeroext %a) nounwind {
;
; RV32IB-LABEL: bswap_i16:
; RV32IB: # %bb.0:
; RV32IB-NEXT: rev8 a0, a0
; RV32IB-NEXT: srli a0, a0, 16
; RV32IB-NEXT: rev8.h a0, a0
; RV32IB-NEXT: ret
;
; RV32IBP-LABEL: bswap_i16:
; RV32IBP: # %bb.0:
; RV32IBP-NEXT: rev8 a0, a0
; RV32IBP-NEXT: srli a0, a0, 16
; RV32IBP-NEXT: rev8.h a0, a0
; RV32IBP-NEXT: ret
%1 = tail call i16 @llvm.bswap.i16(i16 %a)
ret i16 %1
@ -2467,14 +2465,12 @@ define zeroext i8 @bitreverse_i8(i8 zeroext %a) nounwind {
;
; RV32IB-LABEL: bitreverse_i8:
; RV32IB: # %bb.0:
; RV32IB-NEXT: rev a0, a0
; RV32IB-NEXT: srli a0, a0, 24
; RV32IB-NEXT: rev.b a0, a0
; RV32IB-NEXT: ret
;
; RV32IBP-LABEL: bitreverse_i8:
; RV32IBP: # %bb.0:
; RV32IBP-NEXT: rev a0, a0
; RV32IBP-NEXT: srli a0, a0, 24
; RV32IBP-NEXT: rev.b a0, a0
; RV32IBP-NEXT: ret
%1 = tail call i8 @llvm.bitreverse.i8(i8 %a)
ret i8 %1
@ -2519,14 +2515,12 @@ define zeroext i16 @bitreverse_i16(i16 zeroext %a) nounwind {
;
; RV32IB-LABEL: bitreverse_i16:
; RV32IB: # %bb.0:
; RV32IB-NEXT: rev a0, a0
; RV32IB-NEXT: srli a0, a0, 16
; RV32IB-NEXT: rev.h a0, a0
; RV32IB-NEXT: ret
;
; RV32IBP-LABEL: bitreverse_i16:
; RV32IBP: # %bb.0:
; RV32IBP-NEXT: rev a0, a0
; RV32IBP-NEXT: srli a0, a0, 16
; RV32IBP-NEXT: rev.h a0, a0
; RV32IBP-NEXT: ret
%1 = tail call i16 @llvm.bitreverse.i16(i16 %a)
ret i16 %1

View File

@ -2679,14 +2679,12 @@ define zeroext i16 @bswap_i16(i16 zeroext %a) nounwind {
;
; RV64IB-LABEL: bswap_i16:
; RV64IB: # %bb.0:
; RV64IB-NEXT: rev8 a0, a0
; RV64IB-NEXT: srli a0, a0, 48
; RV64IB-NEXT: greviw a0, a0, 8
; RV64IB-NEXT: ret
;
; RV64IBP-LABEL: bswap_i16:
; RV64IBP: # %bb.0:
; RV64IBP-NEXT: rev8 a0, a0
; RV64IBP-NEXT: srli a0, a0, 48
; RV64IBP-NEXT: greviw a0, a0, 8
; RV64IBP-NEXT: ret
%1 = tail call i16 @llvm.bswap.i16(i16 %a)
ret i16 %1
@ -2832,14 +2830,12 @@ define zeroext i8 @bitreverse_i8(i8 zeroext %a) nounwind {
;
; RV64IB-LABEL: bitreverse_i8:
; RV64IB: # %bb.0:
; RV64IB-NEXT: rev a0, a0
; RV64IB-NEXT: srli a0, a0, 56
; RV64IB-NEXT: greviw a0, a0, 7
; RV64IB-NEXT: ret
;
; RV64IBP-LABEL: bitreverse_i8:
; RV64IBP: # %bb.0:
; RV64IBP-NEXT: rev a0, a0
; RV64IBP-NEXT: srli a0, a0, 56
; RV64IBP-NEXT: greviw a0, a0, 7
; RV64IBP-NEXT: ret
%1 = tail call i8 @llvm.bitreverse.i8(i8 %a)
ret i8 %1
@ -2884,14 +2880,12 @@ define zeroext i16 @bitreverse_i16(i16 zeroext %a) nounwind {
;
; RV64IB-LABEL: bitreverse_i16:
; RV64IB: # %bb.0:
; RV64IB-NEXT: rev a0, a0
; RV64IB-NEXT: srli a0, a0, 48
; RV64IB-NEXT: greviw a0, a0, 15
; RV64IB-NEXT: ret
;
; RV64IBP-LABEL: bitreverse_i16:
; RV64IBP: # %bb.0:
; RV64IBP-NEXT: rev a0, a0
; RV64IBP-NEXT: srli a0, a0, 48
; RV64IBP-NEXT: greviw a0, a0, 15
; RV64IBP-NEXT: ret
%1 = tail call i16 @llvm.bitreverse.i16(i16 %a)
ret i16 %1