[mlir:Bytecode] Add encoding support for a majority of the builtin attributes

This adds support for the non-location, non-elements, non-affine
builtin attributes.

Differential Revision: https://reviews.llvm.org/D132539
This commit is contained in:
River Riddle 2022-08-23 23:39:18 -07:00
parent 937aaead87
commit 2f90764ce8
7 changed files with 399 additions and 13 deletions

View File

@ -65,6 +65,22 @@ x1000000: 49 value bits, the encoding uses 7 bytes
00000000: 64 value bits, the encoding uses 9 bytes
```
##### Signed Variable-Width Integers
Signed variable width integer values are encoded in a similar fashion to
[varints](#variable-width-integers), but employ
[zigzag encoding](https://en.wikipedia.org/wiki/Variable-length_quantity#Zigzag_encoding).
This encoding uses the low bit of the value to indicate the sign, which allows
for more efficiently encoding negative numbers. If a negative value were encoded
using a normal [varint](#variable-width-integers), it would be treated as an
extremely large unsigned value. Using zigzag encoding allows for a smaller
number of active bits in the value, leading to a smaller encoding. Below is the
basic computation for generating a zigzag encoding:
```
(value << 1) ^ (value >> 63)
```
#### Strings
Strings are blobs of characters with an associated length.

View File

@ -78,14 +78,14 @@ public:
return readList(attrs, [this](T &attr) { return readAttribute(attr); });
}
template <typename T>
LogicalResult parseAttribute(T &result) {
LogicalResult readAttribute(T &result) {
Attribute baseResult;
if (failed(parseAttribute(baseResult)))
if (failed(readAttribute(baseResult)))
return failure();
if ((result = baseResult.dyn_cast<T>()))
return success();
return emitError() << "expected attribute of type: "
<< llvm::getTypeName<T>() << ", but got: " << baseResult;
return emitError() << "expected " << llvm::getTypeName<T>()
<< ", but got: " << baseResult;
}
/// Read a reference to the given type.
@ -94,15 +94,35 @@ public:
LogicalResult readTypes(SmallVectorImpl<T> &types) {
return readList(types, [this](T &type) { return readType(type); });
}
template <typename T>
LogicalResult readType(T &result) {
Type baseResult;
if (failed(readType(baseResult)))
return failure();
if ((result = baseResult.dyn_cast<T>()))
return success();
return emitError() << "expected " << llvm::getTypeName<T>()
<< ", but got: " << baseResult;
}
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
/// Read a variable width integer.
// TODO: Add a signed variant when necessary.
virtual LogicalResult readVarInt(uint64_t &result) = 0;
/// Read a signed variable width integer.
virtual LogicalResult readSignedVarInt(int64_t &result) = 0;
/// Read an APInt that is known to have been encoded with the given width.
virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
/// Read an APFloat that is known to have been encoded with the given
/// semantics.
virtual FailureOr<APFloat>
readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) = 0;
/// Read a string from the bytecode.
virtual LogicalResult readString(StringRef &result) = 0;
};
@ -153,9 +173,25 @@ public:
/// Write a variable width integer to the output stream. This should be the
/// preferred method for emitting integers whenever possible.
// TODO: Add a signed variant when necessary.
virtual void writeVarInt(uint64_t value) = 0;
/// Write a signed variable width integer to the output stream. This should be
/// the preferred method for emitting signed integers whenever possible.
virtual void writeSignedVarInt(int64_t value) = 0;
/// Write an APInt to the bytecode stream whose bitwidth will be known
/// externally at read time. This method is useful for encoding APInt values
/// when the width is known via external means, such as via a type. This
/// method should generally only be invoked if you need an APInt, otherwise
/// use the varint methods above. APInt values are generally encoded using
/// zigzag encoding, to enable more efficient encodings for negative values.
virtual void writeAPIntWithKnownWidth(const APInt &value) = 0;
/// Write an APFloat to the bytecode stream whose semantics will be known
/// externally at read time. This method is useful for encoding APFloat values
/// when the semantics are known via external means, such as via a type.
virtual void writeAPFloatWithKnownSemantics(const APFloat &value) = 0;
/// Write a string to the bytecode, which is owned by the caller and is
/// guaranteed to not die before the end of the bytecode process. This should
/// only be called if such a guarantee can be made, such as when the string is

View File

@ -128,6 +128,17 @@ public:
return parseMultiByteVarInt(result);
}
/// Parse a signed variable length encoded integer from the byte stream. A
/// signed varint is encoded as a normal varint with zigzag encoding applied,
/// i.e. the low bit of the value is used to indicate the sign.
LogicalResult parseSignedVarInt(uint64_t &result) {
if (failed(parseVarInt(result)))
return failure();
// Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
result = (result >> 1) ^ (~(result & 1) + 1);
return success();
}
/// Parse a variable length encoded integer whose low bit is used to encode an
/// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
@ -511,6 +522,52 @@ public:
return reader.parseVarInt(result);
}
LogicalResult readSignedVarInt(int64_t &result) override {
uint64_t unsignedResult;
if (failed(reader.parseSignedVarInt(unsignedResult)))
return failure();
result = static_cast<int64_t>(unsignedResult);
return success();
}
FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
// Small values are encoded using a single byte.
if (bitWidth <= 8) {
uint8_t value;
if (failed(reader.parseByte(value)))
return failure();
return APInt(bitWidth, value);
}
// Large values up to 64 bits are encoded using a single varint.
if (bitWidth <= 64) {
uint64_t value;
if (failed(reader.parseSignedVarInt(value)))
return failure();
return APInt(bitWidth, value);
}
// Otherwise, for really big values we encode the array of active words in
// the value.
uint64_t numActiveWords;
if (failed(reader.parseVarInt(numActiveWords)))
return failure();
SmallVector<uint64_t, 4> words(numActiveWords);
for (uint64_t i = 0; i < numActiveWords; ++i)
if (failed(reader.parseSignedVarInt(words[i])))
return failure();
return APInt(bitWidth, words);
}
FailureOr<APFloat>
readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
FailureOr<APInt> intVal =
readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
if (failed(intVal))
return failure();
return APFloat(semantics, *intVal);
}
LogicalResult readString(StringRef &result) override {
return stringReader.parseString(reader, result);
}

View File

@ -85,6 +85,14 @@ public:
emitMultiByteVarInt(value);
}
/// Emit a signed variable length integer. Signed varints are encoded using
/// a varint with zigzag encoding, meaning that we use the low bit of the
/// value to indicate the sign of the value. This allows for more efficient
/// encoding of negative values by limiting the number of active bits
void emitSignedVarInt(uint64_t value) {
emitVarInt((value << 1) ^ (uint64_t)((int64_t)value >> 63));
}
/// Emit a variable length integer whose low bit is used to encode the
/// provided flag, i.e. encoded as: (value << 1) | (flag ? 1 : 0).
void emitVarIntWithFlag(uint64_t value, bool flag) {
@ -384,6 +392,37 @@ public:
void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); }
void writeSignedVarInt(int64_t value) override {
emitter.emitSignedVarInt(value);
}
void writeAPIntWithKnownWidth(const APInt &value) override {
size_t bitWidth = value.getBitWidth();
// If the value is a single byte, just emit it directly without going
// through a varint.
if (bitWidth <= 8)
return emitter.emitByte(value.getLimitedValue());
// If the value fits within a single varint, emit it directly.
if (bitWidth <= 64)
return emitter.emitSignedVarInt(value.getLimitedValue());
// Otherwise, we need to encode a variable number of active words. We use
// active words instead of the number of total words under the observation
// that smaller values will be more common.
unsigned numActiveWords = value.getActiveWords();
emitter.emitVarInt(numActiveWords);
const uint64_t *rawValueData = value.getRawData();
for (unsigned i = 0; i < numActiveWords; ++i)
emitter.emitSignedVarInt(rawValueData[i]);
}
void writeAPFloatWithKnownSemantics(const APFloat &value) override {
writeAPIntWithKnownWidth(value.bitcastToAPInt());
}
void writeOwnedString(StringRef str) override {
emitter.emitVarInt(stringSection.insert(str));
}

View File

@ -27,6 +27,9 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
/// Stubbed out methods that are not used for numbering.
void writeVarInt(uint64_t) override {}
void writeSignedVarInt(int64_t value) override {}
void writeAPIntWithKnownWidth(const APInt &value) override {}
void writeAPFloatWithKnownSemantics(const APFloat &value) override {}
void writeOwnedString(StringRef) override {
// TODO: It might be nice to prenumber strings and sort by the number of
// references. This could potentially be useful for optimizing things like

View File

@ -38,9 +38,49 @@ enum AttributeCode {
kDictionaryAttr = 1,
/// StringAttr {
/// string
/// value: string
/// }
kStringAttr = 2,
/// StringAttrWithType {
/// value: string,
/// type: Type
/// }
/// A variant of StringAttr with a type.
kStringAttrWithType = 3,
/// FlatSymbolRefAttr {
/// rootReference: StringAttr
/// }
/// A variant of SymbolRefAttr with no leaf references.
kFlatSymbolRefAttr = 4,
/// SymbolRefAttr {
/// rootReference: StringAttr,
/// leafReferences: FlatSymbolRefAttr[]
/// }
kSymbolRefAttr = 5,
/// TypeAttr {
/// value: Type
/// }
kTypeAttr = 6,
/// UnitAttr {
/// }
kUnitAttr = 7,
/// IntegerAttr {
/// type: Type
/// value: APInt,
/// }
kIntegerAttr = 8,
/// FloatAttr {
/// type: FloatType
/// value: APFloat
/// }
kFloatAttr = 9,
};
/// This enum contains marker codes used to indicate which type is currently
@ -86,13 +126,22 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
Attribute readAttribute(DialectBytecodeReader &reader) const override;
ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const;
DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const;
StringAttr readStringAttr(DialectBytecodeReader &reader) const;
FloatAttr readFloatAttr(DialectBytecodeReader &reader) const;
IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const;
StringAttr readStringAttr(DialectBytecodeReader &reader, bool hasType) const;
SymbolRefAttr readSymbolRefAttr(DialectBytecodeReader &reader,
bool hasNestedRefs) const;
TypeAttr readTypeAttr(DialectBytecodeReader &reader) const;
LogicalResult writeAttribute(Attribute attr,
DialectBytecodeWriter &writer) const override;
void write(ArrayAttr attr, DialectBytecodeWriter &writer) const;
void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const;
void write(IntegerAttr attr, DialectBytecodeWriter &writer) const;
void write(FloatAttr attr, DialectBytecodeWriter &writer) const;
void write(StringAttr attr, DialectBytecodeWriter &writer) const;
void write(SymbolRefAttr attr, DialectBytecodeWriter &writer) const;
void write(TypeAttr attr, DialectBytecodeWriter &writer) const;
//===--------------------------------------------------------------------===//
// Types
@ -126,7 +175,21 @@ Attribute BuiltinDialectBytecodeInterface::readAttribute(
case builtin_encoding::kDictionaryAttr:
return readDictionaryAttr(reader);
case builtin_encoding::kStringAttr:
return readStringAttr(reader);
return readStringAttr(reader, /*hasType=*/false);
case builtin_encoding::kStringAttrWithType:
return readStringAttr(reader, /*hasType=*/true);
case builtin_encoding::kFlatSymbolRefAttr:
return readSymbolRefAttr(reader, /*hasNestedRefs=*/false);
case builtin_encoding::kSymbolRefAttr:
return readSymbolRefAttr(reader, /*hasNestedRefs=*/true);
case builtin_encoding::kTypeAttr:
return readTypeAttr(reader);
case builtin_encoding::kUnitAttr:
return UnitAttr::get(getContext());
case builtin_encoding::kIntegerAttr:
return readIntegerAttr(reader);
case builtin_encoding::kFloatAttr:
return readFloatAttr(reader);
default:
reader.emitError() << "unknown builtin attribute code: " << code;
return Attribute();
@ -157,12 +220,75 @@ DictionaryAttr BuiltinDialectBytecodeInterface::readDictionaryAttr(
return DictionaryAttr::get(getContext(), attrs);
}
StringAttr BuiltinDialectBytecodeInterface::readStringAttr(
FloatAttr BuiltinDialectBytecodeInterface::readFloatAttr(
DialectBytecodeReader &reader) const {
FloatType type;
if (failed(reader.readType(type)))
return FloatAttr();
FailureOr<APFloat> value =
reader.readAPFloatWithKnownSemantics(type.getFloatSemantics());
if (failed(value))
return FloatAttr();
return FloatAttr::get(type, *value);
}
IntegerAttr BuiltinDialectBytecodeInterface::readIntegerAttr(
DialectBytecodeReader &reader) const {
Type type;
if (failed(reader.readType(type)))
return IntegerAttr();
// Extract the value storage width from the type.
unsigned bitWidth;
if (auto intType = type.dyn_cast<IntegerType>()) {
bitWidth = intType.getWidth();
} else if (type.isa<IndexType>()) {
bitWidth = IndexType::kInternalStorageBitWidth;
} else {
reader.emitError()
<< "expected integer or index type for IntegerAttr, but got: " << type;
return IntegerAttr();
}
FailureOr<APInt> value = reader.readAPIntWithKnownWidth(bitWidth);
if (failed(value))
return IntegerAttr();
return IntegerAttr::get(type, *value);
}
StringAttr
BuiltinDialectBytecodeInterface::readStringAttr(DialectBytecodeReader &reader,
bool hasType) const {
StringRef string;
if (failed(reader.readString(string)))
return StringAttr();
return StringAttr::get(getContext(), string);
// Read the type if present.
Type type;
if (!hasType)
type = NoneType::get(getContext());
else if (failed(reader.readType(type)))
return StringAttr();
return StringAttr::get(string, type);
}
SymbolRefAttr BuiltinDialectBytecodeInterface::readSymbolRefAttr(
DialectBytecodeReader &reader, bool hasNestedRefs) const {
StringAttr rootReference;
if (failed(reader.readAttribute(rootReference)))
return SymbolRefAttr();
SmallVector<FlatSymbolRefAttr> nestedReferences;
if (hasNestedRefs && failed(reader.readAttributes(nestedReferences)))
return SymbolRefAttr();
return SymbolRefAttr::get(rootReference, nestedReferences);
}
TypeAttr BuiltinDialectBytecodeInterface::readTypeAttr(
DialectBytecodeReader &reader) const {
Type type;
if (failed(reader.readType(type)))
return TypeAttr();
return TypeAttr::get(type);
}
//===----------------------------------------------------------------------===//
@ -171,10 +297,15 @@ StringAttr BuiltinDialectBytecodeInterface::readStringAttr(
LogicalResult BuiltinDialectBytecodeInterface::writeAttribute(
Attribute attr, DialectBytecodeWriter &writer) const {
return TypeSwitch<Attribute, LogicalResult>(attr)
.Case<ArrayAttr, DictionaryAttr, StringAttr>([&](auto attr) {
.Case<ArrayAttr, DictionaryAttr, FloatAttr, IntegerAttr, StringAttr,
SymbolRefAttr, TypeAttr>([&](auto attr) {
write(attr, writer);
return success();
})
.Case([&](UnitAttr) {
writer.writeVarInt(builtin_encoding::kUnitAttr);
return success();
})
.Default([&](Attribute) { return failure(); });
}
@ -193,12 +324,52 @@ void BuiltinDialectBytecodeInterface::write(
});
}
void BuiltinDialectBytecodeInterface::write(
FloatAttr attr, DialectBytecodeWriter &writer) const {
writer.writeVarInt(builtin_encoding::kFloatAttr);
writer.writeType(attr.getType());
writer.writeAPFloatWithKnownSemantics(attr.getValue());
}
void BuiltinDialectBytecodeInterface::write(
IntegerAttr attr, DialectBytecodeWriter &writer) const {
writer.writeVarInt(builtin_encoding::kIntegerAttr);
writer.writeType(attr.getType());
writer.writeAPIntWithKnownWidth(attr.getValue());
}
void BuiltinDialectBytecodeInterface::write(
StringAttr attr, DialectBytecodeWriter &writer) const {
// We only encode the type if it isn't NoneType, which is significantly less
// common.
Type type = attr.getType();
if (!type.isa<NoneType>()) {
writer.writeVarInt(builtin_encoding::kStringAttrWithType);
writer.writeOwnedString(attr.getValue());
writer.writeType(type);
return;
}
writer.writeVarInt(builtin_encoding::kStringAttr);
writer.writeOwnedString(attr.getValue());
}
void BuiltinDialectBytecodeInterface::write(
SymbolRefAttr attr, DialectBytecodeWriter &writer) const {
ArrayRef<FlatSymbolRefAttr> nestedRefs = attr.getNestedReferences();
writer.writeVarInt(nestedRefs.empty() ? builtin_encoding::kFlatSymbolRefAttr
: builtin_encoding::kSymbolRefAttr);
writer.writeAttribute(attr.getRootReference());
if (!nestedRefs.empty())
writer.writeAttributes(nestedRefs);
}
void BuiltinDialectBytecodeInterface::write(
TypeAttr attr, DialectBytecodeWriter &writer) const {
writer.writeVarInt(builtin_encoding::kTypeAttr);
writer.writeType(attr.getValue());
}
//===----------------------------------------------------------------------===//
// Types: Reader

View File

@ -3,14 +3,78 @@
// Bytecode currently does not support big-endian platforms
// UNSUPPORTED: s390x-
//===----------------------------------------------------------------------===//
// ArrayAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestArray
module @TestArray attributes {
// CHECK: bytecode.array = [unit]
bytecode.array = [unit]
} {}
//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestFloat
module @TestFloat attributes {
// CHECK: bytecode.float = 1.000000e+01 : f64
// CHECK: bytecode.float1 = 0.10000{{.*}} : f80
// CHECK: bytecode.float2 = 0.10000{{.*}} : f128
// CHECK: bytecode.float3 = -5.000000e-01 : bf16
bytecode.float = 10.0 : f64,
bytecode.float1 = 0.1 : f80,
bytecode.float2 = 0.1 : f128,
bytecode.float3 = -0.5 : bf16
} {}
//===----------------------------------------------------------------------===//
// IntegerAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestInt
module @TestInt attributes {
// CHECK: bytecode.int = false
// CHECK: bytecode.int1 = -1 : i8
// CHECK: bytecode.int2 = 800 : ui64
// CHECK: bytecode.int3 = 90000000000000000300000000000000000001 : i128
bytecode.int = false,
bytecode.int1 = -1 : i8,
bytecode.int2 = 800 : ui64,
bytecode.int3 = 90000000000000000300000000000000000001 : i128
} {}
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestString
module @TestString attributes {
// CHECK: bytecode.string = "hello"
bytecode.string = "hello"
// CHECK: bytecode.string2 = "hello" : i32
bytecode.string = "hello",
bytecode.string2 = "hello" : i32
} {}
//===----------------------------------------------------------------------===//
// SymbolRefAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestSymbolRef
module @TestSymbolRef attributes {
// CHECK: bytecode.ref = @foo
// CHECK: bytecode.ref2 = @foo::@bar::@foo
bytecode.ref = @foo,
bytecode.ref2 = @foo::@bar::@foo
} {}
//===----------------------------------------------------------------------===//
// TypeAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestType
module @TestType attributes {
// CHECK: bytecode.type = i178
bytecode.type = i178
} {}