[AArch64] Add support for various operations on nxv1i1 types.

The supported operations are:
* Logical operations (and, or, xor, bic)
* Logical reductions (and, or, xor, [us]min, [us]max)
* Conversions to/from svbool_t
* Predicate count (CNTP)

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D128835
This commit is contained in:
Sander de Smalen 2022-07-06 15:00:38 +00:00
parent e7db82d701
commit 95e08824fa
6 changed files with 169 additions and 2 deletions

View File

@ -1086,8 +1086,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// SVE and SME.
if (Subtarget->hasSVE() || Subtarget->hasSME()) {
for (auto VT :
{MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1})
{MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) {
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
}
}
if (Subtarget->hasSVE()) {
@ -1176,7 +1178,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
setOperationAction(ISD::SELECT_CC, VT, Expand);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
@ -4361,6 +4362,8 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
int Pattern) {
if (VT == MVT::nxv1i1 && Pattern == AArch64SVEPredPattern::all)
return DAG.getConstant(1, DL, MVT::nxv1i1);
return DAG.getNode(AArch64ISD::PTRUE, DL, VT,
DAG.getTargetConstant(Pattern, DL, MVT::i32));
}
@ -21104,6 +21107,11 @@ SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
case ISD::VECREDUCE_XOR: {
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64);
if (OpVT == MVT::nxv1i1) {
// Emulate a CNTP on .Q using .D and a different governing predicate.
Pg = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv2i1, Pg);
Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv2i1, Op);
}
SDValue Cntp =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, ID, Pg, Op);
return DAG.getAnyExtOrTrunc(Cntp, DL, VT);

View File

