[mlir][sparse] LICM for SparseTensorReader::readCOO

This commit performs two related changes.  First we adjust `readCOOValue` to take the `IsPattern` bool as a template parameter rather than a function argument.  Second we factor `readCOOLoop` out from `readCOO`, and template it on `IsPattern` and `IsSymmetric`.  Together this moves all the assertions and header-dependent conditionals out of the main for-loop of `readCOO`.  The only remaining conditional is in the `IsSymmetric=true` variant: checking whether the element is on the diagonal or not (which cannot be lifted out of the loop).

Depends On D138363

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D138365
This commit is contained in:
wren romano 2022-12-01 18:29:22 -08:00
parent 0d58834700
commit 1dfb9a64f1
1 changed files with 71 additions and 31 deletions

View File

@ -42,32 +42,46 @@ struct is_complex final : public std::false_type {};
template <typename T>
struct is_complex<std::complex<T>> final : public std::true_type {};
/// Reads an element of a non-complex type for the current indices in
/// coordinate scheme.
template <typename V>
inline std::enable_if_t<!is_complex<V>::value, V>
readCOOValue(char **linePtr, bool is_pattern) {
/// Returns an element-value of non-complex type. If `IsPattern` is true,
/// then returns an arbitrary value. If `IsPattern` is false, then
/// reads the value from the current line buffer beginning at `linePtr`.
template <typename V, bool IsPattern>
inline std::enable_if_t<!is_complex<V>::value, V> readCOOValue(char **linePtr) {
// The external formats always store these numerical values with the type
// double, but we cast these values to the sparse tensor object type.
// For a pattern tensor, we arbitrarily pick the value 1 for all entries.
return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
if constexpr (IsPattern)
return 1.0;
return strtod(*linePtr, linePtr);
}
/// Reads an element of a complex type for the current indices in
/// coordinate scheme.
template <typename V>
inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
bool is_pattern) {
/// Returns an element-value of complex type. If `IsPattern` is true,
/// then returns an arbitrary value. If `IsPattern` is false, then reads
/// the value from the current line buffer beginning at `linePtr`.
template <typename V, bool IsPattern>
inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr) {
// Read two values to make a complex. The external formats always store
// numerical values with the type double, but we cast these values to the
// sparse tensor object type. For a pattern tensor, we arbitrarily pick the
// value 1 for all entries.
double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
if constexpr (IsPattern)
return V(1.0, 1.0);
double re = strtod(*linePtr, linePtr);
double im = strtod(*linePtr, linePtr);
// Avoiding brace-notation since that forbids narrowing to `float`.
return V(re, im);
}
/// Returns an element-value. If `is_pattern` is true, then returns an
/// arbitrary value. If `is_pattern` is false, then reads the value from
/// the current line buffer beginning at `linePtr`.
template <typename V>
inline V readCOOValue(char **linePtr, bool is_pattern) {
if (is_pattern)
return readCOOValue<V, true>(linePtr);
return readCOOValue<V, false>(linePtr);
}
} // namespace detail
//===----------------------------------------------------------------------===//
@ -249,6 +263,18 @@ private:
/// Precondition: `indices` is valid for `getRank()`.
char *readCOOIndices(uint64_t *indices);
/// The internal implementation of `readCOO`. We template over
/// `IsPattern` and `IsSymmetric` in order to perform LICM without
/// needing to duplicate the source code.
//
// TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
// since that's what `readCOO` creates. Once we update `readCOO` to
// functionalize the mapping, then this helper will just take that
// same function.
template <typename V, bool IsPattern, bool IsSymmetric>
void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
SparseTensorCOO<V> *lvlCOO);
/// Reads the MME header of a general sparse matrix of type real.
void readMMEHeader();
@ -282,36 +308,50 @@ SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
assert(lvlRank == dimRank && "Rank mismatch");
detail::PermutationRef d2l(dimRank, dim2lvl);
// Prepare a COO object with the number of nonzeros as initial capacity.
const uint64_t nnz = getNNZ();
auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, nnz);
// Read all nonzero elements.
auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNNZ());
// Do some manual LICM, to avoid assertions in the for-loop.
const bool IsPattern = isPattern();
const bool IsSymmetric = (isSymmetric() && getRank() == 2);
if (IsPattern && IsSymmetric)
readCOOLoop<V, true, true>(lvlRank, d2l, lvlCOO);
else if (IsPattern)
readCOOLoop<V, true, false>(lvlRank, d2l, lvlCOO);
else if (IsSymmetric)
readCOOLoop<V, false, true>(lvlRank, d2l, lvlCOO);
else
readCOOLoop<V, false, false>(lvlRank, d2l, lvlCOO);
// Close the file and return the COO.
closeFile();
return lvlCOO;
}
template <typename V, bool IsPattern, bool IsSymmetric>
void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
detail::PermutationRef dim2lvl,
SparseTensorCOO<V> *lvlCOO) {
const uint64_t dimRank = getRank();
std::vector<uint64_t> dimInd(dimRank);
std::vector<uint64_t> lvlInd(lvlRank);
// Do some manual LICM, to avoid assertions in the for-loop.
const bool addSymmetric = (isSymmetric() && dimRank == 2);
const bool isPattern_ = isPattern();
for (uint64_t k = 0; k < nnz; ++k) {
for (uint64_t nnz = getNNZ(), k = 0; k < nnz; ++k) {
// We inline `readCOOElement` here in order to avoid redundant
// assertions, since they're guaranteed by the call to `isValid()`
// and the construction of `dimInd` above.
char *linePtr = readCOOIndices(dimInd.data());
const V value = detail::readCOOValue<V>(&linePtr, isPattern_);
d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
const V value = detail::readCOOValue<V, IsPattern>(&linePtr);
dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
lvlCOO->add(lvlInd, value);
// We currently chose to deal with symmetric matrices by fully
// constructing them. In the future, we may want to make symmetry
// implicit for storage reasons.
if (addSymmetric && dimInd[0] != dimInd[1]) {
// Must recompute `lvlInd`, since arbitrary mappings don't preserve swap.
std::swap(dimInd[0], dimInd[1]);
d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
lvlCOO->add(lvlInd, value);
}
if constexpr (IsSymmetric)
if (dimInd[0] != dimInd[1]) {
// Must recompute `lvlInd`, since arbitrary maps don't preserve swap.
std::swap(dimInd[0], dimInd[1]);
dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
lvlCOO->add(lvlInd, value);
}
}
// Close the file and return the COO.
closeFile();
return lvlCOO;
}
/// Writes the sparse tensor to `filename` in extended FROSTT format.