[mlir][math] Add math.ctlz expansion to control flow + arith operations

Ctlz is an intrinsic in LLVM but does not have equivalent operations in SPIR-V.
Including a decomposition gives an alternative path for these platforms.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D126261
This commit is contained in:
Rob Suderman 2022-06-01 11:35:26 -07:00 committed by Rob Suderman
parent 02f640672e
commit f3bdb56d61
7 changed files with 102 additions and 14 deletions

View File

@ -13,6 +13,7 @@ namespace mlir {
class RewritePatternSet;
void populateExpandCtlzPattern(RewritePatternSet &patterns);
void populateExpandTanhPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

View File

@ -1,6 +1,6 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandTanh.cpp
ExpandPatterns.cpp
PolynomialApproximation.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"
@ -53,6 +54,67 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
return success();
}
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
PatternRewriter &rewriter) {
auto operand = op.getOperand();
auto elementTy = operand.getType();
auto resultTy = op.getType();
Location loc = op.getLoc();
int bitWidth = elementTy.getIntOrFloatBitWidth();
auto zero =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto leadingZeros = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, bitWidth));
SmallVector<Value> operands = {operand, leadingZeros, zero};
SmallVector<Type> types = {elementTy, elementTy, elementTy};
SmallVector<Location> locations = {loc, loc, loc};
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
Block *before =
rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
Block *after =
rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
// The conditional block of the while loop.
{
rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
Value input = before->getArgument(0);
Value zero = before->getArgument(2);
Value inputNotZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, input, zero);
rewriter.create<scf::ConditionOp>(loc, inputNotZero,
before->getArguments());
}
// The body of the while loop: shift right until reaching a value of 0.
{
rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
Value input = after->getArgument(0);
Value leadingZeros = after->getArgument(1);
auto one =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
auto leadingZerosMinusOne =
rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
rewriter.create<scf::YieldOp>(
loc,
ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
}
rewriter.setInsertionPointAfter(whileOp);
rewriter.replaceOp(op, whileOp->getResult(1));
return success();
}
void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
patterns.add(convertCtlzOp);
}
void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
patterns.add(convertTanhOp);
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-expand-tanh | FileCheck %s
// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
// CHECK-LABEL: func @tanh
func.func @tanh(%arg: f32) -> f32 {
@ -21,3 +21,22 @@ func.func @tanh(%arg: f32) -> f32 {
// CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32
// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
// CHECK: return %[[RESULT]]
// ----
// CHECK-LABEL: func @ctlz
func.func @ctlz(%arg: i32) -> i32 {
// CHECK: %[[C0:.+]] = arith.constant 0 : i32
// CHECK: %[[C32:.+]] = arith.constant 32 : i32
// CHECK: %[[C1:.+]] = arith.constant 1 : i32
// CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]])
// CHECK: %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]]
// CHECK: scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]]
// CHECK: %[[SHR:.+]] = arith.shrui %[[A1]], %[[C1]]
// CHECK: %[[SUB:.+]] = arith.subi %[[A2]], %[[C1]]
// CHECK: scf.yield %[[SHR]], %[[SUB]], %[[A3]]
%res = math.ctlz %arg : i32
// CHECK: return %[[WHILE]]#1
return %res : i32
}

View File

@ -1,7 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMathTestPasses
TestAlgebraicSimplification.cpp
TestExpandTanh.cpp
TestExpandMath.cpp
TestPolynomialApproximation.cpp
EXCLUDE_FROM_LIBMLIR

View File

@ -1,4 +1,4 @@
//===- TestExpandTanh.cpp - Test expand tanh op into exp form -------------===//
//===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,35 +6,41 @@
//
//===----------------------------------------------------------------------===//
//
// This file contains test passes for expanding tanh.
// This file contains test passes for expanding math operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TestExpandTanhPass
: public PassWrapper<TestExpandTanhPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandTanhPass)
struct TestExpandMathPass
: public PassWrapper<TestExpandMathPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
void runOnOperation() override;
StringRef getArgument() const final { return "test-expand-tanh"; }
StringRef getDescription() const final { return "Test expanding tanh"; }
StringRef getArgument() const final { return "test-expand-math"; }
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
}
StringRef getDescription() const final { return "Test expanding math"; }
};
} // namespace
void TestExpandTanhPass::runOnOperation() {
void TestExpandMathPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateExpandCtlzPattern(patterns);
populateExpandTanhPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
namespace mlir {
namespace test {
void registerTestExpandTanhPass() { PassRegistration<TestExpandTanhPass>(); }
void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
} // namespace test
} // namespace mlir

View File

@ -76,7 +76,7 @@ void registerTestDecomposeCallGraphTypes();
void registerTestDiagnosticsPass();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestExpandTanhPass();
void registerTestExpandMathPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
void registerTestIRVisitorsPass();
@ -172,7 +172,7 @@ void registerTestPasses() {
mlir::test::registerTestDataLayoutQuery();
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
mlir::test::registerTestExpandTanhPass();
mlir::test::registerTestExpandMathPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIRVisitorsPass();