[mlir] Improve bitEnumContains methods.
839b436c93
changes the behavior. Based on the discussion, we also want to support
"and" behavior. The revision changes it into two functions, bitEnumContainsAny
and bitEnumContainsAll.
Reviewed By: krzysz00, antiagainst
Differential Revision: https://reviews.llvm.org/D133507
This commit is contained in:
parent
aee094fb8b
commit
aac844a4b1
|
@ -1442,9 +1442,12 @@ inline constexpr MyBitEnum operator~(MyBitEnum bits) {
|
|||
// Ensure only bits that can be present in the enum are set
|
||||
return static_cast<MyBitEnum>(~static_cast<uint32_t>(bits) & static_cast<uint32_t>(15u));
|
||||
}
|
||||
inline constexpr bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
|
||||
inline constexpr bool bitEnumContainsAll(MyBitEnum bits, MyBitEnum bit) {
|
||||
return (bits & bit) == bit;
|
||||
}
|
||||
inline constexpr bool bitEnumContainsAny(MyBitEnum bits, MyBitEnum bit) {
|
||||
return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
|
||||
}
|
||||
inline constexpr MyBitEnum bitEnumClear(MyBitEnum bits, MyBitEnum bit) {
|
||||
return bits & ~bit;
|
||||
}
|
||||
|
|
|
@ -260,7 +260,8 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
|
|||
kMemoryAccessAttrName))
|
||||
return failure();
|
||||
|
||||
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
|
||||
if (spirv::bitEnumContainsAll(memoryAccessAttr,
|
||||
spirv::MemoryAccess::Aligned)) {
|
||||
// Parse integer attribute for alignment.
|
||||
Attribute alignmentAttr;
|
||||
Type i32Type = parser.getBuilder().getIntegerType(32);
|
||||
|
@ -290,7 +291,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
|
|||
kSourceMemoryAccessAttrName))
|
||||
return failure();
|
||||
|
||||
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
|
||||
if (spirv::bitEnumContainsAll(memoryAccessAttr,
|
||||
spirv::MemoryAccess::Aligned)) {
|
||||
// Parse integer attribute for alignment.
|
||||
Attribute alignmentAttr;
|
||||
Type i32Type = parser.getBuilder().getIntegerType(32);
|
||||
|
@ -316,7 +318,7 @@ static void printMemoryAccessAttribute(
|
|||
|
||||
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
||||
|
||||
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
|
||||
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
|
||||
// Print integer alignment attribute.
|
||||
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
|
||||
: memoryOp.alignment())) {
|
||||
|
@ -349,7 +351,7 @@ static void printSourceMemoryAccessAttribute(
|
|||
|
||||
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
||||
|
||||
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
|
||||
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
|
||||
// Print integer alignment attribute.
|
||||
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
|
||||
: memoryOp.alignment())) {
|
||||
|
@ -407,7 +409,7 @@ static LogicalResult verifyImageOperands(Op imageOp,
|
|||
spirv::ImageOperands::MakeTexelVisible |
|
||||
spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
|
||||
|
||||
if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
|
||||
if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
|
||||
llvm_unreachable("unimplemented operands of Image Operands");
|
||||
|
||||
return success();
|
||||
|
@ -491,8 +493,8 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
|
|||
<< memAccessAttr;
|
||||
}
|
||||
|
||||
if (spirv::bitEnumContains(memAccess.getValue(),
|
||||
spirv::MemoryAccess::Aligned)) {
|
||||
if (spirv::bitEnumContainsAll(memAccess.getValue(),
|
||||
spirv::MemoryAccess::Aligned)) {
|
||||
if (!op->getAttr(kAlignmentAttrName)) {
|
||||
return memoryOp.emitOpError("missing alignment value");
|
||||
}
|
||||
|
@ -535,8 +537,8 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
|
|||
<< memAccess;
|
||||
}
|
||||
|
||||
if (spirv::bitEnumContains(memAccess.getValue(),
|
||||
spirv::MemoryAccess::Aligned)) {
|
||||
if (spirv::bitEnumContainsAll(memAccess.getValue(),
|
||||
spirv::MemoryAccess::Aligned)) {
|
||||
if (!op->getAttr(kSourceAlignmentAttrName)) {
|
||||
return memoryOp.emitOpError("missing alignment value");
|
||||
}
|
||||
|
|
|
@ -162,7 +162,7 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
|
|||
llvm::FastMathFlags ret;
|
||||
auto fmf = op.getFastmathFlags();
|
||||
for (auto it : handlers)
|
||||
if (bitEnumContains(fmf, it.first))
|
||||
if (bitEnumContainsAll(fmf, it.first))
|
||||
(ret.*(it.second))(true);
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -138,7 +138,8 @@ getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
|
|||
// inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
|
||||
// inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
|
||||
// inline constexpr <enum-type> operator~(<enum-type> bits);
|
||||
// inline constexpr bool bitEnumContains(<enum-type> bits, <enum-type> bit);
|
||||
// inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
|
||||
// inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
|
||||
// inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
|
||||
// inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
|
||||
// bool value=true);
|
||||
|
@ -161,9 +162,12 @@ inline constexpr {0} operator~({0} bits) {{
|
|||
// Ensure only bits that can be present in the enum are set
|
||||
return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
|
||||
}
|
||||
inline constexpr bool bitEnumContains({0} bits, {0} bit) {{
|
||||
inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{
|
||||
return (bits & bit) == bit;
|
||||
}
|
||||
inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{
|
||||
return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;
|
||||
}
|
||||
inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
|
||||
return bits & ~bit;
|
||||
}
|
||||
|
|
|
@ -142,10 +142,10 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
|
|||
}
|
||||
|
||||
TEST(EnumsGenTest, GeneratedOperator) {
|
||||
EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
|
||||
BitEnumWithNone::Bit0));
|
||||
EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3,
|
||||
BitEnumWithNone::Bit0));
|
||||
EXPECT_TRUE(bitEnumContainsAll(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
|
||||
BitEnumWithNone::Bit0));
|
||||
EXPECT_FALSE(bitEnumContainsAll(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3,
|
||||
BitEnumWithNone::Bit0));
|
||||
}
|
||||
|
||||
TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) {
|
||||
|
|
Loading…
Reference in New Issue