From 0780087477cdc210bebc9ff5b3085205f47d884d Mon Sep 17 00:00:00 2001 From: zhoujing Date: Wed, 26 Jul 2023 17:31:13 +0800 Subject: [PATCH] [VENTUS][RISCV][fix] Fix the register usage calculation of VGPR/GPR --- .../include/llvm/CodeGen/TargetRegisterInfo.h | 5 ++++ llvm/lib/CodeGen/VirtRegMap.cpp | 2 +- llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp | 26 +++++------------ llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 3 +- llvm/lib/Target/RISCV/RISCVInstrInfo.h | 2 +- llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp | 29 ++++++++++++------- llvm/lib/Target/RISCV/RISCVRegisterInfo.h | 8 ++--- llvm/lib/Target/RISCV/RISCVSubtarget.h | 5 ++++ llvm/lib/Target/RISCV/VentusProgramInfo.h | 8 ++--- 9 files changed, 46 insertions(+), 42 deletions(-) diff --git a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h index 0ab88b360213..2fadd5b4572c 100644 --- a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h +++ b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h @@ -16,6 +16,7 @@ #define LLVM_CODEGEN_TARGETREGISTERINFO_H #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" @@ -548,6 +549,10 @@ public: return false; } + /// Analyze register usage information + virtual void analyzeRegisterUsage(DenseSet RewriteRegs, + MachineFunction *MF) const {} + /// Returns true if PhysReg is unallocatable and constant throughout the /// function. Used by MachineRegisterInfo::isConstantPhysReg(). virtual bool isConstantPhysReg(MCRegister PhysReg) const { return false; } diff --git a/llvm/lib/CodeGen/VirtRegMap.cpp b/llvm/lib/CodeGen/VirtRegMap.cpp index 069aca742da0..64af503421fc 100644 --- a/llvm/lib/CodeGen/VirtRegMap.cpp +++ b/llvm/lib/CodeGen/VirtRegMap.cpp @@ -639,7 +639,7 @@ void VirtRegRewriter::rewrite() { } } } - + TRI->analyzeRegisterUsage(RewriteRegs, MF); RewriteRegs.clear(); } diff --git a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp index a31cac9b9893..2090fb4f76fb 100644 --- a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp +++ b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp @@ -49,7 +49,6 @@ namespace { class RISCVAsmPrinter : public AsmPrinter { const MCSubtargetInfo *MCSTI; const RISCVSubtarget *STI; - VentusProgramInfo CurrentProgramInfo; public: explicit RISCVAsmPrinter(TargetMachine &TM, @@ -68,7 +67,7 @@ public: const char *ExtraCode, raw_ostream &OS) override; void EmitToStreamer(MCStreamer &S, const MCInst &Inst); - void getVentusProgramInfo(VentusProgramInfo &Out, const MachineFunction &MF); + bool emitPseudoExpansionLowering(MCStreamer &OutStreamer, const MachineInstr *MI); @@ -198,16 +197,16 @@ bool RISCVAsmPrinter::runOnMachineFunction(MachineFunction &MF) { NewSTI.setFeatureBits(MF.getSubtarget().getFeatureBits()); MCSTI = &NewSTI; STI = &MF.getSubtarget(); - CurrentProgramInfo = VentusProgramInfo(); + auto *CurrentProgramInfo = const_cast( + STI->getVentusProgramInfo()); if (MF.getInfo()->isEntryFunction()) { - getVentusProgramInfo(CurrentProgramInfo, MF); MCSectionELF *ResourceSection = OutContext.getELFSection( ".rodata.ventus.resource", ELF::SHT_PROGBITS, ELF::SHF_WRITE); OutStreamer->switchSection(ResourceSection); - OutStreamer->emitInt16(CurrentProgramInfo.VGPRUsage); - OutStreamer->emitInt16(CurrentProgramInfo.SGPRUsage); - OutStreamer->emitInt16(CurrentProgramInfo.LDSMemory); - OutStreamer->emitInt16(CurrentProgramInfo.PDSMemory); + OutStreamer->emitInt16(CurrentProgramInfo->VGPRUsage); + OutStreamer->emitInt16(CurrentProgramInfo->SGPRUsage); + OutStreamer->emitInt16(CurrentProgramInfo->LDSMemory); + OutStreamer->emitInt16(CurrentProgramInfo->PDSMemory); } SetupMachineFunction(MF); @@ -225,17 +224,6 @@ void RISCVAsmPrinter::emitStartOfAsmFile(Module &M) { emitAttributes(); } -void RISCVAsmPrinter::getVentusProgramInfo(VentusProgramInfo &Out, - const MachineFunction &MF) { - const RISCVSubtarget &ST = MF.getSubtarget(); - const RISCVRegisterInfo *RI = ST.getRegisterInfo(); - Out.VGPRUsage = - RI->getUsedRegistersNum(MF.getRegInfo(), &RISCV::VGPRRegClass, MF); - Out.SGPRUsage = - RI->getUsedRegistersNum(MF.getRegInfo(), &RISCV::GPRRegClass, MF); - // TODO:: Add LDS/PDS calculation -} - void RISCVAsmPrinter::emitEndOfAsmFile(Module &M) { RISCVTargetStreamer &RTS = static_cast(*OutStreamer->getTargetStreamer()); diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index d306de4f709b..d9eb6b89e29a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -98,7 +98,7 @@ unsigned RISCVInstrInfo::isLoadFromStackSlot(const MachineInstr &MI, return 0; } -bool RISCVInstrInfo::isPrivateMemoryAccess(const MachineInstr &MI) const { +bool RISCVInstrInfo::isVGPRMemoryAccess(const MachineInstr &MI) const { switch (MI.getOpcode()) { default: return false; @@ -110,6 +110,7 @@ bool RISCVInstrInfo::isPrivateMemoryAccess(const MachineInstr &MI) const { case RISCV::VSW: case RISCV::VSH: case RISCV::VSB: + case RISCV::VSWI12: return true; } } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index c97c9414de09..611e3ceaac23 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -55,7 +55,7 @@ public: MCInst getNop() const override; const MCInstrDesc &getBrCond(RISCVCC::CondCode CC) const; - bool isPrivateMemoryAccess(const MachineInstr &MI) const; + bool isVGPRMemoryAccess(const MachineInstr &MI) const; unsigned isLoadFromStackSlot(const MachineInstr &MI, int &FrameIndex) const override; diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp index 7491e2220412..8129a4273922 100644 --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp @@ -170,16 +170,19 @@ MCRegister RISCVRegisterInfo::findUnusedRegister(const MachineRegisterInfo &MRI, return MCRegister(); } -uint32_t RISCVRegisterInfo::getUsedRegistersNum(const MachineRegisterInfo &MRI, - const TargetRegisterClass *RC, - const MachineFunction &MF) const { - auto TotalRegNum = std::distance(RC->begin(), RC->end()); - unsigned UsedRegNum = 0; - for (MCRegister Reg : *RC) - if (MRI.isPhysRegUsed(Reg)) - UsedRegNum++; - assert(UsedRegNum <= TotalRegNum && "Register using overflow!"); - return UsedRegNum; +void RISCVRegisterInfo::analyzeRegisterUsage(DenseSet RewriteRegs, + MachineFunction *MF) const { + auto CurrentProgramInfo = const_cast( + MF->getSubtarget().getVentusProgramInfo()); + MachineRegisterInfo &MRI = MF->getRegInfo(); + for(auto Reg : RewriteRegs) { + if(!isSGPRReg(MRI, Reg)) + CurrentProgramInfo->VGPRUsage++; + else + CurrentProgramInfo->SGPRUsage++; + } + // FIXME: need to add one more because of ra, how to simplify this? + CurrentProgramInfo->SGPRUsage++; } bool RISCVRegisterInfo::isSGPRReg(const MachineRegisterInfo &MRI, @@ -337,6 +340,10 @@ bool RISCVRegisterInfo::eliminateFrameIndex(MachineBasicBlock::iterator II, Register FrameReg; StackOffset Offset = // FIXME: The FrameReg and Offset should be depended on divergency route. getFrameLowering(MF)->getFrameIndexReference(MF, FrameIndex, FrameReg); + // TODO: finish + // if(!RII->isVGPRMemoryAccess(MI)) + // Offset -= StackOffset::getFixed( + // MF.getInfo()->getVarArgsSaveSize() - 4); int64_t Lo11 = Offset.getFixed(); Offset += StackOffset::getFixed(MI.getOperand(FIOperandNum + 1).getImm()); @@ -373,7 +380,7 @@ bool RISCVRegisterInfo::eliminateFrameIndex(MachineBasicBlock::iterator II, } - if(RII->isPrivateMemoryAccess(MI)) { + if(RII->isVGPRMemoryAccess(MI)) { MI.getOperand(FIOperandNum).ChangeToRegister(getPrivateMemoryBaseRegister(MF), /*IsDef*/false, /*IsImp*/false, diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.h b/llvm/lib/Target/RISCV/RISCVRegisterInfo.h index f0dd8a79714b..c3810406510e 100644 --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.h +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.h @@ -122,11 +122,9 @@ struct RISCVRegisterInfo : public RISCVGenRegisterInfo { const TargetRegisterClass *RC, const MachineFunction &MF, bool ReserveHighestVGPR = false) const; - - uint32_t getUsedRegistersNum(const MachineRegisterInfo &MRI, - const TargetRegisterClass *RC, - const MachineFunction &MF) const; - + + void analyzeRegisterUsage(DenseSet RewriteRegs, + MachineFunction *MF) const override; unsigned getRegisterCostTableIndex(const MachineFunction &MF) const override; bool getRegAllocationHints(Register VirtReg, ArrayRef Order, diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h index 4359c23e84e8..67bc6b81f34f 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.h +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h @@ -17,6 +17,7 @@ #include "RISCVFrameLowering.h" #include "RISCVISelLowering.h" #include "RISCVInstrInfo.h" +#include "VentusProgramInfo.h" #include "llvm/CodeGen/GlobalISel/CallLowering.h" #include "llvm/CodeGen/GlobalISel/InstructionSelector.h" #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" @@ -123,6 +124,7 @@ private: RISCVRegisterInfo RegInfo; RISCVTargetLowering TLInfo; SelectionDAGTargetInfo TSInfo; + VentusProgramInfo CurrentProgramInfo = VentusProgramInfo(); /// Initializes using the passed in CPU and feature strings so that we can /// use initializer lists for subtarget initialization. @@ -145,6 +147,9 @@ public: return &FrameLowering; } const RISCVInstrInfo *getInstrInfo() const override { return &InstrInfo; } + const VentusProgramInfo *getVentusProgramInfo() const { + return &CurrentProgramInfo; + } const RISCVRegisterInfo *getRegisterInfo() const override { return &RegInfo; } const RISCVTargetLowering *getTargetLowering() const override { return &TLInfo; diff --git a/llvm/lib/Target/RISCV/VentusProgramInfo.h b/llvm/lib/Target/RISCV/VentusProgramInfo.h index d188fa7585a4..b5d1a10f4702 100644 --- a/llvm/lib/Target/RISCV/VentusProgramInfo.h +++ b/llvm/lib/Target/RISCV/VentusProgramInfo.h @@ -19,10 +19,10 @@ namespace llvm { struct VentusProgramInfo { - uint32_t VGPRUsage = 256; //The number of VGPRS which has been used - uint32_t SGPRUsage = 64; //The number of SGPRS which has been used - uint32_t LDSMemory = 1 << 12; //The number of VGPRS which has been used - uint32_t PDSMemory = 1 << 10; //The number of VGPRS which has been used + uint32_t VGPRUsage = 0; // The number of VGPRS which has been used + uint32_t SGPRUsage = 0; // The number of SGPRS which has been used + uint32_t LDSMemory = 1 << 12; // Used local memory size + uint32_t PDSMemory = 1 << 10; // Used private memory size VentusProgramInfo() = default;