[NFC][mlir] VectorUtils / IndexingUtils simplifications and cleanups
This revision refactors and cleans up a bunch of infra related to vector, shapes and indexing into more reusable APIs. Differential Revision: https://reviews.llvm.org/D138501
This commit is contained in:
parent
f5eeda037f
commit
7a69a9d7ae
|
@ -28,8 +28,39 @@ int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
|
|||
/// Given the strides together with a linear index in the dimension
|
||||
/// space, returns the vector-space offsets in each dimension for a
|
||||
/// de-linearized index.
|
||||
SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides,
|
||||
int64_t linearIndex);
|
||||
SmallVector<int64_t> delinearize(ArrayRef<int64_t> strides,
|
||||
int64_t linearIndex);
|
||||
|
||||
/// Given a set of sizes, compute and return the strides (i.e. the number of
|
||||
/// linear incides to skip along the (k-1) most minor dimensions to get the next
|
||||
/// k-slice). This is also the basis that one can use to linearize an n-D offset
|
||||
/// confined to `[0 .. sizes]`.
|
||||
SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes);
|
||||
|
||||
/// Return a vector containing llvm::zip of v1 and v2 multiplied elementwise.
|
||||
SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
|
||||
ArrayRef<int64_t> v2);
|
||||
|
||||
/// Compute and return the multi-dimensional integral ratio of `subShape` to
|
||||
/// the trailing dimensions of `shape`. This represents how many times
|
||||
/// `subShape` fits within `shape`.
|
||||
/// If integral division is not possible, return None.
|
||||
/// The trailing `subShape.size()` entries of both shapes are assumed (and
|
||||
/// enforced) to only contain noonnegative values.
|
||||
///
|
||||
/// Examples:
|
||||
/// - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1}.
|
||||
/// - shapeRatio({3, 8}, {2, 5, 2}) returns None (subshape has higher rank).
|
||||
/// - shapeRatio({42, 2, 10, 32}, {2, 5, 2}) returns {42, 1, 2, 16} which is
|
||||
/// derived as {42(leading shape dim), 2/2, 10/5, 32/2}.
|
||||
/// - shapeRatio({42, 2, 11, 32}, {2, 5, 2}) returns None which is
|
||||
/// derived as {42(leading shape dim), 2/2, 11/5(not divisible), 32/2}.
|
||||
Optional<SmallVector<int64_t>> computeShapeRatio(ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> subShape);
|
||||
|
||||
/// Return the number of elements of basis (i.e. the max linear index).
|
||||
/// Return `0` if `basis` is empty.
|
||||
int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
|
||||
|
||||
/// Apply the permutation defined by `permutation` to `inVec`.
|
||||
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
|
||||
|
@ -45,16 +76,15 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
|
|||
}
|
||||
|
||||
/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
||||
SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront = 0,
|
||||
unsigned dropBack = 0);
|
||||
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
|
||||
unsigned dropBack = 0);
|
||||
|
||||
/// Computes and returns linearized affine expression w.r.t. `basis`.
|
||||
mlir::AffineExpr getLinearAffineExpr(ArrayRef<int64_t> basis, mlir::Builder &b);
|
||||
|
||||
/// Given the strides in the dimension space, returns the affine expressions for
|
||||
/// vector-space offsets in each dimension for a de-linearized index.
|
||||
SmallVector<mlir::AffineExpr, 4>
|
||||
SmallVector<mlir::AffineExpr>
|
||||
getDelinearizedAffineExpr(ArrayRef<int64_t> strides, mlir::Builder &b);
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -111,7 +111,7 @@ struct UnrollVectorOptions {
|
|||
}
|
||||
|
||||
using NativeShapeFnType =
|
||||
std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
|
||||
std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
|
||||
/// Function that returns the shape of the vector to unroll to for a given
|
||||
/// operation. The unrolling is aborted if the function returns `llvm::None`.
|
||||
NativeShapeFnType nativeShape = nullptr;
|
||||
|
@ -122,8 +122,8 @@ struct UnrollVectorOptions {
|
|||
|
||||
/// Set the native shape to use for unrolling.
|
||||
UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
|
||||
nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
|
||||
SmallVector<int64_t> tsShape(shape.begin(), shape.end());
|
||||
nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t>> {
|
||||
return tsShape;
|
||||
};
|
||||
return *this;
|
||||
|
|
|
@ -36,43 +36,6 @@ namespace vector {
|
|||
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
|
||||
} // namespace vector
|
||||
|
||||
/// Return the number of elements of basis, `0` if empty.
|
||||
int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
|
||||
|
||||
/// Given the shape and sizes of a vector, returns the corresponding
|
||||
/// strides for each dimension.
|
||||
/// TODO: needs better doc of how it is used.
|
||||
SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> sizes);
|
||||
|
||||
/// Given the target sizes of a vector, together with vector-space offsets,
|
||||
/// returns the element-space offsets for each dimension.
|
||||
SmallVector<int64_t, 4>
|
||||
computeElementOffsetsFromVectorSliceOffsets(ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> vectorOffsets);
|
||||
|
||||
/// Computes and returns the multi-dimensional ratio of `superShape` to
|
||||
/// `subShape`. This is calculated by performing a traversal from minor to major
|
||||
/// dimensions (i.e. in reverse shape order). If integral division is not
|
||||
/// possible, returns None.
|
||||
/// The ArrayRefs are assumed (and enforced) to only contain > 1 values.
|
||||
/// This constraint comes from the fact that they are meant to be used with
|
||||
/// VectorTypes, for which the property holds by construction.
|
||||
///
|
||||
/// Examples:
|
||||
/// - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4}
|
||||
/// - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None
|
||||
/// - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16}
|
||||
Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
|
||||
ArrayRef<int64_t> subShape);
|
||||
|
||||
/// Computes and returns the multi-dimensional ratio of the shapes of
|
||||
/// `superVector` to `subVector`. If integral division is not possible, returns
|
||||
/// None.
|
||||
/// Assumes and enforces that the VectorTypes have the same elemental type.
|
||||
Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
|
||||
VectorType subVectorType);
|
||||
|
||||
/// Constructs a permutation map of invariant memref indices to vector
|
||||
/// dimension.
|
||||
///
|
||||
|
|
|
@ -80,8 +80,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
|
|||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, DenseElementsAttr::get(
|
||||
vecType, IntegerAttr::get(vecType.getElementType(), 0)));
|
||||
SmallVector<int64_t> ones(shape.size(), 1);
|
||||
SmallVector<int64_t> strides = computeStrides(shape, ones);
|
||||
SmallVector<int64_t> strides = computeStrides(shape);
|
||||
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
||||
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
|
||||
SmallVector<Value> operands;
|
||||
|
|
|
@ -79,8 +79,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
|
|||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, DenseElementsAttr::get(
|
||||
vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
|
||||
SmallVector<int64_t> ones(shape.size(), 1);
|
||||
SmallVector<int64_t> strides = computeStrides(shape, ones);
|
||||
SmallVector<int64_t> strides = computeStrides(shape);
|
||||
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
||||
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
|
||||
SmallVector<Value> operands;
|
||||
|
|
|
@ -127,15 +127,13 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
|
|||
|
||||
// Iterate over all outer dimensions of the compute shape vector type.
|
||||
auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
|
||||
int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims);
|
||||
|
||||
SmallVector<int64_t> ones(iterationDims.size(), 1);
|
||||
auto strides = computeStrides(iterationDims, ones);
|
||||
int64_t maxIndex = computeMaxLinearIndex(iterationDims);
|
||||
auto strides = computeStrides(iterationDims);
|
||||
|
||||
// Compute results for each one dimensional vector.
|
||||
SmallVector<Value> results(maxLinearIndex);
|
||||
SmallVector<Value> results(maxIndex);
|
||||
|
||||
for (int64_t i = 0; i < maxLinearIndex; ++i) {
|
||||
for (int64_t i = 0; i < maxIndex; ++i) {
|
||||
auto offsets = delinearize(strides, i);
|
||||
|
||||
SmallVector<Value> extracted(expandedOperands.size());
|
||||
|
@ -152,7 +150,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
|
|||
Value result = builder.create<arith::ConstantOp>(
|
||||
resultExpandedType, builder.getZeroAttr(resultExpandedType));
|
||||
|
||||
for (int64_t i = 0; i < maxLinearIndex; ++i)
|
||||
for (int64_t i = 0; i < maxIndex; ++i)
|
||||
result = builder.create<vector::InsertOp>(results[i], result,
|
||||
delinearize(strides, i));
|
||||
|
||||
|
|
|
@ -12,6 +12,53 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
SmallVector<int64_t> mlir::computeStrides(ArrayRef<int64_t> sizes) {
|
||||
SmallVector<int64_t> strides(sizes.size(), 1);
|
||||
for (int64_t r = strides.size() - 2; r >= 0; --r)
|
||||
strides[r] = strides[r + 1] * sizes[r + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
|
||||
ArrayRef<int64_t> v2) {
|
||||
SmallVector<int64_t> result;
|
||||
for (auto it : llvm::zip(v1, v2))
|
||||
result.push_back(std::get<0>(it) * std::get<1>(it));
|
||||
return result;
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t>>
|
||||
mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
|
||||
if (shape.size() < subShape.size())
|
||||
return None;
|
||||
assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&
|
||||
"shape must be nonnegative");
|
||||
assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&
|
||||
"subShape must be nonnegative");
|
||||
|
||||
// Starting from the end, compute the integer divisors.
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(shape.size());
|
||||
for (auto [size, subSize] :
|
||||
llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) {
|
||||
// If integral division does not occur, return and let the caller decide.
|
||||
if (size % subSize != 0)
|
||||
return None;
|
||||
result.push_back(size / subSize);
|
||||
}
|
||||
// At this point we computed the ratio (in reverse) for the common size.
|
||||
// Fill with the remaining entries from the shape (still in reverse).
|
||||
int commonSize = subShape.size();
|
||||
std::copy(shape.rbegin() + commonSize, shape.rend(),
|
||||
std::back_inserter(result));
|
||||
// Reverse again to get it back in the proper order and return.
|
||||
return SmallVector<int64_t>{result.rbegin(), result.rend()};
|
||||
}
|
||||
|
||||
int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
|
||||
assert(offsets.size() == basis.size());
|
||||
int64_t linearIndex = 0;
|
||||
|
@ -20,10 +67,10 @@ int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
|
|||
return linearIndex;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
|
||||
int64_t index) {
|
||||
llvm::SmallVector<int64_t> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
|
||||
int64_t index) {
|
||||
int64_t rank = sliceStrides.size();
|
||||
SmallVector<int64_t, 4> vectorOffsets(rank);
|
||||
SmallVector<int64_t> vectorOffsets(rank);
|
||||
for (int64_t r = 0; r < rank; ++r) {
|
||||
assert(sliceStrides[r] > 0);
|
||||
vectorOffsets[r] = index / sliceStrides[r];
|
||||
|
@ -32,12 +79,19 @@ llvm::SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
|
|||
return vectorOffsets;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront,
|
||||
unsigned dropBack) {
|
||||
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
|
||||
if (basis.empty())
|
||||
return 0;
|
||||
return std::accumulate(basis.begin(), basis.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront,
|
||||
unsigned dropBack) {
|
||||
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
||||
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
||||
SmallVector<int64_t, 4> res;
|
||||
SmallVector<int64_t> res;
|
||||
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
||||
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
||||
it != eit; ++it)
|
||||
|
@ -54,11 +108,11 @@ mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef<int64_t> basis,
|
|||
return resultExpr;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::AffineExpr, 4>
|
||||
llvm::SmallVector<mlir::AffineExpr>
|
||||
mlir::getDelinearizedAffineExpr(mlir::ArrayRef<int64_t> strides, Builder &b) {
|
||||
AffineExpr resultExpr = b.getAffineDimExpr(0);
|
||||
int64_t rank = strides.size();
|
||||
SmallVector<AffineExpr, 4> vectorOffsets(rank);
|
||||
SmallVector<AffineExpr> vectorOffsets(rank);
|
||||
vectorOffsets[0] = resultExpr.floorDiv(strides[0]);
|
||||
resultExpr = resultExpr % strides[0];
|
||||
for (unsigned i = 1; i < rank; i++) {
|
||||
|
|
|
@ -54,9 +54,9 @@ static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
|
|||
}
|
||||
|
||||
// Helper to construct iterator types with one index removed.
|
||||
static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
|
||||
int64_t index) {
|
||||
SmallVector<Attribute, 4> results;
|
||||
static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
|
||||
int64_t index) {
|
||||
SmallVector<Attribute> results;
|
||||
for (const auto &it : llvm::enumerate(iteratorTypes)) {
|
||||
int64_t idx = it.index();
|
||||
if (idx == index)
|
||||
|
@ -70,7 +70,7 @@ static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
|
|||
static AffineMap adjustMap(AffineMap map, int64_t index,
|
||||
PatternRewriter &rewriter) {
|
||||
auto *ctx = rewriter.getContext();
|
||||
SmallVector<AffineExpr, 4> results;
|
||||
SmallVector<AffineExpr> results;
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t idx = map.getDimPosition(i);
|
||||
if (idx == index)
|
||||
|
@ -140,7 +140,7 @@ static Value reshapeStore(Location loc, Value val, Value result,
|
|||
}
|
||||
|
||||
template <typename IntType>
|
||||
static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
|
||||
static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
|
||||
return llvm::to_vector<4>(llvm::map_range(
|
||||
arrayAttr.getAsRange<IntegerAttr>(),
|
||||
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
|
||||
|
@ -399,7 +399,7 @@ public:
|
|||
VectorType resType = op.getResultType();
|
||||
|
||||
// Set up convenience transposition table.
|
||||
SmallVector<int64_t, 4> transp;
|
||||
SmallVector<int64_t> transp;
|
||||
for (auto attr : op.getTransp())
|
||||
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
|
||||
|
@ -430,12 +430,11 @@ public:
|
|||
// in vector form to improve performance. Therefore, we prune those
|
||||
// dimensions from the shape/transpose data structures used to generate the
|
||||
// extract/insert ops.
|
||||
SmallVector<int64_t, 4> prunedTransp;
|
||||
SmallVector<int64_t> prunedTransp;
|
||||
pruneNonTransposedDims(transp, prunedTransp);
|
||||
size_t numPrunedDims = transp.size() - prunedTransp.size();
|
||||
auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
|
||||
SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
|
||||
auto prunedInStrides = computeStrides(prunedInShape, ones);
|
||||
auto prunedInStrides = computeStrides(prunedInShape);
|
||||
|
||||
// Generates the extract/insert operations for every scalar/vector element
|
||||
// of the leftmost transposed dimensions. We traverse every transpose
|
||||
|
@ -448,7 +447,7 @@ public:
|
|||
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
|
||||
++linearIdx) {
|
||||
auto extractIdxs = delinearize(prunedInStrides, linearIdx);
|
||||
SmallVector<int64_t, 4> insertIdxs(extractIdxs);
|
||||
SmallVector<int64_t> insertIdxs(extractIdxs);
|
||||
applyPermutationToVector(insertIdxs, prunedTransp);
|
||||
Value extractOp =
|
||||
rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
|
||||
|
@ -488,7 +487,7 @@ public:
|
|||
if (srcType.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
|
||||
|
||||
SmallVector<int64_t, 4> transp;
|
||||
SmallVector<int64_t> transp;
|
||||
for (auto attr : op.getTransp())
|
||||
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
if (transp[0] != 1 && transp[1] != 0)
|
||||
|
@ -685,8 +684,8 @@ struct ContractOpToElementwise
|
|||
bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
|
||||
newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
|
||||
newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
|
||||
SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
|
||||
SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
|
||||
SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
|
||||
SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
|
||||
newLhs = rewriter.create<vector::ExtractOp>(
|
||||
loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
|
||||
newRhs = rewriter.create<vector::ExtractOp>(
|
||||
|
@ -752,7 +751,7 @@ public:
|
|||
if (rank == 1) {
|
||||
// Express constant 1-D case in explicit vector form:
|
||||
// [T,..,T,F,..,F].
|
||||
SmallVector<bool, 4> values(dstType.getDimSize(0));
|
||||
SmallVector<bool> values(dstType.getDimSize(0));
|
||||
for (int64_t d = 0; d < trueDim; d++)
|
||||
values[d] = true;
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
|
@ -762,7 +761,7 @@ public:
|
|||
|
||||
VectorType lowType =
|
||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||
SmallVector<int64_t, 4> newDimSizes;
|
||||
SmallVector<int64_t> newDimSizes;
|
||||
for (int64_t r = 1; r < rank; r++)
|
||||
newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
|
||||
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
|
||||
|
@ -931,8 +930,8 @@ public:
|
|||
// x[0,1,0] = y[0,2]
|
||||
// etc., incrementing the two index vectors "row-major"
|
||||
// within the source and result shape.
|
||||
SmallVector<int64_t, 4> srcIdx(srcRank);
|
||||
SmallVector<int64_t, 4> resIdx(resRank);
|
||||
SmallVector<int64_t> srcIdx(srcRank);
|
||||
SmallVector<int64_t> resIdx(resRank);
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
||||
for (int64_t i = 0; i < numElts; i++) {
|
||||
|
@ -948,7 +947,7 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
|
||||
static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
|
||||
assert(0 <= r && r < tp.getRank());
|
||||
if (++idx[r] == tp.getDimSize(r)) {
|
||||
idx[r] = 0;
|
||||
|
@ -1039,7 +1038,7 @@ struct CombineContractABTranspose final
|
|||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<AffineMap, 4> maps =
|
||||
SmallVector<AffineMap> maps =
|
||||
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
|
||||
Value lhs = contractOp.getLhs();
|
||||
Value rhs = contractOp.getRhs();
|
||||
|
@ -1169,7 +1168,7 @@ struct CombineContractBroadcast
|
|||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<AffineMap, 4> maps =
|
||||
SmallVector<AffineMap> maps =
|
||||
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
|
||||
Value lhs = contractOp.getLhs();
|
||||
Value rhs = contractOp.getRhs();
|
||||
|
@ -1234,7 +1233,7 @@ struct CombineContractBroadcast
|
|||
for (auto &m : maps)
|
||||
m = compressDims(m, unusedDimsBitVector);
|
||||
// Compute the combined iterators.
|
||||
SmallVector<Attribute, 4> iterators;
|
||||
SmallVector<Attribute> iterators;
|
||||
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
|
||||
if (!unusedDimsBitVector.test(i))
|
||||
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
|
||||
|
@ -1328,7 +1327,7 @@ struct ReorderElementwiseOpsOnTranspose final
|
|||
|
||||
// Make sure all operands are transpose/constant ops and collect their
|
||||
// transposition maps.
|
||||
SmallVector<ArrayAttr, 4> transposeMaps;
|
||||
SmallVector<ArrayAttr> transposeMaps;
|
||||
transposeMaps.reserve(op->getNumOperands());
|
||||
// Record the initial type before transposition. We'll use its shape later.
|
||||
// Any type will do here as we will check all transpose maps are the same.
|
||||
|
@ -1350,7 +1349,7 @@ struct ReorderElementwiseOpsOnTranspose final
|
|||
if (!llvm::all_equal(transposeMaps))
|
||||
return rewriter.notifyMatchFailure(op, "different transpose map");
|
||||
|
||||
SmallVector<Value, 4> srcValues;
|
||||
SmallVector<Value> srcValues;
|
||||
srcValues.reserve(op->getNumOperands());
|
||||
|
||||
// If there are constant operands, we need to insert inverse transposes for
|
||||
|
@ -1724,7 +1723,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
|
|||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
|
||||
AffineExpr m, n, k;
|
||||
bindDims(rewriter.getContext(), m, n, k);
|
||||
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
|
||||
SmallVector<AffineMap> maps = op.getIndexingMapsArray();
|
||||
//
|
||||
// In the following we wish to make the reduction dimension innermost so we
|
||||
// can load vectors and just fmul + reduce into a scalar.
|
||||
|
@ -1940,7 +1939,7 @@ ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
|
|||
VectorType rhsType = op.getRhsType();
|
||||
VectorType resType = op.getResultType().cast<VectorType>();
|
||||
// Find the iterator type index and result index.
|
||||
SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
|
||||
SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
|
||||
int64_t iterIndex = -1;
|
||||
int64_t dimSize = -1;
|
||||
if (lhsIndex >= 0) {
|
||||
|
@ -2011,7 +2010,7 @@ ContractionOpLowering::lowerReduction(vector::ContractionOp op,
|
|||
bool isInt = resType.isa<IntegerType>();
|
||||
// Use iterator index 0.
|
||||
int64_t iterIndex = 0;
|
||||
SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
|
||||
SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
|
||||
Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
|
||||
Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
|
||||
if (!lookupLhs.has_value())
|
||||
|
@ -2087,7 +2086,7 @@ struct TransferReadToVectorLoadLowering
|
|||
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
|
||||
return failure();
|
||||
|
||||
SmallVector<unsigned, 4> broadcastedDims;
|
||||
SmallVector<unsigned> broadcastedDims;
|
||||
// Permutations are handled by VectorToSCF or
|
||||
// populateVectorTransferPermutationMapLoweringPatterns.
|
||||
// We let the 0-d corner case pass-through as it is supported.
|
||||
|
@ -2106,8 +2105,8 @@ struct TransferReadToVectorLoadLowering
|
|||
// If there is broadcasting involved then we first load the unbroadcasted
|
||||
// vector, and then broadcast it with `vector.broadcast`.
|
||||
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
|
||||
SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
|
||||
vectorShape.end());
|
||||
SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
|
||||
vectorShape.end());
|
||||
for (unsigned i : broadcastedDims)
|
||||
unbroadcastedVectorShape[i] = 1;
|
||||
VectorType unbroadcastedVectorType = VectorType::get(
|
||||
|
@ -2286,7 +2285,7 @@ struct TransferWriteToVectorStoreLowering
|
|||
};
|
||||
|
||||
// Returns the values in `arrayAttr` as an integer vector.
|
||||
static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
|
||||
static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
|
||||
[](IntegerAttr attr) { return attr.getInt(); }));
|
||||
|
@ -2410,7 +2409,7 @@ struct BubbleDownBitCastForStridedSliceExtract
|
|||
// dimension's offset given we are extracting from less elements now.
|
||||
ArrayAttr newOffsets = extractOp.getOffsets();
|
||||
if (newOffsets.size() == rank) {
|
||||
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
|
||||
SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
|
||||
if (offsets.back() % expandRatio != 0)
|
||||
return failure();
|
||||
offsets.back() = offsets.back() / expandRatio;
|
||||
|
@ -2420,14 +2419,14 @@ struct BubbleDownBitCastForStridedSliceExtract
|
|||
// Similarly for sizes.
|
||||
ArrayAttr newSizes = extractOp.getSizes();
|
||||
if (newSizes.size() == rank) {
|
||||
SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
|
||||
SmallVector<int64_t> sizes = getIntValueVector(newSizes);
|
||||
if (sizes.back() % expandRatio != 0)
|
||||
return failure();
|
||||
sizes.back() = sizes.back() / expandRatio;
|
||||
newSizes = rewriter.getI64ArrayAttr(sizes);
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> dims =
|
||||
SmallVector<int64_t> dims =
|
||||
llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
|
||||
dims.back() = dims.back() / expandRatio;
|
||||
VectorType newExtractType =
|
||||
|
@ -2500,13 +2499,13 @@ struct BubbleUpBitCastForStridedSliceInsert
|
|||
|
||||
ArrayAttr newOffsets = insertOp.getOffsets();
|
||||
assert(newOffsets.size() == rank);
|
||||
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
|
||||
SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
|
||||
if (offsets.back() % shrinkRatio != 0)
|
||||
return failure();
|
||||
offsets.back() = offsets.back() / shrinkRatio;
|
||||
newOffsets = rewriter.getI64ArrayAttr(offsets);
|
||||
|
||||
SmallVector<int64_t, 4> srcDims =
|
||||
SmallVector<int64_t> srcDims =
|
||||
llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
|
||||
srcDims.back() = srcDims.back() / shrinkRatio;
|
||||
VectorType newCastSrcType =
|
||||
|
@ -2515,7 +2514,7 @@ struct BubbleUpBitCastForStridedSliceInsert
|
|||
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
|
||||
bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
|
||||
|
||||
SmallVector<int64_t, 4> dstDims =
|
||||
SmallVector<int64_t> dstDims =
|
||||
llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
|
||||
dstDims.back() = dstDims.back() / shrinkRatio;
|
||||
VectorType newCastDstType =
|
||||
|
|
|
@ -27,24 +27,19 @@ using namespace mlir::vector;
|
|||
|
||||
/// During unrolling from `originalShape` to `targetShape` return the offset for
|
||||
/// the slice `index`.
|
||||
static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
int64_t index) {
|
||||
SmallVector<int64_t, 4> dstSliceStrides =
|
||||
computeStrides(originalShape, targetShape);
|
||||
SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
|
||||
SmallVector<int64_t, 4> elementOffsets =
|
||||
computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
|
||||
return elementOffsets;
|
||||
static SmallVector<int64_t> getVectorOffset(ArrayRef<int64_t> ratioStrides,
|
||||
int64_t index,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
return computeElementwiseMul(delinearize(ratioStrides, index), targetShape);
|
||||
}
|
||||
|
||||
/// A functor that accomplishes the same thing as `getVectorOffset` but allows
|
||||
/// for reordering the traversal of the dimensions. The order of traversal is
|
||||
/// given in "for loop order" (outer to inner).
|
||||
/// A functor that accomplishes the same thing as `getVectorOffset` but
|
||||
/// allows for reordering the traversal of the dimensions. The order of
|
||||
/// traversal is given in "for loop order" (outer to inner).
|
||||
namespace {
|
||||
class DecomposeShapeIterator {
|
||||
private:
|
||||
SmallVector<int64_t, 4> vectorShape;
|
||||
SmallVector<int64_t> vectorShape;
|
||||
SmallVector<int64_t> loopOrder;
|
||||
SmallVector<int64_t> sliceStrides;
|
||||
int64_t maxIndexVal{1};
|
||||
|
@ -56,15 +51,15 @@ public:
|
|||
: vectorShape(targetShape.begin(), targetShape.end()),
|
||||
loopOrder(loopOrder.begin(), loopOrder.end()),
|
||||
sliceStrides(originalShape.size()) {
|
||||
assert(originalShape.size() == targetShape.size());
|
||||
assert(loopOrder.size() == targetShape.size());
|
||||
assert(originalShape.size() >= targetShape.size());
|
||||
assert(loopOrder.size() == originalShape.size());
|
||||
|
||||
// Compute the count for each dimension.
|
||||
SmallVector<int64_t> sliceDimCounts(originalShape.size());
|
||||
for (unsigned r = 0; r < originalShape.size(); ++r) {
|
||||
sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
|
||||
maxIndexVal *= sliceDimCounts[r];
|
||||
}
|
||||
auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape);
|
||||
assert(maybeShapeRatio && "Shape does not evenly divide");
|
||||
// Pad `sliceDimCounts` with leading 1s so that all sizes match.
|
||||
SmallVector<int64_t> sliceDimCounts = *maybeShapeRatio;
|
||||
maxIndexVal = computeMaxLinearIndex(sliceDimCounts);
|
||||
|
||||
// Reversing "loop order" gives dimensions from fastest varying to slowest
|
||||
// varying (smallest stride to largest stride).
|
||||
|
@ -95,7 +90,7 @@ public:
|
|||
SmallVector<int64_t> getVectorOffset(int64_t index) const {
|
||||
SmallVector<int64_t> vectorOffsets = delinearize(index);
|
||||
SmallVector<int64_t> elementOffsets =
|
||||
computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
|
||||
computeElementwiseMul(vectorShape, vectorOffsets);
|
||||
return elementOffsets;
|
||||
}
|
||||
};
|
||||
|
@ -139,7 +134,7 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
|
|||
|
||||
/// Return the target shape for unrolling for the given `op`. Return llvm::None
|
||||
/// if the op shouldn't be or cannot be unrolled.
|
||||
static Optional<SmallVector<int64_t, 4>>
|
||||
static Optional<SmallVector<int64_t>>
|
||||
getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
|
||||
if (options.filterConstraint && failed(options.filterConstraint(op)))
|
||||
return llvm::None;
|
||||
|
@ -152,10 +147,10 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
|
|||
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
|
||||
if (!maybeUnrollShape)
|
||||
return llvm::None;
|
||||
Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
|
||||
Optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
|
||||
if (!targetShape)
|
||||
return llvm::None;
|
||||
auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
|
||||
auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
|
||||
if (!maybeShapeRatio ||
|
||||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
|
||||
return llvm::None;
|
||||
|
@ -197,7 +192,7 @@ struct UnrollTransferReadPattern
|
|||
if (!targetShape)
|
||||
return failure();
|
||||
auto sourceVectorType = readOp.getVectorType();
|
||||
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
|
||||
SmallVector<int64_t> strides(targetShape->size(), 1);
|
||||
Location loc = readOp.getLoc();
|
||||
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
|
||||
|
||||
|
@ -206,17 +201,16 @@ struct UnrollTransferReadPattern
|
|||
loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
|
||||
auto targetType =
|
||||
VectorType::get(*targetShape, sourceVectorType.getElementType());
|
||||
SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
|
||||
readOp.getIndices().end());
|
||||
SmallVector<Value> originalIndices(readOp.getIndices().begin(),
|
||||
readOp.getIndices().end());
|
||||
|
||||
SmallVector<int64_t> loopOrder =
|
||||
getUnrollOrder(originalSize.size(), readOp, options);
|
||||
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
|
||||
loopOrder);
|
||||
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
|
||||
SmallVector<int64_t, 4> elementOffsets =
|
||||
indexToOffsets.getVectorOffset(i);
|
||||
SmallVector<Value, 4> indices =
|
||||
SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
|
||||
SmallVector<Value> indices =
|
||||
sliceTransferIndices(elementOffsets, originalIndices,
|
||||
readOp.getPermutationMap(), loc, rewriter);
|
||||
auto slicedRead = rewriter.create<vector::TransferReadOp>(
|
||||
|
@ -255,11 +249,11 @@ struct UnrollTransferWritePattern
|
|||
if (!targetShape)
|
||||
return failure();
|
||||
auto sourceVectorType = writeOp.getVectorType();
|
||||
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
|
||||
SmallVector<int64_t> strides(targetShape->size(), 1);
|
||||
Location loc = writeOp.getLoc();
|
||||
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
|
||||
SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
|
||||
writeOp.getIndices().end());
|
||||
SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
|
||||
writeOp.getIndices().end());
|
||||
|
||||
SmallVector<int64_t> loopOrder =
|
||||
getUnrollOrder(originalSize.size(), writeOp, options);
|
||||
|
@ -267,11 +261,10 @@ struct UnrollTransferWritePattern
|
|||
loopOrder);
|
||||
Value resultTensor;
|
||||
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
|
||||
SmallVector<int64_t, 4> elementOffsets =
|
||||
indexToOffsets.getVectorOffset(i);
|
||||
SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
|
||||
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
|
||||
SmallVector<Value, 4> indices =
|
||||
SmallVector<Value> indices =
|
||||
sliceTransferIndices(elementOffsets, originalIndices,
|
||||
writeOp.getPermutationMap(), loc, rewriter);
|
||||
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
|
||||
|
@ -321,7 +314,7 @@ struct UnrollContractionPattern
|
|||
if (!targetShape)
|
||||
return failure();
|
||||
auto dstVecType = contractOp.getResultType().cast<VectorType>();
|
||||
SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
|
||||
SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
|
||||
|
||||
Location loc = contractOp.getLoc();
|
||||
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
|
||||
|
@ -337,16 +330,16 @@ struct UnrollContractionPattern
|
|||
loopOrder);
|
||||
const int64_t sliceCount = indexToOffsets.maxIndex();
|
||||
for (int64_t i = 0; i < sliceCount; i++) {
|
||||
SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
|
||||
SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
|
||||
SmallVector<int64_t> offsets = indexToOffsets.getVectorOffset(i);
|
||||
SmallVector<Value> slicesOperands(contractOp.getNumOperands());
|
||||
|
||||
// Helper to coompute the new shape of each operand and extract the slice.
|
||||
// Helper to compute the new shape of each operand and extract the slice.
|
||||
auto extractOperand = [&](unsigned index, Value operand,
|
||||
AffineMap permutationMap,
|
||||
ArrayRef<int64_t> operandOffets) {
|
||||
SmallVector<int64_t> operandShape = applyPermutationMap(
|
||||
permutationMap, ArrayRef<int64_t>(*targetShape));
|
||||
SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
|
||||
SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
|
||||
slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, operand, operandOffets, operandShape, operandStrides);
|
||||
};
|
||||
|
@ -420,12 +413,12 @@ struct UnrollMultiReductionPattern
|
|||
|
||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Optional<SmallVector<int64_t, 4>> targetShape =
|
||||
Optional<SmallVector<int64_t>> targetShape =
|
||||
getTargetShape(options, reductionOp);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
|
||||
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
|
||||
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
|
||||
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
|
||||
llvm::MapVector<
|
||||
SmallVector<int64_t>, Value,
|
||||
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
|
||||
|
@ -433,12 +426,16 @@ struct UnrollMultiReductionPattern
|
|||
// Compute shape ratio of 'shape' and 'sizes'.
|
||||
int64_t sliceCount = computeMaxLinearIndex(ratio);
|
||||
Location loc = reductionOp.getLoc();
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
auto ratioStrides = computeStrides(ratio);
|
||||
for (int64_t i = 0; i < sliceCount; i++) {
|
||||
SmallVector<int64_t, 4> offsets =
|
||||
getVectorOffset(originalSize, *targetShape, i);
|
||||
SmallVector<int64_t> offsets =
|
||||
getVectorOffset(ratioStrides, i, *targetShape);
|
||||
|
||||
SmallVector<Value> operands;
|
||||
SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
|
||||
SmallVector<int64_t> operandStrides(offsets.size(), 1);
|
||||
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
|
||||
operands.push_back(slicedOperand);
|
||||
|
@ -451,7 +448,7 @@ struct UnrollMultiReductionPattern
|
|||
}
|
||||
}
|
||||
Value acc;
|
||||
SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
|
||||
SmallVector<int64_t> accStrides(destOffset.size(), 1);
|
||||
// If a version of the accumulator has already been computed, use it
|
||||
// otherwise extract the first version from the original operand.
|
||||
auto accIt = accCache.find(destOffset);
|
||||
|
@ -500,21 +497,25 @@ struct UnrollElementwisePattern : public RewritePattern {
|
|||
if (!targetShape)
|
||||
return failure();
|
||||
auto dstVecType = op->getResult(0).getType().cast<VectorType>();
|
||||
SmallVector<int64_t, 4> originalSize =
|
||||
SmallVector<int64_t> originalSize =
|
||||
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
|
||||
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
|
||||
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
|
||||
int64_t sliceCount = computeMaxLinearIndex(ratio);
|
||||
Location loc = op->getLoc();
|
||||
// Prepare the result vector.
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
|
||||
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
|
||||
SmallVector<int64_t> strides(targetShape->size(), 1);
|
||||
VectorType newVecType =
|
||||
VectorType::get(*targetShape, dstVecType.getElementType());
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
auto ratioStrides = computeStrides(ratio);
|
||||
for (int64_t i = 0; i < sliceCount; i++) {
|
||||
SmallVector<int64_t, 4> offsets =
|
||||
getVectorOffset(originalSize, *targetShape, i);
|
||||
SmallVector<Value, 4> extractOperands;
|
||||
SmallVector<int64_t> offsets =
|
||||
getVectorOffset(ratioStrides, i, *targetShape);
|
||||
SmallVector<Value> extractOperands;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto vecType = operand.get().getType().template dyn_cast<VectorType>();
|
||||
if (!vecType) {
|
||||
|
@ -547,19 +548,24 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Optional<SmallVector<int64_t, 4>> targetShape =
|
||||
Optional<SmallVector<int64_t>> targetShape =
|
||||
getTargetShape(options, reductionOp);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
|
||||
int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
|
||||
auto ratio = *computeShapeRatio(originalSize, *targetShape);
|
||||
int64_t sliceCount = ratio[0];
|
||||
|
||||
// Create unrolled vector reduction.
|
||||
Location loc = reductionOp.getLoc();
|
||||
Value accumulator = nullptr;
|
||||
for (int64_t i = 0; i < ratio; ++i) {
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
auto ratioStrides = computeStrides(ratio);
|
||||
for (int64_t i = 0; i < sliceCount; ++i) {
|
||||
SmallVector<int64_t> offsets =
|
||||
getVectorOffset(originalSize, *targetShape, i);
|
||||
getVectorOffset(ratioStrides, i, *targetShape);
|
||||
SmallVector<int64_t> strides(offsets.size(), 1);
|
||||
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, reductionOp.getVector(), offsets, *targetShape, strides);
|
||||
|
@ -600,21 +606,25 @@ struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
|
|||
if (!targetShape)
|
||||
return failure();
|
||||
auto originalVectorType = tranposeOp.getResultType();
|
||||
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
|
||||
SmallVector<int64_t> strides(targetShape->size(), 1);
|
||||
Location loc = tranposeOp.getLoc();
|
||||
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
|
||||
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
|
||||
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
|
||||
int64_t sliceCount = computeMaxLinearIndex(ratio);
|
||||
// Prepare the result vector;
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
|
||||
SmallVector<int64_t> permutation;
|
||||
tranposeOp.getTransp(permutation);
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
auto ratioStrides = computeStrides(ratio);
|
||||
for (int64_t i = 0; i < sliceCount; i++) {
|
||||
SmallVector<int64_t, 4> elementOffsets =
|
||||
getVectorOffset(originalSize, *targetShape, i);
|
||||
SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
|
||||
SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
|
||||
SmallVector<int64_t> elementOffsets =
|
||||
getVectorOffset(ratioStrides, i, *targetShape);
|
||||
SmallVector<int64_t> permutedOffsets(elementOffsets.size());
|
||||
SmallVector<int64_t> permutedShape(elementOffsets.size());
|
||||
// Compute the source offsets and shape.
|
||||
for (auto &indices : llvm::enumerate(permutation)) {
|
||||
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
|
@ -25,7 +26,6 @@
|
|||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include <numeric>
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
@ -43,78 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
|
|||
llvm_unreachable("Expected MemRefType or TensorType");
|
||||
}
|
||||
|
||||
/// Return the number of elements of basis, `0` if empty.
|
||||
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
|
||||
if (basis.empty())
|
||||
return 0;
|
||||
return std::accumulate(basis.begin(), basis.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> sizes) {
|
||||
int64_t rank = shape.size();
|
||||
// Compute the count for each dimension.
|
||||
SmallVector<int64_t, 4> sliceDimCounts(rank);
|
||||
for (int64_t r = 0; r < rank; ++r)
|
||||
sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
|
||||
// Use that to compute the slice stride for each dimension.
|
||||
SmallVector<int64_t, 4> sliceStrides(rank);
|
||||
sliceStrides[rank - 1] = 1;
|
||||
for (int64_t r = rank - 2; r >= 0; --r)
|
||||
sliceStrides[r] = sliceStrides[r + 1] * sliceDimCounts[r + 1];
|
||||
return sliceStrides;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
|
||||
ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
|
||||
SmallVector<int64_t, 4> result;
|
||||
for (auto it : llvm::zip(vectorOffsets, sizes))
|
||||
result.push_back(std::get<0>(it) * std::get<1>(it));
|
||||
return result;
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
|
||||
ArrayRef<int64_t> subShape) {
|
||||
if (superShape.size() < subShape.size()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Starting from the end, compute the integer divisors.
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(superShape.size());
|
||||
for (auto [superSize, subSize] :
|
||||
llvm::zip(llvm::reverse(superShape), llvm::reverse(subShape))) {
|
||||
assert(superSize > 0 && "superSize must be > 0");
|
||||
assert(subSize > 0 && "subSize must be > 0");
|
||||
|
||||
// If integral division does not occur, return and let the caller decide.
|
||||
if (superSize % subSize != 0)
|
||||
return None;
|
||||
result.push_back(superSize / subSize);
|
||||
}
|
||||
|
||||
// At this point we computed the ratio (in reverse) for the common
|
||||
// size. Fill with the remaining entries from the super-vector shape (still in
|
||||
// reverse).
|
||||
int commonSize = subShape.size();
|
||||
std::copy(superShape.rbegin() + commonSize, superShape.rend(),
|
||||
std::back_inserter(result));
|
||||
|
||||
assert(result.size() == superShape.size() &&
|
||||
"super to sub shape ratio is not of the same size as the super rank");
|
||||
|
||||
// Reverse again to get it back in the proper order and return.
|
||||
return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
|
||||
VectorType subVectorType) {
|
||||
assert(superVectorType.getElementType() == subVectorType.getElementType() &&
|
||||
"vector types must be of the same elemental type");
|
||||
return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
|
||||
}
|
||||
|
||||
/// Constructs a permutation map from memref indices to vector dimension.
|
||||
///
|
||||
/// The implementation uses the knowledge of the mapping of enclosing loop to
|
||||
|
@ -144,8 +72,8 @@ static AffineMap makePermutationMap(
|
|||
return AffineMap();
|
||||
MLIRContext *context =
|
||||
enclosingLoopToVectorDim.begin()->getFirst()->getContext();
|
||||
SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
|
||||
getAffineConstantExpr(0, context));
|
||||
SmallVector<AffineExpr> perm(enclosingLoopToVectorDim.size(),
|
||||
getAffineConstantExpr(0, context));
|
||||
|
||||
for (auto kvp : enclosingLoopToVectorDim) {
|
||||
assert(kvp.second < perm.size());
|
||||
|
@ -252,7 +180,8 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
|
|||
}
|
||||
|
||||
// Get the ratio.
|
||||
auto ratio = shapeRatio(superVectorType, subVectorType);
|
||||
auto ratio =
|
||||
computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
|
||||
|
||||
// Sanity check.
|
||||
assert((ratio || !mustDivide) &&
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Dialect/Affine/LoopUtils.h"
|
||||
#include "mlir/Dialect/Affine/Utils.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -126,7 +127,8 @@ void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
|
|||
// purpose of this test. If we need to test more intricate behavior in the
|
||||
// future we can always extend.
|
||||
auto superVectorType = opInst->getResult(0).getType().cast<VectorType>();
|
||||
auto ratio = shapeRatio(superVectorType, subVectorType);
|
||||
auto ratio =
|
||||
computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
|
||||
if (!ratio) {
|
||||
opInst->emitRemark("NOT MATCHED");
|
||||
} else {
|
||||
|
|
|
@ -72,11 +72,11 @@ struct TestVectorToVectorLowering
|
|||
|
||||
private:
|
||||
// Return the target shape based on op type.
|
||||
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
|
||||
static Optional<SmallVector<int64_t>> getShape(Operation *op) {
|
||||
if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
|
||||
return SmallVector<int64_t, 4>(2, 2);
|
||||
return SmallVector<int64_t>(2, 2);
|
||||
if (isa<vector::ContractionOp>(op))
|
||||
return SmallVector<int64_t, 4>(3, 2);
|
||||
return SmallVector<int64_t>(3, 2);
|
||||
// For transfer ops, just propagate the shape coming from
|
||||
// InsertStridedSlices/ExtractStridedSlices.
|
||||
if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
|
||||
|
@ -90,15 +90,15 @@ private:
|
|||
return llvm::None;
|
||||
dstVec = vecType;
|
||||
}
|
||||
return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
|
||||
dstVec.getShape().end());
|
||||
return SmallVector<int64_t>(dstVec.getShape().begin(),
|
||||
dstVec.getShape().end());
|
||||
}
|
||||
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
|
||||
auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
|
||||
if (!insert)
|
||||
return llvm::None;
|
||||
ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
|
||||
return SmallVector<int64_t, 4>(shape.begin(), shape.end());
|
||||
return SmallVector<int64_t>(shape.begin(), shape.end());
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
@ -314,10 +314,10 @@ struct TestVectorUnrollingPatterns
|
|||
|
||||
if (unrollBasedOnType) {
|
||||
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
|
||||
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
|
||||
[](Operation *op) -> Optional<SmallVector<int64_t>> {
|
||||
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
|
||||
SmallVector<int64_t, 4> nativeShape(
|
||||
contractOp.getIteratorTypes().size(), 4);
|
||||
SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
|
||||
4);
|
||||
Type lhsType = contractOp.getLhsType().getElementType();
|
||||
nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
|
||||
return nativeShape;
|
||||
|
@ -339,12 +339,11 @@ struct TestVectorUnrollingPatterns
|
|||
}
|
||||
populateVectorUnrollPatterns(patterns, opts);
|
||||
} else {
|
||||
auto nativeShapeFn =
|
||||
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
|
||||
auto nativeShapeFn = [](Operation *op) -> Optional<SmallVector<int64_t>> {
|
||||
auto contractOp = dyn_cast<ContractionOp>(op);
|
||||
if (!contractOp)
|
||||
return None;
|
||||
return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
|
||||
return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
|
||||
};
|
||||
populateVectorUnrollPatterns(patterns,
|
||||
UnrollVectorOptions()
|
||||
|
|
|
@ -61,7 +61,6 @@ struct LinalgTransformationFilter {
|
|||
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
|
||||
void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
|
||||
Operation *op) const;
|
||||
bool hasReplacementFilter(Operation *op) const;
|
||||
|
||||
LinalgTransformationFilter &addFilter(const FilterFunction &f) {
|
||||
if (f)
|
||||
|
@ -100,15 +99,6 @@ LinalgTransformationFilter::LinalgTransformationFilter(
|
|||
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
|
||||
replacement(replacement), matchByDefault(false) {}
|
||||
|
||||
LinalgTransformationFilter::LinalgTransformationFilter(
|
||||
const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
|
||||
Optional<StringAttr> replacement)
|
||||
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
|
||||
replacement(replacement), matchByDefault(false) {
|
||||
if (f)
|
||||
filters.push_back(f);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
|
||||
Operation *op) const {
|
||||
|
@ -150,13 +140,6 @@ void LinalgTransformationFilter::replaceLinalgTransformationFilter(
|
|||
op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker));
|
||||
}
|
||||
|
||||
bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const {
|
||||
if (!replacement)
|
||||
return false;
|
||||
auto attr = op->getAttr(kLinalgTransformMarker).dyn_cast<StringAttr>();
|
||||
return attr && attr == *replacement;
|
||||
}
|
||||
|
||||
/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
|
||||
/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
|
||||
/// using a `filter` to avoid recursive application.
|
||||
|
|
Loading…
Reference in New Issue