diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 44a88ad03a21..e5d9aeb94d61 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -29,6 +29,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/DiagnosticInfo.h" @@ -36,6 +37,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicsRISCV.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -4554,7 +4556,9 @@ SDValue RISCVTargetLowering::lowerVASTART(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); SDValue FI = DAG.getFrameIndex(FuncInfo->getVarArgsFrameIndex(), getPointerTy(MF.getDataLayout())); - + auto *FrameIndex = cast(Op.getOperand(1)); + assert(FrameIndex && "Not frame index node"); + setVastartStoreFrameIndex(FrameIndex->getIndex()); // vastart just stores the address of the VarArgsFrameIndex slot into the // memory location argument. const Value *SV= cast(Op.getOperand(2))->getValue(); @@ -13390,13 +13394,16 @@ bool RISCVTargetLowering::isSDNodeSourceOfDivergence( } case ISD::LOAD: { const LoadSDNode *L = cast(N); - return L->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS || - L->getAddressSpace() == RISCVAS::LOCAL_ADDRESS; + // If load from varstart store frame index, load action is divergent + if( auto *Base = dyn_cast(L->getBasePtr())) + if(auto *BaseBase = dyn_cast(Base->getOperand(1))) + if(BaseBase->getIndex() == getVastartStoreFrameIndex()) + return true; + return L->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS; } case ISD::STORE: { const StoreSDNode *Store= cast(N); return Store->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS || - Store->getAddressSpace() == RISCVAS::LOCAL_ADDRESS || Store->getPointerInfo().StackID == RISCVStackID::VGPRSpill; } case ISD::CALLSEQ_END: diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 0d61c29b492b..6dea791dc854 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -334,10 +334,11 @@ enum NodeType : unsigned { class RISCVTargetLowering : public TargetLowering { const RISCVSubtarget &Subtarget; - + int *VastartStoreFrameIndex = new int; public: explicit RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI); + ~RISCVTargetLowering() { delete VastartStoreFrameIndex; } const RISCVSubtarget &getSubtarget() const { return Subtarget; } @@ -479,6 +480,12 @@ public: return ISD::SIGN_EXTEND; } + int getVastartStoreFrameIndex() const { return *VastartStoreFrameIndex; } + + void setVastartStoreFrameIndex(int Index) const { + *VastartStoreFrameIndex = Index; + } + bool shouldExpandShift(SelectionDAG &DAG, SDNode *N) const override { if (DAG.getMachineFunction().getFunction().hasMinSize()) return false; diff --git a/llvm/lib/Target/RISCV/VentusInstrInfo.td b/llvm/lib/Target/RISCV/VentusInstrInfo.td index c2bd47426fe0..32fbac9151dc 100644 --- a/llvm/lib/Target/RISCV/VentusInstrInfo.td +++ b/llvm/lib/Target/RISCV/VentusInstrInfo.td @@ -40,10 +40,17 @@ class DivergentPrivateLoadFrag : PatFrag< (ops node:$src0), (Op $src0), [{ + const LoadSDNode *L = cast(N); + bool IsDivergent = false; + if( auto *Base = dyn_cast(L->getBasePtr())) + if(auto *BaseBase = dyn_cast(Base->getOperand(1))) + if(BaseBase->getIndex() == CurDAG->getMachineFunction(). +getSubtarget().getTargetLowering()->getVastartStoreFrameIndex()) + IsDivergent = true; return N->isDivergent() && (cast(N)->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS || cast(N)->getPointerInfo().StackID == RISCVStackID::VGPRSpill - ); + || IsDivergent); }]>; class DivergentNonPrivateLoadFrag : PatFrag< diff --git a/llvm/lib/Target/RISCV/VentusInstrInfoV.td b/llvm/lib/Target/RISCV/VentusInstrInfoV.td index db7dfa7ac2a5..67268809b261 100644 --- a/llvm/lib/Target/RISCV/VentusInstrInfoV.td +++ b/llvm/lib/Target/RISCV/VentusInstrInfoV.td @@ -1230,6 +1230,19 @@ def VFTTA : RVInstIVI<0b000011, (outs VGPR:$vd_wb), // Ventus vALU divergent execution patterns //===----------------------------------------------------------------------===// +// ATTENTION: please don't change the pattern order +// Private memory per-thread load/store +def : DivergentPriLdPat; +def : DivergentPriLdPat; +def : DivergentPriLdPat; +def : DivergentPriLdPat; +def : DivergentPriLdPat; +def : DivergentPriLdPat; +def : DivergentPriLdPat; +def : DivergentPriStPat; +def : DivergentPriStPat; +def : DivergentPriStPat; + // Non-private memory load/store // TODO: add store/load test file for testing pattern match def : DivergentNonPriLdImmPat; @@ -1253,18 +1266,6 @@ def : DivergentNonPriStPat; def : DivergentNonPriStPat; def : DivergentNonPriStPat; -// Private memory per-thread load/store -def : DivergentPriLdPat; -def : DivergentPriLdPat; -def : DivergentPriLdPat; -def : DivergentPriLdPat; -def : DivergentPriLdPat; -def : DivergentPriLdPat; -def : DivergentPriLdPat; -def : DivergentPriStPat; -def : DivergentPriStPat; -def : DivergentPriStPat; - // FIXME: check this review: https://reviews.llvm.org/D131729#inline-1269307 // def : PatIntSetCC<[VGPR, VGPR], SETLE, VMSLE_VV>; // def : PatIntSetCC<[VGPR, GPR], SETLE, VMSLE_VX>;