[mlir][sparse] introduce vectorization pass for sparse loops

This brings back previous SIMD functionality, but in a separate pass.
The idea is to improve this new pass incrementally, going beyond for-loops
to while-loops for co-iteration as welll (masking), while introducing new
abstractions to make the lowering more progressive. The separation of
sparsification and vectorization is a very good first step on this journey.

Also brings back ArmSVE support

Still to be fine-tuned:
  + use of "index" in SIMD loop (viz. a[i] = i)
  + check that all ops really have SIMD support
  + check all forms of reductions
  + chain reduction SIMD values

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D138236
This commit is contained in:
Aart Bik 2022-11-18 12:18:00 -08:00
parent 9df8ba631d
commit 99b3849d89
7 changed files with 1016 additions and 89 deletions

View File

@ -172,6 +172,16 @@ std::unique_ptr<Pass> createSparseBufferRewritePass();
std::unique_ptr<Pass>
createSparseBufferRewritePass(bool enableBufferInitialization);
void populateSparseVectorizationPatterns(RewritePatternSet &patterns,
unsigned vectorLength,
bool enableVLAVectorization,
bool enableSIMDIndex32);
std::unique_ptr<Pass> createSparseVectorizationPass();
std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
bool enableVLAVectorization,
bool enableSIMDIndex32);
//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//

View File

@ -225,4 +225,64 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
];
}
def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
let summary = "Vectorizes loops after sparsification";
let description = [{
A pass that converts loops after sparsification into vector loops.
The vector dialect is used as target to provide an architectural
neutral way of exploiting any platform that supports SIMD instructions.
The vector length (viz. `vl`) describes the number of packed data elements
(e.g. both vector<16xf32> and vector<16xf64> have a vector length of 16 even
though the actual bitwidths differ). A small multiple of the actual lengths
supported in hardware typically results in efficient SIMD code, since the
backend will map longer vectors to multiple vector registers, thereby
effectively unrolling an addition level within the generated for-loop.
Example of the conversion:
```mlir
Before:
%3 = memref.load %2[] : memref<f32>
%4 = scf.for %arg3 = %c0 to %c1024 step %c1 iter_args(%arg4 = %3) -> (f32) {
%6 = memref.load %0[%arg3] : memref<?xf32>
%7 = memref.load %1[%arg3] : memref<1024xf32>
%8 = arith.mulf %6, %7 : f32
%9 = arith.addf %arg4, %8 : f32
scf.yield %9 : f32
}
memref.store %4, %2[] : memref<f32>
After:
%3 = memref.load %2[] : memref<f32>
%4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32>
%5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) {
%8 = vector.load %0[%arg3] : memref<?xf32>, vector<32xf32>
%9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32>
%10 = arith.mulf %8, %9 : vector<32xf32>
%11 = arith.addf %arg4, %10 : vector<32xf32>
scf.yield %11 : vector<32xf32>
}
%6 = vector.reduction <add>, %5 : vector<32xf32> into f32
memref.store %6, %2[] : memref<f32>
```
}];
let constructor = "mlir::createSparseVectorizationPass()";
let dependentDialects = [
"arith::ArithDialect",
"memref::MemRefDialect",
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
"vector::VectorDialect",
];
let options = [
Option<"vectorLength", "vl", "int32_t", "0",
"Set the vector length (use 0 to disable vectorization)">,
Option<"enableVLAVectorization", "enable-vla-vectorization", "bool",
"false", "Enable vector length agnostic vectorization">,
Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false",
"Enable i32 indexing into vectors (for efficient gather/scatter)">,
];
}
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

View File

@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseTensorConversion.cpp
SparseTensorPasses.cpp
SparseTensorRewriting.cpp
SparseVectorization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

View File

@ -27,6 +27,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
#define GEN_PASS_DEF_SPARSEVECTORIZATION
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
@ -67,10 +68,9 @@ struct SparsificationPass
auto *ctx = &getContext();
// Translate strategy flags to strategy options.
SparsificationOptions options(parallelization);
// Apply sparsification and vector cleanup rewriting.
// Apply sparsification and cleanup rewriting.
RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
@ -250,6 +250,27 @@ struct SparseBufferRewritePass
}
};
struct SparseVectorizationPass
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
SparseVectorizationPass() = default;
SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
vectorLength = vl;
enableVLAVectorization = vla;
enableSIMDIndex32 = sidx32;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseVectorizationPatterns(
patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
//===----------------------------------------------------------------------===//
@ -322,3 +343,15 @@ std::unique_ptr<Pass>
mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
}
std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
return std::make_unique<SparseVectorizationPass>();
}
std::unique_ptr<Pass>
mlir::createSparseVectorizationPass(unsigned vectorLength,
bool enableVLAVectorization,
bool enableSIMDIndex32) {
return std::make_unique<SparseVectorizationPass>(
vectorLength, enableVLAVectorization, enableSIMDIndex32);
}

