[flang] Propagate fastmath flags during intrinsics simplification.

In general, the meaning of fastmath flags on a call during inlining
is that the call's operation flags must be ignored. For user functions
that means that the fastmath flags used for the function definition
override any call site's fastmath flags. For intrinsic functions
we can use the call site's fastmath flags, but we have to make sure
that the call sites with different flags produce/use different
simplified versions of the same intrinsic function.

Differential Revision: https://reviews.llvm.org/D138048
This commit is contained in:
Slava Zakharin 2022-11-15 11:09:59 -08:00
parent f44e846402
commit ffe1661fab
3 changed files with 148 additions and 5 deletions

View File

@ -419,6 +419,9 @@ public:
/// config.
void setFastMathFlags(Fortran::common::MathOptionsBase options);
/// Get current FastMathFlags value.
mlir::arith::FastMathFlags getFastMathFlags() const { return fastMathFlags; }
/// Dump the current function. (debug)
LLVM_DUMP_METHOD void dumpFunc();

View File

@ -85,6 +85,35 @@ private:
} // namespace
/// Create FirOpBuilder with the provided \p op insertion point
/// and \p kindMap additionally inheriting FastMathFlags from \p op.
static fir::FirOpBuilder
getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
fir::FirOpBuilder builder{op, kindMap};
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
if (!fmi)
return builder;
// Regardless of what default FastMathFlags are used by FirOpBuilder,
// override them with FastMathFlags attached to the operation.
builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
return builder;
}
/// Stringify FastMathFlags set for the given \p builder in a way
/// that the string may be used for mangling a function name.
/// If FastMathFlags are set to 'none', then the result is an empty
/// string.
static std::string getFastMathFlagsString(const fir::FirOpBuilder &builder) {
mlir::arith::FastMathFlags flags = builder.getFastMathFlags();
if (flags == mlir::arith::FastMathFlags::none)
return {};
std::string fmfString{mlir::arith::stringifyFastMathFlags(flags)};
std::replace(fmfString.begin(), fmfString.end(), ',', '_');
return fmfString;
}
/// Generate function type for the simplified version of RTNAME(Sum) and
/// similar functions with a fir.box<none> type returning \p elementType.
static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
@ -511,7 +540,8 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
unsigned rank = getDimCount(args[0]);
if (dimAndMaskAbsent && rank > 0) {
mlir::Location loc = call.getLoc();
fir::FirOpBuilder builder(call, kindMap);
fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
std::string fmfString{getFastMathFlagsString(builder)};
// Support only floating point and integer results now.
mlir::Type resultType = call.getResult(0).getType();
@ -535,7 +565,10 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
// Mangle the function name with the rank value as "x<rank>".
std::string funcName =
(mlir::Twine{callee.getLeafReference().getValue(), "x"} +
mlir::Twine{rank})
mlir::Twine{rank} +
// We must mangle the generated function name with FastMathFlags
// value.
(fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
.str();
mlir::func::FuncOp newFunc =
getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
@ -576,7 +609,10 @@ void SimplifyIntrinsicsPass::runOnOperation() {
const mlir::Value &v1 = args[0];
const mlir::Value &v2 = args[1];
mlir::Location loc = call.getLoc();
fir::FirOpBuilder builder(op, kindMap);
fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)};
// Stringize the builder's FastMathFlags flags for mangling
// the generated function name.
std::string fmfString{getFastMathFlagsString(builder)};
mlir::Type type = call.getResult(0).getType();
if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
@ -611,9 +647,13 @@ void SimplifyIntrinsicsPass::runOnOperation() {
// of the arguments.
std::string typedFuncName(funcName);
llvm::raw_string_ostream nameOS(typedFuncName);
nameOS << "_";
// We must mangle the generated function name with FastMathFlags
// value.
if (!fmfString.empty())
nameOS << '_' << fmfString;
nameOS << '_';
arg1Type->print(nameOS);
nameOS << "_";
nameOS << '_';
arg2Type->print(nameOS);
mlir::func::FuncOp newFunc = getOrCreateFunction(

View File

@ -998,3 +998,103 @@ fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> {
// CHECK-NOT: call{{.*}}_FortranASumInteger8(
// CHECK: call @_FortranASumInteger8x2_simplified(
// CHECK-NOT: call{{.*}}_FortranASumInteger8(
// -----
func.func @dot_f32_contract_reassoc(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}, %arg1: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "b"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "dot", uniq_name = "_QFdotEdot"}
%1 = fir.address_of(@_QQcl.2E2F646F742E66393000) : !fir.ref<!fir.char<1,10>>
%c3_i32 = arith.constant 3 : i32
%2 = fir.convert %arg0 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
%3 = fir.convert %arg1 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
%4 = fir.convert %1 : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<i8>
%5 = fir.call @_FortranADotProductReal4(%2, %3, %4, %c3_i32) fastmath<contract,reassoc> : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f32
fir.store %5 to %0 : !fir.ref<f32>
%6 = fir.load %0 : !fir.ref<f32>
return %6 : f32
}
func.func @dot_f32_fast(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}, %arg1: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "b"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "dot", uniq_name = "_QFdotEdot"}
%1 = fir.address_of(@_QQcl.2E2F646F742E66393000) : !fir.ref<!fir.char<1,10>>
%c3_i32 = arith.constant 3 : i32
%2 = fir.convert %arg0 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
%3 = fir.convert %arg1 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
%4 = fir.convert %1 : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<i8>
%5 = fir.call @_FortranADotProductReal4(%2, %3, %4, %c3_i32) fastmath<fast> : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f32
fir.store %5 to %0 : !fir.ref<f32>
%6 = fir.load %0 : !fir.ref<f32>
return %6 : f32
}
func.func private @_FortranADotProductReal4(!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f32 attributes {fir.runtime}
fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> {
%0 = fir.string_lit "./dot.f90\00"(10) : !fir.char<1,10>
fir.has_value %0 : !fir.char<1,10>
}
// CHECK-LABEL: @dot_f32_contract_reassoc
// CHECK: fir.call @_FortranADotProductReal4_reassoc_contract_f32_f32_simplified(%2, %3) fastmath<reassoc,contract>
// CHECK-LABEL: @dot_f32_fast
// CHECK: fir.call @_FortranADotProductReal4_fast_f32_f32_simplified(%2, %3) fastmath<fast>
// CHECK-LABEL: func.func private @_FortranADotProductReal4_reassoc_contract_f32_f32_simplified
// CHECK: arith.mulf %{{.*}}, %{{.*}} fastmath<reassoc,contract> : f32
// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<reassoc,contract> : f32
// CHECK-LABEL: func.func private @_FortranADotProductReal4_fast_f32_f32_simplified
// CHECK: arith.mulf %{{.*}}, %{{.*}} fastmath<fast> : f32
// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<fast> : f32
// -----
func.func @sum_1d_real_contract_reassoc(%arg0: !fir.ref<!fir.array<10xf64>> {fir.bindc_name = "a"}) -> f64 {
%c10 = arith.constant 10 : index
%0 = fir.alloca f64 {bindc_name = "sum_1d_real", uniq_name = "_QFsum_1d_realEsum_1d_real"}
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
%2 = fir.embox %arg0(%1) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
%3 = fir.absent !fir.box<i1>
%c0 = arith.constant 0 : index
%4 = fir.address_of(@_QQcl.2E2F6973756D5F352E66393000) : !fir.ref<!fir.char<1,13>>
%c5_i32 = arith.constant 5 : i32
%5 = fir.convert %2 : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
%6 = fir.convert %4 : (!fir.ref<!fir.char<1,13>>) -> !fir.ref<i8>
%7 = fir.convert %c0 : (index) -> i32
%8 = fir.convert %3 : (!fir.box<i1>) -> !fir.box<none>
%9 = fir.call @_FortranASumReal8(%5, %6, %c5_i32, %7, %8) fastmath<contract,reassoc> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> f64
fir.store %9 to %0 : !fir.ref<f64>
%10 = fir.load %0 : !fir.ref<f64>
return %10 : f64
}
func.func @sum_1d_real_fast(%arg0: !fir.ref<!fir.array<10xf64>> {fir.bindc_name = "a"}) -> f64 {
%c10 = arith.constant 10 : index
%0 = fir.alloca f64 {bindc_name = "sum_1d_real", uniq_name = "_QFsum_1d_realEsum_1d_real"}
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
%2 = fir.embox %arg0(%1) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
%3 = fir.absent !fir.box<i1>
%c0 = arith.constant 0 : index
%4 = fir.address_of(@_QQcl.2E2F6973756D5F352E66393000) : !fir.ref<!fir.char<1,13>>
%c5_i32 = arith.constant 5 : i32
%5 = fir.convert %2 : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
%6 = fir.convert %4 : (!fir.ref<!fir.char<1,13>>) -> !fir.ref<i8>
%7 = fir.convert %c0 : (index) -> i32
%8 = fir.convert %3 : (!fir.box<i1>) -> !fir.box<none>
%9 = fir.call @_FortranASumReal8(%5, %6, %c5_i32, %7, %8) fastmath<fast> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> f64
fir.store %9 to %0 : !fir.ref<f64>
%10 = fir.load %0 : !fir.ref<f64>
return %10 : f64
}
func.func private @_FortranASumReal8(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> f64 attributes {fir.runtime}
fir.global linkonce @_QQcl.2E2F6973756D5F352E66393000 constant : !fir.char<1,13> {
%0 = fir.string_lit "./isum_5.f90\00"(13) : !fir.char<1,13>
fir.has_value %0 : !fir.char<1,13>
}
// CHECK-LABEL: @sum_1d_real_contract_reassoc
// CHECK: fir.call @_FortranASumReal8x1_reassoc_contract_simplified(%5) fastmath<reassoc,contract>
// CHECK-LABEL: @sum_1d_real_fast
// CHECK: fir.call @_FortranASumReal8x1_fast_simplified(%5) fastmath<fast>
// CHECK-LABEL: func.func private @_FortranASumReal8x1_reassoc_contract_simplified
// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<reassoc,contract> : f64
// CHECK-LABEL: func.func private @_FortranASumReal8x1_fast_simplified
// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<fast> : f64