[Polly] Move MatMul optimization into its own file. NFC.

Functions shared between generalized matrix-multiplication optimization
and other post-reschedule optimizations (tiling, prevect) are moved into
the schedule tree transformation utility ScheduleTreeTransform.
This commit is contained in:
Michael Kruse 2021-06-04 23:17:41 -05:00
parent d8a4a2cb93
commit d123e983b3
9 changed files with 1244 additions and 1192 deletions

View File

@ -0,0 +1,74 @@
//===- MatmulOptimizer.h -------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef POLLY_MATMULOPTIMIZER_H
#define POLLY_MATMULOPTIMIZER_H
#include "isl/isl-noexceptions.h"
namespace llvm {
class TargetTransformInfo;
}
namespace polly {
struct Dependences;
/// Apply the BLIS matmul optimization pattern if possible.
///
/// Make the loops containing the matrix multiplication be the innermost
/// loops and apply the BLIS matmul optimization pattern. BLIS implements
/// gemm as three nested loops around a macro-kernel, plus two packing
/// routines. The macro-kernel is implemented in terms of two additional
/// loops around a micro-kernel. The micro-kernel is a loop around a rank-1
/// (i.e., outer product) update.
///
/// For a detailed description please see [1].
///
/// The order of the loops defines the data reused in the BLIS implementation
/// of gemm ([1]). In particular, elements of the matrix B, the second
/// operand of matrix multiplication, are reused between iterations of the
/// innermost loop. To keep the reused data in cache, only elements of matrix
/// A, the first operand of matrix multiplication, should be evicted during
/// an iteration of the innermost loop. To provide such a cache replacement
/// policy, elements of the matrix A can, in particular, be loaded first and,
/// consequently, be least-recently-used.
///
/// In our case matrices are stored in row-major order instead of
/// column-major order used in the BLIS implementation ([1]). It affects only
/// on the form of the BLIS micro kernel and the computation of its
/// parameters. In particular, reused elements of the matrix B are
/// successively multiplied by specific elements of the matrix A.
///
/// Refs.:
/// [1] - Analytical Modeling is Enough for High Performance BLIS
/// Tze Meng Low, Francisco D Igual, Tyler M Smith, Enrique S Quintana-Orti
/// Technical Report, 2014
/// http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf
///
/// @see ScheduleTreeOptimizer::createMicroKernel
/// @see ScheduleTreeOptimizer::createMacroKernel
/// @see getMicroKernelParams
/// @see getMacroKernelParams
///
/// TODO: Implement the packing transformation.
///
/// @param Node The node that contains a band to be optimized. The node
/// is required to successfully pass
/// ScheduleTreeOptimizer::isMatrMultPattern.
/// @param TTI Target Transform Info.
/// @param D The dependencies.
///
/// @returns The transformed schedule or nullptr if the optimization
/// cannot be applied.
isl::schedule_node
tryOptimizeMatMulPattern(isl::schedule_node Node,
const llvm::TargetTransformInfo *TTI,
const Dependences *D);
} // namespace polly
#endif // POLLY_MATMULOPTIMIZER_H

View File

