[mlir][math] Promote (b)f16 to f32 when lowering to libm calls
libm doesn't have overloads for the small types, so promote them to a bigger type and use the f32 function. Differential Revision: https://reviews.llvm.org/D125093
This commit is contained in:
parent
ae7fe65cf6
commit
a48adc5658
|
@ -30,6 +30,14 @@ public:
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
|
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
|
||||||
};
|
};
|
||||||
|
// Pattern to promote an op of a smaller floating point type to F32.
|
||||||
|
template <typename Op>
|
||||||
|
struct PromoteOpToF32 : public OpRewritePattern<Op> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<Op>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
|
||||||
|
};
|
||||||
// Pattern to convert scalar math operations to calls to libm functions.
|
// Pattern to convert scalar math operations to calls to libm functions.
|
||||||
// Additionally the libm function signatures are declared.
|
// Additionally the libm function signatures are declared.
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
|
@ -82,13 +90,30 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
LogicalResult
|
||||||
|
PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
|
||||||
|
auto opType = op.getType();
|
||||||
|
if (!opType.template isa<Float16Type, BFloat16Type>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto f32 = rewriter.getF32Type();
|
||||||
|
auto extendedOperands = llvm::to_vector(
|
||||||
|
llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
|
||||||
|
return rewriter.create<arith::ExtFOp>(loc, f32, operand);
|
||||||
|
}));
|
||||||
|
auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
|
||||||
|
rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
LogicalResult
|
LogicalResult
|
||||||
ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
|
ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
|
||||||
PatternRewriter &rewriter) const {
|
PatternRewriter &rewriter) const {
|
||||||
auto module = SymbolTable::getNearestSymbolTable(op);
|
auto module = SymbolTable::getNearestSymbolTable(op);
|
||||||
auto type = op.getType();
|
auto type = op.getType();
|
||||||
// TODO: Support Float16 by upcasting to Float32
|
|
||||||
if (!type.template isa<Float32Type, Float64Type>())
|
if (!type.template isa<Float32Type, Float64Type>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -117,6 +142,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
|
||||||
PatternBenefit benefit) {
|
PatternBenefit benefit) {
|
||||||
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
|
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
|
||||||
VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
|
VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
|
||||||
|
patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
|
||||||
|
PromoteOpToF32<math::TanhOp>>(patterns.getContext(), benefit);
|
||||||
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
|
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
|
||||||
"atan2f", "atan2", benefit);
|
"atan2f", "atan2", benefit);
|
||||||
patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
|
patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
|
||||||
|
|
|
@ -25,13 +25,25 @@ func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) {
|
||||||
// CHECK-LABEL: func @atan2_caller
|
// CHECK-LABEL: func @atan2_caller
|
||||||
// CHECK-SAME: %[[FLOAT:.*]]: f32
|
// CHECK-SAME: %[[FLOAT:.*]]: f32
|
||||||
// CHECK-SAME: %[[DOUBLE:.*]]: f64
|
// CHECK-SAME: %[[DOUBLE:.*]]: f64
|
||||||
func.func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) {
|
// CHECK-SAME: %[[HALF:.*]]: f16
|
||||||
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32
|
// CHECK-SAME: %[[BFLOAT:.*]]: bf16
|
||||||
|
func.func @atan2_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) -> (f32, f64, f16, bf16) {
|
||||||
|
// CHECK: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32
|
||||||
%float_result = math.atan2 %float, %float : f32
|
%float_result = math.atan2 %float, %float : f32
|
||||||
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64
|
// CHECK: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64
|
||||||
%double_result = math.atan2 %double, %double : f64
|
%double_result = math.atan2 %double, %double : f64
|
||||||
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
|
// CHECK: %[[HALF_PROMOTED1:.*]] = arith.extf %[[HALF]] : f16 to f32
|
||||||
return %float_result, %double_result : f32, f64
|
// CHECK: %[[HALF_PROMOTED2:.*]] = arith.extf %[[HALF]] : f16 to f32
|
||||||
|
// CHECK: %[[HALF_CALL:.*]] = call @atan2f(%[[HALF_PROMOTED1]], %[[HALF_PROMOTED2]]) : (f32, f32) -> f32
|
||||||
|
// CHECK: %[[HALF_RESULT:.*]] = arith.truncf %[[HALF_CALL]] : f32 to f16
|
||||||
|
%half_result = math.atan2 %half, %half : f16
|
||||||
|
// CHECK: %[[BFLOAT_PROMOTED1:.*]] = arith.extf %[[BFLOAT]] : bf16 to f32
|
||||||
|
// CHECK: %[[BFLOAT_PROMOTED2:.*]] = arith.extf %[[BFLOAT]] : bf16 to f32
|
||||||
|
// CHECK: %[[BFLOAT_CALL:.*]] = call @atan2f(%[[BFLOAT_PROMOTED1]], %[[BFLOAT_PROMOTED2]]) : (f32, f32) -> f32
|
||||||
|
// CHECK: %[[BFLOAT_RESULT:.*]] = arith.truncf %[[BFLOAT_CALL]] : f32 to bf16
|
||||||
|
%bfloat_result = math.atan2 %bfloat, %bfloat : bf16
|
||||||
|
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]], %[[HALF_RESULT]], %[[BFLOAT_RESULT]]
|
||||||
|
return %float_result, %double_result, %half_result, %bfloat_result : f32, f64, f16, bf16
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @erf_caller
|
// CHECK-LABEL: func @erf_caller
|
||||||
|
|
Loading…
Reference in New Issue