Extend loop unrolling to unroll by a given factor; add builder for affine

apply op.

- add builder for AffineApplyOp (first one for an operation that has
  non-zero operands)
- add support for loop unrolling by a given factor; uses the affine apply op
  builder.

While on this, change 'step' of ForStmt to be 'unsigned' instead of
AffineConstantExpr *. Add setters for ForStmt lb, ub, step.

Sample Input:

// CHECK-LABEL: mlfunc @loop_nest_unroll_cleanup() {
mlfunc @loop_nest_unroll_cleanup() {
  for %i = 1 to 100 {
    for %j = 0 to 17 {
      %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
      %y = "addi32"(%x, %x) : (i32, i32) -> i32
    }
  }
  return
}

Output:

$ mlir-opt -loop-unroll -unroll-factor=4 /tmp/single2.mlir
#map0 = (d0) -> (d0 + 1)
#map1 = (d0) -> (d0 + 2)
#map2 = (d0) -> (d0 + 3)
mlfunc @loop_nest_unroll_cleanup() {
  for %i0 = 1 to 100 {
    for %i1 = 0 to 17 step 4 {
      %0 = "addi32"(%i1, %i1) : (affineint, affineint) -> i32
      %1 = "addi32"(%0, %0) : (i32, i32) -> i32
      %2 = affine_apply #map0(%i1)
      %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
      %4 = affine_apply #map1(%i1)
      %5 = "addi32"(%4, %4) : (affineint, affineint) -> i32
      %6 = affine_apply #map2(%i1)
      %7 = "addi32"(%6, %6) : (affineint, affineint) -> i32
    }
    for %i2 = 16 to 17 {
      %8 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32
      %9 = "addi32"(%8, %8) : (i32, i32) -> i32
    }
  }
  return
}

PiperOrigin-RevId: 209676220
This commit is contained in:
Uday Bondhugula 2018-08-21 16:01:23 -07:00 committed by jpienaar
parent 6911c24e97
commit 00bed4bd99
13 changed files with 276 additions and 68 deletions

View File

