[SLP] Refactoring isLegalBroadcastLoad() to use `ElementCount`.

Replacing `unsigned` with `ElementCount` in the argument of `isLegalBroadcastLoad()`.
This helps reduce the diff of a future SLP patch for AArch64.
This commit is contained in:
Vasileios Porpodas 2022-04-20 08:17:29 -07:00
parent f296b4c444
commit 889588ee97
6 changed files with 12 additions and 10 deletions

View File

@ -660,7 +660,7 @@ public:
/// \Returns true if the target supports broadcasting a load to a vector of /// \Returns true if the target supports broadcasting a load to a vector of
/// type <NumElements x ElementTy>. /// type <NumElements x ElementTy>.
bool isLegalBroadcastLoad(Type *ElementTy, unsigned NumElements) const; bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const;
/// Return true if the target supports masked scatter. /// Return true if the target supports masked scatter.
bool isLegalMaskedScatter(Type *DataType, Align Alignment) const; bool isLegalMaskedScatter(Type *DataType, Align Alignment) const;
@ -1560,7 +1560,7 @@ public:
virtual bool isLegalNTStore(Type *DataType, Align Alignment) = 0; virtual bool isLegalNTStore(Type *DataType, Align Alignment) = 0;
virtual bool isLegalNTLoad(Type *DataType, Align Alignment) = 0; virtual bool isLegalNTLoad(Type *DataType, Align Alignment) = 0;
virtual bool isLegalBroadcastLoad(Type *ElementTy, virtual bool isLegalBroadcastLoad(Type *ElementTy,
unsigned NumElements) const = 0; ElementCount NumElements) const = 0;
virtual bool isLegalMaskedScatter(Type *DataType, Align Alignment) = 0; virtual bool isLegalMaskedScatter(Type *DataType, Align Alignment) = 0;
virtual bool isLegalMaskedGather(Type *DataType, Align Alignment) = 0; virtual bool isLegalMaskedGather(Type *DataType, Align Alignment) = 0;
virtual bool forceScalarizeMaskedGather(VectorType *DataType, virtual bool forceScalarizeMaskedGather(VectorType *DataType,
@ -1968,7 +1968,7 @@ public:
return Impl.isLegalNTLoad(DataType, Alignment); return Impl.isLegalNTLoad(DataType, Alignment);
} }
bool isLegalBroadcastLoad(Type *ElementTy, bool isLegalBroadcastLoad(Type *ElementTy,
unsigned NumElements) const override { ElementCount NumElements) const override {
return Impl.isLegalBroadcastLoad(ElementTy, NumElements); return Impl.isLegalBroadcastLoad(ElementTy, NumElements);
} }
bool isLegalMaskedScatter(Type *DataType, Align Alignment) override { bool isLegalMaskedScatter(Type *DataType, Align Alignment) override {

View File

@ -256,7 +256,7 @@ public:
return Alignment >= DataSize && isPowerOf2_32(DataSize); return Alignment >= DataSize && isPowerOf2_32(DataSize);
} }
bool isLegalBroadcastLoad(Type *ElementTy, unsigned NumElements) const { bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const {
return false; return false;
} }

View File

@ -397,7 +397,7 @@ bool TargetTransformInfo::isLegalNTLoad(Type *DataType, Align Alignment) const {
} }
bool TargetTransformInfo::isLegalBroadcastLoad(Type *ElementTy, bool TargetTransformInfo::isLegalBroadcastLoad(Type *ElementTy,
unsigned NumElements) const { ElementCount NumElements) const {
return TTIImpl->isLegalBroadcastLoad(ElementTy, NumElements); return TTIImpl->isLegalBroadcastLoad(ElementTy, NumElements);
} }

View File

@ -1558,7 +1558,7 @@ InstructionCost X86TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
if (const auto *Entry = if (const auto *Entry =
CostTableLookup(SSE3BroadcastLoadTbl, Kind, LT.second)) { CostTableLookup(SSE3BroadcastLoadTbl, Kind, LT.second)) {
assert(isLegalBroadcastLoad(BaseTp->getElementType(), assert(isLegalBroadcastLoad(BaseTp->getElementType(),
LT.second.getVectorNumElements()) && LT.second.getVectorElementCount()) &&
"Table entry missing from isLegalBroadcastLoad()"); "Table entry missing from isLegalBroadcastLoad()");
return LT.first * Entry->Cost; return LT.first * Entry->Cost;
} }
@ -5137,9 +5137,10 @@ bool X86TTIImpl::isLegalNTStore(Type *DataType, Align Alignment) {
} }
bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy, bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy,
unsigned NumElements) const { ElementCount NumElements) const {
// movddup // movddup
return ST->hasSSE3() && NumElements == 2 && return ST->hasSSE3() && !NumElements.isScalable() &&
NumElements.getFixedValue() == 2 &&
ElementTy == Type::getDoubleTy(ElementTy->getContext()); ElementTy == Type::getDoubleTy(ElementTy->getContext());
} }

View File

@ -232,7 +232,7 @@ public:
bool isLegalMaskedStore(Type *DataType, Align Alignment); bool isLegalMaskedStore(Type *DataType, Align Alignment);
bool isLegalNTLoad(Type *DataType, Align Alignment); bool isLegalNTLoad(Type *DataType, Align Alignment);
bool isLegalNTStore(Type *DataType, Align Alignment); bool isLegalNTStore(Type *DataType, Align Alignment);
bool isLegalBroadcastLoad(Type *ElementTy, unsigned NumElements) const; bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const;
bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment); bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment);
bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) { bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) {
return forceScalarizeMaskedGather(VTy, Alignment); return forceScalarizeMaskedGather(VTy, Alignment);

View File

@ -1188,7 +1188,8 @@ public:
return AllUsersVectorized(V1) && AllUsersVectorized(V2); return AllUsersVectorized(V1) && AllUsersVectorized(V2);
}; };
// A broadcast of a load can be cheaper on some targets. // A broadcast of a load can be cheaper on some targets.
if (R.TTI->isLegalBroadcastLoad(V1->getType(), NumLanes) && if (R.TTI->isLegalBroadcastLoad(V1->getType(),
ElementCount::getFixed(NumLanes)) &&
((int)V1->getNumUses() == NumLanes || ((int)V1->getNumUses() == NumLanes ||
AllUsersAreInternal(V1, V2))) AllUsersAreInternal(V1, V2)))
return VLOperands::ScoreSplatLoads; return VLOperands::ScoreSplatLoads;