[VENTUS][fix] Legalize vlw12.v instruction for variadic functions
This commit is contained in:
parent
649f4da46a
commit
c25d00552c
|
@ -22,6 +22,7 @@
|
|||
#include "clang/Basic/Builtins.h"
|
||||
#include "clang/Basic/CodeGenOptions.h"
|
||||
#include "clang/Basic/DiagnosticFrontend.h"
|
||||
#include "clang/Basic/TargetInfo.h"
|
||||
#include "clang/CodeGen/CGFunctionInfo.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -351,6 +352,10 @@ static Address emitVoidPtrDirectVAArg(CodeGenFunction &CGF,
|
|||
|
||||
// Advance the pointer past the argument, then store that back.
|
||||
CharUnits FullDirectSize = DirectSize.alignTo(SlotSize);
|
||||
const TargetInfo &Info = CGF.getTarget();
|
||||
// In ventus, the stack grow upwards
|
||||
if(Info.getTargetOpts().CPU == "ventus-gpgpu" && Info.getTriple().isRISCV32())
|
||||
FullDirectSize = -FullDirectSize;
|
||||
Address NextPtr =
|
||||
CGF.Builder.CreateConstInBoundsByteGEP(Addr, FullDirectSize, "argp.next");
|
||||
CGF.Builder.CreateStore(NextPtr.getPointer(), VAListAddr);
|
||||
|
|
|
@ -42,6 +42,7 @@ add_llvm_target(RISCVCodeGen
|
|||
VentusRegextInsertion.cpp
|
||||
VentusVVInstrConversion.cpp
|
||||
VentusInsertJoinToVBranch.cpp
|
||||
VentusLegalizeLoad.cpp
|
||||
GISel/RISCVCallLowering.cpp
|
||||
GISel/RISCVInstructionSelector.cpp
|
||||
GISel/RISCVLegalizerInfo.cpp
|
||||
|
|
|
@ -72,6 +72,9 @@ void initializeVentusRegextInsertionPass(PassRegistry &);
|
|||
FunctionPass *createVentusVVInstrConversionPass();
|
||||
void initializeVentusVVInstrConversionPass(PassRegistry &);
|
||||
|
||||
FunctionPass *createVentusLegalizeLoadPass();
|
||||
void initializeVentusLegalizeLoadPass(PassRegistry &);
|
||||
|
||||
FunctionPass *createVentusInsertJoinToVBranchPass();
|
||||
void initializeVentusInsertJoinToVBranchPass(PassRegistry &);
|
||||
|
||||
|
|
|
@ -13394,16 +13394,13 @@ bool RISCVTargetLowering::isSDNodeSourceOfDivergence(
|
|||
}
|
||||
case ISD::LOAD: {
|
||||
const LoadSDNode *L = cast<LoadSDNode>(N);
|
||||
// If load from varstart store frame index, load action is divergent
|
||||
if( auto *Base = dyn_cast<LoadSDNode>(L->getBasePtr()))
|
||||
if(auto *BaseBase = dyn_cast<FrameIndexSDNode>(Base->getOperand(1)))
|
||||
if(BaseBase->getIndex() == getVastartStoreFrameIndex())
|
||||
return true;
|
||||
return L->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS;
|
||||
return L->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS ||
|
||||
L->getAddressSpace() == RISCVAS::LOCAL_ADDRESS;
|
||||
}
|
||||
case ISD::STORE: {
|
||||
const StoreSDNode *Store= cast<StoreSDNode>(N);
|
||||
return Store->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS ||
|
||||
Store->getAddressSpace() == RISCVAS::LOCAL_ADDRESS ||
|
||||
Store->getPointerInfo().StackID == RISCVStackID::VGPRSpill;
|
||||
}
|
||||
case ISD::CALLSEQ_END:
|
||||
|
|
|
@ -292,7 +292,7 @@ void RISCVPassConfig::addPreRegAlloc() {
|
|||
if (TM->getOptLevel() != CodeGenOpt::None)
|
||||
addPass(createRISCVMergeBaseOffsetOptPass());
|
||||
addPass(createVentusVVInstrConversionPass());
|
||||
|
||||
addPass(createVentusLegalizeLoadPass());
|
||||
}
|
||||
|
||||
void RISCVPassConfig::addPostRegAlloc() {
|
||||
|
|
|
@ -40,17 +40,10 @@ class DivergentPrivateLoadFrag<SDPatternOperator Op> : PatFrag<
|
|||
(ops node:$src0),
|
||||
(Op $src0),
|
||||
[{
|
||||
const LoadSDNode *L = cast<LoadSDNode>(N);
|
||||
bool IsDivergent = false;
|
||||
if( auto *Base = dyn_cast<LoadSDNode>(L->getBasePtr()))
|
||||
if(auto *BaseBase = dyn_cast<FrameIndexSDNode>(Base->getOperand(1)))
|
||||
if(BaseBase->getIndex() == CurDAG->getMachineFunction().
|
||||
getSubtarget<RISCVSubtarget>().getTargetLowering()->getVastartStoreFrameIndex())
|
||||
IsDivergent = true;
|
||||
return N->isDivergent() &&
|
||||
(cast<LoadSDNode>(N)->getAddressSpace() == RISCVAS::PRIVATE_ADDRESS ||
|
||||
cast<LoadSDNode>(N)->getPointerInfo().StackID == RISCVStackID::VGPRSpill
|
||||
|| IsDivergent);
|
||||
);
|
||||
}]>;
|
||||
|
||||
class DivergentNonPrivateLoadFrag<SDPatternOperator Op> : PatFrag<
|
||||
|
|
|
@ -1230,19 +1230,6 @@ def VFTTA : RVInstIVI<0b000011, (outs VGPR:$vd_wb),
|
|||
// Ventus vALU divergent execution patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ATTENTION: please don't change the pattern order
|
||||
// Private memory per-thread load/store
|
||||
def : DivergentPriLdPat<load, VLW>;
|
||||
def : DivergentPriLdPat<zextloadi16, VLHU>;
|
||||
def : DivergentPriLdPat<sextloadi16, VLH>;
|
||||
def : DivergentPriLdPat<extloadi16, VLH>;
|
||||
def : DivergentPriLdPat<zextloadi8, VLBU>;
|
||||
def : DivergentPriLdPat<extloadi8, VLB>;
|
||||
def : DivergentPriLdPat<sextloadi8, VLB>;
|
||||
def : DivergentPriStPat<store, VSW>;
|
||||
def : DivergentPriStPat<truncstorei16, VSH>;
|
||||
def : DivergentPriStPat<truncstorei8, VSB>;
|
||||
|
||||
// Non-private memory load/store
|
||||
// TODO: add store/load test file for testing pattern match
|
||||
def : DivergentNonPriLdImmPat<load, VLWI12>;
|
||||
|
@ -1266,6 +1253,18 @@ def : DivergentNonPriStPat<truncstorei8, VSUXEI8>;
|
|||
def : DivergentNonPriStPat<truncstorei16, VSUXEI16>;
|
||||
def : DivergentNonPriStPat<store, VSUXEI32>;
|
||||
|
||||
// Private memory per-thread load/store
|
||||
def : DivergentPriLdPat<load, VLW>;
|
||||
def : DivergentPriLdPat<zextloadi16, VLHU>;
|
||||
def : DivergentPriLdPat<sextloadi16, VLH>;
|
||||
def : DivergentPriLdPat<extloadi16, VLH>;
|
||||
def : DivergentPriLdPat<zextloadi8, VLBU>;
|
||||
def : DivergentPriLdPat<extloadi8, VLB>;
|
||||
def : DivergentPriLdPat<sextloadi8, VLB>;
|
||||
def : DivergentPriStPat<store, VSW>;
|
||||
def : DivergentPriStPat<truncstorei16, VSH>;
|
||||
def : DivergentPriStPat<truncstorei8, VSB>;
|
||||
|
||||
// FIXME: check this review: https://reviews.llvm.org/D131729#inline-1269307
|
||||
// def : PatIntSetCC<[VGPR, VGPR], SETLE, VMSLE_VV>;
|
||||
// def : PatIntSetCC<[VGPR, GPR], SETLE, VMSLE_VX>;
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
//===-- VentusLegalizeLoad.cpp - vlw12 instruction legalization ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Legalize vlw12.v instruction, in vararg function codegen, the vararg start
|
||||
// index will be stored in stack, in ventus, this index will be store in private
|
||||
// stack, some code example will be as below
|
||||
// addi sp, sp, 4 | addi sp, sp, -64
|
||||
// addi tp, tp, 80 | addi s0, sp, 32
|
||||
// vsw.v v7, -80(v8) | sw a7, 28(s0)
|
||||
// vsw.v v6, -76(v8) | sw a6, 24(s0)
|
||||
// vsw.v v5, -72(v8) | sw a5, 20(s0)
|
||||
// vsw.v v4, -68(v8) | sw a4, 16(s0)
|
||||
// vsw.v v3, -64(v8) | sw a3, 12(s0)
|
||||
// vsw.v v2, -60(v8) | sw a2, 8(s0)
|
||||
// vsw.v v1, -56(v8) | sw a1, 4(s0)
|
||||
// vsw.v v0, -48(v8) | sw a0, -16(s0)
|
||||
// addi t0, tp, -56 | addi a0, s0, 4
|
||||
// vmv.v.x v0, t0 | sw a0, -20(s0)
|
||||
// vsw.v v0, -44(v8)
|
||||
//
|
||||
// left hand code is generated by ventus, right hand code is generated by
|
||||
// standard riscv, when load from -44(v8), vlw.v v0, -44(v8), if need to
|
||||
// load from v0, in right now ventus, the codegen will be like this vlw12.v v0,
|
||||
// 0(v0), but it is actually illegal, it should be vlw.v v0, 0(v0)
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "MCTargetDesc/RISCVMCTargetDesc.h"
|
||||
#include "RISCV.h"
|
||||
#include "RISCVFrameLowering.h"
|
||||
#include "RISCVISelLowering.h"
|
||||
#include "RISCVInstrInfo.h"
|
||||
#include "RISCVSubtarget.h"
|
||||
#include "llvm/CodeGen/MachineFunctionPass.h"
|
||||
#include "llvm/CodeGen/MachineInstr.h"
|
||||
#include "llvm/CodeGen/MachineInstrBuilder.h"
|
||||
#include "llvm/CodeGen/MachineOperand.h"
|
||||
#include "llvm/CodeGen/Register.h"
|
||||
#include "llvm/IR/DebugLoc.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cassert>
|
||||
|
||||
#define VENTUS_LOAD_LEGALIZATION "Ventus load instruction legalization pass"
|
||||
#define DEBUG_TYPE "vlw12-legalize"
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace {
|
||||
|
||||
class VentusLegalizeLoad : public MachineFunctionPass {
|
||||
public:
|
||||
const RISCVInstrInfo *TII;
|
||||
static char ID;
|
||||
const RISCVRegisterInfo *MRI;
|
||||
const RISCVTargetLowering *RTI;
|
||||
const MachineRegisterInfo *MR;
|
||||
|
||||
VentusLegalizeLoad() : MachineFunctionPass(ID) {
|
||||
initializeVentusLegalizeLoadPass(*PassRegistry::getPassRegistry());
|
||||
}
|
||||
|
||||
bool runOnMachineFunction(MachineFunction &MF) override;
|
||||
|
||||
StringRef getPassName() const override { return VENTUS_LOAD_LEGALIZATION; }
|
||||
|
||||
private:
|
||||
bool runOnMachineBasicBlock(MachineBasicBlock &MBB);
|
||||
|
||||
bool checkVLWDependency(MachineBasicBlock &MBB, MachineInstr &MI);
|
||||
|
||||
bool checkInstructionUse(MachineBasicBlock &MBB, MachineInstr &MI);
|
||||
|
||||
// Check whether machine instruction is local memory load instruction or not
|
||||
bool isLocalMemLoadInstr(MachineInstr &MI) {
|
||||
switch (MI.getOpcode()) {
|
||||
case RISCV::VLWI12:
|
||||
case RISCV::VLBI12:
|
||||
case RISCV::VLBUI12:
|
||||
case RISCV::VLHI12:
|
||||
case RISCV::VLHUI12:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getRelativePrivateMemLoadInstr(MachineInstr &MI) {
|
||||
assert(isLocalMemLoadInstr(MI) && "Illegal instruction.");
|
||||
switch (MI.getOpcode()) {
|
||||
case RISCV::VLWI12:
|
||||
return RISCV::VLW;
|
||||
case RISCV::VLBI12:
|
||||
return RISCV::VLB;
|
||||
case RISCV::VLBUI12:
|
||||
return RISCV::VLBU;
|
||||
case RISCV::VLHI12:
|
||||
return RISCV::VLH;
|
||||
case RISCV::VLHUI12:
|
||||
return RISCV::VLHU;
|
||||
default:
|
||||
llvm_unreachable("Unexpected instruction.");
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
bool VentusLegalizeLoad::runOnMachineFunction(MachineFunction &MF) {
|
||||
bool IsChanged = false;
|
||||
TII = static_cast<const RISCVInstrInfo *>(MF.getSubtarget().getInstrInfo());
|
||||
MRI = MF.getSubtarget<RISCVSubtarget>().getRegisterInfo();
|
||||
RTI = MF.getSubtarget<RISCVSubtarget>().getTargetLowering();
|
||||
MR = &MF.getRegInfo();
|
||||
for (auto &MBB : MF)
|
||||
IsChanged |= runOnMachineBasicBlock(MBB);
|
||||
return IsChanged;
|
||||
}
|
||||
|
||||
bool VentusLegalizeLoad::runOnMachineBasicBlock(MachineBasicBlock &MBB) {
|
||||
bool IsMBBChanged = false;
|
||||
// Right now we only legalize vlw12.v instruction, and we need to analyze its
|
||||
// operand def-use chain
|
||||
for (auto &MI : MBB) {
|
||||
if (MI.getOpcode() == RISCV::VLW)
|
||||
IsMBBChanged |= checkVLWDependency(MBB, MI);
|
||||
}
|
||||
return IsMBBChanged;
|
||||
}
|
||||
|
||||
bool VentusLegalizeLoad::checkVLWDependency(MachineBasicBlock &MBB,
|
||||
MachineInstr &MI) {
|
||||
bool IsChanged = false;
|
||||
|
||||
int StoreFrameIndex = RTI->getVastartStoreFrameIndex();
|
||||
MachineOperand LoadBaseOperand = MI.getOperand(1);
|
||||
if (LoadBaseOperand.isFI() && LoadBaseOperand.getIndex() == StoreFrameIndex)
|
||||
// Check base operand has data dependency with vastart frame or not
|
||||
{
|
||||
Register VastartFrameRegister = MI.getOperand(0).getReg();
|
||||
auto UseList = MR->use_instructions(VastartFrameRegister);
|
||||
for (auto &Use : UseList) {
|
||||
// Other operation on vastart frame register
|
||||
IsChanged |= checkInstructionUse(MBB, Use);
|
||||
}
|
||||
}
|
||||
|
||||
return IsChanged;
|
||||
}
|
||||
|
||||
bool VentusLegalizeLoad::checkInstructionUse(MachineBasicBlock &MBB, MachineInstr &MI) {
|
||||
bool IsChanged = false;
|
||||
if (isLocalMemLoadInstr(MI)) {
|
||||
// Here need to change vlw12 to vlw
|
||||
LLVM_DEBUG(dbgs() << "Instruction to be changed: \n");
|
||||
LLVM_DEBUG(MI.dump());
|
||||
MI.setDesc(TII->get(getRelativePrivateMemLoadInstr(MI)));
|
||||
LLVM_DEBUG(dbgs() << "Instruction after changed: \n");
|
||||
LLVM_DEBUG(MI.dump());
|
||||
auto UseList = MR->use_instructions(MI.getOperand(0).getReg());
|
||||
for (auto &Use : UseList) {
|
||||
checkInstructionUse(MBB, Use);
|
||||
}
|
||||
return IsChanged = true;
|
||||
}
|
||||
if (MI.getOpcode() == RISCV::VSW)
|
||||
return IsChanged;
|
||||
for (auto &Use : MR->use_instructions(MI.getOperand(0).getReg())) {
|
||||
IsChanged |= checkInstructionUse(MBB, Use);
|
||||
}
|
||||
return IsChanged;
|
||||
}
|
||||
|
||||
char VentusLegalizeLoad::ID = 0;
|
||||
|
||||
} // end of anonymous namespace
|
||||
|
||||
INITIALIZE_PASS(VentusLegalizeLoad, "Ventus load instruction legalization",
|
||||
VENTUS_LOAD_LEGALIZATION, false, false)
|
||||
|
||||
namespace llvm {
|
||||
FunctionPass *createVentusLegalizeLoadPass() {
|
||||
return new VentusLegalizeLoad();
|
||||
}
|
||||
} // end of namespace llvm
|
Loading…
Reference in New Issue