[mlir][sparse] code refactoring, move <tid, loop id> -> dim map to Merger.
To address unresolved comments in D136185 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D136780
This commit is contained in:
parent
0cb65b0a58
commit
32c512e49f
|
@ -156,7 +156,8 @@ public:
|
|||
Merger(unsigned t, unsigned l)
|
||||
: outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
|
||||
hasSparseOut(false),
|
||||
dimTypes(t + 1, std::vector<DimLevelType>(l, DimLevelType::Undef)) {}
|
||||
dimTypes(t + 1, std::vector<DimLevelType>(l, DimLevelType::Undef)),
|
||||
loopIdxToDim(t + 1, std::vector<Optional<unsigned>>(l, llvm::None)) {}
|
||||
|
||||
/// Adds a tensor expression. Returns its index.
|
||||
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
|
||||
|
@ -246,7 +247,7 @@ public:
|
|||
/// Returns true if any set bit corresponds to sparse dimension level type.
|
||||
bool hasAnySparse(const BitVector &bits) const;
|
||||
|
||||
/// Gets the dimension level type of the `i`th loop of the `t`th tensor.
|
||||
/// Gets the dimension level type of the `t`th tensor on `i`th loop.
|
||||
DimLevelType getDimLevelType(unsigned t, unsigned i) const {
|
||||
assert(t < numTensors && i < numLoops);
|
||||
return dimTypes[t][i];
|
||||
|
@ -257,10 +258,35 @@ public:
|
|||
return getDimLevelType(tensor(b), index(b));
|
||||
}
|
||||
|
||||
/// Sets the dimension level type of the `i`th loop of the `t`th tensor.
|
||||
void setDimLevelType(unsigned t, unsigned i, DimLevelType d) {
|
||||
assert(isValidDLT(d));
|
||||
dimTypes[t][i] = d;
|
||||
/// Gets the dimension number of the the `t`th tensor on `i`th loop.
|
||||
Optional<unsigned> getDimNum(unsigned t, unsigned i) const {
|
||||
assert(t < numTensors && i < numLoops);
|
||||
return loopIdxToDim[t][i];
|
||||
}
|
||||
|
||||
/// Gets the dimension number of `b`.
|
||||
Optional<unsigned> getDimNum(unsigned b) const {
|
||||
return getDimNum(tensor(b), index(b));
|
||||
}
|
||||
|
||||
/// Sets the dimension and dimension level type of the `t`th tensor on `i`th
|
||||
/// loop.
|
||||
void setDimAndDimLevelType(unsigned t, unsigned i, unsigned dim,
|
||||
DimLevelType dlt) {
|
||||
assert(isValidDLT(dlt));
|
||||
dimTypes[t][i] = dlt;
|
||||
loopIdxToDim[t][i] = dim;
|
||||
}
|
||||
|
||||
// Iterates the bits of a lattice, for each set bit, converts it into the
|
||||
// corresponding tensor dimension and invokes the callback.
|
||||
void foreachTidDimPairInBits(
|
||||
const BitVector &bits,
|
||||
function_ref<void(unsigned b, unsigned tid, Optional<unsigned> dim,
|
||||
DimLevelType dlt)>
|
||||
cb) {
|
||||
for (unsigned b : bits.set_bits())
|
||||
cb(b, tensor(b), getDimNum(b), getDimLevelType(b));
|
||||
}
|
||||
|
||||
// Has sparse output tensor setter.
|
||||
|
@ -310,7 +336,11 @@ private:
|
|||
const unsigned numTensors;
|
||||
const unsigned numLoops;
|
||||
bool hasSparseOut;
|
||||
// Map that converts pair<tensor id, loop id> to the corresponding dimension
|
||||
// level type.
|
||||
std::vector<std::vector<DimLevelType>> dimTypes;
|
||||
// Map that converts pair<tensor id, loop id> to the corresponding dimension.
|
||||
std::vector<std::vector<Optional<unsigned>>> loopIdxToDim;
|
||||
llvm::SmallVector<TensorExp, 32> tensorExps;
|
||||
llvm::SmallVector<LatPoint, 16> latPoints;
|
||||
llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
|
||||
|
|
|
@ -40,8 +40,6 @@ using namespace mlir::sparse_tensor;
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr unsigned INVALID_ID = std::numeric_limits<unsigned>::max();
|
||||
|
||||
// Iteration graph sorting.
|
||||
enum SortMask {
|
||||
kSparseOnly = 0x0,
|
||||
|
@ -83,14 +81,6 @@ struct CodeGen {
|
|||
// Topsort (reference should remain in scope).
|
||||
std::vector<unsigned> &topSort;
|
||||
|
||||
// From tensor id + loop id => dim id.
|
||||
// TODO: This map should probably be maintained by Merger (it can be set up
|
||||
// together with dimLvlType Map).
|
||||
std::vector<std::vector<unsigned>> loopIdxToDim;
|
||||
|
||||
// Initialize the above two mapping.
|
||||
void buildLoopIdxToDimMap(linalg::GenericOp op);
|
||||
|
||||
Value getLoopIdxValue(size_t loopIdx) const {
|
||||
for (unsigned lv = 0; lv < topSort.size(); lv++)
|
||||
if (topSort[lv] == loopIdx)
|
||||
|
@ -100,30 +90,6 @@ struct CodeGen {
|
|||
}
|
||||
};
|
||||
|
||||
void CodeGen::buildLoopIdxToDimMap(linalg::GenericOp op) {
|
||||
size_t numLoops = op.getNumLoops();
|
||||
size_t numTensors = op.getNumOperands();
|
||||
loopIdxToDim.assign(numTensors, std::vector<unsigned>(numLoops, INVALID_ID));
|
||||
|
||||
for (OpOperand &t : op->getOpOperands()) {
|
||||
auto map = op.getMatchingIndexingMap(&t);
|
||||
auto enc = getSparseTensorEncoding(t.get().getType());
|
||||
// Scan all dimensions of current tensor.
|
||||
unsigned tid = t.getOperandNumber();
|
||||
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
|
||||
auto a = map.getResult(toOrigDim(enc, d)).dyn_cast<AffineDimExpr>();
|
||||
if (a) {
|
||||
unsigned loopId = a.getPosition();
|
||||
// Fills the mapping.
|
||||
loopIdxToDim[tid][loopId] = d;
|
||||
}
|
||||
// Else a compound affine, do nothing. (at least we are good for
|
||||
// now, as we only support compound affine expr on non-annoated dense
|
||||
// tensors).
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -151,8 +117,9 @@ static AffineMap permute(MLIRContext *context, AffineMap m,
|
|||
/// Helper method to inspect affine expressions. Rejects cases where the
|
||||
/// same index is used more than once. Also rejects compound affine
|
||||
/// expressions in sparse dimensions.
|
||||
static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
|
||||
DimLevelType dim, bool setLvlFormat = true) {
|
||||
static bool findAffine(Merger &merger, unsigned tensor, unsigned dim,
|
||||
AffineExpr a, DimLevelType dlt,
|
||||
bool setLvlFormat = true) {
|
||||
switch (a.getKind()) {
|
||||
case AffineExprKind::DimId: {
|
||||
unsigned idx = a.cast<AffineDimExpr>().getPosition();
|
||||
|
@ -160,21 +127,21 @@ static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
|
|||
return false; // used more than once
|
||||
|
||||
if (setLvlFormat)
|
||||
merger.setDimLevelType(tensor, idx, dim);
|
||||
merger.setDimAndDimLevelType(tensor, idx, dim, dlt);
|
||||
return true;
|
||||
}
|
||||
case AffineExprKind::Add:
|
||||
case AffineExprKind::Mul: {
|
||||
if (!isDenseDLT(dim))
|
||||
if (!isDenseDLT(dlt))
|
||||
return false; // compound only in dense dim
|
||||
auto binOp = a.cast<AffineBinaryOpExpr>();
|
||||
// We do not set dim level format for affine expresssion like d0 + d1 on
|
||||
// both loop index at d0 and d1,
|
||||
return findAffine(merger, tensor, binOp.getLHS(), dim, false) &&
|
||||
findAffine(merger, tensor, binOp.getRHS(), dim, false);
|
||||
return findAffine(merger, tensor, dim, binOp.getLHS(), dlt, false) &&
|
||||
findAffine(merger, tensor, dim, binOp.getRHS(), dlt, false);
|
||||
}
|
||||
case AffineExprKind::Constant:
|
||||
return isDenseDLT(dim); // const only in dense dim
|
||||
return isDenseDLT(dlt); // const only in dense dim
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
@ -196,7 +163,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
|
|||
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
|
||||
unsigned tensor = t.getOperandNumber();
|
||||
AffineExpr a = map.getResult(toOrigDim(enc, d));
|
||||
if (!findAffine(merger, tensor, a, getDimLevelType(enc, d)))
|
||||
if (!findAffine(merger, tensor, d, a, getDimLevelType(enc, d)))
|
||||
return false; // inadmissible affine expression
|
||||
}
|
||||
}
|
||||
|
@ -1024,8 +991,7 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
|
|||
Value clause;
|
||||
if (isCompressedDLT(merger.getDimLevelType(b)) ||
|
||||
isSingletonDLT(merger.getDimLevelType(b))) {
|
||||
auto dim = codegen.loopIdxToDim[tensor][idx];
|
||||
assert(dim != INVALID_ID);
|
||||
auto dim = merger.getDimNum(tensor, idx).value();
|
||||
Value op1 = codegen.loopEmitter.getCoord()[tensor][dim];
|
||||
Value op2 = codegen.getLoopIdxValue(idx);
|
||||
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
|
||||
|
@ -1082,23 +1048,22 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
|
|||
unsigned l0 = merger.set(lts)[0];
|
||||
bool needsUniv = false;
|
||||
|
||||
SmallVector<size_t, 4> ts;
|
||||
SmallVector<size_t, 4> ds;
|
||||
for (auto b : merger.lat(l0).bits.set_bits()) {
|
||||
if (isDenseDLT(merger.getDimLevelType(b)) ||
|
||||
isUndefDLT(merger.getDimLevelType(b))) {
|
||||
needsUniv = true;
|
||||
} else {
|
||||
unsigned tensor = merger.tensor(b);
|
||||
assert(idx == merger.index(b));
|
||||
size_t dim = codegen.loopIdxToDim[tensor][idx];
|
||||
assert(dim != INVALID_ID);
|
||||
ts.push_back(tensor);
|
||||
ds.push_back(dim);
|
||||
}
|
||||
}
|
||||
SmallVector<size_t> tids;
|
||||
SmallVector<size_t> dims;
|
||||
merger.foreachTidDimPairInBits(
|
||||
merger.lat(l0).bits,
|
||||
[&](unsigned b, unsigned tid, Optional<unsigned> dim, DimLevelType dlt) {
|
||||
assert(merger.index(b) == idx);
|
||||
if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
|
||||
needsUniv = true;
|
||||
} else {
|
||||
// sparse/singleton dim levels.
|
||||
tids.push_back(tid);
|
||||
dims.push_back(dim.value());
|
||||
}
|
||||
});
|
||||
|
||||
codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), ts, ds);
|
||||
codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), tids, dims);
|
||||
|
||||
// Maintain the universal index only if it is actually
|
||||
// consumed by a subsequent lattice point.
|
||||
|
@ -1119,17 +1084,15 @@ static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen,
|
|||
SmallVectorImpl<size_t> &condDims,
|
||||
SmallVectorImpl<size_t> &extraTids,
|
||||
SmallVectorImpl<size_t> &extraDims) {
|
||||
const BitVector &simple = merger.lat(li).simple;
|
||||
const BitVector &all = merger.lat(li).bits;
|
||||
assert(simple.size() == all.size());
|
||||
// First converts bits to array + dim pair
|
||||
for (unsigned b = 0, e = simple.size(); b < e; b++) {
|
||||
size_t tid = merger.tensor(b);
|
||||
const BitVector &simple = merger.lat(li).simple;
|
||||
|
||||
// Converts bits to array + dim pair
|
||||
merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
|
||||
Optional<unsigned> dim,
|
||||
DimLevelType dlt) {
|
||||
if (simple.test(b)) {
|
||||
// the simplified condition must be a subset of the original condition.
|
||||
assert(all.test(b));
|
||||
assert(merger.index(b) == idx);
|
||||
if (isUndefDLT(merger.getDimLevelType(b))) {
|
||||
if (isUndefDLT(dlt)) {
|
||||
// An undefined dlt in the lattices, we probably mean to iterate based
|
||||
// on the dim of output tensor.
|
||||
// E.g., this could be a synthetic tensor (for invariants and sparse
|
||||
|
@ -1137,26 +1100,28 @@ static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen,
|
|||
// out[i][j] = invariant; or a broadcast
|
||||
// out[i][j] = in[i] (j is undef for input)
|
||||
tid = merger.getOutTensorID();
|
||||
dim = merger.getDimNum(tid, idx);
|
||||
// Skips invalid dim (e.g., when this is a zero ranked tensor).
|
||||
if (!dim)
|
||||
return;
|
||||
}
|
||||
auto dim = codegen.loopIdxToDim[tid][idx];
|
||||
if (dim != INVALID_ID) {
|
||||
// dim could be invalid if this is a zero ranked tensor
|
||||
condTids.push_back(tid);
|
||||
condDims.push_back(dim);
|
||||
}
|
||||
} else if ((all.test(b) || merger.isOutTensor(b, idx)) &&
|
||||
isDenseDLT(merger.getDimLevelType(b))) {
|
||||
assert(merger.index(b) == idx);
|
||||
// Note that we generate dense indices of the output tensor
|
||||
// unconditionally, since they may not appear in the lattice, but may be
|
||||
// needed for linearized codegen.
|
||||
// Only dense dimensions should be optimized from conditions.
|
||||
assert(isDenseDLT(merger.getDimLevelType(b)));
|
||||
auto dim = codegen.loopIdxToDim[tid][idx];
|
||||
assert(dim != INVALID_ID);
|
||||
condTids.push_back(tid);
|
||||
condDims.push_back(dim.value());
|
||||
} else if (isDenseDLT(dlt)) {
|
||||
// TODO: get rid of extraTids and extraDims.
|
||||
extraTids.push_back(tid);
|
||||
extraDims.push_back(dim);
|
||||
extraDims.push_back(dim.value());
|
||||
}
|
||||
});
|
||||
|
||||
if (isDenseDLT(merger.getDimLevelType(merger.getOutTensorID(), idx))) {
|
||||
// Note that we generate dense indices of the output tensor
|
||||
// unconditionally, since they may not appear in the lattice, but may be
|
||||
// needed for linearized codegen.
|
||||
// Only dense dimensions should be optimized from conditions.
|
||||
auto dim = merger.getDimNum(merger.getOutTensorID(), idx).value();
|
||||
extraTids.push_back(merger.getOutTensorID());
|
||||
extraDims.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1370,8 +1335,6 @@ public:
|
|||
// Recursively generates code if admissible.
|
||||
CodeGen codegen(options, tensors, numTensors, numLoops, sparseOut,
|
||||
outerParNest, topSort);
|
||||
// TODO: maybe merger should be responsible of maintaining the map.
|
||||
codegen.buildLoopIdxToDimMap(op);
|
||||
genBuffers(merger, codegen, rewriter, op);
|
||||
genStmt(merger, codegen, rewriter, op, exp, 0);
|
||||
genResult(merger, codegen, rewriter, op);
|
||||
|
|
|
@ -313,15 +313,15 @@ protected:
|
|||
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t0, -1u);
|
||||
merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
|
||||
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
|
||||
|
||||
// Tensor 1: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t1, -1u);
|
||||
merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
|
||||
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed);
|
||||
|
||||
// Tensor 2: dense output vector.
|
||||
merger.addExp(Kind::kTensor, t2, -1u);
|
||||
merger.setDimLevelType(t2, l0, DimLevelType::Dense);
|
||||
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -338,19 +338,19 @@ protected:
|
|||
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t0, -1u);
|
||||
merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
|
||||
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
|
||||
|
||||
// Tensor 1: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t1, -1u);
|
||||
merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
|
||||
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed);
|
||||
|
||||
// Tensor 2: sparse input vector
|
||||
merger.addExp(Kind::kTensor, t2, -1u);
|
||||
merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
|
||||
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed);
|
||||
|
||||
// Tensor 3: dense output vector
|
||||
merger.addExp(Kind::kTensor, t3, -1u);
|
||||
merger.setDimLevelType(t3, l0, DimLevelType::Dense);
|
||||
merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -371,15 +371,15 @@ protected:
|
|||
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.addExp(Kind::kTensor, t0, -1u);
|
||||
merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
|
||||
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
|
||||
|
||||
// Tensor 1: dense input vector.
|
||||
merger.addExp(Kind::kTensor, t1, -1u);
|
||||
merger.setDimLevelType(t1, l0, DimLevelType::Dense);
|
||||
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense);
|
||||
|
||||
// Tensor 2: dense output vector.
|
||||
merger.addExp(Kind::kTensor, t2, -1u);
|
||||
merger.setDimLevelType(t2, l0, DimLevelType::Dense);
|
||||
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -400,19 +400,19 @@ protected:
|
|||
|
||||
// Tensor 0: undef input vector.
|
||||
merger.addExp(Kind::kTensor, t0, -1u);
|
||||
merger.setDimLevelType(t0, l0, DimLevelType::Undef);
|
||||
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef);
|
||||
|
||||
// Tensor 1: dense input vector.
|
||||
merger.addExp(Kind::kTensor, t1, -1u);
|
||||
merger.setDimLevelType(t1, l0, DimLevelType::Dense);
|
||||
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense);
|
||||
|
||||
// Tensor 2: undef input vector.
|
||||
merger.addExp(Kind::kTensor, t2, -1u);
|
||||
merger.setDimLevelType(t2, l0, DimLevelType::Undef);
|
||||
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Undef);
|
||||
|
||||
// Tensor 3: dense output vector.
|
||||
merger.addExp(Kind::kTensor, t3, -1u);
|
||||
merger.setDimLevelType(t3, l0, DimLevelType::Dense);
|
||||
merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -436,15 +436,15 @@ protected:
|
|||
|
||||
// Tensor 0: undef input vector.
|
||||
merger.addExp(Kind::kTensor, t0, -1u);
|
||||
merger.setDimLevelType(t0, l0, DimLevelType::Undef);
|
||||
merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef);
|
||||
|
||||
// Tensor 1: undef input vector.
|
||||
merger.addExp(Kind::kTensor, t1, -1u);
|
||||
merger.setDimLevelType(t1, l0, DimLevelType::Undef);
|
||||
merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Undef);
|
||||
|
||||
// Tensor 2: sparse output vector.
|
||||
merger.addExp(Kind::kTensor, t2, -1u);
|
||||
merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
|
||||
merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue