[RISCV] Preserve fixed-length VL on insert_vector_elt in more cases

This patch fixes up one case where the fixed-length-vector VL was
dropped (falling back to VLMAX) when inserting vector elements, as the
code would lower via ISD::INSERT_VECTOR_ELT (at index 0) which loses the
fixed-length vector information.

To this end, a custom node, VMV_S_XF_VL, was introduced to carry the VL
operand through to the final instruction. This node wraps the RVV
vmv.s.x and vmv.s.f instructions, which were being selected by
insert_vector_elt anyway.

There should be no observable difference in scalable-vector codegen.

There is still one outstanding drop from fixed-length VL to VLMAX, when
an i64 element is inserted into a vector on RV32; the splat (which is
custom legalized) has no notion of the original fixed-length vector
type.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D97842
This commit is contained in:
Fraser Cormack 2021-03-03 07:50:49 +00:00
parent 1bdb636661
commit d8e1d2ebf4
5 changed files with 29 additions and 49 deletions

View File

@ -2210,14 +2210,12 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
// (slideup vec, (insertelt (slidedown impdef, vec, idx), val, 0), idx), // (slideup vec, (insertelt (slidedown impdef, vec, idx), val, 0), idx),
if (Subtarget.is64Bit() || Val.getValueType() != MVT::i64) { if (Subtarget.is64Bit() || Val.getValueType() != MVT::i64) {
if (isNullConstant(Idx)) if (isNullConstant(Idx))
return Op; return DAG.getNode(RISCVISD::VMV_S_XF_VL, DL, ContainerVT, Vec, Val, VL);
SDValue Slidedown = SDValue Slidedown =
DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, ContainerVT, DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, ContainerVT,
DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL); DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL);
SDValue InsertElt0 = SDValue InsertElt0 =
DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT, Slidedown, Val, DAG.getNode(RISCVISD::VMV_S_XF_VL, DL, ContainerVT, Slidedown, Val, VL);
DAG.getConstant(0, DL, Subtarget.getXLenVT()));
return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, ContainerVT, Vec, InsertElt0, return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, ContainerVT, Vec, InsertElt0,
Idx, Mask, VL); Idx, Mask, VL);
} }
@ -5735,6 +5733,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VMV_V_X_VL) NODE_NAME_CASE(VMV_V_X_VL)
NODE_NAME_CASE(VFMV_V_F_VL) NODE_NAME_CASE(VFMV_V_F_VL)
NODE_NAME_CASE(VMV_X_S) NODE_NAME_CASE(VMV_X_S)
NODE_NAME_CASE(VMV_S_XF_VL)
NODE_NAME_CASE(SPLAT_VECTOR_I64) NODE_NAME_CASE(SPLAT_VECTOR_I64)
NODE_NAME_CASE(READ_VLENB) NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL) NODE_NAME_CASE(TRUNCATE_VECTOR_VL)

View File