@ -37,26 +37,6 @@ struct IslScheduleOptimizerPrinterPass
private:
llvm::raw_ostream &OS;
};
/// Build the desired set of partial tile prefixes.
///
/// We build a set of partial tile prefixes, which are prefixes of the vector
/// loop that have exactly VectorWidth iterations.
///
/// 1. Drop all constraints involving the dimension that represents the
/// vector loop.
/// 2. Constrain the last dimension to get a set, which has exactly VectorWidth
/// iterations.
/// 3. Subtract loop domain from it, project out the vector loop dimension and
/// get a set that contains prefixes, which do not have exactly VectorWidth
/// iterations.
/// 4. Project out the vector loop dimension of the set that was build on the
/// first step and subtract the set built on the previous step to get the
/// desired set of prefixes.
///
/// @param ScheduleRange A range of a map, which describes a prefix schedule
/// relation.
isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth);
} // namespace polly
namespace llvm {

View File

@ -13,6 +13,7 @@
#ifndef POLLY_SCHEDULETREETRANSFORM_H
#define POLLY_SCHEDULETREETRANSFORM_H
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/ErrorHandling.h"
#include "isl/isl-noexceptions.h"
#include <cassert>
@ -164,6 +165,65 @@ isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll);
/// Replace the AST band @p BandToUnroll by a partially unrolled equivalent.
isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor);
/// Build the desired set of partial tile prefixes.
///
/// We build a set of partial tile prefixes, which are prefixes of the vector
/// loop that have exactly VectorWidth iterations.
///
/// 1. Drop all constraints involving the dimension that represents the
/// vector loop.
/// 2. Constrain the last dimension to get a set, which has exactly VectorWidth
/// iterations.
/// 3. Subtract loop domain from it, project out the vector loop dimension and
/// get a set that contains prefixes, which do not have exactly VectorWidth
/// iterations.
/// 4. Project out the vector loop dimension of the set that was build on the
/// first step and subtract the set built on the previous step to get the
/// desired set of prefixes.
///
/// @param ScheduleRange A range of a map, which describes a prefix schedule
/// relation.
isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth);
/// Create an isl::union_set, which describes the isolate option based on
/// IsolateDomain.
///
/// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should
/// belong to the current band node.
/// @param OutDimsNum A number of dimensions that should belong to
/// the current band node.
isl::union_set getIsolateOptions(isl::set IsolateDomain, isl_size OutDimsNum);
/// Create an isl::union_set, which describes the specified option for the
/// dimension of the current node.
///
/// @param Ctx An isl::ctx, which is used to create the isl::union_set.
/// @param Option The name of the option.
isl::union_set getDimOptions(isl::ctx Ctx, const char *Option);
/// Tile a schedule node.
///
/// @param Node The node to tile.
/// @param Identifier An name that identifies this kind of tiling and
/// that is used to mark the tiled loops in the
/// generated AST.
/// @param TileSizes A vector of tile sizes that should be used for
/// tiling.
/// @param DefaultTileSize A default tile size that is used for dimensions
/// that are not covered by the TileSizes vector.
isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier,
llvm::ArrayRef<int> TileSizes, int DefaultTileSize);
/// Tile a schedule node and unroll point loops.
///
/// @param Node The node to register tile.
/// @param TileSizes A vector of tile sizes that should be used for
/// tiling.
/// @param DefaultTileSize A default tile size that is used for dimensions
isl::schedule_node applyRegisterTiling(isl::schedule_node Node,
llvm::ArrayRef<int> TileSizes,
int DefaultTileSize);
} // namespace polly
#endif // POLLY_SCHEDULETREETRANSFORM_H

View File

