[mlir][arith] Initial support for fastmath flag attributes in the Arithmetic dialect (v2)

This diff adds initial (partial) support for "fastmath" attributes for floating
point operations in the arithmetic dialect. The "fastmath" attributes are
implemented using a default-valued bit enum. The defined flags currently mirror
the fastmath flags in the LLVM dialect (and in LLVM itself). Extending the
set of flags (if necessary) is left as a future task.

In this diff:
- Definition of FastMathAttr as a custom attribute in the Arithmetic dialect
  that inherits from the EnumAttr class.
- Definition of ArithFastMathInterface, which is an interface that is
  implemented by operations that have an arith::fastmath attribute.
- Declaration of a default-valued fastmath attribute for unary and (some) binary
  floating point operations in the Arithmetic dialect.
- Conversion code to lower arithmetic fastmath flags to LLVM fastmath flags

NOT in this diff (but planned or currently in progress):
- Documentation of flag meanings
- Addition of FastMathAttr attributes to other dialects that might lower to the
  Arithmetic dialect (e.g. Math and Complex)
- Folding/rewrite implementations that are enabled by fastmath flags
- Specification of fastmath values from Python bindings (pending other in-
  progress diffs)

Reviewed By: mehdi_amini, vzakhari

Differential Revision: https://reviews.llvm.org/D126305
This commit is contained in:
Jeremy Furtek 2022-10-26 11:07:20 -07:00 committed by Slava Zakharin
parent f6eb089734
commit b56e65d318
24 changed files with 372 additions and 44 deletions

View File

@ -22,8 +22,10 @@ namespace detail {
/// and given operands.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
} // namespace detail
} // namespace LLVM
@ -197,7 +199,7 @@ public:
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
adaptor.getOperands(),
adaptor.getOperands(), op->getAttrs(),
*this->getTypeConverter(), rewriter);
}
};

View File

@ -56,14 +56,34 @@ LogicalResult handleMultidimensionalVectors(
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
} // namespace detail
} // namespace LLVM
// Default attribute conversion class, which passes all source attributes
// through to the target op, unmodified.
template <typename SourceOp, typename TargetOp>
class AttrConvertPassThrough {
public:
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
private:
ArrayRef<NamedAttribute> srcAttrs;
};
/// Basic lowering implementation to rewrite Ops with just one result to the
/// LLVM Dialect. This supports higher-dimensional vector types.
template <typename SourceOp, typename TargetOp>
/// The AttrConvert template template parameter should be a template class
/// with SourceOp and TargetOp type parameters, a constructor that takes
/// a SourceOp instance, and a getAttrs() method that returns
/// ArrayRef<NamedAttribute>.
template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough>
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
public:
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@ -75,9 +95,12 @@ public:
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
return LLVM::detail::vectorOneToOneRewrite(
op, TargetOp::getOperationName(), adaptor.getOperands(),
*this->getTypeConverter(), rewriter);
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
}
};
} // namespace mlir

View File

@ -17,6 +17,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/StringExtras.h"
//===----------------------------------------------------------------------===//
// ArithDialect
@ -29,6 +30,13 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/ArithOpsEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.h.inc"
//===----------------------------------------------------------------------===//
// Arith Interfaces
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.h.inc"
//===----------------------------------------------------------------------===//
// Arith Dialect Operations

View File

@ -23,6 +23,7 @@ def Arith_Dialect : Dialect {
}];
let hasConstantMaterializer = 1;
let useDefaultAttributePrinterParser = 1;
}
// The predicate indicates the type of the comparison to perform:
@ -92,4 +93,32 @@ def AtomicRMWKindAttr : I64EnumAttr<
let cppNamespace = "::mlir::arith";
}
def FASTMATH_NONE : I32BitEnumAttrCaseNone<"none" >;
def FASTMATH_REASSOC : I32BitEnumAttrCaseBit<"reassoc", 0>;
def FASTMATH_NO_NANS : I32BitEnumAttrCaseBit<"nnan", 1>;
def FASTMATH_NO_INFS : I32BitEnumAttrCaseBit<"ninf", 2>;
def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz", 3>;
def FASTMATH_ALLOW_RECIP : I32BitEnumAttrCaseBit<"arcp", 4>;
def FASTMATH_ALLOW_CONTRACT : I32BitEnumAttrCaseBit<"contract", 5>;
def FASTMATH_APPROX_FUNC : I32BitEnumAttrCaseBit<"afn", 6>;
def FASTMATH_FAST : I32BitEnumAttrCaseGroup<
"fast",
[
FASTMATH_REASSOC, FASTMATH_NO_NANS, FASTMATH_NO_INFS,
FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT,
FASTMATH_APPROX_FUNC]>;
def FastMathFlags : I32BitEnumAttr<
"FastMathFlags",
"Floating point fast math flags",
[
FASTMATH_NONE, FASTMATH_REASSOC, FASTMATH_NO_NANS,
FASTMATH_NO_INFS, FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP,
FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FAST]> {
let separator = ",";
let cppNamespace = "::mlir::arith";
let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}
#endif // ARITH_BASE