@ -2139,6 +2139,8 @@ let Predicates = [HasSVEorSME] in {
(PTEST_PP PPR:$pg, PPR:$src)>;
def : Pat<(AArch64ptest (nxv2i1 PPR:$pg), (nxv2i1 PPR:$src)),
(PTEST_PP PPR:$pg, PPR:$src)>;
def : Pat<(AArch64ptest (nxv1i1 PPR:$pg), (nxv1i1 PPR:$src)),
(PTEST_PP PPR:$pg, PPR:$src)>;
let AddedComplexity = 1 in {
class LD1RPat<ValueType vt, SDPatternOperator operator,
@ -2404,6 +2406,9 @@ let Predicates = [HasSVEorSME] in {
(AND_PPzPP (PTRUE_S 31), PPR:$Ps1, PPR:$Ps2)>;
def : Pat<(nxv2i1 (and PPR:$Ps1, PPR:$Ps2)),
(AND_PPzPP (PTRUE_D 31), PPR:$Ps1, PPR:$Ps2)>;
// Emulate .Q operation using a PTRUE_D when the other lanes don't matter.
def : Pat<(nxv1i1 (and PPR:$Ps1, PPR:$Ps2)),
(AND_PPzPP (PTRUE_D 31), PPR:$Ps1, PPR:$Ps2)>;
// Add more complex addressing modes here as required
multiclass pred_load<ValueType Ty, ValueType PredTy, SDPatternOperator Load,

View File

@ -1691,6 +1691,9 @@ multiclass sve_int_pred_log<bits<4> opc, string asm, SDPatternOperator op,
!cast<Instruction>(NAME), PTRUE_S>;
def : SVE_2_Op_AllActive_Pat<nxv2i1, op_nopred, nxv2i1, nxv2i1,
!cast<Instruction>(NAME), PTRUE_D>;
// Emulate .Q operation using a PTRUE_D when the other lanes don't matter.
def : SVE_2_Op_AllActive_Pat<nxv1i1, op_nopred, nxv1i1, nxv1i1,
!cast<Instruction>(NAME), PTRUE_D>;
}
// An instance of sve_int_pred_log_and but uses op_nopred's first operand as the
@ -1706,6 +1709,9 @@ multiclass sve_int_pred_log_v2<bits<4> opc, string asm, SDPatternOperator op,
(!cast<Instruction>(NAME) $Op1, $Op1, $Op2)>;
def : Pat<(nxv2i1 (op_nopred nxv2i1:$Op1, nxv2i1:$Op2)),
(!cast<Instruction>(NAME) $Op1, $Op1, $Op2)>;
// Emulate .Q operation using a PTRUE_D when the other lanes don't matter.
def : Pat<(nxv1i1 (op_nopred nxv1i1:$Op1, nxv1i1:$Op2)),
(!cast<Instruction>(NAME) $Op1, $Op1, $Op2)>;
}
//===----------------------------------------------------------------------===//

View File

@ -46,6 +46,15 @@ define <vscale x 16 x i8> @and_b_zero(<vscale x 16 x i8> %a) {
ret <vscale x 16 x i8> %res
}
define <vscale x 1 x i1> @and_pred_q(<vscale x 1 x i1> %a, <vscale x 1 x i1> %b) {
; CHECK-LABEL: and_pred_q:
; CHECK: // %bb.0:
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%res = and <vscale x 1 x i1> %a, %b
ret <vscale x 1 x i1> %res
}
define <vscale x 2 x i1> @and_pred_d(<vscale x 2 x i1> %a, <vscale x 2 x i1> %b) {
; CHECK-LABEL: and_pred_d:
; CHECK: // %bb.0:
@ -126,6 +135,17 @@ define <vscale x 16 x i8> @bic_b(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
ret <vscale x 16 x i8> %res
}
define <vscale x 1 x i1> @bic_pred_q(<vscale x 1 x i1> %a, <vscale x 1 x i1> %b) {
; CHECK-LABEL: bic_pred_q:
; CHECK: // %bb.0:
; CHECK-NEXT: bic p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%allones = shufflevector <vscale x 1 x i1> insertelement(<vscale x 1 x i1> undef, i1 true, i32 0), <vscale x 1 x i1> undef, <vscale x 1 x i32> zeroinitializer
%not_b = xor <vscale x 1 x i1> %b, %allones
%res = and <vscale x 1 x i1> %a, %not_b
ret <vscale x 1 x i1> %res
}
define <vscale x 2 x i1> @bic_pred_d(<vscale x 2 x i1> %a, <vscale x 2 x i1> %b) {
; CHECK-LABEL: bic_pred_d:
; CHECK: // %bb.0:
@ -214,6 +234,15 @@ define <vscale x 16 x i8> @or_b_zero(<vscale x 16 x i8> %a) {
ret <vscale x 16 x i8> %res
}
define <vscale x 1 x i1> @or_pred_q(<vscale x 1 x i1> %a, <vscale x 1 x i1> %b) {
; CHECK-LABEL: or_pred_q:
; CHECK: // %bb.0:
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
; CHECK-NEXT: ret
%res = or <vscale x 1 x i1> %a, %b
ret <vscale x 1 x i1> %res
}
define <vscale x 2 x i1> @or_pred_d(<vscale x 2 x i1> %a, <vscale x 2 x i1> %b) {
; CHECK-LABEL: or_pred_d:
; CHECK: // %bb.0:
@ -294,6 +323,16 @@ define <vscale x 16 x i8> @xor_b_zero(<vscale x 16 x i8> %a) {
ret <vscale x 16 x i8> %res
}
define <vscale x 1 x i1> @xor_pred_q(<vscale x 1 x i1> %a, <vscale x 1 x i1> %b) {
; CHECK-LABEL: xor_pred_q:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p2.d
; CHECK-NEXT: eor p0.b, p2/z, p0.b, p1.b
; CHECK-NEXT: ret
%res = xor <vscale x 1 x i1> %a, %b
ret <vscale x 1 x i1> %res
}
define <vscale x 2 x i1> @xor_pred_d(<vscale x 2 x i1> %a, <vscale x 2 x i1> %b) {
; CHECK-LABEL: xor_pred_d:
; CHECK: // %bb.0:

View File

@ -51,6 +51,19 @@ define i1 @reduce_and_nxv2i1(<vscale x 2 x i1> %vec) {
ret i1 %res
}
define i1 @reduce_and_nxv1i1(<vscale x 1 x i1> %vec) {
; CHECK-LABEL: reduce_and_nxv1i1:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: punpklo p2.h, p1.b
; CHECK-NEXT: eor p0.b, p1/z, p0.b, p2.b
; CHECK-NEXT: ptest p2, p0.b
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
%res = call i1 @llvm.vector.reduce.and.i1.nxv1i1(<vscale x 1 x i1> %vec)
ret i1 %res
}
; ORV
define i1 @reduce_or_nxv16i1(<vscale x 16 x i1> %vec) {
@ -93,6 +106,16 @@ define i1 @reduce_or_nxv2i1(<vscale x 2 x i1> %vec) {
ret i1 %res
}
define i1 @reduce_or_nxv1i1(<vscale x 1 x i1> %vec) {
; CHECK-LABEL: reduce_or_nxv1i1:
; CHECK: // %bb.0:
; CHECK-NEXT: ptest p0, p0.b
; CHECK-NEXT: cset w0, ne
; CHECK-NEXT: ret
%res = call i1 @llvm.vector.reduce.or.i1.nxv1i1(<vscale x 1 x i1> %vec)
ret i1 %res
}
; XORV
define i1 @reduce_xor_nxv16i1(<vscale x 16 x i1> %vec) {
@ -139,6 +162,18 @@ define i1 @reduce_xor_nxv2i1(<vscale x 2 x i1> %vec) {
ret i1 %res
}
define i1 @reduce_xor_nxv1i1(<vscale x 1 x i1> %vec) {
; CHECK-LABEL: reduce_xor_nxv1i1:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: punpklo p1.h, p1.b
; CHECK-NEXT: cntp x8, p1, p0.d
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: ret
%res = call i1 @llvm.vector.reduce.xor.i1.nxv1i1(<vscale x 1 x i1> %vec)
ret i1 %res
}
; SMAXV
define i1 @reduce_smax_nxv16i1(<vscale x 16 x i1> %vec) {
@ -189,6 +224,19 @@ define i1 @reduce_smax_nxv2i1(<vscale x 2 x i1> %vec) {
ret i1 %res
}
define i1 @reduce_smax_nxv1i1(<vscale x 1 x i1> %vec) {
; CHECK-LABEL: reduce_smax_nxv1i1:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: punpklo p2.h, p1.b
; CHECK-NEXT: eor p0.b, p1/z, p0.b, p2.b
; CHECK-NEXT: ptest p2, p0.b
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
%res = call i1 @llvm.vector.reduce.smax.i1.nxv1i1(<vscale x 1 x i1> %vec)
ret i1 %res
}
; SMINV
define i1 @reduce_smin_nxv16i1(<vscale x 16 x i1> %vec) {
@ -231,6 +279,16 @@ define i1 @reduce_smin_nxv2i1(<vscale x 2 x i1> %vec) {
ret i1 %res
}
define i1 @reduce_smin_nxv1i1(<vscale x 1 x i1> %vec) {
; CHECK-LABEL: reduce_smin_nxv1i1:
; CHECK: // %bb.0:
; CHECK-NEXT: ptest p0, p0.b
; CHECK-NEXT: cset w0, ne
; CHECK-NEXT: ret
%res = call i1 @llvm.vector.reduce.smin.i1.nxv1i1(<vscale x 1 x i1> %vec)
ret i1 %res
}
; UMAXV
define i1 @reduce_umax_nxv16i1(<vscale x 16 x i1> %vec) {
@ -273,6 +331,16 @@ define i1 @reduce_umax_nxv2i1(<vscale x 2 x i1> %vec) {
ret i1 %res
}
define i1 @reduce_umax_nxv1i1(<vscale x 1 x i1> %vec) {
; CHECK-LABEL: reduce_umax_nxv1i1:
; CHECK: // %bb.0:
; CHECK-NEXT: ptest p0, p0.b
; CHECK-NEXT: cset w0, ne
; CHECK-NEXT: ret
%res = call i1 @llvm.vector.reduce.umax.i1.nxv1i1(<vscale x 1 x i1> %vec)
ret i1 %res
}
; UMINV
define i1 @reduce_umin_nxv16i1(<vscale x 16 x i1> %vec) {
@ -311,6 +379,19 @@ define i1 @reduce_umin_nxv4i1(<vscale x 4 x i1> %vec) {
ret i1 %res
}
define i1 @reduce_umin_nxv1i1(<vscale x 1 x i1> %vec) {
; CHECK-LABEL: reduce_umin_nxv1i1:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: punpklo p2.h, p1.b
; CHECK-NEXT: eor p0.b, p1/z, p0.b, p2.b
; CHECK-NEXT: ptest p2, p0.b
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
%res = call i1 @llvm.vector.reduce.umin.i1.nxv1i1(<vscale x 1 x i1> %vec)
ret i1 %res
}
define i1 @reduce_umin_nxv2i1(<vscale x 2 x i1> %vec) {
; CHECK-LABEL: reduce_umin_nxv2i1:
; CHECK: // %bb.0:
@ -327,33 +408,40 @@ declare i1 @llvm.vector.reduce.and.i1.nxv16i1(<vscale x 16 x i1> %vec)
declare i1 @llvm.vector.reduce.and.i1.nxv8i1(<vscale x 8 x i1> %vec)
declare i1 @llvm.vector.reduce.and.i1.nxv4i1(<vscale x 4 x i1> %vec)
declare i1 @llvm.vector.reduce.and.i1.nxv2i1(<vscale x 2 x i1> %vec)
declare i1 @llvm.vector.reduce.and.i1.nxv1i1(<vscale x 1 x i1> %vec)
declare i1 @llvm.vector.reduce.or.i1.nxv16i1(<vscale x 16 x i1> %vec)
declare i1 @llvm.vector.reduce.or.i1.nxv8i1(<vscale x 8 x i1> %vec)
declare i1 @llvm.vector.reduce.or.i1.nxv4i1(<vscale x 4 x i1> %vec)
declare i1 @llvm.vector.reduce.or.i1.nxv2i1(<vscale x 2 x i1> %vec)
declare i1 @llvm.vector.reduce.or.i1.nxv1i1(<vscale x 1 x i1> %vec)
declare i1 @llvm.vector.reduce.xor.i1.nxv16i1(<vscale x 16 x i1> %vec)
declare i1 @llvm.vector.reduce.xor.i1.nxv8i1(<vscale x 8 x i1> %vec)
declare i1 @llvm.vector.reduce.xor.i1.nxv4i1(<vscale x 4 x i1> %vec)
declare i1 @llvm.vector.reduce.xor.i1.nxv2i1(<vscale x 2 x i1> %vec)
declare i1 @llvm.vector.reduce.xor.i1.nxv1i1(<vscale x 1 x i1> %vec)
declare i1 @llvm.vector.reduce.smin.i1.nxv16i1(<vscale x 16 x i1> %vec)
declare i1 @llvm.vector.reduce.smin.i1.nxv8i1(<vscale x 8 x i1> %vec)
declare i1 @llvm.vector.reduce.smin.i1.nxv4i1(<vscale x 4 x i1> %vec)
declare i1 @llvm.vector.reduce.smin.i1.nxv2i1(<vscale x 2 x i1> %vec)
declare i1 @llvm.vector.reduce.smin.i1.nxv1i1(<vscale x 1 x i1> %vec)
declare i1 @llvm.vector.reduce.smax.i1.nxv16i1(<vscale x 16 x i1> %vec)
declare i1 @llvm.vector.reduce.smax.i1.nxv8i1(<vscale x 8 x i1> %vec)
declare i1 @llvm.vector.reduce.smax.i1.nxv4i1(<vscale x 4 x i1> %vec)
declare i1 @llvm.vector.reduce.smax.i1.nxv2i1(<vscale x 2 x i1> %vec)
declare i1 @llvm.vector.reduce.smax.i1.nxv1i1(<vscale x 1 x i1> %vec)
declare i1 @llvm.vector.reduce.umin.i1.nxv16i1(<vscale x 16 x i1> %vec)
declare i1 @llvm.vector.reduce.umin.i1.nxv8i1(<vscale x 8 x i1> %vec)
declare i1 @llvm.vector.reduce.umin.i1.nxv4i1(<vscale x 4 x i1> %vec)
declare i1 @llvm.vector.reduce.umin.i1.nxv2i1(<vscale x 2 x i1> %vec)
declare i1 @llvm.vector.reduce.umin.i1.nxv1i1(<vscale x 1 x i1> %vec)
declare i1 @llvm.vector.reduce.umax.i1.nxv16i1(<vscale x 16 x i1> %vec)
declare i1 @llvm.vector.reduce.umax.i1.nxv8i1(<vscale x 8 x i1> %vec)
declare i1 @llvm.vector.reduce.umax.i1.nxv4i1(<vscale x 4 x i1> %vec)
declare i1 @llvm.vector.reduce.umax.i1.nxv2i1(<vscale x 2 x i1> %vec)
declare i1 @llvm.vector.reduce.umax.i1.nxv1i1(<vscale x 1 x i1> %vec)

View File

@ -44,6 +44,17 @@ define <vscale x 16 x i1> @reinterpret_bool_from_d(<vscale x 2 x i1> %pg) {
ret <vscale x 16 x i1> %out
}
define <vscale x 16 x i1> @reinterpret_bool_from_q(<vscale x 1 x i1> %arg) {
; CHECK-LABEL: reinterpret_bool_from_q:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: punpklo p1.h, p1.b
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%res = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv1i1(<vscale x 1 x i1> %arg)
ret <vscale x 16 x i1> %res
}
;
; Converting from svbool_t
;
@ -80,6 +91,14 @@ define <vscale x 2 x i1> @reinterpret_bool_to_d(<vscale x 16 x i1> %pg) {
ret <vscale x 2 x i1> %out
}
define <vscale x 1 x i1> @reinterpret_bool_to_q(<vscale x 16 x i1> %pg) {
; CHECK-LABEL: reinterpret_bool_to_q:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%out = call <vscale x 1 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv1i1(<vscale x 16 x i1> %pg)
ret <vscale x 1 x i1> %out
}
; Reinterpreting a ptrue should not introduce an `and` instruction.
define <vscale x 16 x i1> @reinterpret_ptrue() {
; CHECK-LABEL: reinterpret_ptrue:
@ -124,8 +143,10 @@ declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv16i1(<vscale x
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1>)
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1>)
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv1i1(<vscale x 1 x i1>)
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv16i1(<vscale x 16 x i1>)
declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>)
declare <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1>)
declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1>)
declare <vscale x 1 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv1i1(<vscale x 16 x i1>)