[mlir][quant] Initial bytecode encoding for quantized types

Add bytecode encoding for quantized types. These mostly follow the
storage representation of these.

Differential Revision: https://reviews.llvm.org/D136004
This commit is contained in:
Jacques Pienaar 2022-10-17 16:28:46 -07:00
parent dc8035bddd
commit 7732c97f52
6 changed files with 401 additions and 0 deletions

View File

@ -1,4 +1,6 @@
add_mlir_dialect_library(MLIRQuantDialect
QuantDialectBytecode.h
QuantDialectBytecode.cpp
QuantOps.cpp
QuantTypes.cpp
TypeDetail.h

View File

@ -0,0 +1,299 @@
//===- QuantDialectBytecode.cpp - Quant Bytecode Implementation
//------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "QuantDialectBytecode.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::quant;
//===----------------------------------------------------------------------===//
// Encoding
//===----------------------------------------------------------------------===//
namespace {
namespace quant_encoding {
/// This enum contains marker codes used to indicate which type is currently
/// being decoded, and how it should be decoded. The order of these codes should
/// generally be unchanged, as any changes will inevitably break compatibility
/// with older bytecode.
enum TypeCode {
/// AnyQuantizedType {
/// flags: varint
/// storageType: Type
/// storageTypeMin: svarint
/// storageTypeMax: svarint
/// }
///
kAnyQuantizedType = 1,
/// AnyQuantizedType {
/// flags: varint
/// storageType: Type
/// expressedType: Type
/// storageTypeMin: svarint
/// storageTypeMax: svarint
/// }
///
kAnyQuantizedTypeWithExpressedType = 2,
/// CalibratedQuantizedType {
/// expressedType: Type
/// min: APFloat
/// max: APFloat
/// }
///
kCalibratedQuantizedType = 3,
/// UniformQuantizedType {
/// flags: varint
/// storageType: Type
/// expressedType: Type
/// scale: APFloat
/// zeroPoint: svarint
/// storageTypeMin: svarint
/// storageTypeMax: svarint
/// }
///
kUniformQuantizedType = 4,
/// UniformQuantizedPerAxisType {
/// flags: varint
/// storageType: Type
/// expressedType: Type
/// quantizedDimension: varint
/// storageTypeMin: svarint
/// storageTypeMax: svarint
/// scale: APFloat[]
/// zeroPoint: svarint[]
/// }
///
kUniformQuantizedPerAxisType = 5,
};
} // namespace quant_encoding
} // namespace
//===----------------------------------------------------------------------===//
// QuantDialectBytecodeInterface
//===----------------------------------------------------------------------===//
namespace {
/// This class implements the bytecode interface for the Quant dialect.
struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
QuantDialectBytecodeInterface(Dialect *dialect)
: BytecodeDialectInterface(dialect) {}
//===--------------------------------------------------------------------===//
// Types
Type readType(DialectBytecodeReader &reader) const override;
LogicalResult writeType(Type type,
DialectBytecodeWriter &writer) const override;
AnyQuantizedType readAnyQuantizedType(bool withExpressedType,
DialectBytecodeReader &reader) const;
void write(AnyQuantizedType type, DialectBytecodeWriter &writer) const;
CalibratedQuantizedType
readCalibratedQuantizedType(DialectBytecodeReader &reader) const;
void write(CalibratedQuantizedType type, DialectBytecodeWriter &writer) const;
UniformQuantizedType
readUniformQuantizedType(DialectBytecodeReader &reader) const;
void write(UniformQuantizedType type, DialectBytecodeWriter &writer) const;
UniformQuantizedPerAxisType
readUniformQuantizedPerAxisType(DialectBytecodeReader &reader) const;
void write(UniformQuantizedPerAxisType type,
DialectBytecodeWriter &writer) const;
};
} // namespace
void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) {
dialect->addInterfaces<QuantDialectBytecodeInterface>();
}
//===----------------------------------------------------------------------===//
// Types
//===----------------------------------------------------------------------===//
Type QuantDialectBytecodeInterface::readType(
DialectBytecodeReader &reader) const {
uint64_t code;
if (failed(reader.readVarInt(code)))
return Type();
switch (code) {
case quant_encoding::kAnyQuantizedType:
return readAnyQuantizedType(/*withExpressedType=*/false, reader);
case quant_encoding::kAnyQuantizedTypeWithExpressedType:
return readAnyQuantizedType(/*withExpressedType=*/true, reader);
case quant_encoding::kCalibratedQuantizedType:
return readCalibratedQuantizedType(reader);
case quant_encoding::kUniformQuantizedType:
return readUniformQuantizedType(reader);
case quant_encoding::kUniformQuantizedPerAxisType:
return readUniformQuantizedPerAxisType(reader);
default:
reader.emitError() << "unknown builtin type code: " << code;
return Type();
}
}
LogicalResult
QuantDialectBytecodeInterface::writeType(Type type,
DialectBytecodeWriter &writer) const {
return TypeSwitch<Type, LogicalResult>(type)
.Case<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType>(
[&](auto attr) { return write(attr, writer), success(); })
.Default([&](Type) { return failure(); });
}
AnyQuantizedType QuantDialectBytecodeInterface::readAnyQuantizedType(
bool withExpressedType, DialectBytecodeReader &reader) const {
uint64_t flags;
Type storageType, expressedType;
int64_t storageTypeMin, storageTypeMax;
if (failed(reader.readVarInt(flags)) ||
failed(reader.readType(storageType)) ||
(withExpressedType && failed(reader.readType(expressedType))) ||
failed(reader.readSignedVarInt(storageTypeMin)) ||
failed(reader.readSignedVarInt(storageTypeMax)))
return reader.emitError("invalid AnyQuantizedType"), AnyQuantizedType();
return AnyQuantizedType::get(flags, storageType, expressedType,
storageTypeMin, storageTypeMax);
}
void QuantDialectBytecodeInterface::write(AnyQuantizedType type,
DialectBytecodeWriter &writer) const {
if (type.getExpressedType())
writer.writeVarInt(quant_encoding::kAnyQuantizedTypeWithExpressedType);
else
writer.writeVarInt(quant_encoding::kAnyQuantizedType);
writer.writeVarInt(type.getFlags());
writer.writeType(type.getStorageType());
if (type.getExpressedType())
writer.writeType(type.getExpressedType());
writer.writeSignedVarInt(type.getStorageTypeMin());
writer.writeSignedVarInt(type.getStorageTypeMax());
}
CalibratedQuantizedType
QuantDialectBytecodeInterface::readCalibratedQuantizedType(
DialectBytecodeReader &reader) const {
Type expressedType;
FailureOr<APFloat> min, max;
if (failed(reader.readType(expressedType)) ||
failed(min = reader.readAPFloatWithKnownSemantics(
llvm::APFloat::IEEEdouble())) ||
failed(max = reader.readAPFloatWithKnownSemantics(
llvm::APFloat::IEEEdouble())))
return reader.emitError("invalid CalibratedQuantizedType"),
CalibratedQuantizedType();
return CalibratedQuantizedType::get(expressedType,
min.value().convertToDouble(),
max.value().convertToDouble());
}
void QuantDialectBytecodeInterface::write(CalibratedQuantizedType type,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(quant_encoding::kCalibratedQuantizedType);
writer.writeType(type.getExpressedType());
writer.writeAPFloatWithKnownSemantics(APFloat(type.getMin()));
writer.writeAPFloatWithKnownSemantics(APFloat(type.getMax()));
}
UniformQuantizedType QuantDialectBytecodeInterface::readUniformQuantizedType(
DialectBytecodeReader &reader) const {
uint64_t flags;
Type storageType, expressedType;
FailureOr<APFloat> scale;
int64_t zeroPoint, storageTypeMin, storageTypeMax;
if (failed(reader.readVarInt(flags)) ||
failed(reader.readType(storageType)) ||
failed(reader.readType(expressedType)) ||
failed(scale = reader.readAPFloatWithKnownSemantics(
llvm::APFloat::IEEEdouble())) ||
failed(reader.readSignedVarInt(zeroPoint)) ||
failed(reader.readSignedVarInt(storageTypeMin)) ||
failed(reader.readSignedVarInt(storageTypeMax)))
return reader.emitError("invalid UniformQuantizedType"),
UniformQuantizedType();
return UniformQuantizedType::get(flags, storageType, expressedType,
scale.value().convertToDouble(), zeroPoint,
storageTypeMin, storageTypeMax);
}
void QuantDialectBytecodeInterface::write(UniformQuantizedType type,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(quant_encoding::kUniformQuantizedType);
writer.writeVarInt(type.getFlags());
writer.writeType(type.getStorageType());
writer.writeType(type.getExpressedType());
writer.writeAPFloatWithKnownSemantics(APFloat(type.getScale()));
writer.writeSignedVarInt(type.getZeroPoint());
writer.writeSignedVarInt(type.getStorageTypeMin());
writer.writeSignedVarInt(type.getStorageTypeMax());
}
UniformQuantizedPerAxisType
QuantDialectBytecodeInterface::readUniformQuantizedPerAxisType(
DialectBytecodeReader &reader) const {
uint64_t flags;
Type storageType, expressedType;
SmallVector<double> scales;
SmallVector<int64_t> zeroPoints;
uint64_t quantizedDimension;
int64_t storageTypeMin, storageTypeMax;
auto scalesRead = [&](double &val) -> LogicalResult {
FailureOr<APFloat> fl =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble());
if (succeeded(fl)) {
val = fl.value().convertToDouble();
return success();
}
return failure();
};
if (failed(reader.readVarInt(flags)) ||
failed(reader.readType(storageType)) ||
failed(reader.readType(expressedType)) ||
failed(reader.readList(scales, scalesRead)) ||
failed(reader.readSignedVarInts(zeroPoints)) ||
failed(reader.readVarInt(quantizedDimension)) ||
failed(reader.readSignedVarInt(storageTypeMin)) ||
failed(reader.readSignedVarInt(storageTypeMax)))
return reader.emitError("invalid UniformQuantizedPerAxisType"),
UniformQuantizedPerAxisType();
return UniformQuantizedPerAxisType::get(
flags, storageType, expressedType, scales, zeroPoints,
(int32_t)quantizedDimension, storageTypeMin, storageTypeMax);
}
void QuantDialectBytecodeInterface::write(UniformQuantizedPerAxisType type,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(quant_encoding::kUniformQuantizedType);
writer.writeVarInt(type.getFlags());
writer.writeType(type.getStorageType());
writer.writeType(type.getExpressedType());
writer.writeList(type.getScales(), [&](double val) {
writer.writeAPFloatWithKnownSemantics(APFloat(val));
});
writer.writeSignedVarInts(type.getZeroPoints());
writer.writeVarInt(type.getQuantizedDimension());
writer.writeSignedVarInt(type.getStorageTypeMin());
writer.writeSignedVarInt(type.getStorageTypeMax());
}