@ -100,6 +100,9 @@ enum NodeType : unsigned {
// VMV_X_S matches the semantics of vmv.x.s. The result is always XLenVT sign // VMV_X_S matches the semantics of vmv.x.s. The result is always XLenVT sign
// extended from the vector element size. // extended from the vector element size.
VMV_X_S, VMV_X_S,
// VMV_S_XF_VL matches the semantics of vmv.s.x/vmv.s.f, depending on the
// types of its operands. It carries a VL operand.
VMV_S_XF_VL,
// Splats an i64 scalar to a vector type (with element type i64) where the // Splats an i64 scalar to a vector type (with element type i64) where the
// scalar is a sign-extended i32. // scalar is a sign-extended i32.
SPLAT_VECTOR_I64, SPLAT_VECTOR_I64,

View File

@ -754,47 +754,16 @@ foreach fvti = AllFloatVectors in {
} // Predicates = [HasStdExtV, HasStdExtF] } // Predicates = [HasStdExtV, HasStdExtF]
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Vector Element Inserts/Extracts // Vector Element Extracts
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// The built-in TableGen 'insertelt' node must return the same type as the
// vector element type. On RISC-V, XLenVT is the only legal integer type, so
// for integer inserts we use a custom node which inserts an XLenVT-typed
// value.
def riscv_insert_vector_elt
: SDNode<"ISD::INSERT_VECTOR_ELT",
SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisVT<2, XLenVT>,
SDTCisPtrTy<3>]>, []>;
let Predicates = [HasStdExtV] in
foreach vti = AllIntegerVectors in {
def : Pat<(vti.Vector (riscv_insert_vector_elt (vti.Vector vti.RegClass:$merge),
vti.ScalarRegClass:$rs1, 0)),
(!cast<Instruction>("PseudoVMV_S_X_"#vti.LMul.MX)
vti.RegClass:$merge,
(vti.Scalar vti.ScalarRegClass:$rs1),
vti.AVL, vti.SEW)>;
}
let Predicates = [HasStdExtV, HasStdExtF] in let Predicates = [HasStdExtV, HasStdExtF] in
foreach vti = AllFloatVectors in { foreach vti = AllFloatVectors in {
defvar MX = vti.LMul.MX;
defvar vmv_f_s_inst = !cast<Instruction>(!strconcat("PseudoVFMV_", defvar vmv_f_s_inst = !cast<Instruction>(!strconcat("PseudoVFMV_",
vti.ScalarSuffix, vti.ScalarSuffix,
"_S_", MX)); "_S_", vti.LMul.MX));
defvar vmv_s_f_inst = !cast<Instruction>(!strconcat("PseudoVFMV_S_", // Only pattern-match extract-element operations where the index is 0. Any
vti.ScalarSuffix, // other index will have been custom-lowered to slide the vector correctly
"_", vti.LMul.MX)); // into place.
// Only pattern-match insert/extract-element operations where the index is
// 0. Any other index will have been custom-lowered to slide the vector
// correctly into place (and, in the case of insert, slide it back again
// afterwards).
def : Pat<(vti.Scalar (extractelt (vti.Vector vti.RegClass:$rs2), 0)), def : Pat<(vti.Scalar (extractelt (vti.Vector vti.RegClass:$rs2), 0)),
(vmv_f_s_inst vti.RegClass:$rs2, vti.SEW)>; (vmv_f_s_inst vti.RegClass:$rs2, vti.SEW)>;
def : Pat<(vti.Vector (insertelt (vti.Vector vti.RegClass:$merge),
vti.ScalarRegClass:$rs1, 0)),
(vmv_s_f_inst vti.RegClass:$merge,
(vti.Scalar vti.ScalarRegClass:$rs1),
vti.AVL, vti.SEW)>;
} }

View File

@ -54,6 +54,9 @@ def riscv_vfmv_v_f_vl : SDNode<"RISCVISD::VFMV_V_F_VL",
SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisFP<0>, SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisFP<0>,
SDTCisEltOfVec<1, 0>, SDTCisEltOfVec<1, 0>,
SDTCisVT<2, XLenVT>]>>; SDTCisVT<2, XLenVT>]>>;
def riscv_vmv_s_xf_vl : SDNode<"RISCVISD::VMV_S_XF_VL",
SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
SDTCisVT<3, XLenVT>]>>;
def riscv_vle_vl : SDNode<"RISCVISD::VLE_VL", SDT_RISCVVLE_VL, def riscv_vle_vl : SDNode<"RISCVISD::VLE_VL", SDT_RISCVVLE_VL,
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
@ -941,10 +944,16 @@ foreach mti = AllMasks in {
} // Predicates = [HasStdExtV] } // Predicates = [HasStdExtV]
// 17.4. Vector Register GAther Instruction
let Predicates = [HasStdExtV] in { let Predicates = [HasStdExtV] in {
// 17.1. Integer Scalar Move Instructions
// 17.4. Vector Register Gather Instruction
foreach vti = AllIntegerVectors in { foreach vti = AllIntegerVectors in {
def : Pat<(vti.Vector (riscv_vmv_s_xf_vl (vti.Vector vti.RegClass:$merge),
(XLenVT vti.ScalarRegClass:$rs1),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMV_S_X_"#vti.LMul.MX)
vti.RegClass:$merge,
(vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.SEW)>;
def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1, def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask true_mask), (vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))), (XLenVT (VLOp GPR:$vl)))),
@ -961,7 +970,14 @@ foreach vti = AllIntegerVectors in {
let Predicates = [HasStdExtV, HasStdExtF] in { let Predicates = [HasStdExtV, HasStdExtF] in {
// 17.2. Floating-Point Scalar Move Instructions
foreach vti = AllFloatVectors in { foreach vti = AllFloatVectors in {
def : Pat<(vti.Vector (riscv_vmv_s_xf_vl (vti.Vector vti.RegClass:$merge),
vti.ScalarRegClass:$rs1,
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVFMV_S_"#vti.ScalarSuffix#"_"#vti.LMul.MX)
vti.RegClass:$merge,
(vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.SEW)>;
def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1, def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask true_mask), (vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))), (XLenVT (VLOp GPR:$vl)))),

View File

@ -30,7 +30,6 @@ define void @insertelt_v4i64(<4 x i64>* %x, i64 %y) {
; RV64-NEXT: vsetivli a2, 4, e64,m2,ta,mu ; RV64-NEXT: vsetivli a2, 4, e64,m2,ta,mu
; RV64-NEXT: vle64.v v26, (a0) ; RV64-NEXT: vle64.v v26, (a0)
; RV64-NEXT: vslidedown.vi v28, v26, 3 ; RV64-NEXT: vslidedown.vi v28, v26, 3
; RV64-NEXT: vsetvli a2, zero, e64,m2,ta,mu
; RV64-NEXT: vmv.s.x v28, a1 ; RV64-NEXT: vmv.s.x v28, a1
; RV64-NEXT: vsetivli a1, 4, e64,m2,tu,mu ; RV64-NEXT: vsetivli a1, 4, e64,m2,tu,mu
; RV64-NEXT: vslideup.vi v26, v28, 3 ; RV64-NEXT: vslideup.vi v26, v28, 3
@ -101,7 +100,6 @@ define void @insertelt_v16i8(<16 x i8>* %x, i8 %y) {
; RV32-NEXT: vsetivli a2, 16, e8,m1,ta,mu ; RV32-NEXT: vsetivli a2, 16, e8,m1,ta,mu
; RV32-NEXT: vle8.v v25, (a0) ; RV32-NEXT: vle8.v v25, (a0)
; RV32-NEXT: vslidedown.vi v26, v25, 14 ; RV32-NEXT: vslidedown.vi v26, v25, 14
; RV32-NEXT: vsetvli a2, zero, e8,m1,ta,mu
; RV32-NEXT: vmv.s.x v26, a1 ; RV32-NEXT: vmv.s.x v26, a1
; RV32-NEXT: vsetivli a1, 16, e8,m1,tu,mu ; RV32-NEXT: vsetivli a1, 16, e8,m1,tu,mu
; RV32-NEXT: vslideup.vi v25, v26, 14 ; RV32-NEXT: vslideup.vi v25, v26, 14
@ -114,7 +112,6 @@ define void @insertelt_v16i8(<16 x i8>* %x, i8 %y) {
; RV64-NEXT: vsetivli a2, 16, e8,m1,ta,mu ; RV64-NEXT: vsetivli a2, 16, e8,m1,ta,mu
; RV64-NEXT: vle8.v v25, (a0) ; RV64-NEXT: vle8.v v25, (a0)
; RV64-NEXT: vslidedown.vi v26, v25, 14 ; RV64-NEXT: vslidedown.vi v26, v25, 14
; RV64-NEXT: vsetvli a2, zero, e8,m1,ta,mu
; RV64-NEXT: vmv.s.x v26, a1 ; RV64-NEXT: vmv.s.x v26, a1
; RV64-NEXT: vsetivli a1, 16, e8,m1,tu,mu ; RV64-NEXT: vsetivli a1, 16, e8,m1,tu,mu
; RV64-NEXT: vslideup.vi v25, v26, 14 ; RV64-NEXT: vslideup.vi v25, v26, 14
@ -134,7 +131,6 @@ define void @insertelt_v32i16(<32 x i16>* %x, i16 %y, i32 %idx) {
; RV32-NEXT: vsetvli a4, a3, e16,m4,ta,mu ; RV32-NEXT: vsetvli a4, a3, e16,m4,ta,mu
; RV32-NEXT: vle16.v v28, (a0) ; RV32-NEXT: vle16.v v28, (a0)
; RV32-NEXT: vslidedown.vx v8, v28, a2 ; RV32-NEXT: vslidedown.vx v8, v28, a2
; RV32-NEXT: vsetvli a4, zero, e16,m4,ta,mu
; RV32-NEXT: vmv.s.x v8, a1 ; RV32-NEXT: vmv.s.x v8, a1
; RV32-NEXT: vsetvli a1, a3, e16,m4,tu,mu ; RV32-NEXT: vsetvli a1, a3, e16,m4,tu,mu
; RV32-NEXT: vslideup.vx v28, v8, a2 ; RV32-NEXT: vslideup.vx v28, v8, a2
@ -149,7 +145,6 @@ define void @insertelt_v32i16(<32 x i16>* %x, i16 %y, i32 %idx) {
; RV64-NEXT: vle16.v v28, (a0) ; RV64-NEXT: vle16.v v28, (a0)
; RV64-NEXT: sext.w a2, a2 ; RV64-NEXT: sext.w a2, a2
; RV64-NEXT: vslidedown.vx v8, v28, a2 ; RV64-NEXT: vslidedown.vx v8, v28, a2
; RV64-NEXT: vsetvli a4, zero, e16,m4,ta,mu
; RV64-NEXT: vmv.s.x v8, a1 ; RV64-NEXT: vmv.s.x v8, a1
; RV64-NEXT: vsetvli a1, a3, e16,m4,tu,mu ; RV64-NEXT: vsetvli a1, a3, e16,m4,tu,mu
; RV64-NEXT: vslideup.vx v28, v8, a2 ; RV64-NEXT: vslideup.vx v28, v8, a2
@ -168,7 +163,6 @@ define void @insertelt_v8f32(<8 x float>* %x, float %y, i32 %idx) {
; RV32-NEXT: vsetivli a2, 8, e32,m2,ta,mu ; RV32-NEXT: vsetivli a2, 8, e32,m2,ta,mu
; RV32-NEXT: vle32.v v26, (a0) ; RV32-NEXT: vle32.v v26, (a0)
; RV32-NEXT: vslidedown.vx v28, v26, a1 ; RV32-NEXT: vslidedown.vx v28, v26, a1
; RV32-NEXT: vsetvli a2, zero, e32,m2,ta,mu
; RV32-NEXT: vfmv.s.f v28, fa0 ; RV32-NEXT: vfmv.s.f v28, fa0
; RV32-NEXT: vsetivli a2, 8, e32,m2,tu,mu ; RV32-NEXT: vsetivli a2, 8, e32,m2,tu,mu
; RV32-NEXT: vslideup.vx v26, v28, a1 ; RV32-NEXT: vslideup.vx v26, v28, a1
@ -182,7 +176,6 @@ define void @insertelt_v8f32(<8 x float>* %x, float %y, i32 %idx) {
; RV64-NEXT: vle32.v v26, (a0) ; RV64-NEXT: vle32.v v26, (a0)
; RV64-NEXT: sext.w a1, a1 ; RV64-NEXT: sext.w a1, a1
; RV64-NEXT: vslidedown.vx v28, v26, a1 ; RV64-NEXT: vslidedown.vx v28, v26, a1
; RV64-NEXT: vsetvli a2, zero, e32,m2,ta,mu
; RV64-NEXT: vfmv.s.f v28, fa0 ; RV64-NEXT: vfmv.s.f v28, fa0
; RV64-NEXT: vsetivli a2, 8, e32,m2,tu,mu ; RV64-NEXT: vsetivli a2, 8, e32,m2,tu,mu
; RV64-NEXT: vslideup.vx v26, v28, a1 ; RV64-NEXT: vslideup.vx v26, v28, a1