[VENTUS][fix] Legalize vlw12.v instruction for variadic functions

This commit is contained in:
zhoujingya 2023-09-07 16:39:11 +08:00
parent 649f4da46a
commit c25d00552c
8 changed files with 216 additions and 28 deletions

View File

@ -22,6 +22,7 @@
#include "clang/Basic/Builtins.h"
#include "clang/Basic/CodeGenOptions.h"
#include "clang/Basic/DiagnosticFrontend.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/CodeGen/CGFunctionInfo.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringExtras.h"
@ -351,6 +352,10 @@ static Address emitVoidPtrDirectVAArg(CodeGenFunction &CGF,
// Advance the pointer past the argument, then store that back.
CharUnits FullDirectSize = DirectSize.alignTo(SlotSize);
const TargetInfo &Info = CGF.getTarget();
// In ventus, the stack grow upwards
if(Info.getTargetOpts().CPU == "ventus-gpgpu" && Info.getTriple().isRISCV32())
FullDirectSize = -FullDirectSize;
Address NextPtr =
CGF.Builder.CreateConstInBoundsByteGEP(Addr, FullDirectSize, "argp.next");
CGF.Builder.CreateStore(NextPtr.getPointer(), VAListAddr);

View File

@ -42,6 +42,7 @@ add_llvm_target(RISCVCodeGen
VentusRegextInsertion.cpp
VentusVVInstrConversion.cpp
VentusInsertJoinToVBranch.cpp
VentusLegalizeLoad.cpp
GISel/RISCVCallLowering.cpp
GISel/RISCVInstructionSelector.cpp
GISel/RISCVLegalizerInfo.cpp

View File

@ -72,6 +72,9 @@ void initializeVentusRegextInsertionPass(PassRegistry &);
FunctionPass *createVentusVVInstrConversionPass();
void initializeVentusVVInstrConversionPass(PassRegistry &);
FunctionPass *createVentusLegalizeLoadPass();
void initializeVentusLegalizeLoadPass(PassRegistry &);
FunctionPass *createVentusInsertJoinToVBranchPass();
void initializeVentusInsertJoinToVBranchPass(PassRegistry &);

View File

@ -13394,16 +13394,13 @@ bool RISCVTargetLowering::isSDNodeSourceOfDivergence(
}
case ISD::LOAD: {
const LoadSDNode *L = cast<LoadSDNode>(N);
// If load from varstart store frame index, load action is divergent
if( auto *Base = dyn_cast<LoadSDNode>(L->getBasePtr()))
if(auto *BaseBase = dyn_cast<FrameIndexSDNode>(Base->getOperand(1)))
if(BaseBase->getIndex() == getVastartStoreFrameIndex())
return true;
return L->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS;
return L->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS ||
L->getAddressSpace() == RISCVAS::LOCAL_ADDRESS;
}
case ISD::STORE: {
const StoreSDNode *Store= cast<StoreSDNode>(N);
return Store->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS ||
Store->getAddressSpace() == RISCVAS::LOCAL_ADDRESS ||
Store->getPointerInfo().StackID == RISCVStackID::VGPRSpill;
}
case ISD::CALLSEQ_END:

View File

@ -292,7 +292,7 @@ void RISCVPassConfig::addPreRegAlloc() {
if (TM->getOptLevel() != CodeGenOpt::None)
addPass(createRISCVMergeBaseOffsetOptPass());
addPass(createVentusVVInstrConversionPass());
addPass(createVentusLegalizeLoadPass());
}
void RISCVPassConfig::addPostRegAlloc() {

View File

@ -40,17 +40,10 @@ class DivergentPrivateLoadFrag<SDPatternOperator Op> : PatFrag<
(ops node:$src0),
(Op $src0),
[{
const LoadSDNode *L = cast<LoadSDNode>(N);
bool IsDivergent = false;
if( auto *Base = dyn_cast<LoadSDNode>(L->getBasePtr()))
if(auto *BaseBase = dyn_cast<FrameIndexSDNode>(Base->getOperand(1)))
if(BaseBase->getIndex() == CurDAG->getMachineFunction().
getSubtarget<RISCVSubtarget>().getTargetLowering()->getVastartStoreFrameIndex())
IsDivergent = true;
return N->isDivergent() &&
(cast<LoadSDNode>(N)->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS ||
cast<LoadSDNode>(N)->getPointerInfo().StackID == RISCVStackID::VGPRSpill
|| IsDivergent);
);
}]>;
class DivergentNonPrivateLoadFrag<SDPatternOperator Op> : PatFrag<

View File

@ -1230,19 +1230,6 @@ def VFTTA : RVInstIVI<0b000011, (outs VGPR:$vd_wb),
// Ventus vALU divergent execution patterns
//===----------------------------------------------------------------------===//
// ATTENTION: please don't change the pattern order
// Private memory per-thread load/store
def : DivergentPriLdPat<load, VLW>;
def : DivergentPriLdPat<zextloadi16, VLHU>;
def : DivergentPriLdPat<sextloadi16, VLH>;
def : DivergentPriLdPat<extloadi16, VLH>;
def : DivergentPriLdPat<zextloadi8, VLBU>;
def : DivergentPriLdPat<extloadi8, VLB>;
def : DivergentPriLdPat<sextloadi8, VLB>;
def : DivergentPriStPat<store, VSW>;
def : DivergentPriStPat<truncstorei16, VSH>;
def : DivergentPriStPat<truncstorei8, VSB>;
// Non-private memory load/store
// TODO: add store/load test file for testing pattern match
def : DivergentNonPriLdImmPat<load, VLWI12>;
@ -1266,6 +1253,18 @@ def : DivergentNonPriStPat<truncstorei8, VSUXEI8>;
def : DivergentNonPriStPat<truncstorei16, VSUXEI16>;
def : DivergentNonPriStPat<store, VSUXEI32>;
// Private memory per-thread load/store
def : DivergentPriLdPat<load, VLW>;
def : DivergentPriLdPat<zextloadi16, VLHU>;
def : DivergentPriLdPat<sextloadi16, VLH>;
def : DivergentPriLdPat<extloadi16, VLH>;
def : DivergentPriLdPat<zextloadi8, VLBU>;
def : DivergentPriLdPat<extloadi8, VLB>;
def : DivergentPriLdPat<sextloadi8, VLB>;
def : DivergentPriStPat<store, VSW>;
def : DivergentPriStPat<truncstorei16, VSH>;
def : DivergentPriStPat<truncstorei8, VSB>;
// FIXME: check this review: https://reviews.llvm.org/D131729#inline-1269307
// def : PatIntSetCC<[VGPR, VGPR], SETLE, VMSLE_VV>;
// def : PatIntSetCC<[VGPR, GPR], SETLE, VMSLE_VX>;

View File

@ -0,0 +1,190 @@
//===-- VentusLegalizeLoad.cpp - vlw12 instruction legalization ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Legalize vlw12.v instruction, in vararg function codegen, the vararg start
// index will be stored in stack, in ventus, this index will be store in private
// stack, some code example will be as below
// addi sp, sp, 4 | addi sp, sp, -64
// addi tp, tp, 80 | addi s0, sp, 32
// vsw.v v7, -80(v8) | sw a7, 28(s0)
// vsw.v v6, -76(v8) | sw a6, 24(s0)
// vsw.v v5, -72(v8) | sw a5, 20(s0)
// vsw.v v4, -68(v8) | sw a4, 16(s0)
// vsw.v v3, -64(v8) | sw a3, 12(s0)
// vsw.v v2, -60(v8) | sw a2, 8(s0)
// vsw.v v1, -56(v8) | sw a1, 4(s0)
// vsw.v v0, -48(v8) | sw a0, -16(s0)
// addi t0, tp, -56 | addi a0, s0, 4
// vmv.v.x v0, t0 | sw a0, -20(s0)
// vsw.v v0, -44(v8)
//
// left hand code is generated by ventus, right hand code is generated by
// standard riscv, when load from -44(v8), vlw.v v0, -44(v8), if need to
// load from v0, in right now ventus, the codegen will be like this vlw12.v v0,
// 0(v0), but it is actually illegal, it should be vlw.v v0, 0(v0)
//
//===----------------------------------------------------------------------===//
#include "MCTargetDesc/RISCVMCTargetDesc.h"
#include "RISCV.h"
#include "RISCVFrameLowering.h"
#include "RISCVISelLowering.h"
#include "RISCVInstrInfo.h"
#include "RISCVSubtarget.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#define VENTUS_LOAD_LEGALIZATION "Ventus load instruction legalization pass"
#define DEBUG_TYPE "vlw12-legalize"
using namespace llvm;
namespace {
class VentusLegalizeLoad : public MachineFunctionPass {
public:
const RISCVInstrInfo *TII;
static char ID;
const RISCVRegisterInfo *MRI;
const RISCVTargetLowering *RTI;
const MachineRegisterInfo *MR;
VentusLegalizeLoad() : MachineFunctionPass(ID) {
initializeVentusLegalizeLoadPass(*PassRegistry::getPassRegistry());
}
bool runOnMachineFunction(MachineFunction &MF) override;
StringRef getPassName() const override { return VENTUS_LOAD_LEGALIZATION; }
private:
bool runOnMachineBasicBlock(MachineBasicBlock &MBB);
bool checkVLWDependency(MachineBasicBlock &MBB, MachineInstr &MI);
bool checkInstructionUse(MachineBasicBlock &MBB, MachineInstr &MI);
// Check whether machine instruction is local memory load instruction or not
bool isLocalMemLoadInstr(MachineInstr &MI) {
switch (MI.getOpcode()) {
case RISCV::VLWI12:
case RISCV::VLBI12:
case RISCV::VLBUI12:
case RISCV::VLHI12:
case RISCV::VLHUI12:
return true;
default:
return false;
}
}
unsigned getRelativePrivateMemLoadInstr(MachineInstr &MI) {
assert(isLocalMemLoadInstr(MI) && "Illegal instruction.");
switch (MI.getOpcode()) {
case RISCV::VLWI12:
return RISCV::VLW;
case RISCV::VLBI12:
return RISCV::VLB;
case RISCV::VLBUI12:
return RISCV::VLBU;
case RISCV::VLHI12:
return RISCV::VLH;
case RISCV::VLHUI12:
return RISCV::VLHU;
default:
llvm_unreachable("Unexpected instruction.");
}
}
};
bool VentusLegalizeLoad::runOnMachineFunction(MachineFunction &MF) {
bool IsChanged = false;
TII = static_cast<const RISCVInstrInfo *>(MF.getSubtarget().getInstrInfo());
MRI = MF.getSubtarget<RISCVSubtarget>().getRegisterInfo();
RTI = MF.getSubtarget<RISCVSubtarget>().getTargetLowering();
MR = &MF.getRegInfo();
for (auto &MBB : MF)
IsChanged |= runOnMachineBasicBlock(MBB);
return IsChanged;
}
bool VentusLegalizeLoad::runOnMachineBasicBlock(MachineBasicBlock &MBB) {
bool IsMBBChanged = false;
// Right now we only legalize vlw12.v instruction, and we need to analyze its
// operand def-use chain
for (auto &MI : MBB) {
if (MI.getOpcode() == RISCV::VLW)
IsMBBChanged |= checkVLWDependency(MBB, MI);
}
return IsMBBChanged;
}
bool VentusLegalizeLoad::checkVLWDependency(MachineBasicBlock &MBB,
MachineInstr &MI) {
bool IsChanged = false;
int StoreFrameIndex = RTI->getVastartStoreFrameIndex();
MachineOperand LoadBaseOperand = MI.getOperand(1);
if (LoadBaseOperand.isFI() && LoadBaseOperand.getIndex() == StoreFrameIndex)
// Check base operand has data dependency with vastart frame or not
{
Register VastartFrameRegister = MI.getOperand(0).getReg();
auto UseList = MR->use_instructions(VastartFrameRegister);
for (auto &Use : UseList) {
// Other operation on vastart frame register
IsChanged |= checkInstructionUse(MBB, Use);
}
}
return IsChanged;
}
bool VentusLegalizeLoad::checkInstructionUse(MachineBasicBlock &MBB, MachineInstr &MI) {
bool IsChanged = false;
if (isLocalMemLoadInstr(MI)) {
// Here need to change vlw12 to vlw
LLVM_DEBUG(dbgs() << "Instruction to be changed: \n");
LLVM_DEBUG(MI.dump());
MI.setDesc(TII->get(getRelativePrivateMemLoadInstr(MI)));
LLVM_DEBUG(dbgs() << "Instruction after changed: \n");
LLVM_DEBUG(MI.dump());
auto UseList = MR->use_instructions(MI.getOperand(0).getReg());
for (auto &Use : UseList) {
checkInstructionUse(MBB, Use);
}
return IsChanged = true;
}
if (MI.getOpcode() == RISCV::VSW)
return IsChanged;
for (auto &Use : MR->use_instructions(MI.getOperand(0).getReg())) {
IsChanged |= checkInstructionUse(MBB, Use);
}
return IsChanged;
}
char VentusLegalizeLoad::ID = 0;
} // end of anonymous namespace
INITIALIZE_PASS(VentusLegalizeLoad, "Ventus load instruction legalization",
VENTUS_LOAD_LEGALIZATION, false, false)
namespace llvm {
FunctionPass *createVentusLegalizeLoadPass() {
return new VentusLegalizeLoad();
}
} // end of namespace llvm