From e08ca4bb1dfec860eefce636f4eff472fc7081ea Mon Sep 17 00:00:00 2001 From: Reed Date: Wed, 16 Nov 2022 10:24:24 +0100 Subject: [PATCH] Add Float8E4M3FN type to MLIR. The paper https://arxiv.org/abs/2209.05433 introduces two new FP8 dtypes: E5M2 (called Float8E5M2 in LLVM) and E4M3 (called Float8E4M3FN in LLVM). Support for Float8E5M2 in APFloat and MLIR was added in https://reviews.llvm.org/D133823. Support for Float8E4M3FN in APFloat was added in https://reviews.llvm.org/D137760. This change adds Float8E4M3FN to MLIR as well. There is an RFC for adding the FP8 dtypes here: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279. This change is identical to the MLIR changes in the patch that added Float8E5M2, except that Float8E4M3FN is added instead. Reviewed By: stellaraccident, bkramer, rriddle Differential Revision: https://reviews.llvm.org/D138075 --- mlir/include/mlir-c/BuiltinTypes.h | 7 +++++++ mlir/include/mlir/IR/Builders.h | 1 + mlir/include/mlir/IR/BuiltinTypes.h | 9 +++++++-- mlir/include/mlir/IR/BuiltinTypes.td | 23 ++++++++++++++++++++++- mlir/include/mlir/IR/Types.h | 1 + mlir/lib/AsmParser/TokenKinds.def | 1 + mlir/lib/AsmParser/TypeParser.cpp | 4 ++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 8 ++++++++ mlir/lib/IR/AsmPrinter.cpp | 1 + mlir/lib/IR/Builders.cpp | 4 ++++ mlir/lib/IR/BuiltinTypes.cpp | 4 +++- mlir/lib/IR/MLIRContext.cpp | 5 +++++ mlir/lib/IR/Types.cpp | 1 + mlir/test/IR/attribute.mlir | 4 ++++ 14 files changed, 69 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 9bd3d510b248..1c4a1638205b 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -74,6 +74,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); +/// Checks whether the given type is an f8E4M3FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); + +/// Creates an f8E4M3FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 870c834ce2b0..f2a547ef4f75 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -60,6 +60,7 @@ public: // Types. FloatType getFloat8E5M2Type(); + FloatType getFloat8E4M3FNType(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getF32Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 192512725155..ceba71d51758 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -47,6 +47,7 @@ public: static FloatType getF80(MLIRContext *ctx); static FloatType getF128(MLIRContext *ctx); static FloatType getFloat8E5M2(MLIRContext *ctx); + static FloatType getFloat8E4M3FN(MLIRContext *ctx); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); @@ -374,14 +375,18 @@ inline bool BaseMemRefType::isValidElementType(Type type) { } inline bool FloatType::classof(Type type) { - return type.isa(); + return type.isa(); } inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { return Float8E5M2Type::get(ctx); } +inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) { + return Float8E4M3FNType::get(ctx); +} + inline FloatType FloatType::getBF16(MLIRContext *ctx) { return BFloat16Type::get(ctx); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 50d8b3a0cb44..fbd9c6350fcf 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -89,7 +89,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> { * bit encoding: S1E5M2 * exponent bias: 15 * infinities: supported with exponent set to all 1s and mantissa 0s - * NaNs: supported with exponent bits set to all 1s and mantissa of + * NaNs: supported with exponent bits set to all 1s and mantissa of (01, 10, or 11) * denormals when exponent is 0 @@ -97,6 +97,27 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> { }]; } +//===----------------------------------------------------------------------===// +// Float8E4M3FNType + +def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> { + let summary = "8-bit floating point with 3 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits + mantissa. This is not a standard type as defined by IEEE-754, but it follows + similar conventions, with the exception that there are no infinity values + and only two NaN representations. This type has the following + characteristics: + + * bit encoding: S1E4M3 + * exponent bias: 7 + * infinities: Not supported + * NaNs: supported with exponent bits and mantissa bits set to all 1s + * denormals when exponent is 0 + + Described in: https://arxiv.org/abs/2209.05433 + }]; +} //===----------------------------------------------------------------------===// // BFloat16Type diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index 1c4db1b6c0f9..9d64a77742ef 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -124,6 +124,7 @@ public: // derived types should use isa/dyn_cast. bool isIndex() const; bool isFloat8E5M2() const; + bool isFloat8E4M3FN() const; bool isBF16() const; bool isF16() const; bool isF32() const; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def index 02eba88f78b0..9bd7b60afd28 100644 --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -94,6 +94,7 @@ TOK_KEYWORD(f32) TOK_KEYWORD(f64) TOK_KEYWORD(f80) TOK_KEYWORD(f8E5M2) +TOK_KEYWORD(f8E4M3FN) TOK_KEYWORD(f128) TOK_KEYWORD(false) TOK_KEYWORD(floordiv) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index fa428b2f06fa..fc8c3fdbb58d 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -31,6 +31,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) { case Token::kw_vector: case Token::inttype: case Token::kw_f8E5M2: + case Token::kw_f8E4M3FN: case Token::kw_bf16: case Token::kw_f16: case Token::kw_f32: @@ -290,6 +291,9 @@ Type Parser::parseNonFunctionType() { case Token::kw_f8E5M2: consumeToken(Token::kw_f8E5M2); return builder.getFloat8E5M2Type(); + case Token::kw_f8E4M3FN: + consumeToken(Token::kw_f8E4M3FN); + return builder.getFloat8E4M3FNType(); case Token::kw_bf16: consumeToken(Token::kw_bf16); return builder.getBF16Type(); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index ad9a5bc6640e..596a760b99e8 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -76,6 +76,14 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); } +bool mlirTypeIsAFloat8E4M3FN(MlirType type) { + return unwrap(type).isFloat8E4M3FN(); +} + +MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 9a3d3e031dc3..32a26470d94a 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2244,6 +2244,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { }) .Case([&](Type) { os << "index"; }) .Case([&](Type) { os << "f8E5M2"; }) + .Case([&](Type) { os << "f8E4M3FN"; }) .Case([&](Type) { os << "bf16"; }) .Case([&](Type) { os << "f16"; }) .Case([&](Type) { os << "f32"; }) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 053ffce1b157..2f4e07990a0d 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -37,6 +37,10 @@ FloatType Builder::getFloat8E5M2Type() { return FloatType::getFloat8E5M2(context); } +FloatType Builder::getFloat8E4M3FNType() { + return FloatType::getFloat8E4M3FN(context); +} + FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } FloatType Builder::getF16Type() { return FloatType::getF16(context); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index d65c5e9d28b1..f4d64c97836d 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -88,7 +88,7 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { - if (isa()) + if (isa()) return 8; if (isa()) return 16; @@ -107,6 +107,8 @@ unsigned FloatType::getWidth() { const llvm::fltSemantics &FloatType::getFloatSemantics() { if (isa()) return APFloat::Float8E5M2(); + if (isa()) + return APFloat::Float8E4M3FN(); if (isa()) return APFloat::BFloat(); if (isa()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 182e249810e1..298f722da361 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -207,6 +207,7 @@ public: /// Cached Type Instances. Float8E5M2Type f8E5M2Ty; + Float8E4M3FNType f8E4M3FNTy; BFloat16Type bf16Ty; Float16Type f16Ty; Float32Type f32Ty; @@ -278,6 +279,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) //// Types. /// Floating-point Types. impl->f8E5M2Ty = TypeUniquer::get(this); + impl->f8E4M3FNTy = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); @@ -861,6 +863,9 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) { return context->getImpl().f8E5M2Ty; } +Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) { + return context->getImpl().f8E4M3FNTy; +} BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index b97388bf33f5..670974bbf837 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -19,6 +19,7 @@ using namespace mlir::detail; MLIRContext *Type::getContext() const { return getDialect().getContext(); } bool Type::isFloat8E5M2() const { return isa(); } +bool Type::isFloat8E4M3FN() const { return isa(); } bool Type::isBF16() const { return isa(); } bool Type::isF16() const { return isa(); } bool Type::isF32() const { return isa(); } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index 540578e61527..ebfbb8982503 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -40,6 +40,10 @@ func.func @float_attrs_pass() { // CHECK: float_attr = 2.000000e+00 : f8E5M2 float_attr = 2. : f8E5M2 } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E4M3FN + float_attr = 2. : f8E4M3FN + } : () -> () "test.float_attrs"() { // CHECK: float_attr = 2.000000e+00 : f16 float_attr = 2. : f16