@ -252,7 +252,8 @@ public:
this->insertPoint = insertPoint;
}
/// Set the insertion point to the specified operation.
/// Set the insertion point to the specified operation, which will cause
/// subsequent insertions to go right before it.
void setInsertionPoint(Statement *stmt) {
setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
}
@ -298,8 +299,7 @@ public:
// Creates for statement. When step is not specified, it is set to 1.
ForStmt *createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
AffineConstantExpr *step = nullptr);
AffineConstantExpr *upperBound, int64_t step = 1);
IfStmt *createIf(IntegerSet *condition) {
auto *stmt = new IfStmt(condition);

View File

@ -47,6 +47,7 @@ typedef std::pair<Identifier, Attribute*> NamedAttribute;
struct OperationState {
Identifier name;
SmallVector<SSAValue *, 4> operands;
/// Types of the results of this operation.
SmallVector<Type *, 4> types;
SmallVector<NamedAttribute, 4> attributes;

View File

@ -77,6 +77,10 @@ private:
class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
/// Builds an affine apply op with the specified map and operands.
static OperationState build(Builder *builder, AffineMap *map,
ArrayRef<SSAValue *> operands);
// Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map")->getValue();
@ -163,6 +167,7 @@ protected:
///
class ConstantFloatOp : public ConstantOp {
public:
/// Builds a constant float op producing a float of the specified type.
static OperationState build(Builder *builder, double value, FloatType *type);
double getValue() const {

View File

@ -199,7 +199,7 @@ public:
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
explicit ForStmt(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound, AffineConstantExpr *step,
AffineConstantExpr *upperBound, int64_t step,
MLIRContext *context);
~ForStmt() {
@ -216,7 +216,11 @@ public:
AffineConstantExpr *getLowerBound() const { return lowerBound; }
AffineConstantExpr *getUpperBound() const { return upperBound; }
AffineConstantExpr *getStep() const { return step; }
int64_t getStep() const { return step; }
void setLowerBound(AffineConstantExpr *lb) { lowerBound = lb; }
void setUpperBound(AffineConstantExpr *ub) { upperBound = ub; }
void setStep(unsigned s) { step = s; }
using Statement::dump;
using Statement::print;
@ -242,7 +246,7 @@ private:
// an affinemap and its operands as AffineBound.
AffineConstantExpr *lowerBound;
AffineConstantExpr *upperBound;
AffineConstantExpr *step;
int64_t step;
};
/// An if clause represents statements contained within a then or an else clause

View File

@ -28,9 +28,9 @@ namespace mlir {
class MLFunctionPass;
class ModulePass;
/// Loop unrolling passes.
MLFunctionPass *createLoopUnrollPass();
MLFunctionPass *createLoopUnrollPass(unsigned);
// Loop unrolling passes.
/// Creates a loop unrolling pass.
MLFunctionPass *createLoopUnrollPass(int unrollFactor, int unrollFull);
/// Replaces all ML functions in the module with equivalent CFG functions.
/// Function references are appropriately patched to refer to the newly

View File

@ -1316,8 +1316,8 @@ void MLFunctionPrinter::print(const ForStmt *stmt) {
printOperand(stmt);
os << " = " << *stmt->getLowerBound();
os << " to " << *stmt->getUpperBound();
if (stmt->getStep()->getValue() != 1)
os << " step " << *stmt->getStep();
if (stmt->getStep() != 1)
os << " step " << stmt->getStep();
os << " {\n";
print(static_cast<const StmtBlock *>(stmt));

View File

@ -103,8 +103,8 @@ ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
return ArrayAttr::get(value, context);
}
AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) {
return AffineMapAttr::get(value, context);
AffineMapAttr *Builder::getAffineMapAttr(AffineMap *map) {
return AffineMapAttr::get(map, context);
}
TypeAttr *Builder::getTypeAttr(Type *type) {
@ -207,9 +207,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
AffineConstantExpr *step) {
if (!step)
step = getConstantExpr(1);
int64_t step) {
auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
block->getStatements().insert(insertPoint, stmt);
return stmt;

View File

@ -304,6 +304,22 @@ OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
return result;
}
//===----------------------------------------------------------------------===//
// AffineApplyOp
//===----------------------------------------------------------------------===//
OperationState AffineApplyOp::build(Builder *builder, AffineMap *map,
ArrayRef<SSAValue *> operands) {
SmallVector<Type *, 4> resultTypes(map->getNumResults(),
builder->getAffineIntType());
OperationState result(
builder->getIdentifier("affine_apply"), operands, resultTypes,
{{builder->getIdentifier("map"), builder->getAffineMapAttr(map)}});
return result;
}
//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//

View File

@ -198,7 +198,7 @@ MLIRContext *OperationStmt::getContext() const {
//===----------------------------------------------------------------------===//
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
AffineConstantExpr *step, MLIRContext *context)
int64_t step, MLIRContext *context)
: Statement(Kind::For),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),

View File

@ -2164,7 +2164,7 @@ ParseResult MLFunctionParser::parseFunctionBody() {
ParseResult MLFunctionParser::parseForStmt() {
consumeToken(Token::kw_for);
// Parse induction variable
// Parse induction variable.
if (getToken().isNot(Token::percent_identifier))
return emitError("expected SSA identifier for the loop variable");
@ -2175,7 +2175,7 @@ ParseResult MLFunctionParser::parseForStmt() {
if (parseToken(Token::equal, "expected '='"))
return ParseFailure;
// Parse loop bounds
// Parse loop bounds.
AffineConstantExpr *lowerBound = parseIntConstant();
if (!lowerBound)
return ParseFailure;
@ -2187,12 +2187,13 @@ ParseResult MLFunctionParser::parseForStmt() {
if (!upperBound)
return ParseFailure;
// Parse step
AffineConstantExpr *step = nullptr;
// Parse step.
int64_t step = 1;
if (consumeIf(Token::kw_step)) {
step = parseIntConstant();
if (!step)
AffineConstantExpr *stepExpr = parseIntConstant();
if (!stepExpr)
return ParseFailure;
step = stepExpr->getValue();
}
// Create for statement.

View File

@ -31,30 +31,49 @@
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace llvm;
// Loop unrolling factor.
static llvm::cl::opt<unsigned>
clUnrollFactor("unroll-factor", cl::Hidden,
cl::desc("Use this unroll factor for all loops"));
static llvm::cl::opt<bool> clUnrollFull("unroll-full", cl::Hidden,
cl::desc("Fully unroll loops"));
static llvm::cl::opt<unsigned> clUnrollFullThreshold(
"unroll-full-threshold", cl::Hidden,
cl::desc("Unroll all loops with trip count less than or equal to this"));
namespace {
/// Loop unrolling pass. For now, this unrolls all the innermost loops of this
/// MLFunction.
/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
/// full unroll threshold was specified, in which case, fully unrolls all loops
/// with trip count less than the specified threshold. The latter is for testing
/// purposes, especially for testing outer loop unrolling.
struct LoopUnroll : public MLFunctionPass {
void runOnMLFunction(MLFunction *f) override;
void runOnForStmt(ForStmt *forStmt);
};
Optional<unsigned> unrollFactor;
Optional<bool> unrollFull;
explicit LoopUnroll(Optional<unsigned> unrollFactor,
Optional<bool> unrollFull)
: unrollFactor(unrollFactor), unrollFull(unrollFull) {}
/// Unrolls all loops with trip count <= minTripCount.
struct ShortLoopUnroll : public LoopUnroll {
const unsigned minTripCount;
void runOnMLFunction(MLFunction *f) override;
ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
/// Unroll this for stmt. Returns false if nothing was done.
bool runOnForStmt(ForStmt *forStmt);
bool loopUnrollFull(ForStmt *forStmt);
bool loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor);
};
} // end anonymous namespace
MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
return new ShortLoopUnroll(minTripCount);
MLFunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
return new LoopUnroll(unrollFactor == -1 ? None
: Optional<unsigned>(unrollFactor),
unrollFull == -1 ? None : Optional<bool>(unrollFull));
}
void LoopUnroll::runOnMLFunction(MLFunction *f) {
@ -81,7 +100,6 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
if (!hasInnerLoops)
loops.push_back(forStmt);
return true;
}
@ -101,14 +119,6 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
InnermostLoopGatherer ilg;
ilg.walkPostOrder(f);
auto &loops = ilg.loops;
for (auto *forStmt : loops)
runOnForStmt(forStmt);
}
void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all loops with trip count <= minTripCount.
class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
public:
@ -120,27 +130,55 @@ void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
void visitForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
auto step = forStmt->getStep();
if ((ub - lb) / step + 1 <= minTripCount)
loops.push_back(forStmt);
}
};
ShortLoopGatherer slg(minTripCount);
// Do a post order walk so that loops are gathered from innermost to
// outermost (or else unrolling an outer one may delete gathered inner ones).
slg.walkPostOrder(f);
auto &loops = slg.loops;
if (clUnrollFull.getNumOccurrences() > 0 &&
clUnrollFullThreshold.getNumOccurrences() > 0) {
ShortLoopGatherer slg(clUnrollFullThreshold);
// Do a post order walk so that loops are gathered from innermost to
// outermost (or else unrolling an outer one may delete gathered inner
// ones).
slg.walkPostOrder(f);
auto &loops = slg.loops;
for (auto *forStmt : loops)
loopUnrollFull(forStmt);
return;
}
InnermostLoopGatherer ilg;
ilg.walkPostOrder(f);
auto &loops = ilg.loops;
for (auto *forStmt : loops)
runOnForStmt(forStmt);
}
/// Unroll this For loop completely.
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
/// Unroll a for stmt. Default unroll factor is 4.
bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
// Unroll completely if full loop unroll was specified.
if (clUnrollFull.getNumOccurrences() > 0 ||
(unrollFull.hasValue() && unrollFull.getValue()))
return loopUnrollFull(forStmt);
// Unroll by the specified factor if one was specified.
if (clUnrollFactor.getNumOccurrences() > 0)
return loopUnrollByFactor(forStmt, clUnrollFactor);
else if (unrollFactor.hasValue())
return loopUnrollByFactor(forStmt, unrollFactor.getValue());
// Unroll by four otherwise.
return loopUnrollByFactor(forStmt, 4);
}
// Unrolls this loop completely.
bool LoopUnroll::loopUnrollFull(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
auto step = forStmt->getStep();
// Builder to add constants need for the unrolled iterator.
auto *mlFunc = forStmt->findFunction();
@ -164,9 +202,75 @@ void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
// Clone the body of the loop.
for (auto &childStmt : *forStmt) {
(void)builder.clone(childStmt, operandMapping);
builder.clone(childStmt, operandMapping);
}
}
// Erase the original 'for' stmt from the block.
forStmt->eraseFromBlock();
return true;
}
/// Unrolls this loop by the specified unroll factor.
bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
if (unrollFactor == 1 || forStmt->getStatements().empty())
return false;
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep();
int64_t tripCount = (int64_t)ceilf((ub - lb + 1) / (float)step);
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO(bondhugula): option to specify cleanup loop unrolling.
if (tripCount < unrollFactor)
return true;
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
if (tripCount % unrollFactor) {
DenseMap<const MLValue *, MLValue *> operandMap;
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
cleanupForStmt->setLowerBound(builder.getConstantExpr(
lb + (tripCount - tripCount % unrollFactor) * step));
}
// Builder to insert unrolled bodies right after the last statement in the
// body of 'forStmt'.
MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
forStmt->setStep(step * unrollFactor);
forStmt->setUpperBound(builder.getConstantExpr(
lb + (tripCount - tripCount % unrollFactor - 1) * step));
// Keep a pointer to the last statement in the original block so that we know
// what to clone (since we are doing this in-place).
StmtBlock::iterator srcBlockEnd = --forStmt->end();
// Unroll the contents of 'forStmt' (unrollFactor-1 additional copies
// appended).
for (unsigned i = 1; i < unrollFactor; i++) {
DenseMap<const MLValue *, MLValue *> operandMapping;
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
if (!forStmt->use_empty()) {
// iv' = iv + 1/2/3...unrollFactor-1;
auto *bumpExpr = builder.getAddExpr(builder.getDimExpr(0),
builder.getConstantExpr(i * step));
auto *bumpMap = builder.getAffineMap(1, 0, {bumpExpr}, {});
auto *ivUnroll =
builder.create<AffineApplyOp>(bumpMap, forStmt)->getResult(0);
operandMapping[forStmt] = cast<MLValue>(ivUnroll);
}
// Clone the original body of the loop (this doesn't include the last stmt).
for (auto it = forStmt->begin(); it != srcBlockEnd; it++) {
builder.clone(*it, operandMapping);
}
// Clone the last statement in the original body.
builder.clone(*srcBlockEnd, operandMapping);
}
return true;
}