View File

@ -10,6 +10,7 @@
#define ARITH_OPS
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@ -17,6 +18,12 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/EnumAttr.td"
def Arith_FastMathAttr :
EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
let assemblyFormat = "`<` $value `>`";
}
// Base class for Arith dialect ops. Ops in this dialect have no side
// effects and can be applied element-wise to vectors and tensors.
@ -58,15 +65,27 @@ class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
// Base class for floating point unary operations.
class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
Arith_UnaryOp<mnemonic, traits>,
Arguments<(ins FloatLike:$operand)>,
Results<(outs FloatLike:$result)>;
Arith_UnaryOp<mnemonic,
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
traits)>,
Arguments<(ins FloatLike:$operand,
DefaultValuedAttr<Arith_FastMathAttr, "FastMathFlags::none">:$fastmath)>,
Results<(outs FloatLike:$result)> {
let assemblyFormat = [{ $operand custom<ArithFastMathAttr>($fastmath)
attr-dict `:` type($result) }];
}
// Base class for floating point binary operations.
class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>,
Results<(outs FloatLike:$result)>;
Arith_BinaryOp<mnemonic,
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
traits)>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
DefaultValuedAttr<Arith_FastMathAttr, "FastMathFlags::none">:$fastmath)>,
Results<(outs FloatLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs `` custom<ArithFastMathAttr>($fastmath)
attr-dict `:` type($result) }];
}
// Base class for arithmetic cast operations. Requires a single operand and
// result. If either is a shaped type, then the other must be of the same shape.

View File

@ -0,0 +1,52 @@
//===-- ArithOpsInterfaces.td - arith op interfaces ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the Arith interfaces definition file.
//
//===----------------------------------------------------------------------===//
#ifndef ARITH_OPS_INTERFACES
#define ARITH_OPS_INTERFACES
include "mlir/IR/OpBase.td"
def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
let description = [{
Access to operation fastmath flags.
}];
let cppNamespace = "::mlir::arith";
let methods = [
InterfaceMethod<
/*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
/*returnType=*/ "FastMathFlagsAttr",
/*methodName=*/ "getFastMathFlagsAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathAttr();
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getFastMathAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "fastmath";
}]
>
];
}
#endif // ARITH_OPS_INTERFACES

View File

