From b56e65d31825fe4a1ae02fdcbad58bb7993d63a7 Mon Sep 17 00:00:00 2001 From: Jeremy Furtek Date: Wed, 26 Oct 2022 11:07:20 -0700 Subject: [PATCH] [mlir][arith] Initial support for fastmath flag attributes in the Arithmetic dialect (v2) This diff adds initial (partial) support for "fastmath" attributes for floating point operations in the arithmetic dialect. The "fastmath" attributes are implemented using a default-valued bit enum. The defined flags currently mirror the fastmath flags in the LLVM dialect (and in LLVM itself). Extending the set of flags (if necessary) is left as a future task. In this diff: - Definition of FastMathAttr as a custom attribute in the Arithmetic dialect that inherits from the EnumAttr class. - Definition of ArithFastMathInterface, which is an interface that is implemented by operations that have an arith::fastmath attribute. - Declaration of a default-valued fastmath attribute for unary and (some) binary floating point operations in the Arithmetic dialect. - Conversion code to lower arithmetic fastmath flags to LLVM fastmath flags NOT in this diff (but planned or currently in progress): - Documentation of flag meanings - Addition of FastMathAttr attributes to other dialects that might lower to the Arithmetic dialect (e.g. Math and Complex) - Folding/rewrite implementations that are enabled by fastmath flags - Specification of fastmath values from Python bindings (pending other in- progress diffs) Reviewed By: mehdi_amini, vzakhari Differential Revision: https://reviews.llvm.org/D126305 --- .../mlir/Conversion/LLVMCommon/Pattern.h | 4 +- .../Conversion/LLVMCommon/VectorPattern.h | 27 ++++- mlir/include/mlir/Dialect/Arith/IR/Arith.h | 8 ++ .../mlir/Dialect/Arith/IR/ArithBase.td | 29 +++++ .../include/mlir/Dialect/Arith/IR/ArithOps.td | 31 ++++-- .../Dialect/Arith/IR/ArithOpsInterfaces.td | 52 +++++++++ .../mlir/Dialect/Arith/IR/CMakeLists.txt | 9 ++ .../include/mlir/Dialect/LLVMIR/LLVMDialect.h | 1 - .../mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td | 24 ++++- .../Conversion/ArithToLLVM/ArithToLLVM.cpp | 102 ++++++++++++++++-- .../ComplexToLLVM/ComplexToLLVM.cpp | 2 +- mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 5 +- .../Conversion/LLVMCommon/VectorPattern.cpp | 12 ++- .../Dialect/Arith/IR/ArithCanonicalization.td | 12 ++- mlir/lib/Dialect/Arith/IR/ArithDialect.cpp | 9 ++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 25 +++++ mlir/lib/Dialect/Arith/IR/CMakeLists.txt | 1 + .../Transforms/LegalizeForLLVMExport.cpp | 12 +-- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 4 +- mlir/test/CAPI/ir.c | 2 +- .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 17 +++ mlir/test/Dialect/Arith/ops.mlir | 24 +++++ mlir/test/Dialect/Linalg/invalid.mlir | 2 +- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 2 +- 24 files changed, 372 insertions(+), 44 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 9de4334a9d70..90e62aa11787 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -22,8 +22,10 @@ namespace detail { /// and given operands. LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, + ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); + } // namespace detail } // namespace LLVM @@ -197,7 +199,7 @@ public: matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), - adaptor.getOperands(), + adaptor.getOperands(), op->getAttrs(), *this->getTypeConverter(), rewriter); } }; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index cae1b1cf3892..d115c2d2f58f 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -56,14 +56,34 @@ LogicalResult handleMultidimensionalVectors( LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, + ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); } // namespace detail } // namespace LLVM +// Default attribute conversion class, which passes all source attributes +// through to the target op, unmodified. +template +class AttrConvertPassThrough { +public: + AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {} + + ArrayRef getAttrs() const { return srcAttrs; } + +private: + ArrayRef srcAttrs; +}; + /// Basic lowering implementation to rewrite Ops with just one result to the /// LLVM Dialect. This supports higher-dimensional vector types. -template +/// The AttrConvert template template parameter should be a template class +/// with SourceOp and TargetOp type parameters, a constructor that takes +/// a SourceOp instance, and a getAttrs() method that returns +/// ArrayRef. +template typename AttrConvert = + AttrConvertPassThrough> class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -75,9 +95,12 @@ public: static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); + // Determine attributes for the target op + AttrConvert attrConvert(op); + return LLVM::detail::vectorOneToOneRewrite( op, TargetOp::getOperationName(), adaptor.getOperands(), - *this->getTypeConverter(), rewriter); + attrConvert.getAttrs(), *this->getTypeConverter(), rewriter); } }; } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h index 0ecd293d7778..3e14e4d34675 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -17,6 +17,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "llvm/ADT/StringExtras.h" //===----------------------------------------------------------------------===// // ArithDialect @@ -29,6 +30,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/ArithOpsEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.h.inc" + +//===----------------------------------------------------------------------===// +// Arith Interfaces +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.h.inc" //===----------------------------------------------------------------------===// // Arith Dialect Operations diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td index aaaa3b0d5b52..13d252cf056e 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -23,6 +23,7 @@ def Arith_Dialect : Dialect { }]; let hasConstantMaterializer = 1; + let useDefaultAttributePrinterParser = 1; } // The predicate indicates the type of the comparison to perform: @@ -92,4 +93,32 @@ def AtomicRMWKindAttr : I64EnumAttr< let cppNamespace = "::mlir::arith"; } +def FASTMATH_NONE : I32BitEnumAttrCaseNone<"none" >; +def FASTMATH_REASSOC : I32BitEnumAttrCaseBit<"reassoc", 0>; +def FASTMATH_NO_NANS : I32BitEnumAttrCaseBit<"nnan", 1>; +def FASTMATH_NO_INFS : I32BitEnumAttrCaseBit<"ninf", 2>; +def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz", 3>; +def FASTMATH_ALLOW_RECIP : I32BitEnumAttrCaseBit<"arcp", 4>; +def FASTMATH_ALLOW_CONTRACT : I32BitEnumAttrCaseBit<"contract", 5>; +def FASTMATH_APPROX_FUNC : I32BitEnumAttrCaseBit<"afn", 6>; +def FASTMATH_FAST : I32BitEnumAttrCaseGroup< + "fast", + [ + FASTMATH_REASSOC, FASTMATH_NO_NANS, FASTMATH_NO_INFS, + FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT, + FASTMATH_APPROX_FUNC]>; + +def FastMathFlags : I32BitEnumAttr< + "FastMathFlags", + "Floating point fast math flags", + [ + FASTMATH_NONE, FASTMATH_REASSOC, FASTMATH_NO_NANS, + FASTMATH_NO_INFS, FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, + FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FAST]> { + let separator = ","; + let cppNamespace = "::mlir::arith"; + let genSpecializedAttr = 0; + let printBitEnumPrimaryGroups = 1; +} + #endif // ARITH_BASE diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 692338eb8370..f12a1a33f691 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -10,6 +10,7 @@ #define ARITH_OPS include "mlir/Dialect/Arith/IR/ArithBase.td" +include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -17,6 +18,12 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/EnumAttr.td" + +def Arith_FastMathAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} // Base class for Arith dialect ops. Ops in this dialect have no side // effects and can be applied element-wise to vectors and tensors. @@ -58,15 +65,27 @@ class Arith_IntBinaryOp traits = []> : // Base class for floating point unary operations. class Arith_FloatUnaryOp traits = []> : - Arith_UnaryOp, - Arguments<(ins FloatLike:$operand)>, - Results<(outs FloatLike:$result)>; + Arith_UnaryOp], + traits)>, + Arguments<(ins FloatLike:$operand, + DefaultValuedAttr:$fastmath)>, + Results<(outs FloatLike:$result)> { + let assemblyFormat = [{ $operand custom($fastmath) + attr-dict `:` type($result) }]; +} // Base class for floating point binary operations. class Arith_FloatBinaryOp traits = []> : - Arith_BinaryOp, - Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>, - Results<(outs FloatLike:$result)>; + Arith_BinaryOp], + traits)>, + Arguments<(ins FloatLike:$lhs, FloatLike:$rhs, + DefaultValuedAttr:$fastmath)>, + Results<(outs FloatLike:$result)> { + let assemblyFormat = [{ $lhs `,` $rhs `` custom($fastmath) + attr-dict `:` type($result) }]; +} // Base class for arithmetic cast operations. Requires a single operand and // result. If either is a shaped type, then the other must be of the same shape. diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td new file mode 100644 index 000000000000..acaecf6f409d --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td @@ -0,0 +1,52 @@ +//===-- ArithOpsInterfaces.td - arith op interfaces ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the Arith interfaces definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef ARITH_OPS_INTERFACES +#define ARITH_OPS_INTERFACES + +include "mlir/IR/OpBase.td" + +def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> { + let description = [{ + Access to operation fastmath flags. + }]; + + let cppNamespace = "::mlir::arith"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation", + /*returnType=*/ "FastMathFlagsAttr", + /*methodName=*/ "getFastMathFlagsAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getFastmathAttr(); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute + for the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getFastMathAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "fastmath"; + }] + > + + ]; +} + +#endif // ARITH_OPS_INTERFACES diff --git a/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt index 93ff719e677b..5cdde2edd50f 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt @@ -1,5 +1,14 @@ set(LLVM_TARGET_DEFINITIONS ArithOps.td) mlir_tablegen(ArithOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(ArithOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(ArithOpsAttributes.h.inc -gen-attrdef-decls + -attrdefs-dialect=arith) +mlir_tablegen(ArithOpsAttributes.cpp.inc -gen-attrdef-defs + -attrdefs-dialect=arith) add_mlir_dialect(ArithOps arith) add_mlir_doc(ArithOps ArithOps Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS ArithOpsInterfaces.td) +mlir_tablegen(ArithOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ArithOpsInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRArithOpsInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index ca38f072d9df..64e5e0abfd76 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -48,7 +48,6 @@ class SmartMutex; namespace mlir { namespace LLVM { class LLVMDialect; -class LoopOptionsAttrBuilder; namespace detail { struct LLVMTypeStorage; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td index 671e19a36d5e..d9c1a41bd2b6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td @@ -23,8 +23,28 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { let cppNamespace = "::mlir::LLVM"; let methods = [ - InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", - "getFastmathFlags">, + InterfaceMethod< + /*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation", + /*returnType=*/ "FastmathFlagsAttr", + /*methodName=*/ "getFastmathAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getFastmathFlagsAttr(); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute + for the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getFastmathAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "fastmathFlags"; + }] + > ]; } diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 1610e5cee8b7..cbaa67c21532 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -24,16 +24,93 @@ using namespace mlir; namespace { +// Map arithmetic fastmath enum values to LLVMIR enum values. +static LLVM::FastmathFlags +convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) { + LLVM::FastmathFlags llvmFMF{}; + const std::pair flags[] = { + {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, + {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, + {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, + {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, + {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract}, + {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn}, + {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; + for (auto fmfMap : flags) { + if (bitEnumContainsAny(arithFMF, fmfMap.first)) + llvmFMF = llvmFMF | fmfMap.second; + } + return llvmFMF; +} + +// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute. +static LLVM::FastmathFlagsAttr +convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) { + arith::FastMathFlags arithFMF = fmfAttr.getValue(); + return LLVM::FastmathFlagsAttr::get( + fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF)); +} + +// Attribute converter that populates a NamedAttrList by removing the fastmath +// attribute from the source operation attributes, and replacing it with an +// equivalent LLVM fastmath attribute. +template +class AttrConvertFastMath { +public: + AttrConvertFastMath(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + // Get the name of the arith fastmath attribute. + llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); + // Remove the source fastmath attribute. + auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName) + .dyn_cast_or_null(); + if (arithFMFAttr) { + llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName(); + convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr)); + } + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; + +// Attribute converter that populates a NamedAttrList by removing the fastmath +// attribute from the source operation attributes. This may be useful for +// target operations that do not require the fastmath attribute, or for targets +// that do not yet support the LLVM fastmath attribute. +template +class AttrDropFastMath { +public: + AttrDropFastMath(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + // Get the name of the arith fastmath attribute. + llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); + // Remove the source fastmath attribute. + convertedAttr.erase(arithFMFAttrName); + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; + //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// -using AddFOpLowering = VectorConvertToLLVMPattern; +using AddFOpLowering = VectorConvertToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; using AndIOpLowering = VectorConvertToLLVMPattern; using BitcastOpLowering = VectorConvertToLLVMPattern; -using DivFOpLowering = VectorConvertToLLVMPattern; +using DivFOpLowering = VectorConvertToLLVMPattern; using DivSIOpLowering = VectorConvertToLLVMPattern; using DivUIOpLowering = @@ -47,23 +124,29 @@ using FPToSIOpLowering = VectorConvertToLLVMPattern; using FPToUIOpLowering = VectorConvertToLLVMPattern; +// TODO: Add LLVM intrinsic support for fastmath using MaxFOpLowering = - VectorConvertToLLVMPattern; + VectorConvertToLLVMPattern; using MaxSIOpLowering = VectorConvertToLLVMPattern; using MaxUIOpLowering = VectorConvertToLLVMPattern; +// TODO: Add LLVM intrinsic support for fastmath using MinFOpLowering = - VectorConvertToLLVMPattern; + VectorConvertToLLVMPattern; using MinSIOpLowering = VectorConvertToLLVMPattern; using MinUIOpLowering = VectorConvertToLLVMPattern; -using MulFOpLowering = VectorConvertToLLVMPattern; +using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; -using NegFOpLowering = VectorConvertToLLVMPattern; +using NegFOpLowering = VectorConvertToLLVMPattern; using OrIOpLowering = VectorConvertToLLVMPattern; -using RemFOpLowering = VectorConvertToLLVMPattern; +// TODO: Add LLVM intrinsic support for fastmath +using RemFOpLowering = + VectorConvertToLLVMPattern; using RemSIOpLowering = VectorConvertToLLVMPattern; using RemUIOpLowering = @@ -77,7 +160,8 @@ using ShRUIOpLowering = VectorConvertToLLVMPattern; using SIToFPOpLowering = VectorConvertToLLVMPattern; -using SubFOpLowering = VectorConvertToLLVMPattern; +using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using TruncFOpLowering = VectorConvertToLLVMPattern; @@ -153,7 +237,7 @@ LogicalResult ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), - adaptor.getOperands(), + adaptor.getOperands(), op->getAttrs(), *getTypeConverter(), rewriter); } diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index f37d47d744cf..14f32d66d244 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -90,7 +90,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite( op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), - *getTypeConverter(), rewriter); + op->getAttrs(), *getTypeConverter(), rewriter); } }; diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 96d83eec1805..8413dcfc8395 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -308,7 +308,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); SmallVector resultTypes; @@ -322,7 +323,7 @@ LogicalResult LLVM::detail::oneToOneRewrite( // Create the operation through state since we don't know its C++ type. Operation *newOp = rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - resultTypes, op->getAttrs()); + resultTypes, targetAttrs); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index 2f0091f99dd3..e95c702d79f3 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -105,7 +105,8 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. @@ -114,13 +115,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( auto llvmNDVectorTy = operands[0].getType(); if (!llvmNDVectorTy.isa()) - return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); + return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, + rewriter); - auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, - ValueRange operands) { + auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy, + ValueRange operands) { return rewriter .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - llvm1DVectorTy, op->getAttrs()) + llvm1DVectorTy, targetAttrs) ->getResult(0); }; diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 2cb5a553634b..a30ba2eff641 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -215,17 +215,21 @@ def OrOfExtSI : //===----------------------------------------------------------------------===// // mulf(negf(x), negf(y)) -> mulf(x,y) +// (retain fastmath flags of original mulf) def MulFOfNegF : - Pat<(Arith_MulFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_MulFOp $x, $y), - [(Constraint> $x, $y)]>; + Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf), + (Arith_MulFOp $x, $y, $fmf), + [(Constraint> $x, $y)]>; //===----------------------------------------------------------------------===// // DivFOp //===----------------------------------------------------------------------===// // divf(negf(x), negf(y)) -> divf(x,y) +// (retain fastmath flags of original divf) def DivFOfNegF : - Pat<(Arith_DivFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_DivFOp $x, $y), - [(Constraint> $x, $y)]>; + Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf), + (Arith_DivFOp $x, $y, $fmf), + [(Constraint> $x, $y)]>; #endif // ARITH_PATTERNS diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp index 35cb7da2d48a..b15f7f05b853 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -8,12 +8,17 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::arith; #include "mlir/Dialect/Arith/IR/ArithOpsDialect.cpp.inc" +#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc" namespace { /// This class defines the interface for handling inlining for arithmetic @@ -34,6 +39,10 @@ void arith::ArithDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc" + >(); addInterfaces(); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d1d03a549092..5693ad1c0e8d 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -23,6 +23,31 @@ using namespace mlir; using namespace mlir::arith; +//===----------------------------------------------------------------------===// +// Floating point op parse/print helpers +//===----------------------------------------------------------------------===// +static ParseResult parseArithFastMathAttr(OpAsmParser &parser, + Attribute &attr) { + if (succeeded( + parser.parseOptionalKeyword(FastMathFlagsAttr::getMnemonic()))) { + attr = FastMathFlagsAttr::parse(parser, Type{}); + return success(static_cast(attr)); + } else { + // No fastmath attribute mnemonic present - defer attribute creation and use + // the default value. + return success(); + } +} + +static void printArithFastMathAttr(OpAsmPrinter &printer, Operation *op, + FastMathFlagsAttr fmAttr) { + // Elide printing the fastmath attribute when fastmath=none + if (fmAttr && (fmAttr.getValue() != FastMathFlags::none)) { + printer << " " << FastMathFlagsAttr::getMnemonic(); + fmAttr.print(printer); + } +} + //===----------------------------------------------------------------------===// // Pattern helpers //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt index ffd7ee327955..0de17bbfbd12 100644 --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRArithDialect DEPENDS MLIRArithOpsIncGen + MLIRArithOpsInterfacesIncGen LINK_LIBS PUBLIC MLIRDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index ea86819faf1c..1aee27560ea3 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -51,13 +51,13 @@ struct LowerToIntrinsic : public OpConversionPattern { Type elementType = getSrcVectorElementType(op); unsigned bitwidth = elementType.getIntOrFloatBitWidth(); if (bitwidth == 32) - return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), - adaptor.getOperands(), - getTypeConverter(), rewriter); + return LLVM::detail::oneToOneRewrite( + op, Intr32OpTy::getOperationName(), adaptor.getOperands(), + op->getAttrs(), getTypeConverter(), rewriter); if (bitwidth == 64) - return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), - adaptor.getOperands(), - getTypeConverter(), rewriter); + return LLVM::detail::oneToOneRewrite( + op, Intr64OpTy::getOperationName(), adaptor.getOperands(), + op->getAttrs(), getTypeConverter(), rewriter); return rewriter.notifyMatchFailure( op, "expected 'src' to be either f32 or f64"); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index e81c052e6de0..1f89a55ee363 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -160,9 +160,9 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { // clang-format on }; llvm::FastMathFlags ret; - auto fmf = op.getFastmathFlags(); + ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue(); for (auto it : handlers) - if (bitEnumContainsAll(fmf, it.first)) + if (bitEnumContainsAll(fmfMlir, it.first)) (ret.*(it.second))(true); return ret; } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index e1d6133bd065..308a3d87a8d1 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -309,7 +309,7 @@ int collectStats(MlirOperation operation) { // clang-format off // CHECK-LABEL: @stats // CHECK: Number of operations: 12 - // CHECK: Number of attributes: 4 + // CHECK: Number of attributes: 5 // CHECK: Number of blocks: 3 // CHECK: Number of regions: 3 // CHECK: Number of values: 9 diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 05706d89de74..81f402195fb4 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -448,3 +448,20 @@ func.func @minmaxf(%arg0 : f32, %arg1 : f32) -> f32 { %1 = arith.maxf %arg0, %arg1 : f32 return %0 : f32 } + +// ----- + +// CHECK-LABEL: @fastmath +func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) { +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32 +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 + %0 = arith.addf %arg0, %arg1 fastmath : f32 + %1 = arith.mulf %arg0, %arg1 fastmath : f32 + %2 = arith.negf %arg0 fastmath : f32 + %3 = arith.addf %arg0, %arg1 fastmath : f32 + %4 = arith.addf %arg0, %arg1 fastmath : f32 + return +} diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index c34850ff6e30..9d5c686d73b5 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -1031,3 +1031,27 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>, %min_unsigned = arith.minui %i1, %i2 : i32 return } + +// CHECK-LABEL: @fastmath +func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) { +// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.subf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.divf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.remf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.negf %arg0 fastmath : f32 + %0 = arith.addf %arg0, %arg1 fastmath : f32 + %1 = arith.subf %arg0, %arg1 fastmath : f32 + %2 = arith.mulf %arg0, %arg1 fastmath : f32 + %3 = arith.divf %arg0, %arg1 fastmath : f32 + %4 = arith.remf %arg0, %arg1 fastmath : f32 + %5 = arith.negf %arg0 fastmath : f32 +// CHECK: {{.*}} = arith.addf %arg0, %arg1 : f32 + %6 = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath : f32 + %7 = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath : f32 + %8 = arith.mulf %arg0, %arg1 fastmath : f32 + + return +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index e6ab837141f1..9200c6117a49 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -270,7 +270,7 @@ func.func @generic_result_tensor_type(%arg0: memref // ----- func.func @generic(%arg0: memref) { - // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32}} + // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) {fastmath = #arith.fastmath} : (f32, f32) -> f32}} linalg.generic { indexing_maps = [ affine_map<(i, j) -> (i, j)> ], iterator_types = ["parallel", "parallel"]} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index fcfaf86f1229..1d0f50fc6552 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1001,7 +1001,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) { for (FormatElement *param : dir->getArguments()) { if (auto *attr = dyn_cast(param)) { const NamedAttribute *var = attr->getVar(); - if (var->attr.isOptional()) + if (var->attr.isOptional() || var->attr.hasDefaultValue()) body << llvm::formatv(" if ({0}Attr)\n ", var->name); body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",