[mlir][arith] Initial support for fastmath flag attributes in the Arithmetic dialect (v2)
This diff adds initial (partial) support for "fastmath" attributes for floating point operations in the arithmetic dialect. The "fastmath" attributes are implemented using a default-valued bit enum. The defined flags currently mirror the fastmath flags in the LLVM dialect (and in LLVM itself). Extending the set of flags (if necessary) is left as a future task. In this diff: - Definition of FastMathAttr as a custom attribute in the Arithmetic dialect that inherits from the EnumAttr class. - Definition of ArithFastMathInterface, which is an interface that is implemented by operations that have an arith::fastmath attribute. - Declaration of a default-valued fastmath attribute for unary and (some) binary floating point operations in the Arithmetic dialect. - Conversion code to lower arithmetic fastmath flags to LLVM fastmath flags NOT in this diff (but planned or currently in progress): - Documentation of flag meanings - Addition of FastMathAttr attributes to other dialects that might lower to the Arithmetic dialect (e.g. Math and Complex) - Folding/rewrite implementations that are enabled by fastmath flags - Specification of fastmath values from Python bindings (pending other in- progress diffs) Reviewed By: mehdi_amini, vzakhari Differential Revision: https://reviews.llvm.org/D126305
This commit is contained in:
parent
f6eb089734
commit
b56e65d318
|
@ -22,8 +22,10 @@ namespace detail {
|
|||
/// and given operands.
|
||||
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
|
||||
ValueRange operands,
|
||||
ArrayRef<NamedAttribute> targetAttrs,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace LLVM
|
||||
|
||||
|
@ -197,7 +199,7 @@ public:
|
|||
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
|
||||
adaptor.getOperands(),
|
||||
adaptor.getOperands(), op->getAttrs(),
|
||||
*this->getTypeConverter(), rewriter);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -56,14 +56,34 @@ LogicalResult handleMultidimensionalVectors(
|
|||
|
||||
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
|
||||
ValueRange operands,
|
||||
ArrayRef<NamedAttribute> targetAttrs,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
} // namespace detail
|
||||
} // namespace LLVM
|
||||
|
||||
// Default attribute conversion class, which passes all source attributes
|
||||
// through to the target op, unmodified.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
class AttrConvertPassThrough {
|
||||
public:
|
||||
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
|
||||
|
||||
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
|
||||
|
||||
private:
|
||||
ArrayRef<NamedAttribute> srcAttrs;
|
||||
};
|
||||
|
||||
/// Basic lowering implementation to rewrite Ops with just one result to the
|
||||
/// LLVM Dialect. This supports higher-dimensional vector types.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
/// The AttrConvert template template parameter should be a template class
|
||||
/// with SourceOp and TargetOp type parameters, a constructor that takes
|
||||
/// a SourceOp instance, and a getAttrs() method that returns
|
||||
/// ArrayRef<NamedAttribute>.
|
||||
template <typename SourceOp, typename TargetOp,
|
||||
template <typename, typename> typename AttrConvert =
|
||||
AttrConvertPassThrough>
|
||||
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
||||
|
@ -75,9 +95,12 @@ public:
|
|||
static_assert(
|
||||
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
|
||||
"expected single result op");
|
||||
// Determine attributes for the target op
|
||||
AttrConvert<SourceOp, TargetOp> attrConvert(op);
|
||||
|
||||
return LLVM::detail::vectorOneToOneRewrite(
|
||||
op, TargetOp::getOperationName(), adaptor.getOperands(),
|
||||
*this->getTypeConverter(), rewriter);
|
||||
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
|
||||
}
|
||||
};
|
||||
} // namespace mlir
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArithDialect
|
||||
|
@ -29,6 +30,13 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/ArithOpsEnums.h.inc"
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Arith Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Arith Dialect Operations
|
||||
|
|
|
@ -23,6 +23,7 @@ def Arith_Dialect : Dialect {
|
|||
}];
|
||||
|
||||
let hasConstantMaterializer = 1;
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
}
|
||||
|
||||
// The predicate indicates the type of the comparison to perform:
|
||||
|
@ -92,4 +93,32 @@ def AtomicRMWKindAttr : I64EnumAttr<
|
|||
let cppNamespace = "::mlir::arith";
|
||||
}
|
||||
|
||||
def FASTMATH_NONE : I32BitEnumAttrCaseNone<"none" >;
|
||||
def FASTMATH_REASSOC : I32BitEnumAttrCaseBit<"reassoc", 0>;
|
||||
def FASTMATH_NO_NANS : I32BitEnumAttrCaseBit<"nnan", 1>;
|
||||
def FASTMATH_NO_INFS : I32BitEnumAttrCaseBit<"ninf", 2>;
|
||||
def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz", 3>;
|
||||
def FASTMATH_ALLOW_RECIP : I32BitEnumAttrCaseBit<"arcp", 4>;
|
||||
def FASTMATH_ALLOW_CONTRACT : I32BitEnumAttrCaseBit<"contract", 5>;
|
||||
def FASTMATH_APPROX_FUNC : I32BitEnumAttrCaseBit<"afn", 6>;
|
||||
def FASTMATH_FAST : I32BitEnumAttrCaseGroup<
|
||||
"fast",
|
||||
[
|
||||
FASTMATH_REASSOC, FASTMATH_NO_NANS, FASTMATH_NO_INFS,
|
||||
FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT,
|
||||
FASTMATH_APPROX_FUNC]>;
|
||||
|
||||
def FastMathFlags : I32BitEnumAttr<
|
||||
"FastMathFlags",
|
||||
"Floating point fast math flags",
|
||||
[
|
||||
FASTMATH_NONE, FASTMATH_REASSOC, FASTMATH_NO_NANS,
|
||||
FASTMATH_NO_INFS, FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP,
|
||||
FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FAST]> {
|
||||
let separator = ",";
|
||||
let cppNamespace = "::mlir::arith";
|
||||
let genSpecializedAttr = 0;
|
||||
let printBitEnumPrimaryGroups = 1;
|
||||
}
|
||||
|
||||
#endif // ARITH_BASE
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define ARITH_OPS
|
||||
|
||||
include "mlir/Dialect/Arith/IR/ArithBase.td"
|
||||
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/InferIntRangeInterface.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
|
@ -17,6 +18,12 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
include "mlir/Interfaces/VectorInterfaces.td"
|
||||
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
|
||||
def Arith_FastMathAttr :
|
||||
EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
// Base class for Arith dialect ops. Ops in this dialect have no side
|
||||
// effects and can be applied element-wise to vectors and tensors.
|
||||
|
@ -58,15 +65,27 @@ class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
|
|||
|
||||
// Base class for floating point unary operations.
|
||||
class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
|
||||
Arith_UnaryOp<mnemonic, traits>,
|
||||
Arguments<(ins FloatLike:$operand)>,
|
||||
Results<(outs FloatLike:$result)>;
|
||||
Arith_UnaryOp<mnemonic,
|
||||
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
|
||||
traits)>,
|
||||
Arguments<(ins FloatLike:$operand,
|
||||
DefaultValuedAttr<Arith_FastMathAttr, "FastMathFlags::none">:$fastmath)>,
|
||||
Results<(outs FloatLike:$result)> {
|
||||
let assemblyFormat = [{ $operand custom<ArithFastMathAttr>($fastmath)
|
||||
attr-dict `:` type($result) }];
|
||||
}
|
||||
|
||||
// Base class for floating point binary operations.
|
||||
class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
|
||||
Arith_BinaryOp<mnemonic, traits>,
|
||||
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>,
|
||||
Results<(outs FloatLike:$result)>;
|
||||
Arith_BinaryOp<mnemonic,
|
||||
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
|
||||
traits)>,
|
||||
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
|
||||
DefaultValuedAttr<Arith_FastMathAttr, "FastMathFlags::none">:$fastmath)>,
|
||||
Results<(outs FloatLike:$result)> {
|
||||
let assemblyFormat = [{ $lhs `,` $rhs `` custom<ArithFastMathAttr>($fastmath)
|
||||
attr-dict `:` type($result) }];
|
||||
}
|
||||
|
||||
// Base class for arithmetic cast operations. Requires a single operand and
|
||||
// result. If either is a shaped type, then the other must be of the same shape.
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
//===-- ArithOpsInterfaces.td - arith op interfaces ---*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the Arith interfaces definition file.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ARITH_OPS_INTERFACES
|
||||
#define ARITH_OPS_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
|
||||
let description = [{
|
||||
Access to operation fastmath flags.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::arith";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
|
||||
/*returnType=*/ "FastMathFlagsAttr",
|
||||
/*methodName=*/ "getFastMathFlagsAttr",
|
||||
/*args=*/ (ins),
|
||||
/*methodBody=*/ [{}],
|
||||
/*defaultImpl=*/ [{
|
||||
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
|
||||
return op.getFastmathAttr();
|
||||
}]
|
||||
>,
|
||||
StaticInterfaceMethod<
|
||||
/*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
|
||||
for the operation}],
|
||||
/*returnType=*/ "StringRef",
|
||||
/*methodName=*/ "getFastMathAttrName",
|
||||
/*args=*/ (ins),
|
||||
/*methodBody=*/ [{}],
|
||||
/*defaultImpl=*/ [{
|
||||
return "fastmath";
|
||||
}]
|
||||
>
|
||||
|
||||
];
|
||||
}
|
||||
|
||||
#endif // ARITH_OPS_INTERFACES
|
|
@ -1,5 +1,14 @@
|
|||
set(LLVM_TARGET_DEFINITIONS ArithOps.td)
|
||||
mlir_tablegen(ArithOpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(ArithOpsEnums.cpp.inc -gen-enum-defs)
|
||||
mlir_tablegen(ArithOpsAttributes.h.inc -gen-attrdef-decls
|
||||
-attrdefs-dialect=arith)
|
||||
mlir_tablegen(ArithOpsAttributes.cpp.inc -gen-attrdef-defs
|
||||
-attrdefs-dialect=arith)
|
||||
add_mlir_dialect(ArithOps arith)
|
||||
add_mlir_doc(ArithOps ArithOps Dialects/ -gen-dialect-doc)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ArithOpsInterfaces.td)
|
||||
mlir_tablegen(ArithOpsInterfaces.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(ArithOpsInterfaces.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(MLIRArithOpsInterfacesIncGen)
|
||||
|
|
|
@ -48,7 +48,6 @@ class SmartMutex;
|
|||
namespace mlir {
|
||||
namespace LLVM {
|
||||
class LLVMDialect;
|
||||
class LoopOptionsAttrBuilder;
|
||||
|
||||
namespace detail {
|
||||
struct LLVMTypeStorage;
|
||||
|
|
|
@ -23,8 +23,28 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
|
|||
let cppNamespace = "::mlir::LLVM";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags",
|
||||
"getFastmathFlags">,
|
||||
InterfaceMethod<
|
||||
/*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation",
|
||||
/*returnType=*/ "FastmathFlagsAttr",
|
||||
/*methodName=*/ "getFastmathAttr",
|
||||
/*args=*/ (ins),
|
||||
/*methodBody=*/ [{}],
|
||||
/*defaultImpl=*/ [{
|
||||
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
|
||||
return op.getFastmathFlagsAttr();
|
||||
}]
|
||||
>,
|
||||
StaticInterfaceMethod<
|
||||
/*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute
|
||||
for the operation}],
|
||||
/*returnType=*/ "StringRef",
|
||||
/*methodName=*/ "getFastmathAttrName",
|
||||
/*args=*/ (ins),
|
||||
/*methodBody=*/ [{}],
|
||||
/*defaultImpl=*/ [{
|
||||
return "fastmathFlags";
|
||||
}]
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
@ -24,16 +24,93 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
|
||||
// Map arithmetic fastmath enum values to LLVMIR enum values.
|
||||
static LLVM::FastmathFlags
|
||||
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
|
||||
LLVM::FastmathFlags llvmFMF{};
|
||||
const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
|
||||
{arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
|
||||
{arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
|
||||
{arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
|
||||
{arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
|
||||
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
|
||||
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
|
||||
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
|
||||
for (auto fmfMap : flags) {
|
||||
if (bitEnumContainsAny(arithFMF, fmfMap.first))
|
||||
llvmFMF = llvmFMF | fmfMap.second;
|
||||
}
|
||||
return llvmFMF;
|
||||
}
|
||||
|
||||
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
|
||||
static LLVM::FastmathFlagsAttr
|
||||
convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) {
|
||||
arith::FastMathFlags arithFMF = fmfAttr.getValue();
|
||||
return LLVM::FastmathFlagsAttr::get(
|
||||
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
|
||||
}
|
||||
|
||||
// Attribute converter that populates a NamedAttrList by removing the fastmath
|
||||
// attribute from the source operation attributes, and replacing it with an
|
||||
// equivalent LLVM fastmath attribute.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
class AttrConvertFastMath {
|
||||
public:
|
||||
AttrConvertFastMath(SourceOp srcOp) {
|
||||
// Copy the source attributes.
|
||||
convertedAttr = NamedAttrList{srcOp->getAttrs()};
|
||||
// Get the name of the arith fastmath attribute.
|
||||
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
|
||||
// Remove the source fastmath attribute.
|
||||
auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName)
|
||||
.dyn_cast_or_null<arith::FastMathFlagsAttr>();
|
||||
if (arithFMFAttr) {
|
||||
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
|
||||
convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr));
|
||||
}
|
||||
}
|
||||
|
||||
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
|
||||
|
||||
private:
|
||||
NamedAttrList convertedAttr;
|
||||
};
|
||||
|
||||
// Attribute converter that populates a NamedAttrList by removing the fastmath
|
||||
// attribute from the source operation attributes. This may be useful for
|
||||
// target operations that do not require the fastmath attribute, or for targets
|
||||
// that do not yet support the LLVM fastmath attribute.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
class AttrDropFastMath {
|
||||
public:
|
||||
AttrDropFastMath(SourceOp srcOp) {
|
||||
// Copy the source attributes.
|
||||
convertedAttr = NamedAttrList{srcOp->getAttrs()};
|
||||
// Get the name of the arith fastmath attribute.
|
||||
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
|
||||
// Remove the source fastmath attribute.
|
||||
convertedAttr.erase(arithFMFAttrName);
|
||||
}
|
||||
|
||||
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
|
||||
|
||||
private:
|
||||
NamedAttrList convertedAttr;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Straightforward Op Lowerings
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
|
||||
using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
|
||||
AttrConvertFastMath>;
|
||||
using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
|
||||
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
|
||||
using BitcastOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
|
||||
using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
|
||||
using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
|
||||
AttrConvertFastMath>;
|
||||
using DivSIOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
|
||||
using DivUIOpLowering =
|
||||
|
@ -47,23 +124,29 @@ using FPToSIOpLowering =
|
|||
VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
|
||||
using FPToUIOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
|
||||
// TODO: Add LLVM intrinsic support for fastmath
|
||||
using MaxFOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp>;
|
||||
VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp, AttrDropFastMath>;
|
||||
using MaxSIOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
|
||||
using MaxUIOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
|
||||
// TODO: Add LLVM intrinsic support for fastmath
|
||||
using MinFOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp>;
|
||||
VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp, AttrDropFastMath>;
|
||||
using MinSIOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
|
||||
using MinUIOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
|
||||
using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
|
||||
using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
|
||||
AttrConvertFastMath>;
|
||||
using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
|
||||
using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
|
||||
using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
|
||||
AttrConvertFastMath>;
|
||||
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
|
||||
using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
|
||||
// TODO: Add LLVM intrinsic support for fastmath
|
||||
using RemFOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, AttrDropFastMath>;
|
||||
using RemSIOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
|
||||
using RemUIOpLowering =
|
||||
|
@ -77,7 +160,8 @@ using ShRUIOpLowering =
|
|||
VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
|
||||
using SIToFPOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
|
||||
using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
|
||||
using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
|
||||
AttrConvertFastMath>;
|
||||
using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
|
||||
using TruncFOpLowering =
|
||||
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
|
||||
|
@ -153,7 +237,7 @@ LogicalResult
|
|||
ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
|
||||
adaptor.getOperands(),
|
||||
adaptor.getOperands(), op->getAttrs(),
|
||||
*getTypeConverter(), rewriter);
|
||||
}
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
return LLVM::detail::oneToOneRewrite(
|
||||
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
|
||||
*getTypeConverter(), rewriter);
|
||||
op->getAttrs(), *getTypeConverter(), rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -308,7 +308,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
|
|||
/// and given operands.
|
||||
LogicalResult LLVM::detail::oneToOneRewrite(
|
||||
Operation *op, StringRef targetOp, ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
|
||||
ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
unsigned numResults = op->getNumResults();
|
||||
|
||||
SmallVector<Type> resultTypes;
|
||||
|
@ -322,7 +323,7 @@ LogicalResult LLVM::detail::oneToOneRewrite(
|
|||
// Create the operation through state since we don't know its C++ type.
|
||||
Operation *newOp =
|
||||
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
|
||||
resultTypes, op->getAttrs());
|
||||
resultTypes, targetAttrs);
|
||||
|
||||
// If the operation produced 0 or 1 result, return them immediately.
|
||||
if (numResults == 0)
|
||||
|
|
|
@ -105,7 +105,8 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
|
|||
|
||||
LogicalResult LLVM::detail::vectorOneToOneRewrite(
|
||||
Operation *op, StringRef targetOp, ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
|
||||
ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
assert(!operands.empty());
|
||||
|
||||
// Cannot convert ops if their operands are not of LLVM type.
|
||||
|
@ -114,13 +115,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
|
|||
|
||||
auto llvmNDVectorTy = operands[0].getType();
|
||||
if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
|
||||
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
|
||||
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
|
||||
rewriter);
|
||||
|
||||
auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
|
||||
ValueRange operands) {
|
||||
auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
|
||||
ValueRange operands) {
|
||||
return rewriter
|
||||
.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
|
||||
llvm1DVectorTy, op->getAttrs())
|
||||
llvm1DVectorTy, targetAttrs)
|
||||
->getResult(0);
|
||||
};
|
||||
|
||||
|
|
|
@ -215,17 +215,21 @@ def OrOfExtSI :
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// mulf(negf(x), negf(y)) -> mulf(x,y)
|
||||
// (retain fastmath flags of original mulf)
|
||||
def MulFOfNegF :
|
||||
Pat<(Arith_MulFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_MulFOp $x, $y),
|
||||
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
|
||||
Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
|
||||
(Arith_MulFOp $x, $y, $fmf),
|
||||
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// divf(negf(x), negf(y)) -> divf(x,y)
|
||||
// (retain fastmath flags of original divf)
|
||||
def DivFOfNegF :
|
||||
Pat<(Arith_DivFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_DivFOp $x, $y),
|
||||
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
|
||||
Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
|
||||
(Arith_DivFOp $x, $y, $fmf),
|
||||
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
|
||||
|
||||
#endif // ARITH_PATTERNS
|
||||
|
|
|
@ -8,12 +8,17 @@
|
|||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::arith;
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/ArithOpsDialect.cpp.inc"
|
||||
#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.cpp.inc"
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc"
|
||||
|
||||
namespace {
|
||||
/// This class defines the interface for handling inlining for arithmetic
|
||||
|
@ -34,6 +39,10 @@ void arith::ArithDialect::initialize() {
|
|||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
|
||||
>();
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc"
|
||||
>();
|
||||
addInterfaces<ArithInlinerInterface>();
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,31 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::arith;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Floating point op parse/print helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
static ParseResult parseArithFastMathAttr(OpAsmParser &parser,
|
||||
Attribute &attr) {
|
||||
if (succeeded(
|
||||
parser.parseOptionalKeyword(FastMathFlagsAttr::getMnemonic()))) {
|
||||
attr = FastMathFlagsAttr::parse(parser, Type{});
|
||||
return success(static_cast<bool>(attr));
|
||||
} else {
|
||||
// No fastmath attribute mnemonic present - defer attribute creation and use
|
||||
// the default value.
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
static void printArithFastMathAttr(OpAsmPrinter &printer, Operation *op,
|
||||
FastMathFlagsAttr fmAttr) {
|
||||
// Elide printing the fastmath attribute when fastmath=none
|
||||
if (fmAttr && (fmAttr.getValue() != FastMathFlags::none)) {
|
||||
printer << " " << FastMathFlagsAttr::getMnemonic();
|
||||
fmAttr.print(printer);
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRArithDialect
|
|||
|
||||
DEPENDS
|
||||
MLIRArithOpsIncGen
|
||||
MLIRArithOpsInterfacesIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRDialect
|
||||
|
|
|
@ -51,13 +51,13 @@ struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
|
|||
Type elementType = getSrcVectorElementType<OpTy>(op);
|
||||
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
|
||||
if (bitwidth == 32)
|
||||
return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
|
||||
adaptor.getOperands(),
|
||||
getTypeConverter(), rewriter);
|
||||
return LLVM::detail::oneToOneRewrite(
|
||||
op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
|
||||
op->getAttrs(), getTypeConverter(), rewriter);
|
||||
if (bitwidth == 64)
|
||||
return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
|
||||
adaptor.getOperands(),
|
||||
getTypeConverter(), rewriter);
|
||||
return LLVM::detail::oneToOneRewrite(
|
||||
op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
|
||||
op->getAttrs(), getTypeConverter(), rewriter);
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected 'src' to be either f32 or f64");
|
||||
}
|
||||
|
|
|
@ -160,9 +160,9 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
|
|||
// clang-format on
|
||||
};
|
||||
llvm::FastMathFlags ret;
|
||||
auto fmf = op.getFastmathFlags();
|
||||
::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue();
|
||||
for (auto it : handlers)
|
||||
if (bitEnumContainsAll(fmf, it.first))
|
||||
if (bitEnumContainsAll(fmfMlir, it.first))
|
||||
(ret.*(it.second))(true);
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -309,7 +309,7 @@ int collectStats(MlirOperation operation) {
|
|||
// clang-format off
|
||||
// CHECK-LABEL: @stats
|
||||
// CHECK: Number of operations: 12
|
||||
// CHECK: Number of attributes: 4
|
||||
// CHECK: Number of attributes: 5
|
||||
// CHECK: Number of blocks: 3
|
||||
// CHECK: Number of regions: 3
|
||||
// CHECK: Number of values: 9
|
||||
|
|
|
@ -448,3 +448,20 @@ func.func @minmaxf(%arg0 : f32, %arg1 : f32) -> f32 {
|
|||
%1 = arith.maxf %arg0, %arg1 : f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fastmath
|
||||
func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
|
||||
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : f32
|
||||
// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : f32
|
||||
// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : f32
|
||||
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
|
||||
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
|
||||
%0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
|
||||
%1 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
|
||||
%2 = arith.negf %arg0 fastmath<fast> : f32
|
||||
%3 = arith.addf %arg0, %arg1 fastmath<none> : f32
|
||||
%4 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1031,3 +1031,27 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
|
|||
%min_unsigned = arith.minui %i1, %i2 : i32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fastmath
|
||||
func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
|
||||
// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath<fast> : f32
|
||||
// CHECK: {{.*}} = arith.subf %arg0, %arg1 fastmath<fast> : f32
|
||||
// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath<fast> : f32
|
||||
// CHECK: {{.*}} = arith.divf %arg0, %arg1 fastmath<fast> : f32
|
||||
// CHECK: {{.*}} = arith.remf %arg0, %arg1 fastmath<fast> : f32
|
||||
// CHECK: {{.*}} = arith.negf %arg0 fastmath<fast> : f32
|
||||
%0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
|
||||
%1 = arith.subf %arg0, %arg1 fastmath<fast> : f32
|
||||
%2 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
|
||||
%3 = arith.divf %arg0, %arg1 fastmath<fast> : f32
|
||||
%4 = arith.remf %arg0, %arg1 fastmath<fast> : f32
|
||||
%5 = arith.negf %arg0 fastmath<fast> : f32
|
||||
// CHECK: {{.*}} = arith.addf %arg0, %arg1 : f32
|
||||
%6 = arith.addf %arg0, %arg1 fastmath<none> : f32
|
||||
// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
|
||||
%7 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
|
||||
// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath<fast> : f32
|
||||
%8 = arith.mulf %arg0, %arg1 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -270,7 +270,7 @@ func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->
|
|||
// -----
|
||||
|
||||
func.func @generic(%arg0: memref<?x?xf32>) {
|
||||
// expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32}}
|
||||
// expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) {fastmath = #arith.fastmath<none>} : (f32, f32) -> f32}}
|
||||
linalg.generic {
|
||||
indexing_maps = [ affine_map<(i, j) -> (i, j)> ],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
|
|
|
@ -1001,7 +1001,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
|
|||
for (FormatElement *param : dir->getArguments()) {
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(param)) {
|
||||
const NamedAttribute *var = attr->getVar();
|
||||
if (var->attr.isOptional())
|
||||
if (var->attr.isOptional() || var->attr.hasDefaultValue())
|
||||
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
|
||||
|
||||
body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
|
||||
|
|
Loading…
Reference in New Issue