@ -1,5 +1,14 @@
set(LLVM_TARGET_DEFINITIONS ArithOps.td)
mlir_tablegen(ArithOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(ArithOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(ArithOpsAttributes.h.inc -gen-attrdef-decls
-attrdefs-dialect=arith)
mlir_tablegen(ArithOpsAttributes.cpp.inc -gen-attrdef-defs
-attrdefs-dialect=arith)
add_mlir_dialect(ArithOps arith)
add_mlir_doc(ArithOps ArithOps Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS ArithOpsInterfaces.td)
mlir_tablegen(ArithOpsInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(ArithOpsInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRArithOpsInterfacesIncGen)

View File

@ -48,7 +48,6 @@ class SmartMutex;
namespace mlir {
namespace LLVM {
class LLVMDialect;
class LoopOptionsAttrBuilder;
namespace detail {
struct LLVMTypeStorage;

View File

@ -23,8 +23,28 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags",
"getFastmathFlags">,
InterfaceMethod<
/*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation",
/*returnType=*/ "FastmathFlagsAttr",
/*methodName=*/ "getFastmathAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathFlagsAttr();
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getFastmathAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "fastmathFlags";
}]
>
];
}

View File

@ -24,16 +24,93 @@ using namespace mlir;
namespace {
// Map arithmetic fastmath enum values to LLVMIR enum values.
static LLVM::FastmathFlags
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
LLVM::FastmathFlags llvmFMF{};
const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
{arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
{arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
{arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
{arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
for (auto fmfMap : flags) {
if (bitEnumContainsAny(arithFMF, fmfMap.first))
llvmFMF = llvmFMF | fmfMap.second;
}
return llvmFMF;
}
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
static LLVM::FastmathFlagsAttr
convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) {
arith::FastMathFlags arithFMF = fmfAttr.getValue();
return LLVM::FastmathFlagsAttr::get(
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
}
// Attribute converter that populates a NamedAttrList by removing the fastmath
// attribute from the source operation attributes, and replacing it with an
// equivalent LLVM fastmath attribute.
template <typename SourceOp, typename TargetOp>
class AttrConvertFastMath {
public:
AttrConvertFastMath(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith fastmath attribute.
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
// Remove the source fastmath attribute.
auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName)
.dyn_cast_or_null<arith::FastMathFlagsAttr>();
if (arithFMFAttr) {
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr));
}
}
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
private:
NamedAttrList convertedAttr;
};
// Attribute converter that populates a NamedAttrList by removing the fastmath
// attribute from the source operation attributes. This may be useful for
// target operations that do not require the fastmath attribute, or for targets
// that do not yet support the LLVM fastmath attribute.
template <typename SourceOp, typename TargetOp>
class AttrDropFastMath {
public:
AttrDropFastMath(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith fastmath attribute.
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
// Remove the source fastmath attribute.
convertedAttr.erase(arithFMFAttrName);
}
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
private:
NamedAttrList convertedAttr;
};
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
AttrConvertFastMath>;
using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
AttrConvertFastMath>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
@ -47,23 +124,29 @@ using FPToSIOpLowering =
VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
using FPToUIOpLowering =
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
// TODO: Add LLVM intrinsic support for fastmath
using MaxFOpLowering =
VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp>;
VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp, AttrDropFastMath>;
using MaxSIOpLowering =
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
// TODO: Add LLVM intrinsic support for fastmath
using MinFOpLowering =
VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp>;
VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp, AttrDropFastMath>;
using MinSIOpLowering =
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
AttrConvertFastMath>;
using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
AttrConvertFastMath>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
// TODO: Add LLVM intrinsic support for fastmath
using RemFOpLowering =
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, AttrDropFastMath>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@ -77,7 +160,8 @@ using ShRUIOpLowering =
VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
AttrConvertFastMath>;
using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
using TruncFOpLowering =
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
@ -153,7 +237,7 @@ LogicalResult
ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
adaptor.getOperands(),
adaptor.getOperands(), op->getAttrs(),
*getTypeConverter(), rewriter);
}

View File

@ -90,7 +90,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
*getTypeConverter(), rewriter);
op->getAttrs(), *getTypeConverter(), rewriter);
}
};

View File

@ -308,7 +308,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
/// and given operands.
LogicalResult LLVM::detail::oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults();
SmallVector<Type> resultTypes;
@ -322,7 +323,7 @@ LogicalResult LLVM::detail::oneToOneRewrite(
// Create the operation through state since we don't know its C++ type.
Operation *newOp =
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
resultTypes, op->getAttrs());
resultTypes, targetAttrs);
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)

View File

@ -105,7 +105,8 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
assert(!operands.empty());
// Cannot convert ops if their operands are not of LLVM type.
@ -114,13 +115,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
auto llvmNDVectorTy = operands[0].getType();
if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
rewriter);
auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
ValueRange operands) {
auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
ValueRange operands) {
return rewriter
.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
llvm1DVectorTy, op->getAttrs())
llvm1DVectorTy, targetAttrs)
->getResult(0);
};

View File

@ -215,17 +215,21 @@ def OrOfExtSI :
//===----------------------------------------------------------------------===//
// mulf(negf(x), negf(y)) -> mulf(x,y)
// (retain fastmath flags of original mulf)
def MulFOfNegF :
Pat<(Arith_MulFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_MulFOp $x, $y),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
(Arith_MulFOp $x, $y, $fmf),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
//===----------------------------------------------------------------------===//
// DivFOp
//===----------------------------------------------------------------------===//
// divf(negf(x), negf(y)) -> divf(x,y)
// (retain fastmath flags of original divf)
def DivFOfNegF :
Pat<(Arith_DivFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_DivFOp $x, $y),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
(Arith_DivFOp $x, $y, $fmf),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
#endif // ARITH_PATTERNS

View File

@ -8,12 +8,17 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::arith;
#include "mlir/Dialect/Arith/IR/ArithOpsDialect.cpp.inc"
#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc"
namespace {
/// This class defines the interface for handling inlining for arithmetic
@ -34,6 +39,10 @@ void arith::ArithDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc"
>();
addInterfaces<ArithInlinerInterface>();
}

View File

