[AArch64][SVE] Fold away SETCC if original input was predicate vector.

This adds the following two folds:

Fold 1:
   setcc_merge_zero(
       all_active, extend(nxvNi1 ...), != splat(0))
  -> nxvNi1 ...

Fold 2:
   setcc_merge_zero(
       pred, extend(nxvNi1 ...), != splat(0))
  -> nxvNi1 and(pred, ...)

Reviewed By: david-arm

Differential Revision: https://reviews.llvm.org/D119334
This commit is contained in:
Sander de Smalen 2022-02-28 12:17:25 +00:00
parent b3e8ace198
commit eac2638ec1
6 changed files with 126 additions and 50 deletions

View File

@ -17158,26 +17158,45 @@ static SDValue performSetCCPunpkCombine(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
static SDValue performSetccMergeZeroCombine(SDNode *N, SelectionDAG &DAG) {
static SDValue
performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
assert(N->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
"Unexpected opcode!");
SelectionDAG &DAG = DCI.DAG;
SDValue Pred = N->getOperand(0);
SDValue LHS = N->getOperand(1);
SDValue RHS = N->getOperand(2);
ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(3))->get();
// setcc_merge_zero pred (sign_extend (setcc_merge_zero ... pred ...)), 0, ne
// => inner setcc_merge_zero
if (SDValue V = performSetCCPunpkCombine(N, DAG))
return V;
if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) &&
LHS->getOpcode() == ISD::SIGN_EXTEND &&
LHS->getOperand(0)->getValueType(0) == N->getValueType(0) &&
LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
LHS->getOperand(0)->getValueType(0) == N->getValueType(0)) {
// setcc_merge_zero(
// pred, extend(setcc_merge_zero(pred, ...)), != splat(0))
// => setcc_merge_zero(pred, ...)
if (LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
LHS->getOperand(0)->getOperand(0) == Pred)
return LHS->getOperand(0);
if (SDValue V = performSetCCPunpkCombine(N, DAG))
return V;
// setcc_merge_zero(
// all_active, extend(nxvNi1 ...), != splat(0))
// -> nxvNi1 ...
if (isAllActivePredicate(DAG, Pred))
return LHS->getOperand(0);
// setcc_merge_zero(
// pred, extend(nxvNi1 ...), != splat(0))
// -> nxvNi1 and(pred, ...)
if (DCI.isAfterLegalizeDAG())
// Do this after legalization to allow more folds on setcc_merge_zero
// to be recognized.
return DAG.getNode(ISD::AND, SDLoc(N), N->getValueType(0),
LHS->getOperand(0), Pred);
}
return SDValue();
}
@ -18175,7 +18194,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case AArch64ISD::UZP1:
return performUzpCombine(N, DAG);
case AArch64ISD::SETCC_MERGE_ZERO:
return performSetccMergeZeroCombine(N, DAG);
return performSetccMergeZeroCombine(N, DCI);
case AArch64ISD::GLD1_MERGE_ZERO:
case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
case AArch64ISD::GLD1_UXTW_MERGE_ZERO:

View File

@ -36,3 +36,28 @@ define <vscale x 16 x i8> @vselect_cmp_ugt(<vscale x 16 x i8> %a, <vscale x 16 x
%d = select <vscale x 16 x i1> %cmp, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c
ret <vscale x 16 x i8> %d
}
; Some folds to remove a redundant icmp if the original input was a predicate vector.
define <vscale x 4 x i1> @fold_away_icmp_ptrue_all(<vscale x 4 x i1> %p) {
; CHECK-LABEL: fold_away_icmp_ptrue_all:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%t0 = sext <vscale x 4 x i1> %p to <vscale x 4 x i32>
%t1 = icmp ne <vscale x 4 x i32> %t0, zeroinitializer
ret <vscale x 4 x i1> %t1
}
define <vscale x 4 x i1> @fold_away_icmp_ptrue_vl16(<vscale x 4 x i1> %p) vscale_range(4, 4) {
; CHECK-LABEL: fold_away_icmp_ptrue_vl16:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%t0 = call <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32 9) ; VL16 is encoded as 9.
%t1 = sext <vscale x 4 x i1> %p to <vscale x 4 x i32>
%t2 = call <vscale x 4 x i1> @llvm.aarch64.sve.cmpne.nxv4i32(<vscale x 4 x i1> %t0, <vscale x 4 x i32> %t1, <vscale x 4 x i32> zeroinitializer)
ret <vscale x 4 x i1> %t2
}
declare <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32)
declare <vscale x 4 x i1> @llvm.aarch64.sve.cmpne.nxv4i32(<vscale x 4 x i1>, <vscale x 4 x i32>, <vscale x 4 x i32>)