View File

@ -0,0 +1,485 @@
//===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A pass that converts loops generated by the sparse compiler into a form that
// can exploit SIMD instructions of the target architecture. Note that this pass
// ensures the sparse compiler can generate efficient SIMD (including ArmSVE
// support) with proper separation of concerns as far as sparsification and
// vectorization is concerned. However, this pass is not the final abstraction
// level we want, and not the general vectorizer we want either. It forms a good
// stepping stone for incremental future improvements though.
//
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Matchers.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
namespace {
/// Target SIMD properties:
/// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
/// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
/// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
struct VL {
unsigned vectorLength;
bool enableVLAVectorization;
bool enableSIMDIndex32;
};
/// Helper to test for given index value.
static bool isIntValue(Value val, int64_t idx) {
if (auto ival = getConstantIntValue(val))
return *ival == idx;
return false;
}
/// Constructs vector type for element type.
static VectorType vectorType(VL vl, Type etp) {
unsigned numScalableDims = vl.enableVLAVectorization;
return VectorType::get(vl.vectorLength, etp, numScalableDims);
}
/// Constructs vector type from pointer.
static VectorType vectorType(VL vl, Value ptr) {
return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType());
}
/// Constructs vector iteration mask.
static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
Value iv, Value lo, Value hi, Value step) {
VectorType mtp = vectorType(vl, rewriter.getI1Type());
// Special case if the vector length evenly divides the trip count (for
// example, "for i = 0, 128, 16"). A constant all-true mask is generated
// so that all subsequent masked memory operations are immediately folded
// into unconditional memory operations.
IntegerAttr loInt, hiInt, stepInt;
if (matchPattern(lo, m_Constant(&loInt)) &&
matchPattern(hi, m_Constant(&hiInt)) &&
matchPattern(step, m_Constant(&stepInt))) {
if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
Value trueVal = constantI1(rewriter, loc, true);
return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
}
}
// Otherwise, generate a vector mask that avoids overrunning the upperbound
// during vector execution. Here we rely on subsequent loop optimizations to
// avoid executing the mask in all iterations, for example, by splitting the
// loop into an unconditional vector loop and a scalar cleanup loop.
auto min = AffineMap::get(
/*dimCount=*/2, /*symbolCount=*/1,
{rewriter.getAffineSymbolExpr(0),
rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
rewriter.getContext());
Value end =
rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step});
return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
}
/// Generates a vectorized invariant. Here we rely on subsequent loop
/// optimizations to hoist the invariant broadcast out of the vector loop.
static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
Value val) {
VectorType vtp = vectorType(vl, val.getType());
return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
}
/// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
/// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
Value ptr, ArrayRef<Value> idxs, Value vmask) {
VectorType vtp = vectorType(vl, ptr);
Value pass = constantZero(rewriter, loc, vtp);
if (idxs.back().getType().isa<VectorType>()) {
SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
Value indexVec = idxs.back();
scalarArgs.back() = constantIndex(rewriter, loc, 0);
return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs,
indexVec, vmask, pass);
}
return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask,
pass);
}
/// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
/// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
ArrayRef<Value> idxs, Value vmask, Value rhs) {
if (idxs.back().getType().isa<VectorType>()) {
SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
Value indexVec = idxs.back();
scalarArgs.back() = constantIndex(rewriter, loc, 0);
rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask,
rhs);
return;
}
rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs);
}
/// Maps operation to combining kind for reduction.
static vector::CombiningKind getCombiningKind(Operation *def) {
if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def) ||
isa<arith::SubFOp>(def) || isa<arith::SubIOp>(def))
return vector::CombiningKind::ADD;
if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
return vector::CombiningKind::MUL;
if (isa<arith::AndIOp>(def))
return vector::CombiningKind::AND;
if (isa<arith::OrIOp>(def))
return vector::CombiningKind::OR;
if (isa<arith::XOrIOp>(def))
return vector::CombiningKind::XOR;
llvm_unreachable("unknown reduction kind");
}
/// Generates an initial value for a vector reduction, following the scheme
/// given in Chapter 5 of "The Software Vectorization Handbook", where the
/// initial scalar value is correctly embedded in the vector reduction value,
/// and a straightforward horizontal reduction will complete the operation.
/// The value 'r' denotes the initial value of the accumulator. Value 'rd'
/// denotes the accumulation operation, which is solely used here to determine
/// the kind of combining reduction (viz. addf -> sum-accumulation).
static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
VectorType vtp, Value r, Value rd) {
vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
switch (kind) {
case vector::CombiningKind::ADD:
case vector::CombiningKind::XOR:
// Initialize reduction vector to: | 0 | .. | 0 | r |
return rewriter.create<vector::InsertElementOp>(
loc, r, constantZero(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::MUL:
// Initialize reduction vector to: | 1 | .. | 1 | r |
return rewriter.create<vector::InsertElementOp>(
loc, r, constantOne(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::AND:
case vector::CombiningKind::OR:
// Initialize reduction vector to: | r | .. | r | r |
return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
default:
break;
}
llvm_unreachable("unknown reduction kind");
}
/// Generates final value for a vector reduction.
static Value genVectorReducEnd(PatternRewriter &rewriter, Location loc,
Value vexp, Value rd) {
vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
return rewriter.create<vector::ReductionOp>(loc, kind, vexp);
}
/// This method is called twice to analyze and rewrite the given subscripts.
/// The first call (!codegen) does the analysis. Then, on success, the second
/// call (codegen) yields the proper vector form in the output parameter
/// vector 'idxs'. This mechanism ensures that analysis and rewriting code
/// stay in sync.
///
/// See https://llvm.org/docs/GetElementPtr.html for some background on
/// the complications described below.
///
/// We need to generate a pointer/index load from the sparse storage scheme.
/// Narrower data types need to be zero extended before casting the value
/// into the index type used for looping and indexing.
///
/// For the scalar case, subscripts simply zero extend narrower indices
/// into 64-bit values before casting to an index type without a performance
/// penalty. Indices that already are 64-bit, in theory, cannot express the
/// full range since the LLVM backend defines addressing in terms of an
/// unsigned pointer/signed index pair.
static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
VL vl, ValueRange subs, bool codegen,
Value vmask, SmallVectorImpl<Value> &idxs) {
for (auto sub : subs) {
// Invariant indices simply pass through.
if (sub.dyn_cast<BlockArgument>() ||
sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) {
if (codegen)
idxs.push_back(sub);
continue; // success so far
}
// Look under the hood of casting.
auto cast = sub;
while (1) {
if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
cast = icast->getOperand(0);
else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
cast = ecast->getOperand(0);
else
break;
}
// Since the index vector is used in a subsequent gather/scatter
// operations, which effectively defines an unsigned pointer + signed
// index, we must zero extend the vector to an index width. For 8-bit
// and 16-bit values, an 32-bit index width suffices. For 32-bit values,
// zero extending the elements into 64-bit loses some performance since
// the 32-bit indexed gather/scatter is more efficient than the 64-bit
// index variant (if the negative 32-bit index space is unused, the
// enableSIMDIndex32 flag can preserve this performance). For 64-bit
// values, there is no good way to state that the indices are unsigned,
// which creates the potential of incorrect address calculations in the
// unlikely case we need such extremely large offsets.
if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
if (codegen) {
SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
Location loc = forOp.getLoc();
Value vload =
genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
Type etp = vload.getType().cast<VectorType>().getElementType();
if (!etp.isa<IndexType>()) {
if (etp.getIntOrFloatBitWidth() < 32)
vload = rewriter.create<arith::ExtUIOp>(
loc, vectorType(vl, rewriter.getI32Type()), vload);
else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
vload = rewriter.create<arith::ExtUIOp>(
loc, vectorType(vl, rewriter.getI64Type()), vload);
}
idxs.push_back(vload);
}
continue; // success so far
}
return false;
}
return true;
}
#define UNAOP(xxx) \
if (isa<xxx>(def)) { \
if (codegen) \
vexp = rewriter.create<xxx>(loc, vx); \
return true; \
}
#define BINOP(xxx) \
if (isa<xxx>(def)) { \
if (codegen) \
vexp = rewriter.create<xxx>(loc, vx, vy); \
return true; \
}
/// This method is called twice to analyze and rewrite the given expression.
/// The first call (!codegen) does the analysis. Then, on success, the second
/// call (codegen) yields the proper vector form in the output parameter 'vexp'.
/// This mechanism ensures that analysis and rewriting code stay in sync.
static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
Value exp, bool codegen, Value vmask, Value &vexp) {
// A block argument in invariant.
if (auto arg = exp.dyn_cast<BlockArgument>()) {
if (codegen)
vexp = genVectorInvariantValue(rewriter, vl, exp);
return true;
}
// Something defined outside the loop-body is invariant as well.
Operation *def = exp.getDefiningOp();
if (def->getBlock() != &forOp.getRegion().front()) {
if (codegen)
vexp = genVectorInvariantValue(rewriter, vl, exp);
return true;
}
// Inside loop-body unary and binary operations. Note that it would be
// nicer if we could somehow test and build the operations in a more
// concise manner than just listing them all (although this way we know
// for certain that they can vectorize).
Location loc = forOp.getLoc();
if (auto load = dyn_cast<memref::LoadOp>(def)) {
auto subs = load.getIndices();
SmallVector<Value> idxs;
if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
if (codegen)
vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
return true;
}
} else if (def->getNumOperands() == 1) {
Value vx;
if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
vx)) {
UNAOP(math::AbsFOp)
UNAOP(math::AbsIOp)
UNAOP(math::CeilOp)
UNAOP(math::FloorOp)
UNAOP(math::SqrtOp)
UNAOP(math::ExpM1Op)
UNAOP(math::Log1pOp)
UNAOP(math::SinOp)
UNAOP(math::TanhOp)
UNAOP(arith::NegFOp)
}
} else if (def->getNumOperands() == 2) {
Value vx, vy;
if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
vx) &&
vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
vy)) {
BINOP(arith::MulFOp)
BINOP(arith::MulIOp)
BINOP(arith::DivFOp)
BINOP(arith::DivSIOp)
BINOP(arith::DivUIOp)
BINOP(arith::AddFOp)
BINOP(arith::AddIOp)
BINOP(arith::SubFOp)
BINOP(arith::SubIOp)
BINOP(arith::AndIOp)
BINOP(arith::OrIOp)
BINOP(arith::XOrIOp)
}
}
return false;
}
#undef UNAOP
#undef BINOP
/// This method is called twice to analyze and rewrite the given for-loop.
/// The first call (!codegen) does the analysis. Then, on success, the second
/// call (codegen) rewriters the IR into vector form. This mechanism ensures
/// that analysis and rewriting code stay in sync.
static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
bool codegen) {
Location loc = forOp.getLoc();
Block &block = forOp.getRegion().front();
scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
auto &last = *++block.rbegin();
scf::ForOp forOpNew;
// Perform initial set up during codegen (we know that the first analysis
// pass was successful). For reductions, we need to construct a completely
// new for-loop, since the incoming and outgoing reduction type
// changes into SIMD form. For stores, we can simply adjust the stride
// and insert in the existing for-loop. In both cases, we set up a vector
// mask for all operations which takes care of confining vectors to
// the original iteration space (later cleanup loops or other
// optimizations can take care of those).
Value vmask;
if (codegen) {
Value step = constantIndex(rewriter, loc, vl.vectorLength);
if (vl.enableVLAVectorization) {
Value vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
step = rewriter.create<arith::MulIOp>(loc, vscale, step);
}
if (!yield.getResults().empty()) {
Value init = forOp.getInitArgs()[0];
VectorType vtp = vectorType(vl, init.getType());
Value vinit =
genVectorReducInit(rewriter, loc, vtp, init, yield->getOperand(0));
forOpNew = rewriter.create<scf::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
rewriter.setInsertionPointToStart(forOpNew.getBody());
} else {
forOp.setStep(step);
rewriter.setInsertionPoint(yield);
}
vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
forOp.getLowerBound(), forOp.getUpperBound(), step);
}
// Sparse for-loops either are terminated by a non-empty yield operation
// (reduction loop) or otherwise by a store operation (pararallel loop).
if (!yield.getResults().empty()) {
if (yield->getNumOperands() != 1)
return false;
Value redOp = yield->getOperand(0);
// Analyze/vectorize reduction.
// TODO: use linalg utils to verify the actual reduction?
Value vrhs;
if (vectorizeExpr(rewriter, forOp, vl, redOp, codegen, vmask, vrhs)) {
if (codegen) {
Value vpass =
genVectorInvariantValue(rewriter, vl, forOp.getRegionIterArg(0));
Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
rewriter.create<scf::YieldOp>(loc, vred);
rewriter.setInsertionPointAfter(forOpNew);
Value vres = genVectorReducEnd(rewriter, loc, forOpNew.getResult(0), redOp);
// Now do some relinking (last one is not completely type safe
// but all bad ones are removed right away). This also folds away
// nop broadcast operations.
forOp.getResult(0).replaceAllUsesWith(vres);
forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
forOp.getRegionIterArg(0).replaceAllUsesWith(
forOpNew.getRegionIterArg(0));
rewriter.eraseOp(forOp);
}
return true;
}
} else if (auto store = dyn_cast<memref::StoreOp>(last)) {
// Analyze/vectorize store operation.
auto subs = store.getIndices();
SmallVector<Value> idxs;
Value rhs = store.getValue();
Value vrhs;
if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
if (codegen) {
genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
rewriter.eraseOp(store);
}
return true;
}
}
assert(!codegen && "cannot call codegen when analysis failed");
return false;
}
/// Basic for-loop vectorizer.
struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
public:
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
ForOpRewriter(MLIRContext *context, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32)
: OpRewritePattern(context),
vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
// Check for single block, unit-stride for-loop that is generated by
// sparse compiler, which means no data dependence analysis is required,
// and its loop-body is very restricted in form.
if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) ||
!op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()))
return failure();
// Analyze (!codegen) and rewrite (codegen) loop-body.
if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
return success();
return failure();
}
private:
const VL vl;
};
} // namespace
//===----------------------------------------------------------------------===//
// Public method for populating vectorization rules.
//===----------------------------------------------------------------------===//
/// Populates the given patterns list with vectorization rules.
void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
unsigned vectorLength,
bool enableVLAVectorization,
bool enableSIMDIndex32) {
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
enableVLAVectorization, enableSIMDIndex32);
}