@ -23,6 +23,31 @@
using namespace mlir;
using namespace mlir::arith;
//===----------------------------------------------------------------------===//
// Floating point op parse/print helpers
//===----------------------------------------------------------------------===//
static ParseResult parseArithFastMathAttr(OpAsmParser &parser,
Attribute &attr) {
if (succeeded(
parser.parseOptionalKeyword(FastMathFlagsAttr::getMnemonic()))) {
attr = FastMathFlagsAttr::parse(parser, Type{});
return success(static_cast<bool>(attr));
} else {
// No fastmath attribute mnemonic present - defer attribute creation and use
// the default value.
return success();
}
}
static void printArithFastMathAttr(OpAsmPrinter &printer, Operation *op,
FastMathFlagsAttr fmAttr) {
// Elide printing the fastmath attribute when fastmath=none
if (fmAttr && (fmAttr.getValue() != FastMathFlags::none)) {
printer << " " << FastMathFlagsAttr::getMnemonic();
fmAttr.print(printer);
}
}
//===----------------------------------------------------------------------===//
// Pattern helpers
//===----------------------------------------------------------------------===//

View File

@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRArithDialect
DEPENDS
MLIRArithOpsIncGen
MLIRArithOpsInterfacesIncGen
LINK_LIBS PUBLIC
MLIRDialect

View File

@ -51,13 +51,13 @@ struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
Type elementType = getSrcVectorElementType<OpTy>(op);
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
if (bitwidth == 32)
return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
adaptor.getOperands(),
getTypeConverter(), rewriter);
return LLVM::detail::oneToOneRewrite(
op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
op->getAttrs(), getTypeConverter(), rewriter);
if (bitwidth == 64)
return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
adaptor.getOperands(),
getTypeConverter(), rewriter);
return LLVM::detail::oneToOneRewrite(
op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
op->getAttrs(), getTypeConverter(), rewriter);
return rewriter.notifyMatchFailure(
op, "expected 'src' to be either f32 or f64");
}

View File

@ -160,9 +160,9 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
// clang-format on
};
llvm::FastMathFlags ret;
auto fmf = op.getFastmathFlags();
::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue();
for (auto it : handlers)
if (bitEnumContainsAll(fmf, it.first))
if (bitEnumContainsAll(fmfMlir, it.first))
(ret.*(it.second))(true);
return ret;
}

View File

@ -309,7 +309,7 @@ int collectStats(MlirOperation operation) {
// clang-format off
// CHECK-LABEL: @stats
// CHECK: Number of operations: 12
// CHECK: Number of attributes: 4
// CHECK: Number of attributes: 5
// CHECK: Number of blocks: 3
// CHECK: Number of regions: 3
// CHECK: Number of values: 9

View File

@ -448,3 +448,20 @@ func.func @minmaxf(%arg0 : f32, %arg1 : f32) -> f32 {
%1 = arith.maxf %arg0, %arg1 : f32
return %0 : f32
}
// -----
// CHECK-LABEL: @fastmath
func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : f32
// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : f32
// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : f32
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
%0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
%1 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
%2 = arith.negf %arg0 fastmath<fast> : f32
%3 = arith.addf %arg0, %arg1 fastmath<none> : f32
%4 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
return
}

View File

@ -1031,3 +1031,27 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%min_unsigned = arith.minui %i1, %i2 : i32
return
}
// CHECK-LABEL: @fastmath
func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath<fast> : f32
// CHECK: {{.*}} = arith.subf %arg0, %arg1 fastmath<fast> : f32
// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath<fast> : f32
// CHECK: {{.*}} = arith.divf %arg0, %arg1 fastmath<fast> : f32
// CHECK: {{.*}} = arith.remf %arg0, %arg1 fastmath<fast> : f32
// CHECK: {{.*}} = arith.negf %arg0 fastmath<fast> : f32
%0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
%1 = arith.subf %arg0, %arg1 fastmath<fast> : f32
%2 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
%3 = arith.divf %arg0, %arg1 fastmath<fast> : f32
%4 = arith.remf %arg0, %arg1 fastmath<fast> : f32
%5 = arith.negf %arg0 fastmath<fast> : f32
// CHECK: {{.*}} = arith.addf %arg0, %arg1 : f32
%6 = arith.addf %arg0, %arg1 fastmath<none> : f32
// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
%7 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath<fast> : f32
%8 = arith.mulf %arg0, %arg1 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
return
}

View File

@ -270,7 +270,7 @@ func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->
// -----
func.func @generic(%arg0: memref<?x?xf32>) {
// expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32}}
// expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) {fastmath = #arith.fastmath<none>} : (f32, f32) -> f32}}
linalg.generic {
indexing_maps = [ affine_map<(i, j) -> (i, j)> ],
iterator_types = ["parallel", "parallel"]}

View File

@ -1001,7 +1001,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
for (FormatElement *param : dir->getArguments()) {
if (auto *attr = dyn_cast<AttributeVariable>(param)) {
const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional())
if (var->attr.isOptional() || var->attr.hasDefaultValue())
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",