[CUDA, NVPTX] Added basic __bf16 support for NVPTX.
Recent Clang changes expose _bf16 types for SSE2-enabled host compilations and that makes those types visible furing GPU-side compilation, where it currently fails with Sema complaining that __bf16 is not supported. Considering that __bf16 is a storage-only type, enabling it for NVPTX if it's enabled on the host should pose no issues, correctness-wise. Recent NVIDIA GPUs have introduced bf16 support, so we'll likely grow better support for __bf16 on NVPTX going forward. Differential Revision: https://reviews.llvm.org/D136311
This commit is contained in:
parent
fd5a2bfaad
commit
0e8a414ab3
|
@ -52,6 +52,9 @@ NVPTXTargetInfo::NVPTXTargetInfo(const llvm::Triple &Triple,
|
|||
VLASupported = false;
|
||||
AddrSpaceMap = &NVPTXAddrSpaceMap;
|
||||
UseAddrSpaceMapMangling = true;
|
||||
// __bf16 is always available as a load/store only type.
|
||||
BFloat16Width = BFloat16Align = 16;
|
||||
BFloat16Format = &llvm::APFloat::BFloat();
|
||||
|
||||
// Define available target features
|
||||
// These must be defined in sorted order!
|
||||
|
|
|
@ -177,6 +177,8 @@ public:
|
|||
}
|
||||
|
||||
bool hasBitIntType() const override { return true; }
|
||||
bool hasBFloat16Type() const override { return true; }
|
||||
const char *getBFloat16Mangling() const override { return "u6__bf16"; };
|
||||
};
|
||||
} // namespace targets
|
||||
} // namespace clang
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
// REQUIRES: nvptx-registered-target
|
||||
// REQUIRES: x86-registered-target
|
||||
|
||||
// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "nvptx64-nvidia-cuda" \
|
||||
// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -S -o - %s | FileCheck %s
|
||||
|
||||
#include "Inputs/cuda.h"
|
||||
|
||||
// CHECK-LABEL: .visible .func _Z8test_argPu6__bf16u6__bf16(
|
||||
// CHECK: .param .b64 _Z8test_argPu6__bf16u6__bf16_param_0,
|
||||
// CHECK: .param .b16 _Z8test_argPu6__bf16u6__bf16_param_1
|
||||
//
|
||||
__device__ void test_arg(__bf16 *out, __bf16 in) {
|
||||
// CHECK: ld.param.b16 %{{h.*}}, [_Z8test_argPu6__bf16u6__bf16_param_1];
|
||||
__bf16 bf16 = in;
|
||||
*out = bf16;
|
||||
// CHECK: st.b16
|
||||
// CHECK: ret;
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retu6__bf16(
|
||||
// CHECK: .param .b16 _Z8test_retu6__bf16_param_0
|
||||
__device__ __bf16 test_ret( __bf16 in) {
|
||||
// CHECK: ld.param.b16 %h{{.*}}, [_Z8test_retu6__bf16_param_0];
|
||||
return in;
|
||||
// CHECK: st.param.b16 [func_retval0+0], %h
|
||||
// CHECK: ret;
|
||||
}
|
||||
|
||||
// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callu6__bf16(
|
||||
// CHECK: .param .b16 _Z9test_callu6__bf16_param_0
|
||||
__device__ __bf16 test_call( __bf16 in) {
|
||||
// CHECK: ld.param.b16 %h{{.*}}, [_Z9test_callu6__bf16_param_0];
|
||||
// CHECK: st.param.b16 [param0+0], %h2;
|
||||
// CHECK: .param .b32 retval0;
|
||||
// CHECK: call.uni (retval0),
|
||||
// CHECK-NEXT: _Z8test_retu6__bf16,
|
||||
// CHECK-NEXT: (
|
||||
// CHECK-NEXT: param0
|
||||
// CHECK-NEXT );
|
||||
// CHECK: ld.param.b16 %h{{.*}}, [retval0+0];
|
||||
return test_ret(in);
|
||||
// CHECK: st.param.b16 [func_retval0+0], %h
|
||||
// CHECK: ret;
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
// REQUIRES: nvptx-registered-target
|
||||
// REQUIRES: x86-registered-target
|
||||
|
||||
// RUN: %clang_cc1 "-triple" "x86_64-unknown-linux-gnu" "-aux-triple" "nvptx64-nvidia-cuda" \
|
||||
// RUN: "-target-cpu" "x86-64" -fsyntax-only -verify=scalar %s
|
||||
// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "nvptx64-nvidia-cuda" \
|
||||
// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -fsyntax-only -verify=scalar %s
|
||||
|
||||
#include "Inputs/cuda.h"
|
||||
|
||||
__device__ void test(bool b, __bf16 *out, __bf16 in) {
|
||||
__bf16 bf16 = in; // No error on using the type itself.
|
||||
|
||||
bf16 + bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
|
||||
bf16 - bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
|
||||
bf16 * bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
|
||||
bf16 / bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
|
||||
|
||||
__fp16 fp16;
|
||||
|
||||
bf16 + fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
|
||||
fp16 + bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
|
||||
bf16 - fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
|
||||
fp16 - bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
|
||||
bf16 * fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
|
||||
fp16 * bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
|
||||
bf16 / fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
|
||||
fp16 / bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
|
||||
bf16 = fp16; // scalar-error {{assigning to '__bf16' from incompatible type '__fp16'}}
|
||||
fp16 = bf16; // scalar-error {{assigning to '__fp16' from incompatible type '__bf16'}}
|
||||
bf16 + (b ? fp16 : bf16); // scalar-error {{incompatible operand types ('__fp16' and '__bf16')}}
|
||||
*out = bf16;
|
||||
}
|
|
@ -1831,6 +1831,7 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
|
|||
break;
|
||||
|
||||
case Type::HalfTyID:
|
||||
case Type::BFloatTyID:
|
||||
case Type::FloatTyID:
|
||||
case Type::DoubleTyID:
|
||||
AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
|
||||
|
|
|
@ -823,8 +823,10 @@ static Optional<unsigned> pickOpcodeForVT(
|
|||
case MVT::i64:
|
||||
return Opcode_i64;
|
||||
case MVT::f16:
|
||||
case MVT::bf16:
|
||||
return Opcode_f16;
|
||||
case MVT::v2f16:
|
||||
case MVT::v2bf16:
|
||||
return Opcode_f16x2;
|
||||
case MVT::f32:
|
||||
return Opcode_f32;
|
||||
|
@ -835,6 +837,21 @@ static Optional<unsigned> pickOpcodeForVT(
|
|||
}
|
||||
}
|
||||
|
||||
static int getLdStRegType(EVT VT) {
|
||||
if (VT.isFloatingPoint())
|
||||
switch (VT.getSimpleVT().SimpleTy) {
|
||||
case MVT::f16:
|
||||
case MVT::bf16:
|
||||
case MVT::v2f16:
|
||||
case MVT::v2bf16:
|
||||
return NVPTX::PTXLdStInstCode::Untyped;
|
||||
default:
|
||||
return NVPTX::PTXLdStInstCode::Float;
|
||||
}
|
||||
else
|
||||
return NVPTX::PTXLdStInstCode::Unsigned;
|
||||
}
|
||||
|
||||
bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
||||
SDLoc dl(N);
|
||||
MemSDNode *LD = cast<MemSDNode>(N);
|
||||
|
@ -891,19 +908,16 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
|||
// Vector Setting
|
||||
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
|
||||
if (SimpleVT.isVector()) {
|
||||
assert(LoadedVT == MVT::v2f16 && "Unexpected vector type");
|
||||
// v2f16 is loaded using ld.b32
|
||||
assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
|
||||
"Unexpected vector type");
|
||||
// v2f16/v2bf16 is loaded using ld.b32
|
||||
fromTypeWidth = 32;
|
||||
}
|
||||
|
||||
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
|
||||
fromType = NVPTX::PTXLdStInstCode::Signed;
|
||||
else if (ScalarVT.isFloatingPoint())
|
||||
// f16 uses .b16 as its storage type.
|
||||
fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
|
||||
: NVPTX::PTXLdStInstCode::Float;
|
||||
else
|
||||
fromType = NVPTX::PTXLdStInstCode::Unsigned;
|
||||
fromType = getLdStRegType(ScalarVT);
|
||||
|
||||
// Create the machine instruction DAG
|
||||
SDValue Chain = N->getOperand(0);
|
||||
|
@ -1033,11 +1047,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
|
|||
N->getOperand(N->getNumOperands() - 1))->getZExtValue();
|
||||
if (ExtensionType == ISD::SEXTLOAD)
|
||||
FromType = NVPTX::PTXLdStInstCode::Signed;
|
||||
else if (ScalarVT.isFloatingPoint())
|
||||
FromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
|
||||
: NVPTX::PTXLdStInstCode::Float;
|
||||
else
|
||||
FromType = NVPTX::PTXLdStInstCode::Unsigned;
|
||||
FromType = getLdStRegType(ScalarVT);
|
||||
|
||||
unsigned VecType;
|
||||
|
||||
|
@ -1057,7 +1068,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
|
|||
// v8f16 is a special case. PTX doesn't have ld.v8.f16
|
||||
// instruction. Instead, we split the vector into v2f16 chunks and
|
||||
// load them with ld.v4.b32.
|
||||
if (EltVT == MVT::v2f16) {
|
||||
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
|
||||
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
|
||||
EltVT = MVT::i32;
|
||||
FromType = NVPTX::PTXLdStInstCode::Untyped;
|
||||
|
@ -1745,18 +1756,13 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
|
|||
MVT ScalarVT = SimpleVT.getScalarType();
|
||||
unsigned toTypeWidth = ScalarVT.getSizeInBits();
|
||||
if (SimpleVT.isVector()) {
|
||||
assert(StoreVT == MVT::v2f16 && "Unexpected vector type");
|
||||
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
|
||||
"Unexpected vector type");
|
||||
// v2f16 is stored using st.b32
|
||||
toTypeWidth = 32;
|
||||
}
|
||||
|
||||
unsigned int toType;
|
||||
if (ScalarVT.isFloatingPoint())
|
||||
// f16 uses .b16 as its storage type.
|
||||
toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
|
||||
: NVPTX::PTXLdStInstCode::Float;
|
||||
else
|
||||
toType = NVPTX::PTXLdStInstCode::Unsigned;
|
||||
unsigned int toType = getLdStRegType(ScalarVT);
|
||||
|
||||
// Create the machine instruction DAG
|
||||
SDValue Chain = ST->getChain();
|
||||
|
@ -1896,12 +1902,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
|
|||
assert(StoreVT.isSimple() && "Store value is not simple");
|
||||
MVT ScalarVT = StoreVT.getSimpleVT().getScalarType();
|
||||
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
|
||||
unsigned ToType;
|
||||
if (ScalarVT.isFloatingPoint())
|
||||
ToType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
|
||||
: NVPTX::PTXLdStInstCode::Float;
|
||||
else
|
||||
ToType = NVPTX::PTXLdStInstCode::Unsigned;
|
||||
unsigned ToType = getLdStRegType(ScalarVT);
|
||||
|
||||
SmallVector<SDValue, 12> StOps;
|
||||
SDValue N2;
|
||||
|
@ -1929,7 +1930,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
|
|||
// v8f16 is a special case. PTX doesn't have st.v8.f16
|
||||
// instruction. Instead, we split the vector into v2f16 chunks and
|
||||
// store them with st.v4.b32.
|
||||
if (EltVT == MVT::v2f16) {
|
||||
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
|
||||
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
|
||||
EltVT = MVT::i32;
|
||||
ToType = NVPTX::PTXLdStInstCode::Untyped;
|
||||
|
|
|
@ -133,6 +133,9 @@ static bool IsPTXVectorType(MVT VT) {
|
|||
case MVT::v2f16:
|
||||
case MVT::v4f16:
|
||||
case MVT::v8f16: // <4 x f16x2>
|
||||
case MVT::v2bf16:
|
||||
case MVT::v4bf16:
|
||||
case MVT::v8bf16: // <4 x bf16x2>
|
||||
case MVT::v2f32:
|
||||
case MVT::v4f32:
|
||||
case MVT::v2f64:
|
||||
|
@ -190,8 +193,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
|
|||
// Vectors with an even number of f16 elements will be passed to
|
||||
// us as an array of v2f16 elements. We must match this so we
|
||||
// stay in sync with Ins/Outs.
|
||||
if (EltVT == MVT::f16 && NumElts % 2 == 0) {
|
||||
EltVT = MVT::v2f16;
|
||||
if ((EltVT == MVT::f16 || EltVT == MVT::f16) && NumElts % 2 == 0) {
|
||||
EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
|
||||
NumElts /= 2;
|
||||
}
|
||||
for (unsigned j = 0; j != NumElts; ++j) {
|
||||
|
@ -400,6 +403,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
|||
addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
|
||||
addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
|
||||
addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
|
||||
addRegisterClass(MVT::bf16, &NVPTX::Float16RegsRegClass);
|
||||
addRegisterClass(MVT::v2bf16, &NVPTX::Float16x2RegsRegClass);
|
||||
|
||||
// Conversion to/from FP16/FP16x2 is always legal.
|
||||
setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal);
|
||||
|
@ -495,6 +500,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
|||
setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
|
||||
setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
|
||||
setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
|
||||
setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
|
||||
|
||||
// TRAP can be lowered to PTX trap
|
||||
setOperationAction(ISD::TRAP, MVT::Other, Legal);
|
||||
|
@ -2334,14 +2340,17 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
|
|||
case MVT::v2i32:
|
||||
case MVT::v2i64:
|
||||
case MVT::v2f16:
|
||||
case MVT::v2bf16:
|
||||
case MVT::v2f32:
|
||||
case MVT::v2f64:
|
||||
case MVT::v4i8:
|
||||
case MVT::v4i16:
|
||||
case MVT::v4i32:
|
||||
case MVT::v4f16:
|
||||
case MVT::v4bf16:
|
||||
case MVT::v4f32:
|
||||
case MVT::v8f16: // <4 x f16x2>
|
||||
case MVT::v8bf16: // <4 x bf16x2>
|
||||
// This is a "native" vector type
|
||||
break;
|
||||
}
|
||||
|
@ -2386,7 +2395,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
|
|||
// v8f16 is a special case. PTX doesn't have st.v8.f16
|
||||
// instruction. Instead, we split the vector into v2f16 chunks and
|
||||
// store them with st.v4.b32.
|
||||
assert(EltVT == MVT::f16 && "Wrong type for the vector.");
|
||||
assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
|
||||
"Wrong type for the vector.");
|
||||
Opcode = NVPTXISD::StoreV4;
|
||||
StoreF16x2 = true;
|
||||
break;
|
||||
|
@ -4987,11 +4997,12 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
|
|||
// v8f16 is a special case. PTX doesn't have ld.v8.f16
|
||||
// instruction. Instead, we split the vector into v2f16 chunks and
|
||||
// load them with ld.v4.b32.
|
||||
assert(EltVT == MVT::f16 && "Unsupported v8 vector type.");
|
||||
assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
|
||||
"Unsupported v8 vector type.");
|
||||
LoadF16x2 = true;
|
||||
Opcode = NVPTXISD::LoadV4;
|
||||
EVT ListVTs[] = {MVT::v2f16, MVT::v2f16, MVT::v2f16, MVT::v2f16,
|
||||
MVT::Other};
|
||||
EVT VVT = (EltVT == MVT::f16) ? MVT::v2f16 : MVT::v2bf16;
|
||||
EVT ListVTs[] = {VVT, VVT, VVT, VVT, MVT::Other};
|
||||
LdResVTs = DAG.getVTList(ListVTs);
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -172,6 +172,30 @@ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
|
|||
def useShortPtr : Predicate<"useShortPointers()">;
|
||||
def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
|
||||
|
||||
// Helper class to aid conversion between ValueType and a matching RegisterClass.
|
||||
|
||||
class ValueToRegClass<ValueType T> {
|
||||
string name = !cast<string>(T);
|
||||
NVPTXRegClass ret = !cond(
|
||||
!eq(name, "i1"): Int1Regs,
|
||||
!eq(name, "i16"): Int16Regs,
|
||||
!eq(name, "i32"): Int32Regs,
|
||||
!eq(name, "i64"): Int64Regs,
|
||||
!eq(name, "f16"): Float16Regs,
|
||||
!eq(name, "v2f16"): Float16x2Regs,
|
||||
!eq(name, "bf16"): Float16Regs,
|
||||
!eq(name, "v2bf16"): Float16x2Regs,
|
||||
!eq(name, "f32"): Float32Regs,
|
||||
!eq(name, "f64"): Float64Regs,
|
||||
!eq(name, "ai32"): Int32ArgRegs,
|
||||
!eq(name, "ai64"): Int64ArgRegs,
|
||||
!eq(name, "af32"): Float32ArgRegs,
|
||||
!eq(name, "if64"): Float64ArgRegs,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Some Common Instruction Class Templates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -277,26 +301,26 @@ multiclass F3<string OpcStr, SDNode OpNode> {
|
|||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
[(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def f16rr :
|
||||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
[(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>,
|
||||
Requires<[useFP16Math]>;
|
||||
|
||||
def f16x2rr_ftz :
|
||||
NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16x2Regs:$a, Float16x2Regs:$b),
|
||||
!strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
|
||||
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
|
||||
[(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def f16x2rr :
|
||||
NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16x2Regs:$a, Float16x2Regs:$b),
|
||||
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
|
||||
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
|
||||
[(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
|
||||
Requires<[useFP16Math]>;
|
||||
}
|
||||
|
||||
|
@ -351,26 +375,26 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
|
|||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
[(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>,
|
||||
Requires<[useFP16Math, allowFMA, doF32FTZ]>;
|
||||
def f16rr :
|
||||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
[(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>,
|
||||
Requires<[useFP16Math, allowFMA]>;
|
||||
|
||||
def f16x2rr_ftz :
|
||||
NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16x2Regs:$a, Float16x2Regs:$b),
|
||||
!strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
|
||||
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
|
||||
[(set (v2f16 Float16x2Regs:$dst), (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
|
||||
Requires<[useFP16Math, allowFMA, doF32FTZ]>;
|
||||
def f16x2rr :
|
||||
NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16x2Regs:$a, Float16x2Regs:$b),
|
||||
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
|
||||
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
|
||||
[(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
|
||||
Requires<[useFP16Math, allowFMA]>;
|
||||
|
||||
// These have strange names so we don't perturb existing mir tests.
|
||||
|
@ -414,25 +438,25 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
|
|||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".rn.ftz.f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
[(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>,
|
||||
Requires<[useFP16Math, noFMA, doF32FTZ]>;
|
||||
def _rnf16rr :
|
||||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".rn.f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
[(set Float16Regs:$dst, (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>,
|
||||
Requires<[useFP16Math, noFMA]>;
|
||||
def _rnf16x2rr_ftz :
|
||||
NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16x2Regs:$a, Float16x2Regs:$b),
|
||||
!strconcat(OpcStr, ".rn.ftz.f16x2 \t$dst, $a, $b;"),
|
||||
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
|
||||
[(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
|
||||
Requires<[useFP16Math, noFMA, doF32FTZ]>;
|
||||
def _rnf16x2rr :
|
||||
NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16x2Regs:$a, Float16x2Regs:$b),
|
||||
!strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"),
|
||||
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
|
||||
[(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
|
||||
Requires<[useFP16Math, noFMA]>;
|
||||
}
|
||||
|
||||
|
@ -924,15 +948,15 @@ defm FSQRT : F2<"sqrt.rn", fsqrt>;
|
|||
//
|
||||
// F16 NEG
|
||||
//
|
||||
class FNEG_F16_F16X2<string OpcStr, RegisterClass RC, Predicate Pred> :
|
||||
class FNEG_F16_F16X2<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> :
|
||||
NVPTXInst<(outs RC:$dst), (ins RC:$src),
|
||||
!strconcat(OpcStr, " \t$dst, $src;"),
|
||||
[(set RC:$dst, (fneg RC:$src))]>,
|
||||
[(set RC:$dst, (fneg (T RC:$src)))]>,
|
||||
Requires<[useFP16Math, hasPTX60, hasSM53, Pred]>;
|
||||
def FNEG16_ftz : FNEG_F16_F16X2<"neg.ftz.f16", Float16Regs, doF32FTZ>;
|
||||
def FNEG16 : FNEG_F16_F16X2<"neg.f16", Float16Regs, True>;
|
||||
def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", Float16x2Regs, doF32FTZ>;
|
||||
def FNEG16x2 : FNEG_F16_F16X2<"neg.f16x2", Float16x2Regs, True>;
|
||||
def FNEG16_ftz : FNEG_F16_F16X2<"neg.ftz.f16", f16, Float16Regs, doF32FTZ>;
|
||||
def FNEG16 : FNEG_F16_F16X2<"neg.f16", f16, Float16Regs, True>;
|
||||
def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>;
|
||||
def FNEG16x2 : FNEG_F16_F16X2<"neg.f16x2", v2f16, Float16x2Regs, True>;
|
||||
|
||||
//
|
||||
// F64 division
|
||||
|
@ -1105,17 +1129,17 @@ multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred>
|
|||
Requires<[Pred]>;
|
||||
}
|
||||
|
||||
multiclass FMA_F16<string OpcStr, RegisterClass RC, Predicate Pred> {
|
||||
multiclass FMA_F16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
|
||||
def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
|
||||
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
|
||||
[(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>,
|
||||
[(set RC:$dst, (fma (T RC:$a), (T RC:$b), (T RC:$c)))]>,
|
||||
Requires<[useFP16Math, Pred]>;
|
||||
}
|
||||
|
||||
defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", Float16Regs, doF32FTZ>;
|
||||
defm FMA16 : FMA_F16<"fma.rn.f16", Float16Regs, True>;
|
||||
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", Float16x2Regs, doF32FTZ>;
|
||||
defm FMA16x2 : FMA_F16<"fma.rn.f16x2", Float16x2Regs, True>;
|
||||
defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Float16Regs, doF32FTZ>;
|
||||
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Float16Regs, True>;
|
||||
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>;
|
||||
defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Float16x2Regs, True>;
|
||||
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
|
||||
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
|
||||
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
|
||||
|
@ -1569,52 +1593,57 @@ let hasSideEffects = false in {
|
|||
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
|
||||
}
|
||||
|
||||
multiclass SELP_PATTERN<string TypeStr, RegisterClass RC, Operand ImmCls,
|
||||
SDNode ImmNode> {
|
||||
multiclass SELP_PATTERN<string TypeStr, ValueType T, RegisterClass RC,
|
||||
Operand ImmCls, SDNode ImmNode> {
|
||||
def rr :
|
||||
NVPTXInst<(outs RC:$dst),
|
||||
(ins RC:$a, RC:$b, Int1Regs:$p),
|
||||
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
|
||||
[(set RC:$dst, (select Int1Regs:$p, RC:$a, RC:$b))]>;
|
||||
[(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T RC:$b)))]>;
|
||||
def ri :
|
||||
NVPTXInst<(outs RC:$dst),
|
||||
(ins RC:$a, ImmCls:$b, Int1Regs:$p),
|
||||
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
|
||||
[(set RC:$dst, (select Int1Regs:$p, RC:$a, ImmNode:$b))]>;
|
||||
[(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T ImmNode:$b)))]>;
|
||||
def ir :
|
||||
NVPTXInst<(outs RC:$dst),
|
||||
(ins ImmCls:$a, RC:$b, Int1Regs:$p),
|
||||
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
|
||||
[(set RC:$dst, (select Int1Regs:$p, ImmNode:$a, RC:$b))]>;
|
||||
[(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, (T RC:$b)))]>;
|
||||
def ii :
|
||||
NVPTXInst<(outs RC:$dst),
|
||||
(ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
|
||||
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
|
||||
[(set RC:$dst, (select Int1Regs:$p, ImmNode:$a, ImmNode:$b))]>;
|
||||
[(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, ImmNode:$b))]>;
|
||||
}
|
||||
}
|
||||
|
||||
// Don't pattern match on selp.{s,u}{16,32,64} -- selp.b{16,32,64} is just as
|
||||
// good.
|
||||
defm SELP_b16 : SELP_PATTERN<"b16", Int16Regs, i16imm, imm>;
|
||||
defm SELP_b16 : SELP_PATTERN<"b16", i16, Int16Regs, i16imm, imm>;
|
||||
defm SELP_s16 : SELP<"s16", Int16Regs, i16imm>;
|
||||
defm SELP_u16 : SELP<"u16", Int16Regs, i16imm>;
|
||||
defm SELP_b32 : SELP_PATTERN<"b32", Int32Regs, i32imm, imm>;
|
||||
defm SELP_b32 : SELP_PATTERN<"b32", i32, Int32Regs, i32imm, imm>;
|
||||
defm SELP_s32 : SELP<"s32", Int32Regs, i32imm>;
|
||||
defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>;
|
||||
defm SELP_b64 : SELP_PATTERN<"b64", Int64Regs, i64imm, imm>;
|
||||
defm SELP_b64 : SELP_PATTERN<"b64", i64, Int64Regs, i64imm, imm>;
|
||||
defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
|
||||
defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
|
||||
defm SELP_f16 : SELP_PATTERN<"b16", Float16Regs, f16imm, fpimm>;
|
||||
defm SELP_f32 : SELP_PATTERN<"f32", Float32Regs, f32imm, fpimm>;
|
||||
defm SELP_f64 : SELP_PATTERN<"f64", Float64Regs, f64imm, fpimm>;
|
||||
defm SELP_f16 : SELP_PATTERN<"b16", f16, Float16Regs, f16imm, fpimm>;
|
||||
|
||||
defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
|
||||
defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
|
||||
|
||||
// This does not work as tablegen fails to infer the type of 'imm'.
|
||||
//def v2f16imm : Operand<v2f16>;
|
||||
//defm SELP_f16x2 : SELP_PATTERN<"b32", v2f16, Float16x2Regs, v2f16imm, imm>;
|
||||
|
||||
def SELP_f16x2rr :
|
||||
NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16x2Regs:$a, Float16x2Regs:$b, Int1Regs:$p),
|
||||
"selp.b32 \t$dst, $a, $b, $p;",
|
||||
[(set Float16x2Regs:$dst,
|
||||
(select Int1Regs:$p, Float16x2Regs:$a, Float16x2Regs:$b))]>;
|
||||
(select Int1Regs:$p, (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>;
|
||||
|
||||
//-----------------------------------
|
||||
// Data Movement (Load / Store, Move)
|
||||
|
@ -1847,22 +1876,22 @@ def : Pat<(i32 (setne Int1Regs:$a, Int1Regs:$b)),
|
|||
|
||||
multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
|
||||
// f16 -> pred
|
||||
def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)),
|
||||
def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))),
|
||||
(SETP_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>,
|
||||
Requires<[useFP16Math,doF32FTZ]>;
|
||||
def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)),
|
||||
def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))),
|
||||
(SETP_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)),
|
||||
def : Pat<(i1 (OpNode (f16 Float16Regs:$a), fpimm:$b)),
|
||||
(SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>,
|
||||
Requires<[useFP16Math,doF32FTZ]>;
|
||||
def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)),
|
||||
def : Pat<(i1 (OpNode (f16 Float16Regs:$a), fpimm:$b)),
|
||||
(SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)),
|
||||
def : Pat<(i1 (OpNode fpimm:$a, (f16 Float16Regs:$b))),
|
||||
(SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>,
|
||||
Requires<[useFP16Math,doF32FTZ]>;
|
||||
def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)),
|
||||
def : Pat<(i1 (OpNode fpimm:$a, (f16 Float16Regs:$b))),
|
||||
(SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
|
||||
|
@ -1892,22 +1921,22 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
|
|||
(SETP_f64ir fpimm:$a, Float64Regs:$b, Mode)>;
|
||||
|
||||
// f16 -> i32
|
||||
def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)),
|
||||
def : Pat<(i32 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))),
|
||||
(SET_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)),
|
||||
def : Pat<(i32 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))),
|
||||
(SET_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)),
|
||||
def : Pat<(i32 (OpNode (f16 Float16Regs:$a), fpimm:$b)),
|
||||
(SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)),
|
||||
def : Pat<(i32 (OpNode (f16 Float16Regs:$a), fpimm:$b)),
|
||||
(SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)),
|
||||
def : Pat<(i32 (OpNode fpimm:$a, (f16 Float16Regs:$b))),
|
||||
(SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)),
|
||||
def : Pat<(i32 (OpNode fpimm:$a, (f16 Float16Regs:$b))),
|
||||
(SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
|
||||
Requires<[useFP16Math]>;
|
||||
|
||||
|
@ -2329,10 +2358,10 @@ def DeclareScalarRegInst :
|
|||
".reg .b$size param$a;",
|
||||
[(DeclareScalarParam (i32 imm:$a), (i32 imm:$size), (i32 1))]>;
|
||||
|
||||
class MoveParamInst<NVPTXRegClass regclass, string asmstr> :
|
||||
class MoveParamInst<ValueType T, NVPTXRegClass regclass, string asmstr> :
|
||||
NVPTXInst<(outs regclass:$dst), (ins regclass:$src),
|
||||
!strconcat("mov", asmstr, " \t$dst, $src;"),
|
||||
[(set regclass:$dst, (MoveParam regclass:$src))]>;
|
||||
[(set (T regclass:$dst), (MoveParam (T regclass:$src)))]>;
|
||||
|
||||
class MoveParamSymbolInst<NVPTXRegClass regclass, Operand srcty,
|
||||
string asmstr> :
|
||||
|
@ -2340,8 +2369,8 @@ class MoveParamSymbolInst<NVPTXRegClass regclass, Operand srcty,
|
|||
!strconcat("mov", asmstr, " \t$dst, $src;"),
|
||||
[(set regclass:$dst, (MoveParam texternalsym:$src))]>;
|
||||
|
||||
def MoveParamI64 : MoveParamInst<Int64Regs, ".b64">;
|
||||
def MoveParamI32 : MoveParamInst<Int32Regs, ".b32">;
|
||||
def MoveParamI64 : MoveParamInst<i64, Int64Regs, ".b64">;
|
||||
def MoveParamI32 : MoveParamInst<i32, Int32Regs, ".b32">;
|
||||
|
||||
def MoveParamSymbolI64 : MoveParamSymbolInst<Int64Regs, i64imm, ".b64">;
|
||||
def MoveParamSymbolI32 : MoveParamSymbolInst<Int32Regs, i32imm, ".b32">;
|
||||
|
@ -2350,9 +2379,9 @@ def MoveParamI16 :
|
|||
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
|
||||
"cvt.u16.u32 \t$dst, $src;",
|
||||
[(set Int16Regs:$dst, (MoveParam Int16Regs:$src))]>;
|
||||
def MoveParamF64 : MoveParamInst<Float64Regs, ".f64">;
|
||||
def MoveParamF32 : MoveParamInst<Float32Regs, ".f32">;
|
||||
def MoveParamF16 : MoveParamInst<Float16Regs, ".f16">;
|
||||
def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
|
||||
def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
|
||||
def MoveParamF16 : MoveParamInst<f16, Float16Regs, ".f16">;
|
||||
|
||||
class PseudoUseParamInst<NVPTXRegClass regclass> :
|
||||
NVPTXInst<(outs), (ins regclass:$src),
|
||||
|
@ -2365,20 +2394,22 @@ def PseudoUseParamI16 : PseudoUseParamInst<Int16Regs>;
|
|||
def PseudoUseParamF64 : PseudoUseParamInst<Float64Regs>;
|
||||
def PseudoUseParamF32 : PseudoUseParamInst<Float32Regs>;
|
||||
|
||||
class ProxyRegInst<string SzStr, NVPTXRegClass regclass> :
|
||||
class ProxyRegInst<string SzStr, ValueType T, NVPTXRegClass regclass> :
|
||||
NVPTXInst<(outs regclass:$dst), (ins regclass:$src),
|
||||
!strconcat("mov.", SzStr, " \t$dst, $src;"),
|
||||
[(set regclass:$dst, (ProxyReg regclass:$src))]>;
|
||||
[(set (T regclass:$dst), (ProxyReg (T regclass:$src)))]>;
|
||||
|
||||
let isCodeGenOnly=1, isPseudo=1 in {
|
||||
def ProxyRegI1 : ProxyRegInst<"pred", Int1Regs>;
|
||||
def ProxyRegI16 : ProxyRegInst<"b16", Int16Regs>;
|
||||
def ProxyRegI32 : ProxyRegInst<"b32", Int32Regs>;
|
||||
def ProxyRegI64 : ProxyRegInst<"b64", Int64Regs>;
|
||||
def ProxyRegF16 : ProxyRegInst<"b16", Float16Regs>;
|
||||
def ProxyRegF32 : ProxyRegInst<"f32", Float32Regs>;
|
||||
def ProxyRegF64 : ProxyRegInst<"f64", Float64Regs>;
|
||||
def ProxyRegF16x2 : ProxyRegInst<"b32", Float16x2Regs>;
|
||||
def ProxyRegI1 : ProxyRegInst<"pred", i1, Int1Regs>;
|
||||
def ProxyRegI16 : ProxyRegInst<"b16", i16, Int16Regs>;
|
||||
def ProxyRegI32 : ProxyRegInst<"b32", i32, Int32Regs>;
|
||||
def ProxyRegI64 : ProxyRegInst<"b64", i64, Int64Regs>;
|
||||
def ProxyRegF16 : ProxyRegInst<"b16", f16, Float16Regs>;
|
||||
def ProxyRegBF16 : ProxyRegInst<"b16", bf16, Float16Regs>;
|
||||
def ProxyRegF32 : ProxyRegInst<"f32", f32, Float32Regs>;
|
||||
def ProxyRegF64 : ProxyRegInst<"f64", f64, Float64Regs>;
|
||||
def ProxyRegF16x2 : ProxyRegInst<"b32", v2f16, Float16x2Regs>;
|
||||
def ProxyRegBF16x2 : ProxyRegInst<"b32", v2bf16, Float16x2Regs>;
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -2669,22 +2700,29 @@ let mayStore=1, hasSideEffects=0 in {
|
|||
|
||||
//---- Conversion ----
|
||||
|
||||
class F_BITCONVERT<string SzStr, NVPTXRegClass regclassIn,
|
||||
NVPTXRegClass regclassOut> :
|
||||
class F_BITCONVERT<string SzStr, ValueType TIn, ValueType TOut,
|
||||
NVPTXRegClass regclassIn = ValueToRegClass<TIn>.ret,
|
||||
NVPTXRegClass regclassOut = ValueToRegClass<TOut>.ret> :
|
||||
NVPTXInst<(outs regclassOut:$d), (ins regclassIn:$a),
|
||||
!strconcat("mov.b", SzStr, " \t$d, $a;"),
|
||||
[(set regclassOut:$d, (bitconvert regclassIn:$a))]>;
|
||||
[(set (TOut regclassOut:$d), (bitconvert (TIn regclassIn:$a)))]>;
|
||||
|
||||
def BITCONVERT_16_I2F : F_BITCONVERT<"16", Int16Regs, Float16Regs>;
|
||||
def BITCONVERT_16_F2I : F_BITCONVERT<"16", Float16Regs, Int16Regs>;
|
||||
def BITCONVERT_32_I2F : F_BITCONVERT<"32", Int32Regs, Float32Regs>;
|
||||
def BITCONVERT_32_F2I : F_BITCONVERT<"32", Float32Regs, Int32Regs>;
|
||||
def BITCONVERT_64_I2F : F_BITCONVERT<"64", Int64Regs, Float64Regs>;
|
||||
def BITCONVERT_64_F2I : F_BITCONVERT<"64", Float64Regs, Int64Regs>;
|
||||
def BITCONVERT_32_I2F16x2 : F_BITCONVERT<"32", Int32Regs, Float16x2Regs>;
|
||||
def BITCONVERT_32_F16x22I : F_BITCONVERT<"32", Float16x2Regs, Int32Regs>;
|
||||
def BITCONVERT_32_F2F16x2 : F_BITCONVERT<"32", Float32Regs, Float16x2Regs>;
|
||||
def BITCONVERT_32_F16x22F : F_BITCONVERT<"32", Float16x2Regs, Float32Regs>;
|
||||
def BITCONVERT_16_I2F : F_BITCONVERT<"16", i16, f16>;
|
||||
def BITCONVERT_16_F2I : F_BITCONVERT<"16", f16, i16>;
|
||||
def BITCONVERT_16_I2BF : F_BITCONVERT<"16", i16, bf16>;
|
||||
def BITCONVERT_16_BF2I : F_BITCONVERT<"16", bf16, i16>;
|
||||
def BITCONVERT_32_I2F : F_BITCONVERT<"32", i32, f32>;
|
||||
def BITCONVERT_32_F2I : F_BITCONVERT<"32", f32, i32>;
|
||||
def BITCONVERT_64_I2F : F_BITCONVERT<"64", i64, f64>;
|
||||
def BITCONVERT_64_F2I : F_BITCONVERT<"64", f64, i64>;
|
||||
def BITCONVERT_32_I2F16x2 : F_BITCONVERT<"32", i32, v2f16>;
|
||||
def BITCONVERT_32_F16x22I : F_BITCONVERT<"32", v2f16, i32>;
|
||||
def BITCONVERT_32_F2F16x2 : F_BITCONVERT<"32", f32, v2f16>;
|
||||
def BITCONVERT_32_F16x22F : F_BITCONVERT<"32", v2f16, f32>;
|
||||
def BITCONVERT_32_I2BF16x2 : F_BITCONVERT<"32", i32, v2bf16>;
|
||||
def BITCONVERT_32_BF16x22I : F_BITCONVERT<"32", v2bf16, i32>;
|
||||
def BITCONVERT_32_F2BF16x2 : F_BITCONVERT<"32", f32, v2bf16>;
|
||||
def BITCONVERT_32_BF16x22F : F_BITCONVERT<"32", v2bf16, f32>;
|
||||
|
||||
// NOTE: pred->fp are currently sub-optimal due to an issue in TableGen where
|
||||
// we cannot specify floating-point literals in isel patterns. Therefore, we
|
||||
|
@ -2752,23 +2790,23 @@ def : Pat<(f64 (uint_to_fp Int64Regs:$a)),
|
|||
|
||||
|
||||
// f16 -> sint
|
||||
def : Pat<(i1 (fp_to_sint Float16Regs:$a)),
|
||||
def : Pat<(i1 (fp_to_sint (f16 Float16Regs:$a))),
|
||||
(SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>;
|
||||
def : Pat<(i16 (fp_to_sint Float16Regs:$a)),
|
||||
(CVT_s16_f16 Float16Regs:$a, CvtRZI)>;
|
||||
def : Pat<(i32 (fp_to_sint Float16Regs:$a)),
|
||||
(CVT_s32_f16 Float16Regs:$a, CvtRZI)>;
|
||||
def : Pat<(i64 (fp_to_sint Float16Regs:$a)),
|
||||
def : Pat<(i16 (fp_to_sint (f16 Float16Regs:$a))),
|
||||
(CVT_s16_f16 (f16 Float16Regs:$a), CvtRZI)>;
|
||||
def : Pat<(i32 (fp_to_sint (f16 Float16Regs:$a))),
|
||||
(CVT_s32_f16 (f16 Float16Regs:$a), CvtRZI)>;
|
||||
def : Pat<(i64 (fp_to_sint (f16 Float16Regs:$a))),
|
||||
(CVT_s64_f16 Float16Regs:$a, CvtRZI)>;
|
||||
|
||||
// f16 -> uint
|
||||
def : Pat<(i1 (fp_to_uint Float16Regs:$a)),
|
||||
def : Pat<(i1 (fp_to_uint (f16 Float16Regs:$a))),
|
||||
(SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>;
|
||||
def : Pat<(i16 (fp_to_uint Float16Regs:$a)),
|
||||
def : Pat<(i16 (fp_to_uint (f16 Float16Regs:$a))),
|
||||
(CVT_u16_f16 Float16Regs:$a, CvtRZI)>;
|
||||
def : Pat<(i32 (fp_to_uint Float16Regs:$a)),
|
||||
def : Pat<(i32 (fp_to_uint (f16 Float16Regs:$a))),
|
||||
(CVT_u32_f16 Float16Regs:$a, CvtRZI)>;
|
||||
def : Pat<(i64 (fp_to_uint Float16Regs:$a)),
|
||||
def : Pat<(i64 (fp_to_uint (f16 Float16Regs:$a))),
|
||||
(CVT_u64_f16 Float16Regs:$a, CvtRZI)>;
|
||||
|
||||
// f32 -> sint
|
||||
|
@ -2915,7 +2953,7 @@ def : Pat<(select Int32Regs:$pred, Int32Regs:$a, Int32Regs:$b),
|
|||
def : Pat<(select Int32Regs:$pred, Int64Regs:$a, Int64Regs:$b),
|
||||
(SELP_b64rr Int64Regs:$a, Int64Regs:$b,
|
||||
(SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
|
||||
def : Pat<(select Int32Regs:$pred, Float16Regs:$a, Float16Regs:$b),
|
||||
def : Pat<(select Int32Regs:$pred, (f16 Float16Regs:$a), (f16 Float16Regs:$b)),
|
||||
(SELP_f16rr Float16Regs:$a, Float16Regs:$b,
|
||||
(SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
|
||||
def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b),
|
||||
|
@ -2980,7 +3018,7 @@ let hasSideEffects = false in {
|
|||
def BuildF16x2 : NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
"mov.b32 \t$dst, {{$a, $b}};",
|
||||
[(set Float16x2Regs:$dst,
|
||||
[(set (v2f16 Float16x2Regs:$dst),
|
||||
(build_vector (f16 Float16Regs:$a), (f16 Float16Regs:$b)))]>;
|
||||
|
||||
// Directly initializing underlying the b32 register is one less SASS
|
||||
|
@ -3079,13 +3117,13 @@ def : Pat<(f32 (fpround Float64Regs:$a)),
|
|||
(CVT_f32_f64 Float64Regs:$a, CvtRN)>;
|
||||
|
||||
// fpextend f16 -> f32
|
||||
def : Pat<(f32 (fpextend Float16Regs:$a)),
|
||||
def : Pat<(f32 (fpextend (f16 Float16Regs:$a))),
|
||||
(CVT_f32_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(f32 (fpextend Float16Regs:$a)),
|
||||
def : Pat<(f32 (fpextend (f16 Float16Regs:$a))),
|
||||
(CVT_f32_f16 Float16Regs:$a, CvtNONE)>;
|
||||
|
||||
// fpextend f16 -> f64
|
||||
def : Pat<(f64 (fpextend Float16Regs:$a)),
|
||||
def : Pat<(f64 (fpextend (f16 Float16Regs:$a))),
|
||||
(CVT_f64_f16 Float16Regs:$a, CvtNONE)>;
|
||||
|
||||
// fpextend f32 -> f64
|
||||
|
@ -3100,7 +3138,7 @@ def retflag : SDNode<"NVPTXISD::RET_FLAG", SDTNone,
|
|||
// fceil, ffloor, froundeven, ftrunc.
|
||||
|
||||
multiclass CVT_ROUND<SDNode OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
|
||||
def : Pat<(OpNode Float16Regs:$a),
|
||||
def : Pat<(OpNode (f16 Float16Regs:$a)),
|
||||
(CVT_f16_f16 Float16Regs:$a, Mode)>;
|
||||
def : Pat<(OpNode Float32Regs:$a),
|
||||
(CVT_f32_f32 Float32Regs:$a, ModeFTZ)>, Requires<[doF32FTZ]>;
|
||||
|
|
|
@ -75,6 +75,8 @@ bool NVPTXProxyRegErasure::runOnMachineFunction(MachineFunction &MF) {
|
|||
case NVPTX::ProxyRegI64:
|
||||
case NVPTX::ProxyRegF16:
|
||||
case NVPTX::ProxyRegF16x2:
|
||||
case NVPTX::ProxyRegBF16:
|
||||
case NVPTX::ProxyRegBF16x2:
|
||||
case NVPTX::ProxyRegF32:
|
||||
case NVPTX::ProxyRegF64:
|
||||
replaceMachineInstructionUsage(MF, MI);
|
||||
|
|
|
@ -60,8 +60,8 @@ def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>;
|
|||
def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>;
|
||||
def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4), VRFrame32, VRFrameLocal32)>;
|
||||
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
|
||||
def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>;
|
||||
def Float16x2Regs : NVPTXRegClass<[v2f16], 32, (add (sequence "HH%u", 0, 4))>;
|
||||
def Float16Regs : NVPTXRegClass<[f16,bf16], 16, (add (sequence "H%u", 0, 4))>;
|
||||
def Float16x2Regs : NVPTXRegClass<[v2f16,v2bf16], 32, (add (sequence "HH%u", 0, 4))>;
|
||||
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
|
||||
def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
|
||||
def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
; RUN: llc < %s -march=nvptx | FileCheck %s
|
||||
; RUN: %if ptxas %{ llc < %s -march=nvptx | %ptxas-verify %}
|
||||
|
||||
; LDST: .b8 bfloat_array[8] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
@"bfloat_array" = addrspace(1) constant [4 x bfloat]
|
||||
[bfloat 0xR0201, bfloat 0xR0403, bfloat 0xR0605, bfloat 0xR0807]
|
||||
|
||||
define void @test_load_store(bfloat addrspace(1)* %in, bfloat addrspace(1)* %out) {
|
||||
; CHECK-LABEL: @test_load_store
|
||||
; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}]
|
||||
; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]]
|
||||
%val = load bfloat, bfloat addrspace(1)* %in
|
||||
store bfloat %val, bfloat addrspace(1) * %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @test_bitcast_from_bfloat(bfloat addrspace(1)* %in, i16 addrspace(1)* %out) {
|
||||
; CHECK-LABEL: @test_bitcast_from_bfloat
|
||||
; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}]
|
||||
; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]]
|
||||
%val = load bfloat, bfloat addrspace(1) * %in
|
||||
%val_int = bitcast bfloat %val to i16
|
||||
store i16 %val_int, i16 addrspace(1)* %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @test_bitcast_to_bfloat(bfloat addrspace(1)* %out, i16 addrspace(1)* %in) {
|
||||
; CHECK-LABEL: @test_bitcast_to_bfloat
|
||||
; CHECK: ld.global.u16 [[TMP:%rs[0-9]+]], [{{%r[0-9]+}}]
|
||||
; CHECK: st.global.u16 [{{%r[0-9]+}}], [[TMP]]
|
||||
%val = load i16, i16 addrspace(1)* %in
|
||||
%val_fp = bitcast i16 %val to bfloat
|
||||
store bfloat %val_fp, bfloat addrspace(1)* %out
|
||||
ret void
|
||||
}
|
Loading…
Reference in New Issue