@ -99,6 +99,7 @@ add_llvm_pass_plugin(Polly
Transform/RewriteByReferenceParameters.cpp
Transform/ScopInliner.cpp
Transform/ManualOptimizer.cpp
Transform/MatmulOptimizer.cpp
${POLLY_HEADER_FILES}
LINK_COMPONENTS

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -480,6 +480,23 @@ static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor,
return Modulo.domain();
}
/// Make the last dimension of Set to take values from 0 to VectorWidth - 1.
///
/// @param Set A set, which should be modified.
/// @param VectorWidth A parameter, which determines the constraint.
static isl::set addExtentConstraints(isl::set Set, int VectorWidth) {
unsigned Dims = Set.dim(isl::dim::set);
isl::space Space = Set.get_space();
isl::local_space LocalSpace = isl::local_space(Space);
isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
ExtConstr = ExtConstr.set_constant_si(0);
ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1);
Set = Set.add_constraint(ExtConstr);
ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1);
ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1);
return Set.add_constraint(ExtConstr);
}
} // namespace
bool polly::isBandMark(const isl::schedule_node &Node) {
@ -631,3 +648,76 @@ isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll,
return NewLoop.get_schedule();
}
isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange,
int VectorWidth) {
isl_size Dims = ScheduleRange.dim(isl::dim::set);
isl::set LoopPrefixes =
ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1);
auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth);
isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange);
BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1);
LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1);
return LoopPrefixes.subtract(BadPrefixes);
}
isl::union_set polly::getIsolateOptions(isl::set IsolateDomain,
isl_size OutDimsNum) {
isl_size Dims = IsolateDomain.dim(isl::dim::set);
assert(OutDimsNum <= Dims &&
"The isl::set IsolateDomain is used to describe the range of schedule "
"dimensions values, which should be isolated. Consequently, the "
"number of its dimensions should be greater than or equal to the "
"number of the schedule dimensions.");
isl::map IsolateRelation = isl::map::from_domain(IsolateDomain);
IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in,
Dims - OutDimsNum, OutDimsNum);
isl::set IsolateOption = IsolateRelation.wrap();
isl::id Id = isl::id::alloc(IsolateOption.get_ctx(), "isolate", nullptr);
IsolateOption = IsolateOption.set_tuple_id(Id);
return isl::union_set(IsolateOption);
}
isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) {
isl::space Space(Ctx, 0, 1);
auto DimOption = isl::set::universe(Space);
auto Id = isl::id::alloc(Ctx, Option, nullptr);
DimOption = DimOption.set_tuple_id(Id);
return isl::union_set(DimOption);
}
isl::schedule_node polly::tileNode(isl::schedule_node Node,
const char *Identifier,
ArrayRef<int> TileSizes,
int DefaultTileSize) {
auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
auto Dims = Space.dim(isl::dim::set);
auto Sizes = isl::multi_val::zero(Space);
std::string IdentifierString(Identifier);
for (auto i : seq<isl_size>(0, Dims)) {
auto tileSize =
i < (isl_size)TileSizes.size() ? TileSizes[i] : DefaultTileSize;
Sizes = Sizes.set_val(i, isl::val(Node.get_ctx(), tileSize));
}
auto TileLoopMarkerStr = IdentifierString + " - Tiles";
auto TileLoopMarker =
isl::id::alloc(Node.get_ctx(), TileLoopMarkerStr, nullptr);
Node = Node.insert_mark(TileLoopMarker);
Node = Node.child(0);
Node =
isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release()));
Node = Node.child(0);
auto PointLoopMarkerStr = IdentifierString + " - Points";
auto PointLoopMarker =
isl::id::alloc(Node.get_ctx(), PointLoopMarkerStr, nullptr);
Node = Node.insert_mark(PointLoopMarker);
return Node.child(0);
}
isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node,
ArrayRef<int> TileSizes,
int DefaultTileSize) {
Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize);
auto Ctx = Node.get_ctx();
return Node.band_set_ast_build_options(isl::union_set(Ctx, "{unroll[x]}"));
}

View File

@ -1,3 +1,3 @@
add_polly_unittest(ScheduleOptimizerTests
ScheduleOptimizerTest.cpp
ScheduleTreeTransformTest.cpp
)

View File

@ -1,4 +1,4 @@
//===- ScheduleOptimizerTest.cpp ------------------------------------------===//
//===- ScheduleTreeTransformTest.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,18 +6,16 @@
//
//===----------------------------------------------------------------------===//
#include "polly/ScheduleOptimizer.h"
#include "polly/ScheduleTreeTransform.h"
#include "gtest/gtest.h"
#include "isl/stream.h"
#include "isl/val.h"
#include "isl/ctx.h"
using namespace isl;
using namespace polly;
namespace {
TEST(ScheduleOptimizer, getPartialTilePrefixes) {
TEST(ScheduleTreeTransform, getPartialTilePrefixes) {
isl_ctx *ctx = isl_ctx_alloc();
{