View File

@ -1,5 +1,7 @@
// RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
// RUN: %S/../../mlir-opt %s -o - -unroll-short-loops | FileCheck %s --check-prefix SHORT
// RUN: mlir-opt %s -o - -loop-unroll -unroll-full | FileCheck %s
// RUN: mlir-opt %s -o - -loop-unroll -unroll-full -unroll-full-threshold=2 | FileCheck %s --check-prefix SHORT
// RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=4 | FileCheck %s --check-prefix UNROLL-BY-4
// RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=3 | FileCheck %s --check-prefix UNROLL-BY-3
// CHECK: #map0 = (d0) -> (d0 + 1)
@ -279,3 +281,87 @@ mlfunc @loop_nest_seq_long() -> i32 {
%ret = load %C[%zero_idx, %zero_idx] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2>
return %ret : i32
}
// UNROLL-BY-4-LABEL: mlfunc @unroll_unit_stride_no_cleanup() {
mlfunc @unroll_unit_stride_no_cleanup() {
// UNROLL-BY-4: for %i0 = 1 to 100 {
for %i = 1 to 100 {
// UNROLL-BY-4: for [[L1:%i[0-9]+]] = 1 to 8 step 4 {
// UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
// UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
// UNROLL-BY-4-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
// UNROLL-BY-4-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
// UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
// UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
// UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
// UNROLL-BY-4-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
// UNROLL-BY-4-NEXT: %8 = affine_apply #map{{[0-9]+}}([[L1]])
// UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> i32
// UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32
// UNROLL-BY-4-NEXT: }
for %j = 1 to 8 {
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
%y = "addi32"(%x, %x) : (i32, i32) -> i32
}
// empty loop
// UNROLL-BY-4: for %i2 = 1 to 8 {
for %k = 1 to 8 {
}
}
return
}
// UNROLL-BY-4-LABEL: mlfunc @unroll_unit_stride_cleanup() {
mlfunc @unroll_unit_stride_cleanup() {
// UNROLL-BY-4: for %i0 = 1 to 100 {
for %i = 1 to 100 {
// UNROLL-BY-4: for [[L1:%i[0-9]+]] = 1 to 8 step 4 {
// UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
// UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
// UNROLL-BY-4-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
// UNROLL-BY-4-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
// UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
// UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
// UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
// UNROLL-BY-4-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
// UNROLL-BY-4-NEXT: %8 = affine_apply #map{{[0-9]+}}([[L1]])
// UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> i32
// UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32
// UNROLL-BY-4-NEXT: }
// UNROLL-BY-4-NEXT: for [[L2:%i[0-9]+]] = 9 to 10 {
// UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
// UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32
// UNROLL-BY-4-NEXT: }
for %j = 1 to 10 {
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
%y = "addi32"(%x, %x) : (i32, i32) -> i32
}
}
return
}
// UNROLL-BY-3-LABEL: mlfunc @unroll_non_unit_stride_cleanup() {
mlfunc @unroll_non_unit_stride_cleanup() {
// UNROLL-BY-3: for %i0 = 1 to 100 {
for %i = 1 to 100 {
// UNROLL-BY-3: for [[L1:%i[0-9]+]] = 2 to 12 step 15 {
// UNROLL-BY-3-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
// UNROLL-BY-3-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
// UNROLL-BY-3-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
// UNROLL-BY-3-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
// UNROLL-BY-3-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
// UNROLL-BY-3-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
// UNROLL-BY-3-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
// UNROLL-BY-3-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
// UNROLL-BY-3-NEXT: }
// UNROLL-BY-3-NEXT: for [[L2:%i[0-9]+]] = 17 to 20 step 5 {
// UNROLL-BY-3-NEXT: %8 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
// UNROLL-BY-3-NEXT: %9 = "addi32"(%8, %8) : (i32, i32) -> i32
// UNROLL-BY-3-NEXT: }
for %j = 2 to 20 step 5 {
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
%y = "addi32"(%x, %x) : (i32, i32) -> i32
}
}
return
}

View File

@ -53,8 +53,7 @@ checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
enum Passes {
ConvertToCFG,
UnrollInnermostLoops,
UnrollShortLoops,
LoopUnroll,
TFRaiseControlFlow,
};
@ -62,10 +61,7 @@ static cl::list<Passes> passList(
"", cl::desc("Compiler passes to run"),
cl::values(clEnumValN(ConvertToCFG, "convert-to-cfg",
"Convert all ML functions in the module to CFG ones"),
clEnumValN(UnrollInnermostLoops, "unroll-innermost-loops",
"Unroll innermost loops"),
clEnumValN(UnrollShortLoops, "unroll-short-loops",
"Unroll loops of trip count <= 2"),
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
"Dynamic TensorFlow Switch/Match nodes to a CFG")));
@ -112,11 +108,8 @@ OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) {
case ConvertToCFG:
pass = createConvertToCFGPass();
break;
case UnrollInnermostLoops:
pass = createLoopUnrollPass();
break;
case UnrollShortLoops:
pass = createLoopUnrollPass(2);
case LoopUnroll:
pass = createLoopUnrollPass(-1, -1);
break;
case TFRaiseControlFlow:
pass = createRaiseTFControlFlowPass();