[mlir][sparse] code refactoring, move <tid, loop id> -> dim map to Merger.

To address unresolved comments in D136185

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136780
This commit is contained in:
Peiming Liu 2022-10-26 19:07:25 +00:00
parent 0cb65b0a58
commit 32c512e49f
3 changed files with 104 additions and 111 deletions

View File

@ -156,7 +156,8 @@ public:
Merger(unsigned t, unsigned l)
: outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
hasSparseOut(false),
dimTypes(t + 1, std::vector<DimLevelType>(l, DimLevelType::Undef)) {}
dimTypes(t + 1, std::vector<DimLevelType>(l, DimLevelType::Undef)),
loopIdxToDim(t + 1, std::vector<Optional<unsigned>>(l, llvm::None)) {}
/// Adds a tensor expression. Returns its index.
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
@ -246,7 +247,7 @@ public:
/// Returns true if any set bit corresponds to sparse dimension level type.
bool hasAnySparse(const BitVector &bits) const;
/// Gets the dimension level type of the `i`th loop of the `t`th tensor.
/// Gets the dimension level type of the `t`th tensor on `i`th loop.
DimLevelType getDimLevelType(unsigned t, unsigned i) const {
assert(t < numTensors && i < numLoops);
return dimTypes[t][i];
@ -257,10 +258,35 @@ public:
return getDimLevelType(tensor(b), index(b));
}
/// Sets the dimension level type of the `i`th loop of the `t`th tensor.
void setDimLevelType(unsigned t, unsigned i, DimLevelType d) {
assert(isValidDLT(d));
dimTypes[t][i] = d;
/// Gets the dimension number of the the `t`th tensor on `i`th loop.
Optional<unsigned> getDimNum(unsigned t, unsigned i) const {
assert(t < numTensors && i < numLoops);
return loopIdxToDim[t][i];
}
/// Gets the dimension number of `b`.
Optional<unsigned> getDimNum(unsigned b) const {
return getDimNum(tensor(b), index(b));
}
/// Sets the dimension and dimension level type of the `t`th tensor on `i`th
/// loop.
void setDimAndDimLevelType(unsigned t, unsigned i, unsigned dim,
DimLevelType dlt) {
assert(isValidDLT(dlt));
dimTypes[t][i] = dlt;
loopIdxToDim[t][i] = dim;
}
// Iterates the bits of a lattice, for each set bit, converts it into the
// corresponding tensor dimension and invokes the callback.
void foreachTidDimPairInBits(
const BitVector &bits,
function_ref<void(unsigned b, unsigned tid, Optional<unsigned> dim,
DimLevelType dlt)>
cb) {
for (unsigned b : bits.set_bits())
cb(b, tensor(b), getDimNum(b), getDimLevelType(b));
}
// Has sparse output tensor setter.
@ -310,7 +336,11 @@ private:
const unsigned numTensors;
const unsigned numLoops;
bool hasSparseOut;
// Map that converts pair<tensor id, loop id> to the corresponding dimension
// level type.
std::vector<std::vector<DimLevelType>> dimTypes;
// Map that converts pair<tensor id, loop id> to the corresponding dimension.
std::vector<std::vector<Optional<unsigned>>> loopIdxToDim;
llvm::SmallVector<TensorExp, 32> tensorExps;
llvm::SmallVector<LatPoint, 16> latPoints;
llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;

View File

@ -40,8 +40,6 @@ using namespace mlir::sparse_tensor;
namespace {
constexpr unsigned INVALID_ID = std::numeric_limits<unsigned>::max();
// Iteration graph sorting.
enum SortMask {
kSparseOnly = 0x0,
@ -83,14 +81,6 @@ struct CodeGen {
// Topsort (reference should remain in scope).
std::vector<unsigned> &topSort;
// From tensor id + loop id => dim id.
// TODO: This map should probably be maintained by Merger (it can be set up
// together with dimLvlType Map).
std::vector<std::vector<unsigned>> loopIdxToDim;
// Initialize the above two mapping.
void buildLoopIdxToDimMap(linalg::GenericOp op);
Value getLoopIdxValue(size_t loopIdx) const {
for (unsigned lv = 0; lv < topSort.size(); lv++)
if (topSort[lv] == loopIdx)
@ -100,30 +90,6 @@ struct CodeGen {
}
};
void CodeGen::buildLoopIdxToDimMap(linalg::GenericOp op) {
size_t numLoops = op.getNumLoops();
size_t numTensors = op.getNumOperands();
loopIdxToDim.assign(numTensors, std::vector<unsigned>(numLoops, INVALID_ID));
for (OpOperand &t : op->getOpOperands()) {
auto map = op.getMatchingIndexingMap(&t);
auto enc = getSparseTensorEncoding(t.get().getType());
// Scan all dimensions of current tensor.
unsigned tid = t.getOperandNumber();
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
auto a = map.getResult(toOrigDim(enc, d)).dyn_cast<AffineDimExpr>();
if (a) {
unsigned loopId = a.getPosition();
// Fills the mapping.
loopIdxToDim[tid][loopId] = d;
}
// Else a compound affine, do nothing. (at least we are good for
// now, as we only support compound affine expr on non-annoated dense
// tensors).
}
}
}
} // namespace
//===----------------------------------------------------------------------===//
@ -151,8 +117,9 @@ static AffineMap permute(MLIRContext *context, AffineMap m,
/// Helper method to inspect affine expressions. Rejects cases where the
/// same index is used more than once. Also rejects compound affine
/// expressions in sparse dimensions.
static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
DimLevelType dim, bool setLvlFormat = true) {
static bool findAffine(Merger &merger, unsigned tensor, unsigned dim,
AffineExpr a, DimLevelType dlt,
bool setLvlFormat = true) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
unsigned idx = a.cast<AffineDimExpr>().getPosition();
@ -160,21 +127,21 @@ static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
return false; // used more than once
if (setLvlFormat)
merger.setDimLevelType(tensor, idx, dim);
merger.setDimAndDimLevelType(tensor, idx, dim, dlt);
return true;
}
case AffineExprKind::Add:
case AffineExprKind::Mul: {
if (!isDenseDLT(dim))
if (!isDenseDLT(dlt))
return false; // compound only in dense dim
auto binOp = a.cast<AffineBinaryOpExpr>();
// We do not set dim level format for affine expresssion like d0 + d1 on
// both loop index at d0 and d1,
return findAffine(merger, tensor, binOp.getLHS(), dim, false) &&
findAffine(merger, tensor, binOp.getRHS(), dim, false);
return findAffine(merger, tensor, dim, binOp.getLHS(), dlt, false) &&
findAffine(merger, tensor, dim, binOp.getRHS(), dlt, false);
}
case AffineExprKind::Constant:
return isDenseDLT(dim); // const only in dense dim
return isDenseDLT(dlt); // const only in dense dim
default:
return false;
}
@ -196,7 +163,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
unsigned tensor = t.getOperandNumber();
AffineExpr a = map.getResult(toOrigDim(enc, d));
if (!findAffine(merger, tensor, a, getDimLevelType(enc, d)))
if (!findAffine(merger, tensor, d, a, getDimLevelType(enc, d)))
return false; // inadmissible affine expression
}
}
@ -1024,8 +991,7 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
Value clause;
if (isCompressedDLT(merger.getDimLevelType(b)) ||
isSingletonDLT(merger.getDimLevelType(b))) {
auto dim = codegen.loopIdxToDim[tensor][idx];
assert(dim != INVALID_ID);
auto dim = merger.getDimNum(tensor, idx).value();
Value op1 = codegen.loopEmitter.getCoord()[tensor][dim];
Value op2 = codegen.getLoopIdxValue(idx);
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
@ -1082,23 +1048,22 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
unsigned l0 = merger.set(lts)[0];
bool needsUniv = false;
SmallVector<size_t, 4> ts;
SmallVector<size_t, 4> ds;
for (auto b : merger.lat(l0).bits.set_bits()) {
if (isDenseDLT(merger.getDimLevelType(b)) ||
isUndefDLT(merger.getDimLevelType(b))) {
needsUniv = true;
} else {
unsigned tensor = merger.tensor(b);
assert(idx == merger.index(b));
size_t dim = codegen.loopIdxToDim[tensor][idx];
assert(dim != INVALID_ID);
ts.push_back(tensor);
ds.push_back(dim);
}
}
SmallVector<size_t> tids;
SmallVector<size_t> dims;
merger.foreachTidDimPairInBits(
merger.lat(l0).bits,
[&](unsigned b, unsigned tid, Optional<unsigned> dim, DimLevelType dlt) {
assert(merger.index(b) == idx);
if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
needsUniv = true;
} else {
// sparse/singleton dim levels.
tids.push_back(tid);
dims.push_back(dim.value());
}
});
codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), ts, ds);
codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), tids, dims);
// Maintain the universal index only if it is actually
// consumed by a subsequent lattice point.
@ -1119,17 +1084,15 @@ static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen,
SmallVectorImpl<size_t> &condDims,
SmallVectorImpl<size_t> &extraTids,
SmallVectorImpl<size_t> &extraDims) {
const BitVector &simple = merger.lat(li).simple;
const BitVector &all = merger.lat(li).bits;
assert(simple.size() == all.size());
// First converts bits to array + dim pair
for (unsigned b = 0, e = simple.size(); b < e; b++) {
size_t tid = merger.tensor(b);
const BitVector &simple = merger.lat(li).simple;
// Converts bits to array + dim pair
merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
Optional<unsigned> dim,
DimLevelType dlt) {
if (simple.test(b)) {
// the simplified condition must be a subset of the original condition.
assert(all.test(b));
assert(merger.index(b) == idx);
if (isUndefDLT(merger.getDimLevelType(b))) {
if (isUndefDLT(dlt)) {
// An undefined dlt in the lattices, we probably mean to iterate based
// on the dim of output tensor.
// E.g., this could be a synthetic tensor (for invariants and sparse
@ -1137,26 +1100,28 @@ static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen,
// out[i][j] = invariant; or a broadcast
// out[i][j] = in[i] (j is undef for input)
tid = merger.getOutTensorID();
dim = merger.getDimNum(tid, idx);
// Skips invalid dim (e.g., when this is a zero ranked tensor).
if (!dim)
return;
}
auto dim = codegen.loopIdxToDim[tid][idx];
if (dim != INVALID_ID) {
// dim could be invalid if this is a zero ranked tensor
condTids.push_back(tid);
condDims.push_back(dim);
}
} else if ((all.test(b) || merger.isOutTensor(b, idx)) &&
isDenseDLT(merger.getDimLevelType(b))) {
assert(merger.index(b) == idx);
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized codegen.
// Only dense dimensions should be optimized from conditions.
assert(isDenseDLT(merger.getDimLevelType(b)));
auto dim = codegen.loopIdxToDim[tid][idx];
assert(dim != INVALID_ID);
condTids.push_back(tid);
condDims.push_back(dim.value());
} else if (isDenseDLT(dlt)) {
// TODO: get rid of extraTids and extraDims.
extraTids.push_back(tid);
extraDims.push_back(dim);
extraDims.push_back(dim.value());
}
});
if (isDenseDLT(merger.getDimLevelType(merger.getOutTensorID(), idx))) {
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized codegen.
// Only dense dimensions should be optimized from conditions.
auto dim = merger.getDimNum(merger.getOutTensorID(), idx).value();
extraTids.push_back(merger.getOutTensorID());
extraDims.push_back(dim);
}
}
@ -1370,8 +1335,6 @@ public:
// Recursively generates code if admissible.
CodeGen codegen(options, tensors, numTensors, numLoops, sparseOut,
outerParNest, topSort);
// TODO: maybe merger should be responsible of maintaining the map.
codegen.buildLoopIdxToDimMap(op);
genBuffers(merger, codegen, rewriter, op);
genStmt(merger, codegen, rewriter, op, exp, 0);
genResult(merger, codegen, rewriter, op);

