[AArch64] Lowering and legalization of strict FP16
For strict FP16 to work correctly needs some changes in lowering and legalization: * SelectionDAGLegalize::PromoteNode was missing handling for some strict fp opcodes. * Some of the custom lowering of strict fp operations needed to be adjusted to work with FP16. * Custom lowering needed to be added for round-to-int operations. With this, and the previous patches for the rest of the strict fp isel, we can set IsStrictFPEnabled = true. Differential Revision: https://reviews.llvm.org/D115620
This commit is contained in:
parent
d43d9e1d5c
commit
12c1022679
|
@ -4714,6 +4714,12 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
|
|||
Results.push_back(DAG.getNode(ISD::FP_ROUND, dl, OVT,
|
||||
Tmp3, DAG.getIntPtrConstant(0, dl)));
|
||||
break;
|
||||
case ISD::STRICT_FADD:
|
||||
case ISD::STRICT_FSUB:
|
||||
case ISD::STRICT_FMUL:
|
||||
case ISD::STRICT_FDIV:
|
||||
case ISD::STRICT_FMINNUM:
|
||||
case ISD::STRICT_FMAXNUM:
|
||||
case ISD::STRICT_FREM:
|
||||
case ISD::STRICT_FPOW:
|
||||
Tmp1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
|
||||
|
@ -4738,6 +4744,22 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
|
|||
DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1, Tmp2, Tmp3),
|
||||
DAG.getIntPtrConstant(0, dl)));
|
||||
break;
|
||||
case ISD::STRICT_FMA:
|
||||
Tmp1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
|
||||
{Node->getOperand(0), Node->getOperand(1)});
|
||||
Tmp2 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
|
||||
{Node->getOperand(0), Node->getOperand(2)});
|
||||
Tmp3 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
|
||||
{Node->getOperand(0), Node->getOperand(3)});
|
||||
Tmp4 = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Tmp1.getValue(1),
|
||||
Tmp2.getValue(1), Tmp3.getValue(1));
|
||||
Tmp4 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
|
||||
{Tmp4, Tmp1, Tmp2, Tmp3});
|
||||
Tmp4 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
|
||||
{Tmp4.getValue(1), Tmp4, DAG.getIntPtrConstant(0, dl)});
|
||||
Results.push_back(Tmp4);
|
||||
Results.push_back(Tmp4.getValue(1));
|
||||
break;
|
||||
case ISD::FCOPYSIGN:
|
||||
case ISD::FPOWI: {
|
||||
Tmp1 = DAG.getNode(ISD::FP_EXTEND, dl, NVT, Node->getOperand(0));
|
||||
|
@ -4754,6 +4776,16 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
|
|||
Tmp3, DAG.getIntPtrConstant(isTrunc, dl)));
|
||||
break;
|
||||
}
|
||||
case ISD::STRICT_FPOWI:
|
||||
Tmp1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
|
||||
{Node->getOperand(0), Node->getOperand(1)});
|
||||
Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
|
||||
{Tmp1.getValue(1), Tmp1, Node->getOperand(2)});
|
||||
Tmp3 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
|
||||
{Tmp2.getValue(1), Tmp2, DAG.getIntPtrConstant(0, dl)});
|
||||
Results.push_back(Tmp3);
|
||||
Results.push_back(Tmp3.getValue(1));
|
||||
break;
|
||||
case ISD::FFLOOR:
|
||||
case ISD::FCEIL:
|
||||
case ISD::FRINT:
|
||||
|
@ -4778,12 +4810,19 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
|
|||
break;
|
||||
case ISD::STRICT_FFLOOR:
|
||||
case ISD::STRICT_FCEIL:
|
||||
case ISD::STRICT_FRINT:
|
||||
case ISD::STRICT_FNEARBYINT:
|
||||
case ISD::STRICT_FROUND:
|
||||
case ISD::STRICT_FROUNDEVEN:
|
||||
case ISD::STRICT_FTRUNC:
|
||||
case ISD::STRICT_FSQRT:
|
||||
case ISD::STRICT_FSIN:
|
||||
case ISD::STRICT_FCOS:
|
||||
case ISD::STRICT_FLOG:
|
||||
case ISD::STRICT_FLOG2:
|
||||
case ISD::STRICT_FLOG10:
|
||||
case ISD::STRICT_FEXP:
|
||||
case ISD::STRICT_FEXP2:
|
||||
Tmp1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
|
||||
{Node->getOperand(0), Node->getOperand(1)});
|
||||
Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
|
||||
|
|
|
@ -539,64 +539,41 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
|
|||
else
|
||||
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Promote);
|
||||
|
||||
setOperationAction(ISD::FREM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FREM, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FREM, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FPOW, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FPOW, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FPOW, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FPOWI, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FPOWI, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FPOWI, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FCOS, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FCOS, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FCOS, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FSIN, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FSIN, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FSIN, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FSINCOS, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FSINCOS, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FSINCOS, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FEXP, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FEXP, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FEXP, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FEXP2, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FEXP2, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FEXP2, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FLOG, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FLOG, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FLOG, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FLOG2, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FLOG2, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FLOG2, MVT::v8f16, Expand);
|
||||
setOperationAction(ISD::FLOG10, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FLOG10, MVT::v4f16, Expand);
|
||||
setOperationAction(ISD::FLOG10, MVT::v8f16, Expand);
|
||||
for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
|
||||
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
|
||||
ISD::FEXP, ISD::FEXP2, ISD::FLOG,
|
||||
ISD::FLOG2, ISD::FLOG10, ISD::STRICT_FREM,
|
||||
ISD::STRICT_FPOW, ISD::STRICT_FPOWI, ISD::STRICT_FCOS,
|
||||
ISD::STRICT_FSIN, ISD::STRICT_FEXP, ISD::STRICT_FEXP2,
|
||||
ISD::STRICT_FLOG, ISD::STRICT_FLOG2, ISD::STRICT_FLOG10}) {
|
||||
setOperationAction(Op, MVT::f16, Promote);
|
||||
setOperationAction(Op, MVT::v4f16, Expand);
|
||||
setOperationAction(Op, MVT::v8f16, Expand);
|
||||
}
|
||||
|
||||
if (!Subtarget->hasFullFP16()) {
|
||||
setOperationAction(ISD::SELECT, MVT::f16, Promote);
|
||||
setOperationAction(ISD::SELECT_CC, MVT::f16, Promote);
|
||||
setOperationAction(ISD::SETCC, MVT::f16, Promote);
|
||||
setOperationAction(ISD::BR_CC, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FADD, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FSUB, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMUL, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FDIV, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMA, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FNEG, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FABS, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FCEIL, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FSQRT, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FFLOOR, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FNEARBYINT, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FRINT, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FROUND, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FROUNDEVEN, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FTRUNC, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMINIMUM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMAXIMUM, MVT::f16, Promote);
|
||||
for (auto Op :
|
||||
{ISD::SELECT, ISD::SELECT_CC, ISD::SETCC,
|
||||
ISD::BR_CC, ISD::FADD, ISD::FSUB,
|
||||
ISD::FMUL, ISD::FDIV, ISD::FMA,
|
||||
ISD::FNEG, ISD::FABS, ISD::FCEIL,
|
||||
ISD::FSQRT, ISD::FFLOOR, ISD::FNEARBYINT,
|
||||
ISD::FRINT, ISD::FROUND, ISD::FROUNDEVEN,
|
||||
ISD::FTRUNC, ISD::FMINNUM, ISD::FMAXNUM,
|
||||
ISD::FMINIMUM, ISD::FMAXIMUM, ISD::STRICT_FADD,
|
||||
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
|
||||
ISD::STRICT_FMA, ISD::STRICT_FCEIL, ISD::STRICT_FFLOOR,
|
||||
ISD::STRICT_FSQRT, ISD::STRICT_FRINT, ISD::STRICT_FNEARBYINT,
|
||||
ISD::STRICT_FROUND, ISD::STRICT_FTRUNC, ISD::STRICT_FROUNDEVEN,
|
||||
ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM, ISD::STRICT_FMINIMUM,
|
||||
ISD::STRICT_FMAXIMUM})
|
||||
setOperationAction(Op, MVT::f16, Promote);
|
||||
|
||||
// Round-to-integer need custom lowering for fp16, as Promote doesn't work
|
||||
// because the result type is integer.
|
||||
for (auto Op : {ISD::STRICT_LROUND, ISD::STRICT_LLROUND, ISD::STRICT_LRINT,
|
||||
ISD::STRICT_LLRINT})
|
||||
setOperationAction(Op, MVT::f16, Custom);
|
||||
|
||||
// promote v4f16 to v4f32 when that is known to be safe.
|
||||
setOperationAction(ISD::FADD, MVT::v4f16, Promote);
|
||||
|
@ -1402,6 +1379,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
|
|||
}
|
||||
|
||||
PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive();
|
||||
|
||||
IsStrictFPEnabled = true;
|
||||
}
|
||||
|
||||
void AArch64TargetLowering::addTypeForNEON(MVT VT) {
|
||||
|
@ -2592,7 +2571,18 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &dl,
|
|||
bool IsSignaling) {
|
||||
EVT VT = LHS.getValueType();
|
||||
assert(VT != MVT::f128);
|
||||
assert(VT != MVT::f16 && "Lowering of strict fp16 not yet implemented");
|
||||
|
||||
const bool FullFP16 =
|
||||
static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
|
||||
|
||||
if (VT == MVT::f16 && !FullFP16) {
|
||||
LHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
|
||||
{Chain, LHS});
|
||||
RHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
|
||||
{LHS.getValue(1), RHS});
|
||||
Chain = RHS.getValue(1);
|
||||
VT = MVT::f32;
|
||||
}
|
||||
unsigned Opcode =
|
||||
IsSignaling ? AArch64ISD::STRICT_FCMPE : AArch64ISD::STRICT_FCMP;
|
||||
return DAG.getNode(Opcode, dl, {VT, MVT::Other}, {Chain, LHS, RHS});
|
||||
|
@ -3468,8 +3458,7 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
|
|||
MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()),
|
||||
VT.getVectorNumElements());
|
||||
if (IsStrict) {
|
||||
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl,
|
||||
{ExtVT, MVT::Other},
|
||||
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {ExtVT, MVT::Other},
|
||||
{Op.getOperand(0), Op.getOperand(1)});
|
||||
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
|
||||
{Ext.getValue(1), Ext.getValue(0)});
|
||||
|
@ -3506,8 +3495,14 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
|
|||
|
||||
// f16 conversions are promoted to f32 when full fp16 is not supported.
|
||||
if (SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
|
||||
assert(!IsStrict && "Lowering of strict fp16 not yet implemented");
|
||||
SDLoc dl(Op);
|
||||
if (IsStrict) {
|
||||
SDValue Ext =
|
||||
DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
|
||||
{Op.getOperand(0), SrcVal});
|
||||
return DAG.getNode(Op.getOpcode(), dl, {Op.getValueType(), MVT::Other},
|
||||
{Ext.getValue(1), Ext.getValue(0)});
|
||||
}
|
||||
return DAG.getNode(
|
||||
Op.getOpcode(), dl, Op.getValueType(),
|
||||
DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, SrcVal));
|
||||
|
@ -3730,10 +3725,15 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
|
|||
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
|
||||
|
||||
// f16 conversions are promoted to f32 when full fp16 is not supported.
|
||||
if (Op.getValueType() == MVT::f16 &&
|
||||
!Subtarget->hasFullFP16()) {
|
||||
assert(!IsStrict && "Lowering of strict fp16 not yet implemented");
|
||||
if (Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
|
||||
SDLoc dl(Op);
|
||||
if (IsStrict) {
|
||||
SDValue Val = DAG.getNode(Op.getOpcode(), dl, {MVT::f32, MVT::Other},
|
||||
{Op.getOperand(0), SrcVal});
|
||||
return DAG.getNode(
|
||||
ISD::STRICT_FP_ROUND, dl, {MVT::f16, MVT::Other},
|
||||
{Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
|
||||
}
|
||||
return DAG.getNode(
|
||||
ISD::FP_ROUND, dl, MVT::f16,
|
||||
DAG.getNode(Op.getOpcode(), dl, MVT::f32, SrcVal),
|
||||
|
@ -5367,6 +5367,18 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
|
|||
return LowerCTTZ(Op, DAG);
|
||||
case ISD::VECTOR_SPLICE:
|
||||
return LowerVECTOR_SPLICE(Op, DAG);
|
||||
case ISD::STRICT_LROUND:
|
||||
case ISD::STRICT_LLROUND:
|
||||
case ISD::STRICT_LRINT:
|
||||
case ISD::STRICT_LLRINT: {
|
||||
assert(Op.getOperand(1).getValueType() == MVT::f16 &&
|
||||
"Expected custom lowering of rounding operations only for f16");
|
||||
SDLoc DL(Op);
|
||||
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other},
|
||||
{Op.getOperand(0), Op.getOperand(1)});
|
||||
return DAG.getNode(Op.getOpcode(), DL, {Op.getValueType(), MVT::Other},
|
||||
{Ext.getValue(1), Ext.getValue(0)});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,5 @@
|
|||
; RUN: llc -mtriple=aarch64-none-eabi %s -disable-strictnode-mutation -o - | FileCheck %s
|
||||
; RUN: llc -mtriple=aarch64-none-eabi -global-isel=true -global-isel-abort=2 -disable-strictnode-mutation %s -o - | FileCheck %s
|
||||
; RUN: llc -mtriple=aarch64-none-eabi %s -o - | FileCheck %s
|
||||
; RUN: llc -mtriple=aarch64-none-eabi -global-isel=true -global-isel-abort=2 %s -o - | FileCheck %s
|
||||
|
||||
; Check that constrained fp intrinsics are correctly lowered.
|
||||
|
||||
|
|
Loading…
Reference in New Issue