[mlir][sparse] LICM for SparseTensorReader::readCOO
This commit performs two related changes. First we adjust `readCOOValue` to take the `IsPattern` bool as a template parameter rather than a function argument. Second we factor `readCOOLoop` out from `readCOO`, and template it on `IsPattern` and `IsSymmetric`. Together this moves all the assertions and header-dependent conditionals out of the main for-loop of `readCOO`. The only remaining conditional is in the `IsSymmetric=true` variant: checking whether the element is on the diagonal or not (which cannot be lifted out of the loop). Depends On D138363 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D138365
This commit is contained in:
parent
0d58834700
commit
1dfb9a64f1
|
@ -42,32 +42,46 @@ struct is_complex final : public std::false_type {};
|
|||
template <typename T>
|
||||
struct is_complex<std::complex<T>> final : public std::true_type {};
|
||||
|
||||
/// Reads an element of a non-complex type for the current indices in
|
||||
/// coordinate scheme.
|
||||
template <typename V>
|
||||
inline std::enable_if_t<!is_complex<V>::value, V>
|
||||
readCOOValue(char **linePtr, bool is_pattern) {
|
||||
/// Returns an element-value of non-complex type. If `IsPattern` is true,
|
||||
/// then returns an arbitrary value. If `IsPattern` is false, then
|
||||
/// reads the value from the current line buffer beginning at `linePtr`.
|
||||
template <typename V, bool IsPattern>
|
||||
inline std::enable_if_t<!is_complex<V>::value, V> readCOOValue(char **linePtr) {
|
||||
// The external formats always store these numerical values with the type
|
||||
// double, but we cast these values to the sparse tensor object type.
|
||||
// For a pattern tensor, we arbitrarily pick the value 1 for all entries.
|
||||
return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
|
||||
if constexpr (IsPattern)
|
||||
return 1.0;
|
||||
return strtod(*linePtr, linePtr);
|
||||
}
|
||||
|
||||
/// Reads an element of a complex type for the current indices in
|
||||
/// coordinate scheme.
|
||||
template <typename V>
|
||||
inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
|
||||
bool is_pattern) {
|
||||
/// Returns an element-value of complex type. If `IsPattern` is true,
|
||||
/// then returns an arbitrary value. If `IsPattern` is false, then reads
|
||||
/// the value from the current line buffer beginning at `linePtr`.
|
||||
template <typename V, bool IsPattern>
|
||||
inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr) {
|
||||
// Read two values to make a complex. The external formats always store
|
||||
// numerical values with the type double, but we cast these values to the
|
||||
// sparse tensor object type. For a pattern tensor, we arbitrarily pick the
|
||||
// value 1 for all entries.
|
||||
double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
|
||||
double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
|
||||
if constexpr (IsPattern)
|
||||
return V(1.0, 1.0);
|
||||
double re = strtod(*linePtr, linePtr);
|
||||
double im = strtod(*linePtr, linePtr);
|
||||
// Avoiding brace-notation since that forbids narrowing to `float`.
|
||||
return V(re, im);
|
||||
}
|
||||
|
||||
/// Returns an element-value. If `is_pattern` is true, then returns an
|
||||
/// arbitrary value. If `is_pattern` is false, then reads the value from
|
||||
/// the current line buffer beginning at `linePtr`.
|
||||
template <typename V>
|
||||
inline V readCOOValue(char **linePtr, bool is_pattern) {
|
||||
if (is_pattern)
|
||||
return readCOOValue<V, true>(linePtr);
|
||||
return readCOOValue<V, false>(linePtr);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -249,6 +263,18 @@ private:
|
|||
/// Precondition: `indices` is valid for `getRank()`.
|
||||
char *readCOOIndices(uint64_t *indices);
|
||||
|
||||
/// The internal implementation of `readCOO`. We template over
|
||||
/// `IsPattern` and `IsSymmetric` in order to perform LICM without
|
||||
/// needing to duplicate the source code.
|
||||
//
|
||||
// TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
|
||||
// since that's what `readCOO` creates. Once we update `readCOO` to
|
||||
// functionalize the mapping, then this helper will just take that
|
||||
// same function.
|
||||
template <typename V, bool IsPattern, bool IsSymmetric>
|
||||
void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
|
||||
SparseTensorCOO<V> *lvlCOO);
|
||||
|
||||
/// Reads the MME header of a general sparse matrix of type real.
|
||||
void readMMEHeader();
|
||||
|
||||
|
@ -282,36 +308,50 @@ SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
|
|||
assert(lvlRank == dimRank && "Rank mismatch");
|
||||
detail::PermutationRef d2l(dimRank, dim2lvl);
|
||||
// Prepare a COO object with the number of nonzeros as initial capacity.
|
||||
const uint64_t nnz = getNNZ();
|
||||
auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, nnz);
|
||||
// Read all nonzero elements.
|
||||
auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNNZ());
|
||||
// Do some manual LICM, to avoid assertions in the for-loop.
|
||||
const bool IsPattern = isPattern();
|
||||
const bool IsSymmetric = (isSymmetric() && getRank() == 2);
|
||||
if (IsPattern && IsSymmetric)
|
||||
readCOOLoop<V, true, true>(lvlRank, d2l, lvlCOO);
|
||||
else if (IsPattern)
|
||||
readCOOLoop<V, true, false>(lvlRank, d2l, lvlCOO);
|
||||
else if (IsSymmetric)
|
||||
readCOOLoop<V, false, true>(lvlRank, d2l, lvlCOO);
|
||||
else
|
||||
readCOOLoop<V, false, false>(lvlRank, d2l, lvlCOO);
|
||||
// Close the file and return the COO.
|
||||
closeFile();
|
||||
return lvlCOO;
|
||||
}
|
||||
|
||||
template <typename V, bool IsPattern, bool IsSymmetric>
|
||||
void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
|
||||
detail::PermutationRef dim2lvl,
|
||||
SparseTensorCOO<V> *lvlCOO) {
|
||||
const uint64_t dimRank = getRank();
|
||||
std::vector<uint64_t> dimInd(dimRank);
|
||||
std::vector<uint64_t> lvlInd(lvlRank);
|
||||
// Do some manual LICM, to avoid assertions in the for-loop.
|
||||
const bool addSymmetric = (isSymmetric() && dimRank == 2);
|
||||
const bool isPattern_ = isPattern();
|
||||
for (uint64_t k = 0; k < nnz; ++k) {
|
||||
for (uint64_t nnz = getNNZ(), k = 0; k < nnz; ++k) {
|
||||
// We inline `readCOOElement` here in order to avoid redundant
|
||||
// assertions, since they're guaranteed by the call to `isValid()`
|
||||
// and the construction of `dimInd` above.
|
||||
char *linePtr = readCOOIndices(dimInd.data());
|
||||
const V value = detail::readCOOValue<V>(&linePtr, isPattern_);
|
||||
d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
|
||||
const V value = detail::readCOOValue<V, IsPattern>(&linePtr);
|
||||
dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
|
||||
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
|
||||
lvlCOO->add(lvlInd, value);
|
||||
// We currently chose to deal with symmetric matrices by fully
|
||||
// constructing them. In the future, we may want to make symmetry
|
||||
// implicit for storage reasons.
|
||||
if (addSymmetric && dimInd[0] != dimInd[1]) {
|
||||
// Must recompute `lvlInd`, since arbitrary mappings don't preserve swap.
|
||||
std::swap(dimInd[0], dimInd[1]);
|
||||
d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
|
||||
lvlCOO->add(lvlInd, value);
|
||||
}
|
||||
if constexpr (IsSymmetric)
|
||||
if (dimInd[0] != dimInd[1]) {
|
||||
// Must recompute `lvlInd`, since arbitrary maps don't preserve swap.
|
||||
std::swap(dimInd[0], dimInd[1]);
|
||||
dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
|
||||
lvlCOO->add(lvlInd, value);
|
||||
}
|
||||
}
|
||||
// Close the file and return the COO.
|
||||
closeFile();
|
||||
return lvlCOO;
|
||||
}
|
||||
|
||||
/// Writes the sparse tensor to `filename` in extended FROSTT format.
|
||||
|
|
Loading…
Reference in New Issue