View File

@ -0,0 +1,27 @@
//===- QuantDialectBytecode.h - Quant Bytecode Implementation --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header defines hooks into the quantization dialect bytecode
// implementation.
//
//===----------------------------------------------------------------------===//
#ifndef LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H
#define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H
namespace mlir::quant {
class QuantizationDialect;
namespace detail {
/// Add the interfaces necessary for encoding the quantization dialect
/// components in bytecode.
void addBytecodeInterface(QuantizationDialect *dialect);
} // namespace detail
} // namespace mlir::quant
#endif // LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/QuantOps.h"
#include "QuantDialectBytecode.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
@ -32,6 +33,7 @@ void QuantizationDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
>();
addBytecodeInterface(this);
}
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {

View File

@ -0,0 +1,69 @@
// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s
// Bytecode currently does not support big-endian platforms
// UNSUPPORTED: s390x-
//===----------------------------------------------------------------------===//
// AnyQuantized
//===----------------------------------------------------------------------===//
// CHECK-LABEL: parseAnyFullySpecified
module @parseAnyFullySpecified attributes {
// CHECK: bytecode.test = !quant.any<i8<-8:7>:f32>
bytecode.test = !quant.any<i8<-8:7>:f32>
} {}
// CHECK-LABEL: parseAnyNoExpressedType
module @parseAnyNoExpressedType attributes {
// CHECK: bytecode.test = !quant.any<i8<-8:7>>
bytecode.test = !quant.any<i8<-8:7>>
} {}
// CHECK-LABEL: parseAnyOnlyStorageType
module @parseAnyOnlyStorageType attributes {
// CHECK: bytecode.test = !quant.any<i8<-8:7>>
bytecode.test = !quant.any<i8<-8:7>>
} {}
//===----------------------------------------------------------------------===//
// CalibratedQuantized
//===----------------------------------------------------------------------===//
// CHECK-LABEL: parseCalibrated
module @parseCalibrated attributes {
// CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
bytecode.test = !quant.calibrated<f32<-0.998:1.2321>>
} {}
//===----------------------------------------------------------------------===//
// UniformQuantized
//===----------------------------------------------------------------------===//
// CHECK-LABEL: parseUniformPerLayer
module @parseUniformPerLayer attributes {
// CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
bytecode.test = !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
} {}
//===----------------------------------------------------------------------===//
// UniformQuantizedPerAxis
//===----------------------------------------------------------------------===//
// CHECK-LABEL: parseUniformPerAxisScaleZero
module @parseUniformPerAxisScaleZero attributes {
// CHECK: !quant.uniform<u8:f32:1, {2.000000e+02:-120,9.987200e-01:127}>
bytecode.test = !quant.uniform<u8:f32:1, {2.000000e+02:-120,9.987200e-01:127}>
} {}
// CHECK-LABEL: parseUniformPerAxisScaleNoZero
module @parseUniformPerAxisScaleNoZero attributes {
// CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01}>
bytecode.test = !quant.uniform<i8:f32:1, {2.0e+2,0.99872}>
} {}
// CHECK-LABEL: parseUniformPerAxisMixed
module @parseUniformPerAxisMixed attributes {
// CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
bytecode.test = !quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>
} {}

View File

@ -7406,6 +7406,8 @@ gentbl_cc_library(
cc_library(
name = "QuantOps",
srcs = [
"lib/Dialect/Quant/IR/QuantDialectBytecode.h"
"lib/Dialect/Quant/IR/QuantDialectBytecode.cpp"
"lib/Dialect/Quant/IR/QuantOps.cpp",
"lib/Dialect/Quant/IR/QuantTypes.cpp",
"lib/Dialect/Quant/IR/TypeDetail.h",