View File

@ -387,11 +387,10 @@ define void @masked_gather_v8i32(<8 x i32>* %a, <8 x i32*>* %b) #0 {
; VBITS_EQ_256-NEXT: mov z0.s, p2/z, #-1 // =0xffffffffffffffff
; VBITS_EQ_256-NEXT: punpklo p2.h, p2.b
; VBITS_EQ_256-NEXT: ext z0.b, z0.b, z0.b, #16
; VBITS_EQ_256-NEXT: mov z3.d, p2/z, #-1 // =0xffffffffffffffff
; VBITS_EQ_256-NEXT: and p2.b, p2/z, p2.b, p1.b
; VBITS_EQ_256-NEXT: sunpklo z0.d, z0.s
; VBITS_EQ_256-NEXT: cmpne p2.d, p1/z, z3.d, #0
; VBITS_EQ_256-NEXT: cmpne p1.d, p1/z, z0.d, #0
; VBITS_EQ_256-NEXT: ld1w { z2.d }, p2/z, [z2.d]
; VBITS_EQ_256-NEXT: cmpne p1.d, p1/z, z0.d, #0
; VBITS_EQ_256-NEXT: ld1w { z0.d }, p1/z, [z1.d]
; VBITS_EQ_256-NEXT: ptrue p1.s, vl4
; VBITS_EQ_256-NEXT: uzp1 z1.s, z2.s, z2.s

View File