View File

@ -313,15 +313,15 @@ protected:
// Tensor 0: sparse input vector.
merger.addExp(Kind::kTensor, t0, -1u);
merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
// Tensor 1: sparse input vector.
merger.addExp(Kind::kTensor, t1, -1u);
merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed);
// Tensor 2: dense output vector.
merger.addExp(Kind::kTensor, t2, -1u);
merger.setDimLevelType(t2, l0, DimLevelType::Dense);
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense);
}
};
@ -338,19 +338,19 @@ protected:
// Tensor 0: sparse input vector.
merger.addExp(Kind::kTensor, t0, -1u);
merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
// Tensor 1: sparse input vector.
merger.addExp(Kind::kTensor, t1, -1u);
merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed);
// Tensor 2: sparse input vector
merger.addExp(Kind::kTensor, t2, -1u);
merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed);
// Tensor 3: dense output vector
merger.addExp(Kind::kTensor, t3, -1u);
merger.setDimLevelType(t3, l0, DimLevelType::Dense);
merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense);
}
};
@ -371,15 +371,15 @@ protected:
// Tensor 0: sparse input vector.
merger.addExp(Kind::kTensor, t0, -1u);
merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
// Tensor 1: dense input vector.
merger.addExp(Kind::kTensor, t1, -1u);
merger.setDimLevelType(t1, l0, DimLevelType::Dense);
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense);
// Tensor 2: dense output vector.
merger.addExp(Kind::kTensor, t2, -1u);
merger.setDimLevelType(t2, l0, DimLevelType::Dense);
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense);
}
};
@ -400,19 +400,19 @@ protected:
// Tensor 0: undef input vector.
merger.addExp(Kind::kTensor, t0, -1u);
merger.setDimLevelType(t0, l0, DimLevelType::Undef);
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef);
// Tensor 1: dense input vector.
merger.addExp(Kind::kTensor, t1, -1u);
merger.setDimLevelType(t1, l0, DimLevelType::Dense);
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense);
// Tensor 2: undef input vector.
merger.addExp(Kind::kTensor, t2, -1u);
merger.setDimLevelType(t2, l0, DimLevelType::Undef);
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Undef);
// Tensor 3: dense output vector.
merger.addExp(Kind::kTensor, t3, -1u);
merger.setDimLevelType(t3, l0, DimLevelType::Dense);
merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense);
}
};
@ -436,15 +436,15 @@ protected:
// Tensor 0: undef input vector.
merger.addExp(Kind::kTensor, t0, -1u);
merger.setDimLevelType(t0, l0, DimLevelType::Undef);
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef);
// Tensor 1: undef input vector.
merger.addExp(Kind::kTensor, t1, -1u);
merger.setDimLevelType(t1, l0, DimLevelType::Undef);
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Undef);
// Tensor 2: sparse output vector.
merger.addExp(Kind::kTensor, t2, -1u);
merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed);
}
};