Extend loop unrolling to unroll by a given factor; add builder for affine
apply op. - add builder for AffineApplyOp (first one for an operation that has non-zero operands) - add support for loop unrolling by a given factor; uses the affine apply op builder. While on this, change 'step' of ForStmt to be 'unsigned' instead of AffineConstantExpr *. Add setters for ForStmt lb, ub, step. Sample Input: // CHECK-LABEL: mlfunc @loop_nest_unroll_cleanup() { mlfunc @loop_nest_unroll_cleanup() { for %i = 1 to 100 { for %j = 0 to 17 { %x = "addi32"(%j, %j) : (affineint, affineint) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } } return } Output: $ mlir-opt -loop-unroll -unroll-factor=4 /tmp/single2.mlir #map0 = (d0) -> (d0 + 1) #map1 = (d0) -> (d0 + 2) #map2 = (d0) -> (d0 + 3) mlfunc @loop_nest_unroll_cleanup() { for %i0 = 1 to 100 { for %i1 = 0 to 17 step 4 { %0 = "addi32"(%i1, %i1) : (affineint, affineint) -> i32 %1 = "addi32"(%0, %0) : (i32, i32) -> i32 %2 = affine_apply #map0(%i1) %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32 %4 = affine_apply #map1(%i1) %5 = "addi32"(%4, %4) : (affineint, affineint) -> i32 %6 = affine_apply #map2(%i1) %7 = "addi32"(%6, %6) : (affineint, affineint) -> i32 } for %i2 = 16 to 17 { %8 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32 %9 = "addi32"(%8, %8) : (i32, i32) -> i32 } } return } PiperOrigin-RevId: 209676220
This commit is contained in:
parent
6911c24e97
commit
00bed4bd99
|
@ -252,7 +252,8 @@ public:
|
|||
this->insertPoint = insertPoint;
|
||||
}
|
||||
|
||||
/// Set the insertion point to the specified operation.
|
||||
/// Set the insertion point to the specified operation, which will cause
|
||||
/// subsequent insertions to go right before it.
|
||||
void setInsertionPoint(Statement *stmt) {
|
||||
setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
|
||||
}
|
||||
|
@ -298,8 +299,7 @@ public:
|
|||
|
||||
// Creates for statement. When step is not specified, it is set to 1.
|
||||
ForStmt *createFor(AffineConstantExpr *lowerBound,
|
||||
AffineConstantExpr *upperBound,
|
||||
AffineConstantExpr *step = nullptr);
|
||||
AffineConstantExpr *upperBound, int64_t step = 1);
|
||||
|
||||
IfStmt *createIf(IntegerSet *condition) {
|
||||
auto *stmt = new IfStmt(condition);
|
||||
|
|
|
@ -47,6 +47,7 @@ typedef std::pair<Identifier, Attribute*> NamedAttribute;
|
|||
struct OperationState {
|
||||
Identifier name;
|
||||
SmallVector<SSAValue *, 4> operands;
|
||||
/// Types of the results of this operation.
|
||||
SmallVector<Type *, 4> types;
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
|
||||
|
|
|
@ -77,6 +77,10 @@ private:
|
|||
class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
|
||||
OpTrait::VariadicResults> {
|
||||
public:
|
||||
/// Builds an affine apply op with the specified map and operands.
|
||||
static OperationState build(Builder *builder, AffineMap *map,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
|
||||
// Returns the affine map to be applied by this operation.
|
||||
AffineMap *getAffineMap() const {
|
||||
return getAttrOfType<AffineMapAttr>("map")->getValue();
|
||||
|
@ -163,6 +167,7 @@ protected:
|
|||
///
|
||||
class ConstantFloatOp : public ConstantOp {
|
||||
public:
|
||||
/// Builds a constant float op producing a float of the specified type.
|
||||
static OperationState build(Builder *builder, double value, FloatType *type);
|
||||
|
||||
double getValue() const {
|
||||
|
|
|
@ -199,7 +199,7 @@ public:
|
|||
// TODO: lower and upper bounds should be affine maps with
|
||||
// dimension and symbol use lists.
|
||||
explicit ForStmt(AffineConstantExpr *lowerBound,
|
||||
AffineConstantExpr *upperBound, AffineConstantExpr *step,
|
||||
AffineConstantExpr *upperBound, int64_t step,
|
||||
MLIRContext *context);
|
||||
|
||||
~ForStmt() {
|
||||
|
@ -216,7 +216,11 @@ public:
|
|||
|
||||
AffineConstantExpr *getLowerBound() const { return lowerBound; }
|
||||
AffineConstantExpr *getUpperBound() const { return upperBound; }
|
||||
AffineConstantExpr *getStep() const { return step; }
|
||||
int64_t getStep() const { return step; }
|
||||
|
||||
void setLowerBound(AffineConstantExpr *lb) { lowerBound = lb; }
|
||||
void setUpperBound(AffineConstantExpr *ub) { upperBound = ub; }
|
||||
void setStep(unsigned s) { step = s; }
|
||||
|
||||
using Statement::dump;
|
||||
using Statement::print;
|
||||
|
@ -242,7 +246,7 @@ private:
|
|||
// an affinemap and its operands as AffineBound.
|
||||
AffineConstantExpr *lowerBound;
|
||||
AffineConstantExpr *upperBound;
|
||||
AffineConstantExpr *step;
|
||||
int64_t step;
|
||||
};
|
||||
|
||||
/// An if clause represents statements contained within a then or an else clause
|
||||
|
|
|
@ -28,9 +28,9 @@ namespace mlir {
|
|||
class MLFunctionPass;
|
||||
class ModulePass;
|
||||
|
||||
/// Loop unrolling passes.
|
||||
MLFunctionPass *createLoopUnrollPass();
|
||||
MLFunctionPass *createLoopUnrollPass(unsigned);
|
||||
// Loop unrolling passes.
|
||||
/// Creates a loop unrolling pass.
|
||||
MLFunctionPass *createLoopUnrollPass(int unrollFactor, int unrollFull);
|
||||
|
||||
/// Replaces all ML functions in the module with equivalent CFG functions.
|
||||
/// Function references are appropriately patched to refer to the newly
|
||||
|
|
|
@ -1316,8 +1316,8 @@ void MLFunctionPrinter::print(const ForStmt *stmt) {
|
|||
printOperand(stmt);
|
||||
os << " = " << *stmt->getLowerBound();
|
||||
os << " to " << *stmt->getUpperBound();
|
||||
if (stmt->getStep()->getValue() != 1)
|
||||
os << " step " << *stmt->getStep();
|
||||
if (stmt->getStep() != 1)
|
||||
os << " step " << stmt->getStep();
|
||||
|
||||
os << " {\n";
|
||||
print(static_cast<const StmtBlock *>(stmt));
|
||||
|
|
|
@ -103,8 +103,8 @@ ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
|
|||
return ArrayAttr::get(value, context);
|
||||
}
|
||||
|
||||
AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) {
|
||||
return AffineMapAttr::get(value, context);
|
||||
AffineMapAttr *Builder::getAffineMapAttr(AffineMap *map) {
|
||||
return AffineMapAttr::get(map, context);
|
||||
}
|
||||
|
||||
TypeAttr *Builder::getTypeAttr(Type *type) {
|
||||
|
@ -207,9 +207,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
|
|||
|
||||
ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
|
||||
AffineConstantExpr *upperBound,
|
||||
AffineConstantExpr *step) {
|
||||
if (!step)
|
||||
step = getConstantExpr(1);
|
||||
int64_t step) {
|
||||
auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
|
||||
block->getStatements().insert(insertPoint, stmt);
|
||||
return stmt;
|
||||
|
|
|
@ -304,6 +304,22 @@ OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
|
|||
return result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineApplyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperationState AffineApplyOp::build(Builder *builder, AffineMap *map,
|
||||
ArrayRef<SSAValue *> operands) {
|
||||
SmallVector<Type *, 4> resultTypes(map->getNumResults(),
|
||||
builder->getAffineIntType());
|
||||
|
||||
OperationState result(
|
||||
builder->getIdentifier("affine_apply"), operands, resultTypes,
|
||||
{{builder->getIdentifier("map"), builder->getAffineMapAttr(map)}});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DeallocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -198,7 +198,7 @@ MLIRContext *OperationStmt::getContext() const {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
|
||||
AffineConstantExpr *step, MLIRContext *context)
|
||||
int64_t step, MLIRContext *context)
|
||||
: Statement(Kind::For),
|
||||
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
|
||||
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
|
||||
|
|
|
@ -2164,7 +2164,7 @@ ParseResult MLFunctionParser::parseFunctionBody() {
|
|||
ParseResult MLFunctionParser::parseForStmt() {
|
||||
consumeToken(Token::kw_for);
|
||||
|
||||
// Parse induction variable
|
||||
// Parse induction variable.
|
||||
if (getToken().isNot(Token::percent_identifier))
|
||||
return emitError("expected SSA identifier for the loop variable");
|
||||
|
||||
|
@ -2175,7 +2175,7 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
if (parseToken(Token::equal, "expected '='"))
|
||||
return ParseFailure;
|
||||
|
||||
// Parse loop bounds
|
||||
// Parse loop bounds.
|
||||
AffineConstantExpr *lowerBound = parseIntConstant();
|
||||
if (!lowerBound)
|
||||
return ParseFailure;
|
||||
|
@ -2187,12 +2187,13 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
if (!upperBound)
|
||||
return ParseFailure;
|
||||
|
||||
// Parse step
|
||||
AffineConstantExpr *step = nullptr;
|
||||
// Parse step.
|
||||
int64_t step = 1;
|
||||
if (consumeIf(Token::kw_step)) {
|
||||
step = parseIntConstant();
|
||||
if (!step)
|
||||
AffineConstantExpr *stepExpr = parseIntConstant();
|
||||
if (!stepExpr)
|
||||
return ParseFailure;
|
||||
step = stepExpr->getValue();
|
||||
}
|
||||
|
||||
// Create for statement.
|
||||
|
|
|
@ -31,30 +31,49 @@
|
|||
#include "mlir/Transforms/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace llvm;
|
||||
|
||||
// Loop unrolling factor.
|
||||
static llvm::cl::opt<unsigned>
|
||||
clUnrollFactor("unroll-factor", cl::Hidden,
|
||||
cl::desc("Use this unroll factor for all loops"));
|
||||
|
||||
static llvm::cl::opt<bool> clUnrollFull("unroll-full", cl::Hidden,
|
||||
cl::desc("Fully unroll loops"));
|
||||
|
||||
static llvm::cl::opt<unsigned> clUnrollFullThreshold(
|
||||
"unroll-full-threshold", cl::Hidden,
|
||||
cl::desc("Unroll all loops with trip count less than or equal to this"));
|
||||
|
||||
namespace {
|
||||
/// Loop unrolling pass. For now, this unrolls all the innermost loops of this
|
||||
/// MLFunction.
|
||||
/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
|
||||
/// full unroll threshold was specified, in which case, fully unrolls all loops
|
||||
/// with trip count less than the specified threshold. The latter is for testing
|
||||
/// purposes, especially for testing outer loop unrolling.
|
||||
struct LoopUnroll : public MLFunctionPass {
|
||||
void runOnMLFunction(MLFunction *f) override;
|
||||
void runOnForStmt(ForStmt *forStmt);
|
||||
};
|
||||
Optional<unsigned> unrollFactor;
|
||||
Optional<bool> unrollFull;
|
||||
|
||||
explicit LoopUnroll(Optional<unsigned> unrollFactor,
|
||||
Optional<bool> unrollFull)
|
||||
: unrollFactor(unrollFactor), unrollFull(unrollFull) {}
|
||||
|
||||
/// Unrolls all loops with trip count <= minTripCount.
|
||||
struct ShortLoopUnroll : public LoopUnroll {
|
||||
const unsigned minTripCount;
|
||||
void runOnMLFunction(MLFunction *f) override;
|
||||
ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
|
||||
/// Unroll this for stmt. Returns false if nothing was done.
|
||||
bool runOnForStmt(ForStmt *forStmt);
|
||||
bool loopUnrollFull(ForStmt *forStmt);
|
||||
bool loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor);
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
|
||||
|
||||
MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
|
||||
return new ShortLoopUnroll(minTripCount);
|
||||
MLFunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
|
||||
return new LoopUnroll(unrollFactor == -1 ? None
|
||||
: Optional<unsigned>(unrollFactor),
|
||||
unrollFull == -1 ? None : Optional<bool>(unrollFull));
|
||||
}
|
||||
|
||||
void LoopUnroll::runOnMLFunction(MLFunction *f) {
|
||||
|
@ -81,7 +100,6 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
|
|||
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
|
||||
if (!hasInnerLoops)
|
||||
loops.push_back(forStmt);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -101,14 +119,6 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
|
|||
using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
|
||||
};
|
||||
|
||||
InnermostLoopGatherer ilg;
|
||||
ilg.walkPostOrder(f);
|
||||
auto &loops = ilg.loops;
|
||||
for (auto *forStmt : loops)
|
||||
runOnForStmt(forStmt);
|
||||
}
|
||||
|
||||
void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
|
||||
// Gathers all loops with trip count <= minTripCount.
|
||||
class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
|
||||
public:
|
||||
|
@ -120,27 +130,55 @@ void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
|
|||
void visitForStmt(ForStmt *forStmt) {
|
||||
auto lb = forStmt->getLowerBound()->getValue();
|
||||
auto ub = forStmt->getUpperBound()->getValue();
|
||||
auto step = forStmt->getStep()->getValue();
|
||||
auto step = forStmt->getStep();
|
||||
|
||||
if ((ub - lb) / step + 1 <= minTripCount)
|
||||
loops.push_back(forStmt);
|
||||
}
|
||||
};
|
||||
|
||||
ShortLoopGatherer slg(minTripCount);
|
||||
// Do a post order walk so that loops are gathered from innermost to
|
||||
// outermost (or else unrolling an outer one may delete gathered inner ones).
|
||||
slg.walkPostOrder(f);
|
||||
auto &loops = slg.loops;
|
||||
if (clUnrollFull.getNumOccurrences() > 0 &&
|
||||
clUnrollFullThreshold.getNumOccurrences() > 0) {
|
||||
ShortLoopGatherer slg(clUnrollFullThreshold);
|
||||
// Do a post order walk so that loops are gathered from innermost to
|
||||
// outermost (or else unrolling an outer one may delete gathered inner
|
||||
// ones).
|
||||
slg.walkPostOrder(f);
|
||||
auto &loops = slg.loops;
|
||||
for (auto *forStmt : loops)
|
||||
loopUnrollFull(forStmt);
|
||||
return;
|
||||
}
|
||||
|
||||
InnermostLoopGatherer ilg;
|
||||
ilg.walkPostOrder(f);
|
||||
auto &loops = ilg.loops;
|
||||
for (auto *forStmt : loops)
|
||||
runOnForStmt(forStmt);
|
||||
}
|
||||
|
||||
/// Unroll this For loop completely.
|
||||
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
|
||||
/// Unroll a for stmt. Default unroll factor is 4.
|
||||
bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
|
||||
// Unroll completely if full loop unroll was specified.
|
||||
if (clUnrollFull.getNumOccurrences() > 0 ||
|
||||
(unrollFull.hasValue() && unrollFull.getValue()))
|
||||
return loopUnrollFull(forStmt);
|
||||
|
||||
// Unroll by the specified factor if one was specified.
|
||||
if (clUnrollFactor.getNumOccurrences() > 0)
|
||||
return loopUnrollByFactor(forStmt, clUnrollFactor);
|
||||
else if (unrollFactor.hasValue())
|
||||
return loopUnrollByFactor(forStmt, unrollFactor.getValue());
|
||||
|
||||
// Unroll by four otherwise.
|
||||
return loopUnrollByFactor(forStmt, 4);
|
||||
}
|
||||
|
||||
// Unrolls this loop completely.
|
||||
bool LoopUnroll::loopUnrollFull(ForStmt *forStmt) {
|
||||
auto lb = forStmt->getLowerBound()->getValue();
|
||||
auto ub = forStmt->getUpperBound()->getValue();
|
||||
auto step = forStmt->getStep()->getValue();
|
||||
auto step = forStmt->getStep();
|
||||
|
||||
// Builder to add constants need for the unrolled iterator.
|
||||
auto *mlFunc = forStmt->findFunction();
|
||||
|
@ -164,9 +202,75 @@ void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
|
|||
|
||||
// Clone the body of the loop.
|
||||
for (auto &childStmt : *forStmt) {
|
||||
(void)builder.clone(childStmt, operandMapping);
|
||||
builder.clone(childStmt, operandMapping);
|
||||
}
|
||||
}
|
||||
// Erase the original 'for' stmt from the block.
|
||||
forStmt->eraseFromBlock();
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Unrolls this loop by the specified unroll factor.
|
||||
bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor) {
|
||||
assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
|
||||
|
||||
if (unrollFactor == 1 || forStmt->getStatements().empty())
|
||||
return false;
|
||||
|
||||
auto lb = forStmt->getLowerBound()->getValue();
|
||||
auto ub = forStmt->getUpperBound()->getValue();
|
||||
auto step = forStmt->getStep();
|
||||
|
||||
int64_t tripCount = (int64_t)ceilf((ub - lb + 1) / (float)step);
|
||||
|
||||
// If the trip count is lower than the unroll factor, no unrolled body.
|
||||
// TODO(bondhugula): option to specify cleanup loop unrolling.
|
||||
if (tripCount < unrollFactor)
|
||||
return true;
|
||||
|
||||
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
|
||||
if (tripCount % unrollFactor) {
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
|
||||
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
|
||||
cleanupForStmt->setLowerBound(builder.getConstantExpr(
|
||||
lb + (tripCount - tripCount % unrollFactor) * step));
|
||||
}
|
||||
|
||||
// Builder to insert unrolled bodies right after the last statement in the
|
||||
// body of 'forStmt'.
|
||||
MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
|
||||
forStmt->setStep(step * unrollFactor);
|
||||
forStmt->setUpperBound(builder.getConstantExpr(
|
||||
lb + (tripCount - tripCount % unrollFactor - 1) * step));
|
||||
|
||||
// Keep a pointer to the last statement in the original block so that we know
|
||||
// what to clone (since we are doing this in-place).
|
||||
StmtBlock::iterator srcBlockEnd = --forStmt->end();
|
||||
|
||||
// Unroll the contents of 'forStmt' (unrollFactor-1 additional copies
|
||||
// appended).
|
||||
for (unsigned i = 1; i < unrollFactor; i++) {
|
||||
DenseMap<const MLValue *, MLValue *> operandMapping;
|
||||
|
||||
// If the induction variable is used, create a remapping to the value for
|
||||
// this unrolled instance.
|
||||
if (!forStmt->use_empty()) {
|
||||
// iv' = iv + 1/2/3...unrollFactor-1;
|
||||
auto *bumpExpr = builder.getAddExpr(builder.getDimExpr(0),
|
||||
builder.getConstantExpr(i * step));
|
||||
auto *bumpMap = builder.getAffineMap(1, 0, {bumpExpr}, {});
|
||||
auto *ivUnroll =
|
||||
builder.create<AffineApplyOp>(bumpMap, forStmt)->getResult(0);
|
||||
operandMapping[forStmt] = cast<MLValue>(ivUnroll);
|
||||
}
|
||||
|
||||
// Clone the original body of the loop (this doesn't include the last stmt).
|
||||
for (auto it = forStmt->begin(); it != srcBlockEnd; it++) {
|
||||
builder.clone(*it, operandMapping);
|
||||
}
|
||||
// Clone the last statement in the original body.
|
||||
builder.clone(*srcBlockEnd, operandMapping);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
// RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
|
||||
// RUN: %S/../../mlir-opt %s -o - -unroll-short-loops | FileCheck %s --check-prefix SHORT
|
||||
// RUN: mlir-opt %s -o - -loop-unroll -unroll-full | FileCheck %s
|
||||
// RUN: mlir-opt %s -o - -loop-unroll -unroll-full -unroll-full-threshold=2 | FileCheck %s --check-prefix SHORT
|
||||
// RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=4 | FileCheck %s --check-prefix UNROLL-BY-4
|
||||
// RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=3 | FileCheck %s --check-prefix UNROLL-BY-3
|
||||
|
||||
// CHECK: #map0 = (d0) -> (d0 + 1)
|
||||
|
||||
|
@ -279,3 +281,87 @@ mlfunc @loop_nest_seq_long() -> i32 {
|
|||
%ret = load %C[%zero_idx, %zero_idx] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2>
|
||||
return %ret : i32
|
||||
}
|
||||
|
||||
// UNROLL-BY-4-LABEL: mlfunc @unroll_unit_stride_no_cleanup() {
|
||||
mlfunc @unroll_unit_stride_no_cleanup() {
|
||||
// UNROLL-BY-4: for %i0 = 1 to 100 {
|
||||
for %i = 1 to 100 {
|
||||
// UNROLL-BY-4: for [[L1:%i[0-9]+]] = 1 to 8 step 4 {
|
||||
// UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
|
||||
// UNROLL-BY-4-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
|
||||
// UNROLL-BY-4-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
|
||||
// UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
|
||||
// UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
|
||||
// UNROLL-BY-4-NEXT: %8 = affine_apply #map{{[0-9]+}}([[L1]])
|
||||
// UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
for %j = 1 to 8 {
|
||||
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
|
||||
%y = "addi32"(%x, %x) : (i32, i32) -> i32
|
||||
}
|
||||
// empty loop
|
||||
// UNROLL-BY-4: for %i2 = 1 to 8 {
|
||||
for %k = 1 to 8 {
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UNROLL-BY-4-LABEL: mlfunc @unroll_unit_stride_cleanup() {
|
||||
mlfunc @unroll_unit_stride_cleanup() {
|
||||
// UNROLL-BY-4: for %i0 = 1 to 100 {
|
||||
for %i = 1 to 100 {
|
||||
// UNROLL-BY-4: for [[L1:%i[0-9]+]] = 1 to 8 step 4 {
|
||||
// UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
|
||||
// UNROLL-BY-4-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
|
||||
// UNROLL-BY-4-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
|
||||
// UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
|
||||
// UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
|
||||
// UNROLL-BY-4-NEXT: %8 = affine_apply #map{{[0-9]+}}([[L1]])
|
||||
// UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
// UNROLL-BY-4-NEXT: for [[L2:%i[0-9]+]] = 9 to 10 {
|
||||
// UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
for %j = 1 to 10 {
|
||||
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
|
||||
%y = "addi32"(%x, %x) : (i32, i32) -> i32
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UNROLL-BY-3-LABEL: mlfunc @unroll_non_unit_stride_cleanup() {
|
||||
mlfunc @unroll_non_unit_stride_cleanup() {
|
||||
// UNROLL-BY-3: for %i0 = 1 to 100 {
|
||||
for %i = 1 to 100 {
|
||||
// UNROLL-BY-3: for [[L1:%i[0-9]+]] = 2 to 12 step 15 {
|
||||
// UNROLL-BY-3-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-3-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
|
||||
// UNROLL-BY-3-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
|
||||
// UNROLL-BY-3-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-3-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
|
||||
// UNROLL-BY-3-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
|
||||
// UNROLL-BY-3-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-3-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
|
||||
// UNROLL-BY-3-NEXT: }
|
||||
// UNROLL-BY-3-NEXT: for [[L2:%i[0-9]+]] = 17 to 20 step 5 {
|
||||
// UNROLL-BY-3-NEXT: %8 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
|
||||
// UNROLL-BY-3-NEXT: %9 = "addi32"(%8, %8) : (i32, i32) -> i32
|
||||
// UNROLL-BY-3-NEXT: }
|
||||
for %j = 2 to 20 step 5 {
|
||||
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
|
||||
%y = "addi32"(%x, %x) : (i32, i32) -> i32
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -53,8 +53,7 @@ checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
|
|||
|
||||
enum Passes {
|
||||
ConvertToCFG,
|
||||
UnrollInnermostLoops,
|
||||
UnrollShortLoops,
|
||||
LoopUnroll,
|
||||
TFRaiseControlFlow,
|
||||
};
|
||||
|
||||
|
@ -62,10 +61,7 @@ static cl::list<Passes> passList(
|
|||
"", cl::desc("Compiler passes to run"),
|
||||
cl::values(clEnumValN(ConvertToCFG, "convert-to-cfg",
|
||||
"Convert all ML functions in the module to CFG ones"),
|
||||
clEnumValN(UnrollInnermostLoops, "unroll-innermost-loops",
|
||||
"Unroll innermost loops"),
|
||||
clEnumValN(UnrollShortLoops, "unroll-short-loops",
|
||||
"Unroll loops of trip count <= 2"),
|
||||
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
|
||||
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
|
||||
"Dynamic TensorFlow Switch/Match nodes to a CFG")));
|
||||
|
||||
|
@ -112,11 +108,8 @@ OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) {
|
|||
case ConvertToCFG:
|
||||
pass = createConvertToCFGPass();
|
||||
break;
|
||||
case UnrollInnermostLoops:
|
||||
pass = createLoopUnrollPass();
|
||||
break;
|
||||
case UnrollShortLoops:
|
||||
pass = createLoopUnrollPass(2);
|
||||
case LoopUnroll:
|
||||
pass = createLoopUnrollPass(-1, -1);
|
||||
break;
|
||||
case TFRaiseControlFlow:
|
||||
pass = createRaiseTFControlFlowPass();
|
||||
|
|
Loading…
Reference in New Issue