[ARM] Transform a floating-point to fixed-point conversion to a VCVT_fix

Much like fixed-point to floating-point conversion, the converse can
also be transformed into a fixed-point VCVT. This patch transforms
multiplications of floating point numbers by 2^n into a VCVT_fix. The
exception is that a float to fixed conversion with 1 fractional bit
ends up being an FADD (FADD(x, x) emulates FMUL(x, 2)) rather than an FMUL so there is a special case for that. This patch also moves the code from https://reviews.llvm.org/D103903 into a separate function as fixed to float and float to fixed are very similar.

Differential Revision: https://reviews.llvm.org/D104793
This commit is contained in:
Sam Tebbs 2021-06-21 16:00:11 +01:00
parent c475efe916
commit 24d76419d6
3 changed files with 1185 additions and 81 deletions

View File

@ -197,6 +197,10 @@ private:
bool tryT2IndexedLoad(SDNode *N);
bool tryMVEIndexedLoad(SDNode *N);
bool tryFMULFixed(SDNode *N, SDLoc dl);
bool tryFP_TO_INT(SDNode *N, SDLoc dl);
bool transformFixedFloatingPointConversion(SDNode *N, SDNode *FMul,
bool IsUnsigned,
bool FixedToFloat);
/// SelectVLD - Select NEON load intrinsics. NumVecs should be
/// 1, 2, 3 or 4. The opcode arrays specify the instructions used for
@ -3150,6 +3154,154 @@ bool ARMDAGToDAGISel::tryInsertVectorElt(SDNode *N) {
return false;
}
bool ARMDAGToDAGISel::transformFixedFloatingPointConversion(SDNode *N,
SDNode *FMul,
bool IsUnsigned,
bool FixedToFloat) {
auto Type = N->getValueType(0);
unsigned ScalarBits = Type.getScalarSizeInBits();
if (ScalarBits > 32)
return false;
SDNodeFlags FMulFlags = FMul->getFlags();
// The fixed-point vcvt and vcvt+vmul are not always equivalent if inf is
// allowed in 16 bit unsigned floats
if (ScalarBits == 16 && !FMulFlags.hasNoInfs() && IsUnsigned)
return false;
SDValue ImmNode = FMul->getOperand(1);
SDValue VecVal = FMul->getOperand(0);
if (VecVal->getOpcode() == ISD::UINT_TO_FP ||
VecVal->getOpcode() == ISD::SINT_TO_FP)
VecVal = VecVal->getOperand(0);
if (VecVal.getValueType().getScalarSizeInBits() != ScalarBits)
return false;
if (ImmNode.getOpcode() == ISD::BITCAST) {
if (ImmNode.getValueType().getScalarSizeInBits() != ScalarBits)
return false;
ImmNode = ImmNode.getOperand(0);
}
if (ImmNode.getValueType().getScalarSizeInBits() != ScalarBits)
return false;
APFloat ImmAPF(0.0f);
switch (ImmNode.getOpcode()) {
case ARMISD::VMOVIMM:
case ARMISD::VDUP: {
if (!isa<ConstantSDNode>(ImmNode.getOperand(0)))
return false;
unsigned Imm = ImmNode.getConstantOperandVal(0);
if (ImmNode.getOpcode() == ARMISD::VMOVIMM)
Imm = ARM_AM::decodeVMOVModImm(Imm, ScalarBits);
ImmAPF =
APFloat(ScalarBits == 32 ? APFloat::IEEEsingle() : APFloat::IEEEhalf(),
APInt(ScalarBits, Imm));
break;
}
case ARMISD::VMOVFPIMM: {
ImmAPF = APFloat(ARM_AM::getFPImmFloat(ImmNode.getConstantOperandVal(0)));
break;
}
default:
return false;
}
// Where n is the number of fractional bits, multiplying by 2^n will convert
// from float to fixed and multiplying by 2^-n will convert from fixed to
// float. Taking log2 of the factor (after taking the inverse in the case of
// float to fixed) will give n.
APFloat ToConvert = ImmAPF;
if (FixedToFloat) {
if (!ImmAPF.getExactInverse(&ToConvert))
return false;
}
APSInt Converted(64, 0);
bool IsExact;
ToConvert.convertToInteger(Converted, llvm::RoundingMode::NearestTiesToEven,
&IsExact);
if (!IsExact || !Converted.isPowerOf2())
return false;
unsigned FracBits = Converted.logBase2();
if (FracBits > ScalarBits)
return false;
SmallVector<SDValue, 3> Ops{
VecVal, CurDAG->getConstant(FracBits, SDLoc(N), MVT::i32)};
AddEmptyMVEPredicateToOps(Ops, SDLoc(N), Type);
unsigned int Opcode;
switch (ScalarBits) {
case 16:
if (FixedToFloat)
Opcode = IsUnsigned ? ARM::MVE_VCVTf16u16_fix : ARM::MVE_VCVTf16s16_fix;
else
Opcode = IsUnsigned ? ARM::MVE_VCVTu16f16_fix : ARM::MVE_VCVTs16f16_fix;
break;
case 32:
if (FixedToFloat)
Opcode = IsUnsigned ? ARM::MVE_VCVTf32u32_fix : ARM::MVE_VCVTf32s32_fix;
else
Opcode = IsUnsigned ? ARM::MVE_VCVTu32f32_fix : ARM::MVE_VCVTs32f32_fix;
break;
default:
llvm_unreachable("unexpected number of scalar bits");
break;
}
ReplaceNode(N, CurDAG->getMachineNode(Opcode, SDLoc(N), Type, Ops));
return true;
}
bool ARMDAGToDAGISel::tryFP_TO_INT(SDNode *N, SDLoc dl) {
// Transform a floating-point to fixed-point conversion to a VCVT
if (!Subtarget->hasMVEFloatOps())
return false;
EVT Type = N->getValueType(0);
if (!Type.isVector())
return false;
unsigned int ScalarBits = Type.getScalarSizeInBits();
bool IsUnsigned = N->getOpcode() == ISD::FP_TO_UINT;
SDNode *Node = N->getOperand(0).getNode();
// floating-point to fixed-point with one fractional bit gets turned into an
// FP_TO_[U|S]INT(FADD (x, x)) rather than an FP_TO_[U|S]INT(FMUL (x, y))
if (Node->getOpcode() == ISD::FADD) {
if (Node->getOperand(0) != Node->getOperand(1))
return false;
SDNodeFlags Flags = Node->getFlags();
// The fixed-point vcvt and vcvt+vmul are not always equivalent if inf is
// allowed in 16 bit unsigned floats
if (ScalarBits == 16 && !Flags.hasNoInfs() && IsUnsigned)
return false;
unsigned Opcode;
switch (ScalarBits) {
case 16:
Opcode = IsUnsigned ? ARM::MVE_VCVTu16f16_fix : ARM::MVE_VCVTs16f16_fix;
break;
case 32:
Opcode = IsUnsigned ? ARM::MVE_VCVTu32f32_fix : ARM::MVE_VCVTs32f32_fix;
break;
}
SmallVector<SDValue, 3> Ops{Node->getOperand(0),
CurDAG->getConstant(1, dl, MVT::i32)};
AddEmptyMVEPredicateToOps(Ops, dl, Type);
ReplaceNode(N, CurDAG->getMachineNode(Opcode, dl, Type, Ops));
return true;
}
if (Node->getOpcode() != ISD::FMUL)
return false;
return transformFixedFloatingPointConversion(N, Node, IsUnsigned, false);
}
bool ARMDAGToDAGISel::tryFMULFixed(SDNode *N, SDLoc dl) {
// Transform a fixed-point to floating-point conversion to a VCVT
if (!Subtarget->hasMVEFloatOps())
@ -3158,91 +3310,12 @@ bool ARMDAGToDAGISel::tryFMULFixed(SDNode *N, SDLoc dl) {
if (!Type.isVector())
return false;
auto ScalarType = Type.getVectorElementType();
unsigned ScalarBits = ScalarType.getSizeInBits();
auto LHS = N->getOperand(0);
auto RHS = N->getOperand(1);
if (ScalarBits > 32)
return false;
if (RHS.getOpcode() == ISD::BITCAST) {
if (RHS.getValueType().getVectorElementType().getSizeInBits() != ScalarBits)
return false;
RHS = RHS.getOperand(0);
}
if (RHS.getValueType().getVectorElementType().getSizeInBits() != ScalarBits)
return false;
if (LHS.getOpcode() != ISD::SINT_TO_FP && LHS.getOpcode() != ISD::UINT_TO_FP)
return false;
bool IsUnsigned = LHS.getOpcode() == ISD::UINT_TO_FP;
SDNodeFlags FMulFlags = N->getFlags();
// The fixed-point vcvt and vcvt+vmul are not always equivalent if inf is
// allowed in 16 bit unsigned floats
if (ScalarBits == 16 && !FMulFlags.hasNoInfs() && IsUnsigned)
return false;
APFloat ImmAPF(0.0f);
switch (RHS.getOpcode()) {
case ARMISD::VMOVIMM:
case ARMISD::VDUP: {
if (!isa<ConstantSDNode>(RHS.getOperand(0)))
return false;
unsigned Imm = RHS.getConstantOperandVal(0);
if (RHS.getOpcode() == ARMISD::VMOVIMM)
Imm = ARM_AM::decodeVMOVModImm(Imm, ScalarBits);
ImmAPF =
APFloat(ScalarBits == 32 ? APFloat::IEEEsingle() : APFloat::IEEEhalf(),
APInt(ScalarBits, Imm));
break;
}
case ARMISD::VMOVFPIMM: {
ImmAPF = APFloat(ARM_AM::getFPImmFloat(RHS.getConstantOperandVal(0)));
break;
}
default:
return false;
}
// Multiplying by a factor of 2^(-n) will convert from fixed point to
// floating point, where n is the number of fractional bits in the fixed
// point number. Taking the inverse and log2 of the factor will give n
APFloat Inverse(0.0f);
if (!ImmAPF.getExactInverse(&Inverse))
return false;
APSInt Converted(64, 0);
bool IsExact;
Inverse.convertToInteger(Converted, llvm::RoundingMode::NearestTiesToEven,
&IsExact);
if (!IsExact || !Converted.isPowerOf2())
return false;
unsigned FracBits = Converted.logBase2();
if (FracBits > ScalarBits)
return false;
auto SintToFpOperand = LHS.getOperand(0);
SmallVector<SDValue, 3> Ops{SintToFpOperand,
CurDAG->getConstant(FracBits, dl, MVT::i32)};
AddEmptyMVEPredicateToOps(Ops, dl, Type);
unsigned int Opcode;
switch (ScalarBits) {
case 16:
Opcode = IsUnsigned ? ARM::MVE_VCVTf16u16_fix : ARM::MVE_VCVTf16s16_fix;
break;
case 32:
Opcode = IsUnsigned ? ARM::MVE_VCVTf32u32_fix : ARM::MVE_VCVTf32s32_fix;
break;
default:
llvm_unreachable("unexpected number of scalar bits");
break;
}
ReplaceNode(N, CurDAG->getMachineNode(Opcode, dl, Type, Ops));
return true;
return transformFixedFloatingPointConversion(
N, N, LHS.getOpcode() == ISD::UINT_TO_FP, true);
}
bool ARMDAGToDAGISel::tryV6T2BitfieldExtractOp(SDNode *N, bool isSigned) {
@ -3680,6 +3753,11 @@ void ARMDAGToDAGISel::Select(SDNode *N) {
if (tryV6T2BitfieldExtractOp(N, true))
return;
break;
case ISD::FP_TO_UINT:
case ISD::FP_TO_SINT:
if (tryFP_TO_INT(N, dl))
return;
break;
case ISD::FMUL:
if (tryFMULFixed(N, dl))
return;

File diff suppressed because it is too large Load Diff