Add Float8E4M3FN type to MLIR.
The paper https://arxiv.org/abs/2209.05433 introduces two new FP8 dtypes: E5M2 (called Float8E5M2 in LLVM) and E4M3 (called Float8E4M3FN in LLVM). Support for Float8E5M2 in APFloat and MLIR was added in https://reviews.llvm.org/D133823. Support for Float8E4M3FN in APFloat was added in https://reviews.llvm.org/D137760. This change adds Float8E4M3FN to MLIR as well. There is an RFC for adding the FP8 dtypes here: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279. This change is identical to the MLIR changes in the patch that added Float8E5M2, except that Float8E4M3FN is added instead. Reviewed By: stellaraccident, bkramer, rriddle Differential Revision: https://reviews.llvm.org/D138075
This commit is contained in:
parent
2ada5cbea4
commit
e08ca4bb1d
|
@ -74,6 +74,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
|
|||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
|
||||
|
||||
/// Checks whether the given type is an f8E4M3FN type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
|
||||
|
||||
/// Creates an f8E4M3FN type in the given context. The type is owned by the
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
|
||||
|
||||
/// Checks whether the given type is a bf16 type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ public:
|
|||
|
||||
// Types.
|
||||
FloatType getFloat8E5M2Type();
|
||||
FloatType getFloat8E4M3FNType();
|
||||
FloatType getBF16Type();
|
||||
FloatType getF16Type();
|
||||
FloatType getF32Type();
|
||||
|
|
|
@ -47,6 +47,7 @@ public:
|
|||
static FloatType getF80(MLIRContext *ctx);
|
||||
static FloatType getF128(MLIRContext *ctx);
|
||||
static FloatType getFloat8E5M2(MLIRContext *ctx);
|
||||
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(Type type);
|
||||
|
@ -374,14 +375,18 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
|
|||
}
|
||||
|
||||
inline bool FloatType::classof(Type type) {
|
||||
return type.isa<Float8E5M2Type, BFloat16Type, Float16Type, Float32Type,
|
||||
Float64Type, Float80Type, Float128Type>();
|
||||
return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
|
||||
Float32Type, Float64Type, Float80Type, Float128Type>();
|
||||
}
|
||||
|
||||
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
|
||||
return Float8E5M2Type::get(ctx);
|
||||
}
|
||||
|
||||
inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
|
||||
return Float8E4M3FNType::get(ctx);
|
||||
}
|
||||
|
||||
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
|
||||
return BFloat16Type::get(ctx);
|
||||
}
|
||||
|
|
|
@ -89,7 +89,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
|
|||
* bit encoding: S1E5M2
|
||||
* exponent bias: 15
|
||||
* infinities: supported with exponent set to all 1s and mantissa 0s
|
||||
* NaNs: supported with exponent bits set to all 1s and mantissa of
|
||||
* NaNs: supported with exponent bits set to all 1s and mantissa of
|
||||
(01, 10, or 11)
|
||||
* denormals when exponent is 0
|
||||
|
||||
|
@ -97,6 +97,27 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Float8E4M3FNType
|
||||
|
||||
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> {
|
||||
let summary = "8-bit floating point with 3 bit mantissa";
|
||||
let description = [{
|
||||
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
|
||||
mantissa. This is not a standard type as defined by IEEE-754, but it follows
|
||||
similar conventions, with the exception that there are no infinity values
|
||||
and only two NaN representations. This type has the following
|
||||
characteristics:
|
||||
|
||||
* bit encoding: S1E4M3
|
||||
* exponent bias: 7
|
||||
* infinities: Not supported
|
||||
* NaNs: supported with exponent bits and mantissa bits set to all 1s
|
||||
* denormals when exponent is 0
|
||||
|
||||
Described in: https://arxiv.org/abs/2209.05433
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BFloat16Type
|
||||
|
|
|
@ -124,6 +124,7 @@ public:
|
|||
// derived types should use isa/dyn_cast.
|
||||
bool isIndex() const;
|
||||
bool isFloat8E5M2() const;
|
||||
bool isFloat8E4M3FN() const;
|
||||
bool isBF16() const;
|
||||
bool isF16() const;
|
||||
bool isF32() const;
|
||||
|
|
|
@ -94,6 +94,7 @@ TOK_KEYWORD(f32)
|
|||
TOK_KEYWORD(f64)
|
||||
TOK_KEYWORD(f80)
|
||||
TOK_KEYWORD(f8E5M2)
|
||||
TOK_KEYWORD(f8E4M3FN)
|
||||
TOK_KEYWORD(f128)
|
||||
TOK_KEYWORD(false)
|
||||
TOK_KEYWORD(floordiv)
|
||||
|
|
|
@ -31,6 +31,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
|
|||
case Token::kw_vector:
|
||||
case Token::inttype:
|
||||
case Token::kw_f8E5M2:
|
||||
case Token::kw_f8E4M3FN:
|
||||
case Token::kw_bf16:
|
||||
case Token::kw_f16:
|
||||
case Token::kw_f32:
|
||||
|
@ -290,6 +291,9 @@ Type Parser::parseNonFunctionType() {
|
|||
case Token::kw_f8E5M2:
|
||||
consumeToken(Token::kw_f8E5M2);
|
||||
return builder.getFloat8E5M2Type();
|
||||
case Token::kw_f8E4M3FN:
|
||||
consumeToken(Token::kw_f8E4M3FN);
|
||||
return builder.getFloat8E4M3FNType();
|
||||
case Token::kw_bf16:
|
||||
consumeToken(Token::kw_bf16);
|
||||
return builder.getBF16Type();
|
||||
|
|
|
@ -76,6 +76,14 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
|
|||
return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
|
||||
}
|
||||
|
||||
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
|
||||
return unwrap(type).isFloat8E4M3FN();
|
||||
}
|
||||
|
||||
MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
|
||||
}
|
||||
|
||||
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
|
||||
|
||||
MlirType mlirBF16TypeGet(MlirContext ctx) {
|
||||
|
|
|
@ -2244,6 +2244,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
|
|||
})
|
||||
.Case<IndexType>([&](Type) { os << "index"; })
|
||||
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
|
||||
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
|
||||
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
|
||||
.Case<Float16Type>([&](Type) { os << "f16"; })
|
||||
.Case<Float32Type>([&](Type) { os << "f32"; })
|
||||
|
|
|
@ -37,6 +37,10 @@ FloatType Builder::getFloat8E5M2Type() {
|
|||
return FloatType::getFloat8E5M2(context);
|
||||
}
|
||||
|
||||
FloatType Builder::getFloat8E4M3FNType() {
|
||||
return FloatType::getFloat8E4M3FN(context);
|
||||
}
|
||||
|
||||
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
|
||||
|
||||
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
|
||||
|
|
|
@ -88,7 +88,7 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
unsigned FloatType::getWidth() {
|
||||
if (isa<Float8E5M2Type>())
|
||||
if (isa<Float8E5M2Type, Float8E4M3FNType>())
|
||||
return 8;
|
||||
if (isa<Float16Type, BFloat16Type>())
|
||||
return 16;
|
||||
|
@ -107,6 +107,8 @@ unsigned FloatType::getWidth() {
|
|||
const llvm::fltSemantics &FloatType::getFloatSemantics() {
|
||||
if (isa<Float8E5M2Type>())
|
||||
return APFloat::Float8E5M2();
|
||||
if (isa<Float8E4M3FNType>())
|
||||
return APFloat::Float8E4M3FN();
|
||||
if (isa<BFloat16Type>())
|
||||
return APFloat::BFloat();
|
||||
if (isa<Float16Type>())
|
||||
|
|
|
@ -207,6 +207,7 @@ public:
|
|||
|
||||
/// Cached Type Instances.
|
||||
Float8E5M2Type f8E5M2Ty;
|
||||
Float8E4M3FNType f8E4M3FNTy;
|
||||
BFloat16Type bf16Ty;
|
||||
Float16Type f16Ty;
|
||||
Float32Type f32Ty;
|
||||
|
@ -278,6 +279,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
|
|||
//// Types.
|
||||
/// Floating-point Types.
|
||||
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
|
||||
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
|
||||
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
|
||||
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
|
||||
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
|
||||
|
@ -861,6 +863,9 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
|
|||
Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
|
||||
return context->getImpl().f8E5M2Ty;
|
||||
}
|
||||
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
|
||||
return context->getImpl().f8E4M3FNTy;
|
||||
}
|
||||
BFloat16Type BFloat16Type::get(MLIRContext *context) {
|
||||
return context->getImpl().bf16Ty;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ using namespace mlir::detail;
|
|||
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
|
||||
|
||||
bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
|
||||
bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
|
||||
bool Type::isBF16() const { return isa<BFloat16Type>(); }
|
||||
bool Type::isF16() const { return isa<Float16Type>(); }
|
||||
bool Type::isF32() const { return isa<Float32Type>(); }
|
||||
|
|
|
@ -40,6 +40,10 @@ func.func @float_attrs_pass() {
|
|||
// CHECK: float_attr = 2.000000e+00 : f8E5M2
|
||||
float_attr = 2. : f8E5M2
|
||||
} : () -> ()
|
||||
"test.float_attrs"() {
|
||||
// CHECK: float_attr = 2.000000e+00 : f8E4M3FN
|
||||
float_attr = 2. : f8E4M3FN
|
||||
} : () -> ()
|
||||
"test.float_attrs"() {
|
||||
// CHECK: float_attr = 2.000000e+00 : f16
|
||||
float_attr = 2. : f16
|
||||
|
|
Loading…
Reference in New Issue