510
mlir/test/Dialect/SparseTensor/sparse_vector.mlir Normal file → Executable file
View File

@ -1,5 +1,11 @@
// RUN: mlir-opt %s -sparsification -cse -split-input-file | \
// RUN: FileCheck %s
// RUN: FileCheck %s --check-prefix=CHECK-SCALAR
// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16" -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-VEC16
// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16 enable-simd-index32=true" -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-VEC16-IDX32
// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=4 enable-vla-vectorization=true" -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-VEC4-SVE
#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
@ -13,18 +19,59 @@
}
//
// CHECK-LABEL: func @scale_d
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
// CHECK: %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
// CHECK: %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32
// CHECK: store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
// CHECK: }
// CHECK: return
// CHECK-SCALAR-LABEL: func @scale_d
// CHECK-SCALAR-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-SCALAR-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-SCALAR-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-SCALAR: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
// CHECK-SCALAR: %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
// CHECK-SCALAR: %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32
// CHECK-SCALAR: store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
// CHECK-SCALAR: }
// CHECK-SCALAR: return
//
// CHECK-VEC16-LABEL: func @scale_d
// CHECK-VEC16-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
// CHECK-VEC16: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
// CHECK-VEC16: %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32>
// CHECK-VEC16: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC16: }
// CHECK-VEC16: return
//
// CHECK-VEC16-IDX32-LABEL: func @scale_d
// CHECK-VEC16-IDX32-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-IDX32-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16-IDX32-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-IDX32: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
// CHECK-VEC16-IDX32: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16-IDX32: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
// CHECK-VEC16-IDX32: %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32>
// CHECK-VEC16-IDX32: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC16-IDX32: }
// CHECK-VEC16-IDX32: return
//
// CHECK-VEC4-SVE: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
// CHECK-VEC4-SVE-LABEL: func @scale_d
// CHECK-VEC4-SVE-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC4-SVE-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-VEC4-SVE-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC4-SVE-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
// CHECK-VEC4-SVE-DAG: %[[vscale:.*]] = vector.vscale
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
// CHECK-VEC4-SVE: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
// CHECK-VEC4-SVE: %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[scalev:.*]] = vector.broadcast %{{.*}} : f32 to vector<[4]xf32>
// CHECK-VEC4-SVE: %[[scaled:.*]] = arith.mulf %[[val]], %[[scalev]] : vector<[4]xf32>
// CHECK-VEC4-SVE: vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32>
// CHECK-VEC4-SVE: }
// CHECK-VEC4-SVE: return
//
func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
%0 = linalg.generic #trait_scale_d
ins(%arga: tensor<1024xf32, #DenseVector>)
@ -55,27 +102,101 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
}
//
// CHECK-LABEL: func @mul_s
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
// CHECK: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
// CHECK: %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
// CHECK: %[[zi:.*]] = arith.extui %[[li]] : i32 to i64
// CHECK: %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index
// CHECK: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
// CHECK: %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
// CHECK: store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
// CHECK: }
// CHECK: return
// CHECK-SCALAR-LABEL: func @mul_s
// CHECK-SCALAR-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-SCALAR-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-SCALAR: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
// CHECK-SCALAR: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK-SCALAR: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK-SCALAR: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
// CHECK-SCALAR: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK-SCALAR: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK-SCALAR: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
// CHECK-SCALAR: %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
// CHECK-SCALAR: %[[zi:.*]] = arith.extui %[[li]] : i32 to i64
// CHECK-SCALAR: %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index
// CHECK-SCALAR: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
// CHECK-SCALAR: %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
// CHECK-SCALAR: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
// CHECK-SCALAR: store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
// CHECK-SCALAR: }
// CHECK-SCALAR: return
//
func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
// CHECK-VEC16: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
// CHECK-VEC16-LABEL: func @mul_s
// CHECK-VEC16-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-VEC16-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
// CHECK-VEC16: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK-VEC16: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK-VEC16: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
// CHECK-VEC16: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK-VEC16: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK-VEC16: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
// CHECK-VEC16: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
// CHECK-VEC16: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
// CHECK-VEC16: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
// CHECK-VEC16: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
// CHECK-VEC16: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC16: vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
// CHECK-VEC16: }
// CHECK-VEC16: return
//
// CHECK-VEC16-IDX32: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
// CHECK-VEC16-IDX32-LABEL: func @mul_s
// CHECK-VEC16-IDX32-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-IDX32-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-VEC16-IDX32-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16-IDX32: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
// CHECK-VEC16-IDX32: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK-VEC16-IDX32: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK-VEC16-IDX32: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
// CHECK-VEC16-IDX32: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK-VEC16-IDX32: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK-VEC16-IDX32: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
// CHECK-VEC16-IDX32: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
// CHECK-VEC16-IDX32: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
// CHECK-VEC16-IDX32: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
// CHECK-VEC16-IDX32: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16-IDX32: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16-IDX32: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC16-IDX32: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-VEC16-IDX32: }
// CHECK-VEC16-IDX32: return
//
// CHECK-VEC4-SVE: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
// CHECK-VEC4-SVE-LABEL: func @mul_s
// CHECK-VEC4-SVE-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC4-SVE-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-VEC4-SVE-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-VEC4-SVE-DAG: %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK-VEC4-SVE-DAG: %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
// CHECK-VEC4-SVE: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
// CHECK-VEC4-SVE: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK-VEC4-SVE: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK-VEC4-SVE: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
// CHECK-VEC4-SVE: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK-VEC4-SVE: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
// CHECK-VEC4-SVE: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[step]] {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
// CHECK-VEC4-SVE: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
// CHECK-VEC4-SVE: %[[lii64:.*]] = arith.extui %[[li]] : vector<[4]xi32> to vector<[4]xi64>
// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0f]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[v0f]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
// CHECK-VEC4-SVE: vector.scatter %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32>
// CHECK-VEC4-SVE: }
// CHECK-VEC4-SVE: return
//
func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
%argb: tensor<1024xf32>,
%argx: tensor<1024xf32>) -> tensor<1024xf32> {
%0 = linalg.generic #trait_mul_s
ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
outs(%argx: tensor<1024xf32>) {
@ -101,20 +222,79 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>
}
//
// CHECK-LABEL: func @reduction_d
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
// CHECK: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
// CHECK: %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
// CHECK: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32
// CHECK: scf.yield %[[a]] : f32
// CHECK: }
// CHECK: return
// CHECK-SCALAR-LABEL: func @reduction_d
// CHECK-SCALAR-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-SCALAR-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-SCALAR-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-SCALAR: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
// CHECK-SCALAR: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
// CHECK-SCALAR: %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
// CHECK-SCALAR: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
// CHECK-SCALAR: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32
// CHECK-SCALAR: scf.yield %[[a]] : f32
// CHECK-SCALAR: }
// CHECK-SCALAR: return
//
func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
// CHECK-VEC16-LABEL: func @reduction_d
// CHECK-VEC16-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
// CHECK-VEC16: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
// CHECK-VEC16: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC16: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC16: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32>
// CHECK-VEC16: scf.yield %[[a]] : vector<16xf32>
// CHECK-VEC16: }
// CHECK-VEC16: %{{.*}} = vector.reduction <add>, %[[red]] : vector<16xf32> into f32
// CHECK-VEC16: return
//
// CHECK-VEC16-IDX32-LABEL: func @reduction_d
// CHECK-VEC16-IDX32-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-IDX32-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16-IDX32-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-IDX32-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16-IDX32: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
// CHECK-VEC16-IDX32: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16-IDX32: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16-IDX32: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC16-IDX32: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC16-IDX32: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32>
// CHECK-VEC16-IDX32: scf.yield %[[a]] : vector<16xf32>
// CHECK-VEC16-IDX32: }
// CHECK-VEC16-IDX32: %{{.*}} = vector.reduction <add>, %[[red]] : vector<16xf32> into f32
// CHECK-VEC16-IDX32: return
//
// CHECK-VEC4-SVE: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
// CHECK-VEC4-SVE-LABEL: func @reduction_d
// CHECK-VEC4-SVE-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC4-SVE-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-VEC4-SVE-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC4-SVE-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
// CHECK-VEC4-SVE: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
// CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
// CHECK-VEC4-SVE: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
// CHECK-VEC4-SVE: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[lb:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
// CHECK-VEC4-SVE: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<[4]xf32>
// CHECK-VEC4-SVE: %[[sa:.*]] = arith.select %[[mask]], %[[a]], %[[red_in]] : vector<[4]xi1>, vector<[4]xf32>
// CHECK-VEC4-SVE: scf.yield %[[sa]] : vector<[4]xf32>
// CHECK-VEC4-SVE: }
// CHECK-VEC4-SVE: %{{.*}} = vector.reduction <add>, %[[red]] : vector<[4]xf32> into f32
// CHECK-VEC4-SVE: return
//
func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>,
%argb: tensor<1024xf32>,
%argx: tensor<f32>) -> tensor<f32> {
%0 = linalg.generic #trait_reduction_d
ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32>)
outs(%argx: tensor<f32>) {
@ -145,31 +325,117 @@ func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024
}
//
// CHECK-LABEL: func @mul_ds
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index
// CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
// CHECK: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
// CHECK: %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
// CHECK: %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64
// CHECK: %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index
// CHECK: %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
// CHECK: %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
// CHECK: store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
// CHECK: }
// CHECK: }
// CHECK: return
// CHECK-SCALAR-LABEL: func @mul_ds
// CHECK-SCALAR-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-SCALAR-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-SCALAR-DAG: %[[c512:.*]] = arith.constant 512 : index
// CHECK-SCALAR: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
// CHECK-SCALAR: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
// CHECK-SCALAR: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK-SCALAR: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK-SCALAR: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK-SCALAR: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
// CHECK-SCALAR: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK-SCALAR: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK-SCALAR: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
// CHECK-SCALAR: %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
// CHECK-SCALAR: %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64
// CHECK-SCALAR: %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index
// CHECK-SCALAR: %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
// CHECK-SCALAR: %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
// CHECK-SCALAR: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
// CHECK-SCALAR: store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
// CHECK-SCALAR: }
// CHECK-SCALAR: }
// CHECK-SCALAR: return
//
func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
// CHECK-VEC16: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
// CHECK-VEC16-LABEL: func @mul_ds
// CHECK-VEC16-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-VEC16-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16-DAG: %[[c512:.*]] = arith.constant 512 : index
// CHECK-VEC16: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
// CHECK-VEC16: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
// CHECK-VEC16: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK-VEC16: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK-VEC16: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK-VEC16: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
// CHECK-VEC16: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK-VEC16: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK-VEC16: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
// CHECK-VEC16: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
// CHECK-VEC16: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
// CHECK-VEC16: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
// CHECK-VEC16: %[[zj:.*]] = arith.extui %[[lj]] : vector<16xi32> to vector<16xi64>
// CHECK-VEC16: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC16: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
// CHECK-VEC16: }
// CHECK-VEC16: }
// CHECK-VEC16: return
//
// CHECK-VEC16-IDX32: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
// CHECK-VEC16-IDX32-LABEL: func @mul_ds
// CHECK-VEC16-IDX32-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-IDX32-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-VEC16-IDX32-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16-IDX32-DAG: %[[c512:.*]] = arith.constant 512 : index
// CHECK-VEC16-IDX32: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
// CHECK-VEC16-IDX32: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
// CHECK-VEC16-IDX32: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK-VEC16-IDX32: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK-VEC16-IDX32: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK-VEC16-IDX32: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
// CHECK-VEC16-IDX32: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK-VEC16-IDX32: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK-VEC16-IDX32: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
// CHECK-VEC16-IDX32: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
// CHECK-VEC16-IDX32: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
// CHECK-VEC16-IDX32: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
// CHECK-VEC16-IDX32: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16-IDX32: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16-IDX32: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC16-IDX32: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-VEC16-IDX32: }
// CHECK-VEC16-IDX32: }
// CHECK-VEC16-IDX32: return
//
// CHECK-VEC4-SVE: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
// CHECK-VEC4-SVE-LABEL: func @mul_ds
// CHECK-VEC4-SVE-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC4-SVE-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-VEC4-SVE-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-VEC4-SVE-DAG: %[[c512:.*]] = arith.constant 512 : index
// CHECK-VEC4-SVE-DAG: %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK-VEC4-SVE-DAG: %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
// CHECK-VEC4-SVE: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
// CHECK-VEC4-SVE: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
// CHECK-VEC4-SVE: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK-VEC4-SVE: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
// CHECK-VEC4-SVE: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK-VEC4-SVE: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
// CHECK-VEC4-SVE: %[[b:.*]] = arith.extui %[[r]] : i32 to i64
// CHECK-VEC4-SVE: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
// CHECK-VEC4-SVE: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[step]] {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
// CHECK-VEC4-SVE: %[[lji32:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
// CHECK-VEC4-SVE: %[[lj:.*]] = arith.extui %[[lji32]] : vector<[4]xi32> to vector<[4]xi64>
// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0f]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[v0f]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
// CHECK-VEC4-SVE: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32>
// CHECK-VEC4-SVE: }
// CHECK-VEC4-SVE: }
// CHECK-VEC4-SVE: return
//
func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>,
%argb: tensor<512x1024xf32>,
%argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
%0 = linalg.generic #trait_mul_ds
ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)
outs(%argx: tensor<512x1024xf32>) {
@ -194,26 +460,96 @@ func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x
}
//
// CHECK-LABEL: func @add_dense
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
// CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
// CHECK: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
// CHECK: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
// CHECK: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
// CHECK: %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
// CHECK: %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
// CHECK: %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
// CHECK: %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64
// CHECK: memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
// CHECK: }
// CHECK: }
// CHECK: return
// CHECK-SCALAR-LABEL: func @add_dense
// CHECK-SCALAR-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-SCALAR-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-SCALAR-DAG: %[[c32:.*]] = arith.constant 32 : index
// CHECK-SCALAR: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
// CHECK-SCALAR: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
// CHECK-SCALAR: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK-SCALAR: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
// CHECK-SCALAR: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
// CHECK-SCALAR: %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
// CHECK-SCALAR: %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
// CHECK-SCALAR: %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
// CHECK-SCALAR: %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64
// CHECK-SCALAR: memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
// CHECK-SCALAR: }
// CHECK-SCALAR: }
// CHECK-SCALAR: return
//
// CHECK-VEC16: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
// CHECK-VEC16-LABEL: func @add_dense
// CHECK-VEC16-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-VEC16-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16-DAG: %[[c32:.*]] = arith.constant 32 : index
// CHECK-VEC16: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
// CHECK-VEC16: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
// CHECK-VEC16: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK-VEC16: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
// CHECK-VEC16: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] {
// CHECK-VEC16: %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]]
// CHECK-VEC16: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
// CHECK-VEC16: %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xindex>
// CHECK-VEC16: %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64>
// CHECK-VEC16: %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xf64>
// CHECK-VEC16: %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<16xf64>
// CHECK-VEC16: vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
// CHECK-VEC16: }
// CHECK-VEC16: }
// CHECK-VEC16: return
//
// CHECK-VEC16-IDX32: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
// CHECK-VEC16-IDX32-LABEL: func @add_dense
// CHECK-VEC16-IDX32-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC16-IDX32-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-VEC16-IDX32-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-VEC16-IDX32-DAG: %[[c32:.*]] = arith.constant 32 : index
// CHECK-VEC16-IDX32: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
// CHECK-VEC16-IDX32: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
// CHECK-VEC16-IDX32: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK-VEC16-IDX32: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
// CHECK-VEC16-IDX32: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] {
// CHECK-VEC16-IDX32: %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]]
// CHECK-VEC16-IDX32: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
// CHECK-VEC16-IDX32: %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xindex>
// CHECK-VEC16-IDX32: %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64>
// CHECK-VEC16-IDX32: %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xf64>
// CHECK-VEC16-IDX32: %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<16xf64>
// CHECK-VEC16-IDX32: vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
// CHECK-VEC16-IDX32: }
// CHECK-VEC16-IDX32: }
// CHECK-VEC16-IDX32: return
//
// CHECK-VEC4-SVE: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
// CHECK-VEC4-SVE-LABEL: func @add_dense
// CHECK-VEC4-SVE-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-VEC4-SVE-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-VEC4-SVE-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-VEC4-SVE-DAG: %[[c32:.*]] = arith.constant 32 : index
// CHECK-VEC4-SVE-DAG: %[[v0idx:.*]] = arith.constant dense<0> : vector<[4]xindex>
// CHECK-VEC4-SVE-DAG: %[[v0f64:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf64>
// CHECK-VEC4-SVE: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
// CHECK-VEC4-SVE: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
// CHECK-VEC4-SVE: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
// CHECK-VEC4-SVE: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
// CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
// CHECK-VEC4-SVE: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[step]] {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
// CHECK-VEC4-SVE: %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0idx]] : memref<?xindex>
// CHECK-VEC4-SVE: %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[v0f64]] : memref<33x64xf64>
// CHECK-VEC4-SVE: %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0f64]] : memref<?xf64>
// CHECK-VEC4-SVE: %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<[4]xf64>
// CHECK-VEC4-SVE: vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
// CHECK-VEC4-SVE: }
// CHECK-VEC4-SVE: }
// CHECK-VEC4-SVE: return
//
func.func @add_dense(%arga: tensor<32x64xf64, #SparseMatrix>,
%argx: tensor<33x64xf64>) -> tensor<33x64xf64> {
%argx: tensor<33x64xf64>) -> tensor<33x64xf64> {
%0 = linalg.generic #trait_affine
ins(%arga: tensor<32x64xf64, #SparseMatrix>)
outs(%argx: tensor<33x64xf64>) {

View File

@ -2224,6 +2224,7 @@ cc_library(
":LinalgDialect",
":LinalgTransforms",
":LinalgUtils",
":MathDialect",
":MemRefDialect",
":Pass",
":SCFDialect",
@ -2235,6 +2236,7 @@ cc_library(
":Support",
":TensorDialect",
":Transforms",
":VectorDialect",
"//llvm:Support",
],
)