[mlir][quant] Initial bytecode encoding for quantized types
Add bytecode encoding for quantized types. These mostly follow the storage representation of these. Differential Revision: https://reviews.llvm.org/D136004
This commit is contained in:
parent
dc8035bddd
commit
7732c97f52
|
@ -1,4 +1,6 @@
|
|||
add_mlir_dialect_library(MLIRQuantDialect
|
||||
QuantDialectBytecode.h
|
||||
QuantDialectBytecode.cpp
|
||||
QuantOps.cpp
|
||||
QuantTypes.cpp
|
||||
TypeDetail.h
|
||||
|
|
|
@ -0,0 +1,299 @@
|
|||
//===- QuantDialectBytecode.cpp - Quant Bytecode Implementation
|
||||
//------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "QuantDialectBytecode.h"
|
||||
#include "mlir/Bytecode/BytecodeImplementation.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quant;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
namespace quant_encoding {
|
||||
/// This enum contains marker codes used to indicate which type is currently
|
||||
/// being decoded, and how it should be decoded. The order of these codes should
|
||||
/// generally be unchanged, as any changes will inevitably break compatibility
|
||||
/// with older bytecode.
|
||||
enum TypeCode {
|
||||
/// AnyQuantizedType {
|
||||
/// flags: varint
|
||||
/// storageType: Type
|
||||
/// storageTypeMin: svarint
|
||||
/// storageTypeMax: svarint
|
||||
/// }
|
||||
///
|
||||
kAnyQuantizedType = 1,
|
||||
|
||||
/// AnyQuantizedType {
|
||||
/// flags: varint
|
||||
/// storageType: Type
|
||||
/// expressedType: Type
|
||||
/// storageTypeMin: svarint
|
||||
/// storageTypeMax: svarint
|
||||
/// }
|
||||
///
|
||||
kAnyQuantizedTypeWithExpressedType = 2,
|
||||
|
||||
/// CalibratedQuantizedType {
|
||||
/// expressedType: Type
|
||||
/// min: APFloat
|
||||
/// max: APFloat
|
||||
/// }
|
||||
///
|
||||
kCalibratedQuantizedType = 3,
|
||||
|
||||
/// UniformQuantizedType {
|
||||
/// flags: varint
|
||||
/// storageType: Type
|
||||
/// expressedType: Type
|
||||
/// scale: APFloat
|
||||
/// zeroPoint: svarint
|
||||
/// storageTypeMin: svarint
|
||||
/// storageTypeMax: svarint
|
||||
/// }
|
||||
///
|
||||
kUniformQuantizedType = 4,
|
||||
|
||||
/// UniformQuantizedPerAxisType {
|
||||
/// flags: varint
|
||||
/// storageType: Type
|
||||
/// expressedType: Type
|
||||
/// quantizedDimension: varint
|
||||
/// storageTypeMin: svarint
|
||||
/// storageTypeMax: svarint
|
||||
/// scale: APFloat[]
|
||||
/// zeroPoint: svarint[]
|
||||
/// }
|
||||
///
|
||||
kUniformQuantizedPerAxisType = 5,
|
||||
};
|
||||
|
||||
} // namespace quant_encoding
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// QuantDialectBytecodeInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class implements the bytecode interface for the Quant dialect.
|
||||
struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
|
||||
QuantDialectBytecodeInterface(Dialect *dialect)
|
||||
: BytecodeDialectInterface(dialect) {}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Types
|
||||
|
||||
Type readType(DialectBytecodeReader &reader) const override;
|
||||
LogicalResult writeType(Type type,
|
||||
DialectBytecodeWriter &writer) const override;
|
||||
|
||||
AnyQuantizedType readAnyQuantizedType(bool withExpressedType,
|
||||
DialectBytecodeReader &reader) const;
|
||||
void write(AnyQuantizedType type, DialectBytecodeWriter &writer) const;
|
||||
|
||||
CalibratedQuantizedType
|
||||
readCalibratedQuantizedType(DialectBytecodeReader &reader) const;
|
||||
void write(CalibratedQuantizedType type, DialectBytecodeWriter &writer) const;
|
||||
|
||||
UniformQuantizedType
|
||||
readUniformQuantizedType(DialectBytecodeReader &reader) const;
|
||||
void write(UniformQuantizedType type, DialectBytecodeWriter &writer) const;
|
||||
|
||||
UniformQuantizedPerAxisType
|
||||
readUniformQuantizedPerAxisType(DialectBytecodeReader &reader) const;
|
||||
void write(UniformQuantizedPerAxisType type,
|
||||
DialectBytecodeWriter &writer) const;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) {
|
||||
dialect->addInterfaces<QuantDialectBytecodeInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type QuantDialectBytecodeInterface::readType(
|
||||
DialectBytecodeReader &reader) const {
|
||||
uint64_t code;
|
||||
if (failed(reader.readVarInt(code)))
|
||||
return Type();
|
||||
|
||||
switch (code) {
|
||||
case quant_encoding::kAnyQuantizedType:
|
||||
return readAnyQuantizedType(/*withExpressedType=*/false, reader);
|
||||
case quant_encoding::kAnyQuantizedTypeWithExpressedType:
|
||||
return readAnyQuantizedType(/*withExpressedType=*/true, reader);
|
||||
case quant_encoding::kCalibratedQuantizedType:
|
||||
return readCalibratedQuantizedType(reader);
|
||||
case quant_encoding::kUniformQuantizedType:
|
||||
return readUniformQuantizedType(reader);
|
||||
case quant_encoding::kUniformQuantizedPerAxisType:
|
||||
return readUniformQuantizedPerAxisType(reader);
|
||||
|
||||
default:
|
||||
reader.emitError() << "unknown builtin type code: " << code;
|
||||
return Type();
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
QuantDialectBytecodeInterface::writeType(Type type,
|
||||
DialectBytecodeWriter &writer) const {
|
||||
return TypeSwitch<Type, LogicalResult>(type)
|
||||
.Case<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType>(
|
||||
[&](auto attr) { return write(attr, writer), success(); })
|
||||
.Default([&](Type) { return failure(); });
|
||||
}
|
||||
|
||||
AnyQuantizedType QuantDialectBytecodeInterface::readAnyQuantizedType(
|
||||
bool withExpressedType, DialectBytecodeReader &reader) const {
|
||||
uint64_t flags;
|
||||
Type storageType, expressedType;
|
||||
int64_t storageTypeMin, storageTypeMax;
|
||||
if (failed(reader.readVarInt(flags)) ||
|
||||
failed(reader.readType(storageType)) ||
|
||||
(withExpressedType && failed(reader.readType(expressedType))) ||
|
||||
failed(reader.readSignedVarInt(storageTypeMin)) ||
|
||||
failed(reader.readSignedVarInt(storageTypeMax)))
|
||||
return reader.emitError("invalid AnyQuantizedType"), AnyQuantizedType();
|
||||
return AnyQuantizedType::get(flags, storageType, expressedType,
|
||||
storageTypeMin, storageTypeMax);
|
||||
}
|
||||
void QuantDialectBytecodeInterface::write(AnyQuantizedType type,
|
||||
DialectBytecodeWriter &writer) const {
|
||||
if (type.getExpressedType())
|
||||
writer.writeVarInt(quant_encoding::kAnyQuantizedTypeWithExpressedType);
|
||||
else
|
||||
writer.writeVarInt(quant_encoding::kAnyQuantizedType);
|
||||
|
||||
writer.writeVarInt(type.getFlags());
|
||||
writer.writeType(type.getStorageType());
|
||||
if (type.getExpressedType())
|
||||
writer.writeType(type.getExpressedType());
|
||||
writer.writeSignedVarInt(type.getStorageTypeMin());
|
||||
writer.writeSignedVarInt(type.getStorageTypeMax());
|
||||
}
|
||||
|
||||
CalibratedQuantizedType
|
||||
QuantDialectBytecodeInterface::readCalibratedQuantizedType(
|
||||
DialectBytecodeReader &reader) const {
|
||||
Type expressedType;
|
||||
FailureOr<APFloat> min, max;
|
||||
if (failed(reader.readType(expressedType)) ||
|
||||
failed(min = reader.readAPFloatWithKnownSemantics(
|
||||
llvm::APFloat::IEEEdouble())) ||
|
||||
failed(max = reader.readAPFloatWithKnownSemantics(
|
||||
llvm::APFloat::IEEEdouble())))
|
||||
return reader.emitError("invalid CalibratedQuantizedType"),
|
||||
CalibratedQuantizedType();
|
||||
return CalibratedQuantizedType::get(expressedType,
|
||||
min.value().convertToDouble(),
|
||||
max.value().convertToDouble());
|
||||
}
|
||||
void QuantDialectBytecodeInterface::write(CalibratedQuantizedType type,
|
||||
DialectBytecodeWriter &writer) const {
|
||||
writer.writeVarInt(quant_encoding::kCalibratedQuantizedType);
|
||||
writer.writeType(type.getExpressedType());
|
||||
writer.writeAPFloatWithKnownSemantics(APFloat(type.getMin()));
|
||||
writer.writeAPFloatWithKnownSemantics(APFloat(type.getMax()));
|
||||
}
|
||||
|
||||
UniformQuantizedType QuantDialectBytecodeInterface::readUniformQuantizedType(
|
||||
DialectBytecodeReader &reader) const {
|
||||
uint64_t flags;
|
||||
Type storageType, expressedType;
|
||||
FailureOr<APFloat> scale;
|
||||
int64_t zeroPoint, storageTypeMin, storageTypeMax;
|
||||
if (failed(reader.readVarInt(flags)) ||
|
||||
failed(reader.readType(storageType)) ||
|
||||
failed(reader.readType(expressedType)) ||
|
||||
failed(scale = reader.readAPFloatWithKnownSemantics(
|
||||
llvm::APFloat::IEEEdouble())) ||
|
||||
failed(reader.readSignedVarInt(zeroPoint)) ||
|
||||
failed(reader.readSignedVarInt(storageTypeMin)) ||
|
||||
failed(reader.readSignedVarInt(storageTypeMax)))
|
||||
return reader.emitError("invalid UniformQuantizedType"),
|
||||
UniformQuantizedType();
|
||||
return UniformQuantizedType::get(flags, storageType, expressedType,
|
||||
scale.value().convertToDouble(), zeroPoint,
|
||||
storageTypeMin, storageTypeMax);
|
||||
}
|
||||
void QuantDialectBytecodeInterface::write(UniformQuantizedType type,
|
||||
DialectBytecodeWriter &writer) const {
|
||||
writer.writeVarInt(quant_encoding::kUniformQuantizedType);
|
||||
writer.writeVarInt(type.getFlags());
|
||||
writer.writeType(type.getStorageType());
|
||||
writer.writeType(type.getExpressedType());
|
||||
writer.writeAPFloatWithKnownSemantics(APFloat(type.getScale()));
|
||||
writer.writeSignedVarInt(type.getZeroPoint());
|
||||
writer.writeSignedVarInt(type.getStorageTypeMin());
|
||||
writer.writeSignedVarInt(type.getStorageTypeMax());
|
||||
}
|
||||
|
||||
UniformQuantizedPerAxisType
|
||||
QuantDialectBytecodeInterface::readUniformQuantizedPerAxisType(
|
||||
DialectBytecodeReader &reader) const {
|
||||
uint64_t flags;
|
||||
Type storageType, expressedType;
|
||||
SmallVector<double> scales;
|
||||
SmallVector<int64_t> zeroPoints;
|
||||
uint64_t quantizedDimension;
|
||||
int64_t storageTypeMin, storageTypeMax;
|
||||
|
||||
auto scalesRead = [&](double &val) -> LogicalResult {
|
||||
FailureOr<APFloat> fl =
|
||||
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble());
|
||||
if (succeeded(fl)) {
|
||||
val = fl.value().convertToDouble();
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
};
|
||||
|
||||
if (failed(reader.readVarInt(flags)) ||
|
||||
failed(reader.readType(storageType)) ||
|
||||
failed(reader.readType(expressedType)) ||
|
||||
failed(reader.readList(scales, scalesRead)) ||
|
||||
failed(reader.readSignedVarInts(zeroPoints)) ||
|
||||
failed(reader.readVarInt(quantizedDimension)) ||
|
||||
failed(reader.readSignedVarInt(storageTypeMin)) ||
|
||||
failed(reader.readSignedVarInt(storageTypeMax)))
|
||||
return reader.emitError("invalid UniformQuantizedPerAxisType"),
|
||||
UniformQuantizedPerAxisType();
|
||||
return UniformQuantizedPerAxisType::get(
|
||||
flags, storageType, expressedType, scales, zeroPoints,
|
||||
(int32_t)quantizedDimension, storageTypeMin, storageTypeMax);
|
||||
}
|
||||
void QuantDialectBytecodeInterface::write(UniformQuantizedPerAxisType type,
|
||||
DialectBytecodeWriter &writer) const {
|
||||
writer.writeVarInt(quant_encoding::kUniformQuantizedType);
|
||||
writer.writeVarInt(type.getFlags());
|
||||
writer.writeType(type.getStorageType());
|
||||
writer.writeType(type.getExpressedType());
|
||||
writer.writeList(type.getScales(), [&](double val) {
|
||||
writer.writeAPFloatWithKnownSemantics(APFloat(val));
|
||||
});
|
||||
writer.writeSignedVarInts(type.getZeroPoints());
|
||||
writer.writeVarInt(type.getQuantizedDimension());
|
||||
writer.writeSignedVarInt(type.getStorageTypeMin());
|
||||
writer.writeSignedVarInt(type.getStorageTypeMax());
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
//===- QuantDialectBytecode.h - Quant Bytecode Implementation --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header defines hooks into the quantization dialect bytecode
|
||||
// implementation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H
|
||||
#define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H
|
||||
|
||||
namespace mlir::quant {
|
||||
class QuantizationDialect;
|
||||
|
||||
namespace detail {
|
||||
/// Add the interfaces necessary for encoding the quantization dialect
|
||||
/// components in bytecode.
|
||||
void addBytecodeInterface(QuantizationDialect *dialect);
|
||||
} // namespace detail
|
||||
} // namespace mlir::quant
|
||||
|
||||
#endif // LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "QuantDialectBytecode.h"
|
||||
#include "TypeDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
|
@ -32,6 +33,7 @@ void QuantizationDialect::initialize() {
|
|||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
|
||||
>();
|
||||
addBytecodeInterface(this);
|
||||
}
|
||||
|
||||
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s
|
||||
|
||||
// Bytecode currently does not support big-endian platforms
|
||||
// UNSUPPORTED: s390x-
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AnyQuantized
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: parseAnyFullySpecified
|
||||
module @parseAnyFullySpecified attributes {
|
||||
// CHECK: bytecode.test = !quant.any<i8<-8:7>:f32>
|
||||
bytecode.test = !quant.any<i8<-8:7>:f32>
|
||||
} {}
|
||||
|
||||
// CHECK-LABEL: parseAnyNoExpressedType
|
||||
module @parseAnyNoExpressedType attributes {
|
||||
// CHECK: bytecode.test = !quant.any<i8<-8:7>>
|
||||
bytecode.test = !quant.any<i8<-8:7>>
|
||||
} {}
|
||||
|
||||
// CHECK-LABEL: parseAnyOnlyStorageType
|
||||
module @parseAnyOnlyStorageType attributes {
|
||||
// CHECK: bytecode.test = !quant.any<i8<-8:7>>
|
||||
bytecode.test = !quant.any<i8<-8:7>>
|
||||
} {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CalibratedQuantized
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: parseCalibrated
|
||||
module @parseCalibrated attributes {
|
||||
// CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
|
||||
bytecode.test = !quant.calibrated<f32<-0.998:1.2321>>
|
||||
} {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UniformQuantized
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: parseUniformPerLayer
|
||||
module @parseUniformPerLayer attributes {
|
||||
// CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
|
||||
bytecode.test = !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
|
||||
} {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UniformQuantizedPerAxis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: parseUniformPerAxisScaleZero
|
||||
module @parseUniformPerAxisScaleZero attributes {
|
||||
// CHECK: !quant.uniform<u8:f32:1, {2.000000e+02:-120,9.987200e-01:127}>
|
||||
bytecode.test = !quant.uniform<u8:f32:1, {2.000000e+02:-120,9.987200e-01:127}>
|
||||
} {}
|
||||
|
||||
// CHECK-LABEL: parseUniformPerAxisScaleNoZero
|
||||
module @parseUniformPerAxisScaleNoZero attributes {
|
||||
// CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01}>
|
||||
bytecode.test = !quant.uniform<i8:f32:1, {2.0e+2,0.99872}>
|
||||
} {}
|
||||
|
||||
// CHECK-LABEL: parseUniformPerAxisMixed
|
||||
module @parseUniformPerAxisMixed attributes {
|
||||
// CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
|
||||
bytecode.test = !quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>
|
||||
} {}
|
||||
|
|
@ -7406,6 +7406,8 @@ gentbl_cc_library(
|
|||
cc_library(
|
||||
name = "QuantOps",
|
||||
srcs = [
|
||||
"lib/Dialect/Quant/IR/QuantDialectBytecode.h"
|
||||
"lib/Dialect/Quant/IR/QuantDialectBytecode.cpp"
|
||||
"lib/Dialect/Quant/IR/QuantOps.cpp",
|
||||
"lib/Dialect/Quant/IR/QuantTypes.cpp",
|
||||
"lib/Dialect/Quant/IR/TypeDetail.h",
|
||||
|
|
Loading…
Reference in New Issue