[mlir] Improve bitEnumContains methods.

839b436c93
changes the behavior. Based on the discussion, we also want to support
"and" behavior. The revision changes it into two functions, bitEnumContainsAny
and bitEnumContainsAll.

Reviewed By: krzysz00, antiagainst

Differential Revision: https://reviews.llvm.org/D133507
This commit is contained in:
Hanhan Wang 2022-09-09 11:56:29 -07:00
parent aee094fb8b
commit aac844a4b1
5 changed files with 26 additions and 17 deletions

View File

@ -1442,9 +1442,12 @@ inline constexpr MyBitEnum operator~(MyBitEnum bits) {
// Ensure only bits that can be present in the enum are set
return static_cast<MyBitEnum>(~static_cast<uint32_t>(bits) & static_cast<uint32_t>(15u));
}
inline constexpr bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
inline constexpr bool bitEnumContainsAll(MyBitEnum bits, MyBitEnum bit) {
return (bits & bit) == bit;
}
inline constexpr bool bitEnumContainsAny(MyBitEnum bits, MyBitEnum bit) {
return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
}
inline constexpr MyBitEnum bitEnumClear(MyBitEnum bits, MyBitEnum bit) {
return bits & ~bit;
}

View File

@ -260,7 +260,8 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
kMemoryAccessAttrName))
return failure();
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
if (spirv::bitEnumContainsAll(memoryAccessAttr,
spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
Type i32Type = parser.getBuilder().getIntegerType(32);
@ -290,7 +291,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
kSourceMemoryAccessAttrName))
return failure();
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
if (spirv::bitEnumContainsAll(memoryAccessAttr,
spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
Type i32Type = parser.getBuilder().getIntegerType(32);
@ -316,7 +318,7 @@ static void printMemoryAccessAttribute(
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.alignment())) {
@ -349,7 +351,7 @@ static void printSourceMemoryAccessAttribute(
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.alignment())) {
@ -407,7 +409,7 @@ static LogicalResult verifyImageOperands(Op imageOp,
spirv::ImageOperands::MakeTexelVisible |
spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
llvm_unreachable("unimplemented operands of Image Operands");
return success();
@ -491,8 +493,8 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
<< memAccessAttr;
}
if (spirv::bitEnumContains(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kAlignmentAttrName)) {
return memoryOp.emitOpError("missing alignment value");
}
@ -535,8 +537,8 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
<< memAccess;
}
if (spirv::bitEnumContains(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kSourceAlignmentAttrName)) {
return memoryOp.emitOpError("missing alignment value");
}

View File

@ -162,7 +162,7 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
llvm::FastMathFlags ret;
auto fmf = op.getFastmathFlags();
for (auto it : handlers)
if (bitEnumContains(fmf, it.first))
if (bitEnumContainsAll(fmf, it.first))
(ret.*(it.second))(true);
return ret;
}

View File

@ -138,7 +138,8 @@ getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
// inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
// inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
// inline constexpr <enum-type> operator~(<enum-type> bits);
// inline constexpr bool bitEnumContains(<enum-type> bits, <enum-type> bit);
// inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
// inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
// inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
// inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
// bool value=true);
@ -161,9 +162,12 @@ inline constexpr {0} operator~({0} bits) {{
// Ensure only bits that can be present in the enum are set
return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
}
inline constexpr bool bitEnumContains({0} bits, {0} bit) {{
inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{
return (bits & bit) == bit;
}
inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{
return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;
}
inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
return bits & ~bit;
}

View File

@ -142,10 +142,10 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
}
TEST(EnumsGenTest, GeneratedOperator) {
EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
BitEnumWithNone::Bit0));
EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3,
BitEnumWithNone::Bit0));
EXPECT_TRUE(bitEnumContainsAll(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
BitEnumWithNone::Bit0));
EXPECT_FALSE(bitEnumContainsAll(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3,
BitEnumWithNone::Bit0));
}
TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) {