@ -351,22 +351,21 @@ define void @masked_scatter_v8i32(<8 x i32>* %a, <8 x i32*>* %b) #0 {
; VBITS_EQ_256-NEXT: ptrue p0.s, vl8
; VBITS_EQ_256-NEXT: mov x8, #4
; VBITS_EQ_256-NEXT: ld1w { z0.s }, p0/z, [x0]
; VBITS_EQ_256-NEXT: cmpeq p0.s, p0/z, z0.s, #0
; VBITS_EQ_256-NEXT: punpklo p1.h, p0.b
; VBITS_EQ_256-NEXT: mov z4.s, p0/z, #-1 // =0xffffffffffffffff
; VBITS_EQ_256-NEXT: mov z1.d, p1/z, #-1 // =0xffffffffffffffff
; VBITS_EQ_256-NEXT: ptrue p1.d, vl4
; VBITS_EQ_256-NEXT: ld1d { z2.d }, p1/z, [x1, x8, lsl #3]
; VBITS_EQ_256-NEXT: ld1d { z1.d }, p1/z, [x1, x8, lsl #3]
; VBITS_EQ_256-NEXT: ld1d { z3.d }, p1/z, [x1]
; VBITS_EQ_256-NEXT: ext z4.b, z4.b, z4.b, #16
; VBITS_EQ_256-NEXT: cmpne p0.d, p1/z, z1.d, #0
; VBITS_EQ_256-NEXT: uunpklo z1.d, z0.s
; VBITS_EQ_256-NEXT: sunpklo z4.d, z4.s
; VBITS_EQ_256-NEXT: cmpeq p0.s, p0/z, z0.s, #0
; VBITS_EQ_256-NEXT: uunpklo z4.d, z0.s
; VBITS_EQ_256-NEXT: mov z2.s, p0/z, #-1 // =0xffffffffffffffff
; VBITS_EQ_256-NEXT: punpklo p0.h, p0.b
; VBITS_EQ_256-NEXT: ext z2.b, z2.b, z2.b, #16
; VBITS_EQ_256-NEXT: ext z0.b, z0.b, z0.b, #16
; VBITS_EQ_256-NEXT: cmpne p1.d, p1/z, z4.d, #0
; VBITS_EQ_256-NEXT: sunpklo z2.d, z2.s
; VBITS_EQ_256-NEXT: and p0.b, p0/z, p0.b, p1.b
; VBITS_EQ_256-NEXT: cmpne p1.d, p1/z, z2.d, #0
; VBITS_EQ_256-NEXT: uunpklo z0.d, z0.s
; VBITS_EQ_256-NEXT: st1w { z1.d }, p0, [z3.d]
; VBITS_EQ_256-NEXT: st1w { z0.d }, p1, [z2.d]
; VBITS_EQ_256-NEXT: st1w { z4.d }, p0, [z3.d]
; VBITS_EQ_256-NEXT: st1w { z0.d }, p1, [z1.d]
; VBITS_EQ_256-NEXT: ret
; VBITS_GE_512-LABEL: masked_scatter_v8i32:
; VBITS_GE_512: // %bb.0:

View File

@ -23,11 +23,10 @@ define <vscale x 8 x i1> @masked_load_sext_i8i16_ptrue_vl(i8* %ap, <vscale x 16
; CHECK-LABEL: masked_load_sext_i8i16_ptrue_vl:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b, vl64
; CHECK-NEXT: ptrue p1.h, vl32
; CHECK-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: ptrue p0.h, vl32
; CHECK-NEXT: cmpne p0.h, p0/z, z0.h, #0
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%p0 = call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 11)
%cmp = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpeq.nxv16i8(<vscale x 16 x i1> %p0, <vscale x 16 x i8> %b, <vscale x 16 x i8> zeroinitializer)
@ -45,8 +44,7 @@ define <vscale x 8 x i1> @masked_load_sext_i8i16_parg(i8* %ap, <vscale x 16 x i8
; CHECK-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; CHECK-NEXT: ptrue p1.h, vl32
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: cmpne p0.h, p1/z, z0.h, #0
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%cmp = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpeq.nxv16i8(<vscale x 16 x i1> %p0, <vscale x 16 x i8> %b, <vscale x 16 x i8> zeroinitializer)
%extract = call <vscale x 8 x i1> @llvm.experimental.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %cmp, i64 0)
@ -78,12 +76,11 @@ define <vscale x 4 x i1> @masked_load_sext_i8i32_ptrue_vl(i8* %ap, <vscale x 16
; CHECK-LABEL: masked_load_sext_i8i32_ptrue_vl:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b, vl64
; CHECK-NEXT: ptrue p1.s, vl32
; CHECK-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: ptrue p0.s, vl32
; CHECK-NEXT: cmpne p0.s, p0/z, z0.s, #0
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%p0 = call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 11)
%cmp = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpeq.nxv16i8(<vscale x 16 x i1> %p0, <vscale x 16 x i8> %b, <vscale x 16 x i8> zeroinitializer)
@ -102,8 +99,7 @@ define <vscale x 4 x i1> @masked_load_sext_i8i32_parg(i8* %ap, <vscale x 16 x i8
; CHECK-NEXT: ptrue p1.s, vl32
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: cmpne p0.s, p1/z, z0.s, #0
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%cmp = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpeq.nxv16i8(<vscale x 16 x i1> %p0, <vscale x 16 x i8> %b, <vscale x 16 x i8> zeroinitializer)
%extract = call <vscale x 4 x i1> @llvm.experimental.vector.extract.nxv4i1.nxv16i1(<vscale x 16 x i1> %cmp, i64 0)
@ -136,13 +132,12 @@ define <vscale x 2 x i1> @masked_load_sext_i8i64_ptrue_vl(i8* %ap, <vscale x 16
; CHECK-LABEL: masked_load_sext_i8i64_ptrue_vl:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b, vl64
; CHECK-NEXT: ptrue p1.d, vl32
; CHECK-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: mov z0.d, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: ptrue p0.d, vl32
; CHECK-NEXT: cmpne p0.d, p0/z, z0.d, #0
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%p0 = call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 11)
%cmp = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpeq.nxv16i8(<vscale x 16 x i1> %p0, <vscale x 16 x i8> %b, <vscale x 16 x i8> zeroinitializer)
@ -162,8 +157,7 @@ define <vscale x 2 x i1> @masked_load_sext_i8i64_parg(i8* %ap, <vscale x 16 x i8
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: mov z0.d, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: cmpne p0.d, p1/z, z0.d, #0
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%cmp = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpeq.nxv16i8(<vscale x 16 x i1> %p0, <vscale x 16 x i8> %b, <vscale x 16 x i8> zeroinitializer)
%extract = call <vscale x 2 x i1> @llvm.experimental.vector.extract.nxv2i1.nxv16i1(<vscale x 16 x i1> %cmp, i64 0)
@ -178,11 +172,10 @@ define <vscale x 8 x i1> @masked_load_sext_i8i16_ptrue_all(i8* %ap, <vscale x 16
; CHECK-LABEL: masked_load_sext_i8i16_ptrue_all:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b, vl64
; CHECK-NEXT: ptrue p1.h, vl32
; CHECK-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: ptrue p0.h, vl32
; CHECK-NEXT: cmpne p0.h, p0/z, z0.h, #0
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%p0 = call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 11)
%cmp = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpeq.nxv16i8(<vscale x 16 x i1> %p0, <vscale x 16 x i8> %b, <vscale x 16 x i8> zeroinitializer)
@ -198,12 +191,11 @@ define <vscale x 4 x i1> @masked_load_sext_i8i32_ptrue_all(i8* %ap, <vscale x 16
; CHECK-LABEL: masked_load_sext_i8i32_ptrue_all:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b, vl64
; CHECK-NEXT: ptrue p1.s, vl32
; CHECK-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: ptrue p0.s, vl32
; CHECK-NEXT: cmpne p0.s, p0/z, z0.s, #0
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%p0 = call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 11)
%cmp = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpeq.nxv16i8(<vscale x 16 x i1> %p0, <vscale x 16 x i8> %b, <vscale x 16 x i8> zeroinitializer)
@ -223,9 +215,6 @@ define <vscale x 2 x i1> @masked_load_sext_i8i64_ptrue_all(i8* %ap, <vscale x 16
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: mov z0.d, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: cmpne p0.d, p0/z, z0.d, #0
; CHECK-NEXT: ret
%p0 = call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 31)
%cmp = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpeq.nxv16i8(<vscale x 16 x i1> %p0, <vscale x 16 x i8> %b, <vscale x 16 x i8> zeroinitializer)

View File

@ -70,6 +70,51 @@ if.end:
ret void
}
; Fold away the redundant setcc::
; setcc(ne, <all ones>, sext(nxvNi1 ...), splat(0))
; -> nxvNi1 ...
define <vscale x 16 x i1> @sve_cmpne_setcc_all_true_sext(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %pg) {
; CHECK-LABEL: sve_cmpne_setcc_all_true_sext:
; CHECK: // %bb.0:
; CHECK-NEXT: ret
%alltrue.ins = insertelement <vscale x 16 x i1> poison, i1 true, i32 0
%alltrue = shufflevector <vscale x 16 x i1> %alltrue.ins, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
%pg.sext = sext <vscale x 16 x i1> %pg to <vscale x 16 x i8>
%cmp2 = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1> %alltrue, <vscale x 16 x i8> %pg.sext, <vscale x 16 x i8> zeroinitializer)
ret <vscale x 16 x i1> %cmp2
}
; Fold away the redundant setcc::
; setcc(ne, pred, sext(setcc(ne, pred, ..., splat(0))), splat(0))
; -> setcc(ne, pred, ..., splat(0))
define <vscale x 16 x i1> @sve_cmpne_setcc_equal_pred(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %pg) {
; CHECK-LABEL: sve_cmpne_setcc_equal_pred:
; CHECK: // %bb.0:
; CHECK-NEXT: cmpne p0.b, p0/z, z0.b, #0
; CHECK-NEXT: ret
%cmp1 = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %vec, <vscale x 16 x i8> zeroinitializer)
%cmp1.sext = sext <vscale x 16 x i1> %cmp1 to <vscale x 16 x i8>
%cmp2 = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %cmp1.sext, <vscale x 16 x i8> zeroinitializer)
ret <vscale x 16 x i1> %cmp2
}
; Combine:
; setcc(ne, pred1, sext(setcc(ne, pred2, ..., splat(0))), splat(0))
; -> setcc(ne, and(pred1, pred2), ..., splat(0))
define <vscale x 16 x i1> @sve_cmpne_setcc_different_pred(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %pg1, <vscale x 16 x i1> %pg2) {
; CHECK-LABEL: sve_cmpne_setcc_different_pred:
; CHECK: // %bb.0:
; CHECK-NEXT: cmpne p0.b, p0/z, z0.b, #0
; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ret
%cmp1 = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1> %pg1, <vscale x 16 x i8> %vec, <vscale x 16 x i8> zeroinitializer)
%cmp1.sext = sext <vscale x 16 x i1> %cmp1 to <vscale x 16 x i8>
%cmp2 = call <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1> %pg2, <vscale x 16 x i8> %cmp1.sext, <vscale x 16 x i8> zeroinitializer)
ret <vscale x 16 x i1> %cmp2
}
declare <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
declare i1 @llvm.aarch64.sve.ptest.any.nxv8i1(<vscale x 8 x i1>, <vscale x 8 x i1>)
declare i1 @llvm.aarch64.sve.ptest.last.nxv8i1(<vscale x 8 x i1>, <vscale x 8 x i1>)