[mlir][math] Add math.ctlz expansion to control flow + arith operations
Ctlz is an intrinsic in LLVM but does not have equivalent operations in SPIR-V. Including a decomposition gives an alternative path for these platforms. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D126261
This commit is contained in:
parent
02f640672e
commit
f3bdb56d61
|
@ -13,6 +13,7 @@ namespace mlir {
|
|||
|
||||
class RewritePatternSet;
|
||||
|
||||
void populateExpandCtlzPattern(RewritePatternSet &patterns);
|
||||
void populateExpandTanhPattern(RewritePatternSet &patterns);
|
||||
|
||||
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
add_mlir_dialect_library(MLIRMathTransforms
|
||||
AlgebraicSimplification.cpp
|
||||
ExpandTanh.cpp
|
||||
ExpandPatterns.cpp
|
||||
PolynomialApproximation.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
|
@ -53,6 +54,67 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
auto operand = op.getOperand();
|
||||
auto elementTy = operand.getType();
|
||||
auto resultTy = op.getType();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
int bitWidth = elementTy.getIntOrFloatBitWidth();
|
||||
auto zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
|
||||
auto leadingZeros = rewriter.create<arith::ConstantOp>(
|
||||
loc, IntegerAttr::get(elementTy, bitWidth));
|
||||
|
||||
SmallVector<Value> operands = {operand, leadingZeros, zero};
|
||||
SmallVector<Type> types = {elementTy, elementTy, elementTy};
|
||||
SmallVector<Location> locations = {loc, loc, loc};
|
||||
|
||||
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
|
||||
Block *before =
|
||||
rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
|
||||
Block *after =
|
||||
rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
|
||||
|
||||
// The conditional block of the while loop.
|
||||
{
|
||||
rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
|
||||
Value input = before->getArgument(0);
|
||||
Value zero = before->getArgument(2);
|
||||
|
||||
Value inputNotZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ne, input, zero);
|
||||
rewriter.create<scf::ConditionOp>(loc, inputNotZero,
|
||||
before->getArguments());
|
||||
}
|
||||
|
||||
// The body of the while loop: shift right until reaching a value of 0.
|
||||
{
|
||||
rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
|
||||
Value input = after->getArgument(0);
|
||||
Value leadingZeros = after->getArgument(1);
|
||||
|
||||
auto one =
|
||||
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
|
||||
auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
|
||||
auto leadingZerosMinusOne =
|
||||
rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
|
||||
|
||||
rewriter.create<scf::YieldOp>(
|
||||
loc,
|
||||
ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointAfter(whileOp);
|
||||
rewriter.replaceOp(op, whileOp->getResult(1));
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
|
||||
patterns.add(convertCtlzOp);
|
||||
}
|
||||
|
||||
void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
|
||||
patterns.add(convertTanhOp);
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -test-expand-tanh | FileCheck %s
|
||||
// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @tanh
|
||||
func.func @tanh(%arg: f32) -> f32 {
|
||||
|
@ -21,3 +21,22 @@ func.func @tanh(%arg: f32) -> f32 {
|
|||
// CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32
|
||||
// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: func @ctlz
|
||||
func.func @ctlz(%arg: i32) -> i32 {
|
||||
// CHECK: %[[C0:.+]] = arith.constant 0 : i32
|
||||
// CHECK: %[[C32:.+]] = arith.constant 32 : i32
|
||||
// CHECK: %[[C1:.+]] = arith.constant 1 : i32
|
||||
// CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]])
|
||||
// CHECK: %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]]
|
||||
// CHECK: scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]]
|
||||
// CHECK: %[[SHR:.+]] = arith.shrui %[[A1]], %[[C1]]
|
||||
// CHECK: %[[SUB:.+]] = arith.subi %[[A2]], %[[C1]]
|
||||
// CHECK: scf.yield %[[SHR]], %[[SUB]], %[[A3]]
|
||||
%res = math.ctlz %arg : i32
|
||||
|
||||
// CHECK: return %[[WHILE]]#1
|
||||
return %res : i32
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
# Exclude tests from libMLIR.so
|
||||
add_mlir_library(MLIRMathTestPasses
|
||||
TestAlgebraicSimplification.cpp
|
||||
TestExpandTanh.cpp
|
||||
TestExpandMath.cpp
|
||||
TestPolynomialApproximation.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- TestExpandTanh.cpp - Test expand tanh op into exp form -------------===//
|
||||
//===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,35 +6,41 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains test passes for expanding tanh.
|
||||
// This file contains test passes for expanding math operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestExpandTanhPass
|
||||
: public PassWrapper<TestExpandTanhPass, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandTanhPass)
|
||||
struct TestExpandMathPass
|
||||
: public PassWrapper<TestExpandMathPass, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
|
||||
|
||||
void runOnOperation() override;
|
||||
StringRef getArgument() const final { return "test-expand-tanh"; }
|
||||
StringRef getDescription() const final { return "Test expanding tanh"; }
|
||||
StringRef getArgument() const final { return "test-expand-math"; }
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
|
||||
}
|
||||
StringRef getDescription() const final { return "Test expanding math"; }
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void TestExpandTanhPass::runOnOperation() {
|
||||
void TestExpandMathPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateExpandCtlzPattern(patterns);
|
||||
populateExpandTanhPattern(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestExpandTanhPass() { PassRegistration<TestExpandTanhPass>(); }
|
||||
void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
|
@ -76,7 +76,7 @@ void registerTestDecomposeCallGraphTypes();
|
|||
void registerTestDiagnosticsPass();
|
||||
void registerTestDominancePass();
|
||||
void registerTestDynamicPipelinePass();
|
||||
void registerTestExpandTanhPass();
|
||||
void registerTestExpandMathPass();
|
||||
void registerTestComposeSubView();
|
||||
void registerTestMultiBuffering();
|
||||
void registerTestIRVisitorsPass();
|
||||
|
@ -172,7 +172,7 @@ void registerTestPasses() {
|
|||
mlir::test::registerTestDataLayoutQuery();
|
||||
mlir::test::registerTestDominancePass();
|
||||
mlir::test::registerTestDynamicPipelinePass();
|
||||
mlir::test::registerTestExpandTanhPass();
|
||||
mlir::test::registerTestExpandMathPass();
|
||||
mlir::test::registerTestComposeSubView();
|
||||
mlir::test::registerTestMultiBuffering();
|
||||
mlir::test::registerTestIRVisitorsPass();
|
||||
|
|
Loading…
Reference in New Issue