Add Float8E4M3FN type to MLIR.

The paper https://arxiv.org/abs/2209.05433 introduces two new FP8 dtypes: E5M2 (called Float8E5M2 in LLVM) and E4M3 (called Float8E4M3FN in LLVM). Support for Float8E5M2 in APFloat and MLIR was added in https://reviews.llvm.org/D133823. Support for Float8E4M3FN in APFloat was added in https://reviews.llvm.org/D137760. This change adds Float8E4M3FN to MLIR as well.

There is an RFC for adding the FP8 dtypes here: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279.

This change is identical to the MLIR changes in the patch that added Float8E5M2, except that Float8E4M3FN is added instead.

Reviewed By: stellaraccident, bkramer, rriddle

Differential Revision: https://reviews.llvm.org/D138075
This commit is contained in:
Reed 2022-11-16 10:24:24 +01:00 committed by Benjamin Kramer
parent 2ada5cbea4
commit e08ca4bb1d
14 changed files with 69 additions and 4 deletions

View File

@ -74,6 +74,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
/// Checks whether the given type is an f8E4M3FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
/// Creates an f8E4M3FN type in the given context. The type is owned by the
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
/// Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);

View File

@ -60,6 +60,7 @@ public:
// Types.
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3FNType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getF32Type();

View File

@ -47,6 +47,7 @@ public:
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
@ -374,14 +375,18 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}
inline bool FloatType::classof(Type type) {
return type.isa<Float8E5M2Type, BFloat16Type, Float16Type, Float32Type,
Float64Type, Float80Type, Float128Type>();
return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
Float32Type, Float64Type, Float80Type, Float128Type>();
}
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
return Float8E5M2Type::get(ctx);
}
inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
return Float8E4M3FNType::get(ctx);
}
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}

View File

@ -89,7 +89,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
* bit encoding: S1E5M2
* exponent bias: 15
* infinities: supported with exponent set to all 1s and mantissa 0s
* NaNs: supported with exponent bits set to all 1s and mantissa of
* NaNs: supported with exponent bits set to all 1s and mantissa of
(01, 10, or 11)
* denormals when exponent is 0
@ -97,6 +97,27 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
}];
}
//===----------------------------------------------------------------------===//
// Float8E4M3FNType
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
mantissa. This is not a standard type as defined by IEEE-754, but it follows
similar conventions, with the exception that there are no infinity values
and only two NaN representations. This type has the following
characteristics:
* bit encoding: S1E4M3
* exponent bias: 7
* infinities: Not supported
* NaNs: supported with exponent bits and mantissa bits set to all 1s
* denormals when exponent is 0
Described in: https://arxiv.org/abs/2209.05433
}];
}
//===----------------------------------------------------------------------===//
// BFloat16Type

View File

@ -124,6 +124,7 @@ public:
// derived types should use isa/dyn_cast.
bool isIndex() const;
bool isFloat8E5M2() const;
bool isFloat8E4M3FN() const;
bool isBF16() const;
bool isF16() const;
bool isF32() const;

View File

@ -94,6 +94,7 @@ TOK_KEYWORD(f32)
TOK_KEYWORD(f64)
TOK_KEYWORD(f80)
TOK_KEYWORD(f8E5M2)
TOK_KEYWORD(f8E4M3FN)
TOK_KEYWORD(f128)
TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)

View File

@ -31,6 +31,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_vector:
case Token::inttype:
case Token::kw_f8E5M2:
case Token::kw_f8E4M3FN:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_f32:
@ -290,6 +291,9 @@ Type Parser::parseNonFunctionType() {
case Token::kw_f8E5M2:
consumeToken(Token::kw_f8E5M2);
return builder.getFloat8E5M2Type();
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
return builder.getFloat8E4M3FNType();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();

View File

@ -76,6 +76,14 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
}
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
return unwrap(type).isFloat8E4M3FN();
}
MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
}
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
MlirType mlirBF16TypeGet(MlirContext ctx) {

View File

@ -2244,6 +2244,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
})
.Case<IndexType>([&](Type) { os << "index"; })
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
.Case<Float16Type>([&](Type) { os << "f16"; })
.Case<Float32Type>([&](Type) { os << "f32"; })

View File

@ -37,6 +37,10 @@ FloatType Builder::getFloat8E5M2Type() {
return FloatType::getFloat8E5M2(context);
}
FloatType Builder::getFloat8E4M3FNType() {
return FloatType::getFloat8E4M3FN(context);
}
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
FloatType Builder::getF16Type() { return FloatType::getF16(context); }

View File

@ -88,7 +88,7 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
//===----------------------------------------------------------------------===//
unsigned FloatType::getWidth() {
if (isa<Float8E5M2Type>())
if (isa<Float8E5M2Type, Float8E4M3FNType>())
return 8;
if (isa<Float16Type, BFloat16Type>())
return 16;
@ -107,6 +107,8 @@ unsigned FloatType::getWidth() {
const llvm::fltSemantics &FloatType::getFloatSemantics() {
if (isa<Float8E5M2Type>())
return APFloat::Float8E5M2();
if (isa<Float8E4M3FNType>())
return APFloat::Float8E4M3FN();
if (isa<BFloat16Type>())
return APFloat::BFloat();
if (isa<Float16Type>())

View File

@ -207,6 +207,7 @@ public:
/// Cached Type Instances.
Float8E5M2Type f8E5M2Ty;
Float8E4M3FNType f8E4M3FNTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
Float32Type f32Ty;
@ -278,6 +279,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
//// Types.
/// Floating-point Types.
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@ -861,6 +863,9 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
return context->getImpl().f8E5M2Ty;
}
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNTy;
}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}

View File

@ -19,6 +19,7 @@ using namespace mlir::detail;
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
bool Type::isBF16() const { return isa<BFloat16Type>(); }
bool Type::isF16() const { return isa<Float16Type>(); }
bool Type::isF32() const { return isa<Float32Type>(); }

View File

@ -40,6 +40,10 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2
float_attr = 2. : f8E5M2
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FN
float_attr = 2. : f8E4M3FN
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f16
float_attr = 2. : f16