[mlir][sparse] Factoring magic numbers into a header

Addresses https://bugs.llvm.org/show_bug.cgi?id=52303

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D112962
This commit is contained in:
wren romano 2021-11-05 15:15:39 -07:00
parent f57d0e2726
commit 845561ec9d
4 changed files with 259 additions and 174 deletions

View File

@ -0,0 +1,55 @@
//===- SparseTensorUtils.h - Enums shared with the runtime ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines several enums shared between
// Transforms/SparseTensorConversion.cpp and ExecutionEngine/SparseUtils.cpp
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
#define MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
#include <cinttypes>
extern "C" {
/// Encoding of the elemental type, for "overloading" @newSparseTensor.
enum class OverheadType : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
/// Encoding of the elemental type, for "overloading" @newSparseTensor.
enum class PrimaryType : uint32_t {
kF64 = 1,
kF32 = 2,
kI64 = 3,
kI32 = 4,
kI16 = 5,
kI8 = 6
};
/// The actions performed by @newSparseTensor.
enum class Action : uint32_t {
kEmpty = 0,
kFromFile = 1,
kFromCOO = 2,
kEmptyCOO = 3,
kToCOO = 4,
kToIterator = 5
};
/// This enum mimics `SparseTensorEncodingAttr::DimLevelType` for
/// breaking dependency cycles. `SparseTensorEncodingAttr::DimLevelType`
/// is the source of truth and this enum should be kept consistent with it.
enum class DimLevelType : uint8_t {
kDense = 0,
kCompressed = 1,
kSingleton = 2
};
} // extern "C"
#endif // MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_

View File

@ -22,6 +22,7 @@
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/ExecutionEngine/SparseTensorUtils.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@ -29,69 +30,10 @@ using namespace mlir::sparse_tensor;
namespace {
/// New tensor storage action. Keep these values consistent with
/// the sparse runtime support library.
enum Action : uint32_t {
kEmpty = 0,
kFromFile = 1,
kFromCOO = 2,
kEmptyCOO = 3,
kToCOO = 4,
kToIter = 5
};
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
/// Returns internal type encoding for primary storage. Keep these
/// values consistent with the sparse runtime support library.
static uint32_t getPrimaryTypeEncoding(Type tp) {
if (tp.isF64())
return 1;
if (tp.isF32())
return 2;
if (tp.isInteger(64))
return 3;
if (tp.isInteger(32))
return 4;
if (tp.isInteger(16))
return 5;
if (tp.isInteger(8))
return 6;
return 0;
}
/// Returns internal type encoding for overhead storage. Keep these
/// values consistent with the sparse runtime support library.
static uint32_t getOverheadTypeEncoding(unsigned width) {
switch (width) {
default:
return 1;
case 32:
return 2;
case 16:
return 3;
case 8:
return 4;
}
}
/// Returns internal dimension level type encoding. Keep these
/// values consistent with the sparse runtime support library.
static uint32_t
getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
switch (dlt) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
return 0;
case SparseTensorEncodingAttr::DimLevelType::Compressed:
return 1;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
return 2;
}
llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
}
/// Generates a constant zero of the given type.
inline static Value constantZero(ConversionPatternRewriter &rewriter,
Location loc, Type t) {
@ -116,6 +58,75 @@ inline static Value constantI8(ConversionPatternRewriter &rewriter,
return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
}
/// Generates a constant of the given `Action`.
static Value constantAction(ConversionPatternRewriter &rewriter, Location loc,
Action action) {
return constantI32(rewriter, loc, static_cast<uint32_t>(action));
}
/// Generates a constant of the internal type encoding for overhead storage.
static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc, unsigned width) {
OverheadType sec;
switch (width) {
default:
sec = OverheadType::kU64;
break;
case 32:
sec = OverheadType::kU32;
break;
case 16:
sec = OverheadType::kU16;
break;
case 8:
sec = OverheadType::kU8;
break;
}
return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
}
/// Generates a constant of the internal type encoding for primary storage.
static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc, Type tp) {
PrimaryType primary;
if (tp.isF64())
primary = PrimaryType::kF64;
else if (tp.isF32())
primary = PrimaryType::kF32;
else if (tp.isInteger(64))
primary = PrimaryType::kI64;
else if (tp.isInteger(32))
primary = PrimaryType::kI32;
else if (tp.isInteger(16))
primary = PrimaryType::kI16;
else if (tp.isInteger(8))
primary = PrimaryType::kI8;
else
llvm_unreachable("Unknown element type");
return constantI32(rewriter, loc, static_cast<uint32_t>(primary));
}
/// Generates a constant of the internal dimension level type encoding.
static Value
constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
SparseTensorEncodingAttr::DimLevelType dlt) {
DimLevelType dlt2;
switch (dlt) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
dlt2 = DimLevelType::kDense;
break;
case SparseTensorEncodingAttr::DimLevelType::Compressed:
dlt2 = DimLevelType::kCompressed;
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
dlt2 = DimLevelType::kSingleton;
break;
default:
llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
}
return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
}
/// Returns a function reference (first hit also inserts into module). Sets
/// the "_emit_c_interface" on the function declaration when requested,
/// so that LLVM lowering generates a wrapper function that takes care
@ -238,7 +249,7 @@ static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
/// computation.
static void newParams(ConversionPatternRewriter &rewriter,
SmallVector<Value, 8> &params, Operation *op,
SparseTensorEncodingAttr &enc, uint32_t action,
SparseTensorEncodingAttr &enc, Action action,
ValueRange szs, Value ptr = Value()) {
Location loc = op->getLoc();
ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
@ -246,7 +257,7 @@ static void newParams(ConversionPatternRewriter &rewriter,
// Sparsity annotations.
SmallVector<Value, 4> attrs;
for (unsigned i = 0; i < sz; i++)
attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i]));
params.push_back(genBuffer(rewriter, loc, attrs));
// Dimension sizes array of the enveloping tensor. Useful for either
// verification of external data, or for construction of internal data.
@ -268,18 +279,17 @@ static void newParams(ConversionPatternRewriter &rewriter,
params.push_back(genBuffer(rewriter, loc, rev));
// Secondary and primary types encoding.
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
uint32_t primary = getPrimaryTypeEncoding(resType.getElementType());
assert(primary);
params.push_back(constantI32(rewriter, loc, secPtr));
params.push_back(constantI32(rewriter, loc, secInd));
params.push_back(constantI32(rewriter, loc, primary));
params.push_back(
constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()));
params.push_back(
constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()));
params.push_back(
constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
// User action and pointer.
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
if (!ptr)
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
params.push_back(constantI32(rewriter, loc, action));
params.push_back(constantAction(rewriter, loc, action));
params.push_back(ptr);
}
@ -530,7 +540,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
SmallVector<Value, 8> params;
sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
Value ptr = adaptor.getOperands()[0];
newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr);
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
}
@ -549,7 +559,7 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
// Generate the call to construct empty tensor. The sizes are
// explicitly defined by the arguments to the init operator.
SmallVector<Value, 8> params;
newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands());
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
}
@ -588,13 +598,13 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
newParams(rewriter, params, op, enc, kToCOO, sizes, src);
newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, op, params);
params[3] = constantI32(
rewriter, loc, getOverheadTypeEncoding(encDst.getPointerBitWidth()));
params[4] = constantI32(
rewriter, loc, getOverheadTypeEncoding(encDst.getIndexBitWidth()));
params[6] = constantI32(rewriter, loc, kFromCOO);
params[3] = constantOverheadTypeEncoding(rewriter, loc,
encDst.getPointerBitWidth());
params[4] = constantOverheadTypeEncoding(rewriter, loc,
encDst.getIndexBitWidth());
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
@ -613,7 +623,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
Type elemTp = dstTensorTp.getElementType();
// Fabricate a no-permutation encoding for newParams().
// The pointer/index types must be those of `src`.
// The dimLevelTypes aren't actually used by kToIter.
// The dimLevelTypes aren't actually used by Action::kToIterator.
encDst = SparseTensorEncodingAttr::get(
op->getContext(),
SmallVector<SparseTensorEncodingAttr::DimLevelType>(
@ -622,7 +632,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
newParams(rewriter, params, op, encDst, kToIter, sizes, src);
newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src);
Value iter = genNewCall(rewriter, op, params);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
@ -677,7 +687,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromSrc(rewriter, sizes, loc, src);
newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes);
Value ptr = genNewCall(rewriter, op, params);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value perm = params[2];
@ -718,7 +728,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
return {};
});
// Final call to construct sparse tensor storage.
params[6] = constantI32(rewriter, loc, kFromCOO);
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = ptr;
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();

View File

@ -14,6 +14,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/ExecutionEngine/SparseTensorUtils.h"
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
@ -162,8 +163,6 @@ private:
/// function overloading to implement "partial" method specialization.
class SparseTensorStorageBase {
public:
enum DimLevelType : uint8_t { kDense = 0, kCompressed = 1, kSingleton = 2 };
virtual uint64_t getDimSize(uint64_t) = 0;
// Overhead storage.
@ -206,7 +205,7 @@ public:
/// permutation, and per-dimension dense/sparse annotations, using
/// the coordinate scheme tensor for the initial contents if provided.
SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
const uint8_t *sparsity, SparseTensorCOO<V> *tensor)
const DimLevelType *sparsity, SparseTensorCOO<V> *tensor)
: sizes(szs), rev(getRank()), pointers(getRank()), indices(getRank()) {
uint64_t rank = getRank();
// Store "reverse" permutation.
@ -216,17 +215,18 @@ public:
// TODO: needs fine-tuning based on sparsity
for (uint64_t r = 0, s = 1; r < rank; r++) {
s *= sizes[r];
if (sparsity[r] == kCompressed) {
if (sparsity[r] == DimLevelType::kCompressed) {
pointers[r].reserve(s + 1);
indices[r].reserve(s);
s = 1;
} else {
assert(sparsity[r] == kDense && "singleton not yet supported");
assert(sparsity[r] == DimLevelType::kDense &&
"singleton not yet supported");
}
}
// Prepare sparse pointer structures for all dimensions.
for (uint64_t r = 0; r < rank; r++)
if (sparsity[r] == kCompressed)
if (sparsity[r] == DimLevelType::kCompressed)
pointers[r].push_back(0);
// Then assign contents from coordinate scheme tensor if provided.
if (tensor) {
@ -288,7 +288,7 @@ public:
/// permutation as is desired for the new sparse tensor storage.
static SparseTensorStorage<P, I, V> *
newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
const uint8_t *sparsity, SparseTensorCOO<V> *tensor) {
const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
SparseTensorStorage<P, I, V> *n = nullptr;
if (tensor) {
assert(tensor->getRank() == rank);
@ -311,8 +311,8 @@ private:
/// Initializes sparse tensor storage scheme from a memory-resident sparse
/// tensor in coordinate scheme. This method prepares the pointers and
/// indices arrays under the given per-dimension dense/sparse annotations.
void fromCOO(SparseTensorCOO<V> *tensor, const uint8_t *sparsity, uint64_t lo,
uint64_t hi, uint64_t d) {
void fromCOO(SparseTensorCOO<V> *tensor, const DimLevelType *sparsity,
uint64_t lo, uint64_t hi, uint64_t d) {
const std::vector<Element<V>> &elements = tensor->getElements();
// Once dimensions are exhausted, insert the numerical values.
if (d == getRank()) {
@ -331,7 +331,7 @@ private:
while (seg < hi && elements[seg].indices[d] == idx)
seg++;
// Handle segment in interval for sparse or dense dimension.
if (sparsity[d] == kCompressed) {
if (sparsity[d] == DimLevelType::kCompressed) {
indices[d].push_back(idx);
} else {
// For dense storage we must fill in all the zero values between
@ -346,7 +346,7 @@ private:
lo = seg;
}
// Finalize the sparse pointer structure at this dimension.
if (sparsity[d] == kCompressed) {
if (sparsity[d] == DimLevelType::kCompressed) {
pointers[d].push_back(indices[d].size());
} else {
// For dense storage we must fill in all the zero values after
@ -543,53 +543,35 @@ typedef uint64_t index_t;
//
//===----------------------------------------------------------------------===//
enum OverheadTypeEnum : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
enum PrimaryTypeEnum : uint32_t {
kF64 = 1,
kF32 = 2,
kI64 = 3,
kI32 = 4,
kI16 = 5,
kI8 = 6
};
enum Action : uint32_t {
kEmpty = 0,
kFromFile = 1,
kFromCOO = 2,
kEmptyCOO = 3,
kToCOO = 4,
kToIter = 5
};
#define CASE(p, i, v, P, I, V) \
if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \
SparseTensorCOO<V> *tensor = nullptr; \
if (action <= kFromCOO) { \
if (action == kFromFile) { \
if (action <= Action::kFromCOO) { \
if (action == Action::kFromFile) { \
char *filename = static_cast<char *>(ptr); \
tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm); \
} else if (action == kFromCOO) { \
} else if (action == Action::kFromCOO) { \
tensor = static_cast<SparseTensorCOO<V> *>(ptr); \
} else { \
assert(action == kEmpty); \
assert(action == Action::kEmpty); \
} \
return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm, \
sparsity, tensor); \
} else if (action == kEmptyCOO) { \
} else if (action == Action::kEmptyCOO) { \
return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm); \
} else { \
tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \
if (action == kToIter) { \
if (action == Action::kToIterator) { \
tensor->startIterator(); \
} else { \
assert(action == kToCOO); \
assert(action == Action::kToCOO); \
} \
return tensor; \
} \
}
#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
#define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \
assert(ref); \
@ -656,78 +638,110 @@ enum Action : uint32_t {
/// Constructs a new sparse tensor. This is the "swiss army knife"
/// method for materializing sparse tensors into the computation.
///
/// action:
/// Action:
/// kEmpty = returns empty storage to fill later
/// kFromFile = returns storage, where ptr contains filename to read
/// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
/// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
/// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
/// kToIter = returns iterator from storage in ptr (call getNext() to use)
/// kToIterator = returns iterator from storage in ptr (call getNext() to use)
void *
_mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
StridedMemRefType<index_t, 1> *sref,
StridedMemRefType<index_t, 1> *pref,
uint32_t ptrTp, uint32_t indTp, uint32_t valTp,
uint32_t action, void *ptr) {
OverheadType ptrTp, OverheadType indTp,
PrimaryType valTp, Action action, void *ptr) {
assert(aref && sref && pref);
assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
pref->strides[0] == 1);
assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
const uint8_t *sparsity = aref->data + aref->offset;
const DimLevelType *sparsity = aref->data + aref->offset;
const index_t *sizes = sref->data + sref->offset;
const index_t *perm = pref->data + pref->offset;
uint64_t rank = aref->sizes[0];
// Double matrices with all combinations of overhead storage.
CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
uint64_t, double);
CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
uint32_t, double);
CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
uint16_t, double);
CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
uint8_t, double);
CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
uint64_t, double);
CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
uint32_t, double);
CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
uint16_t, double);
CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
uint8_t, double);
CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
uint64_t, double);
CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
uint32_t, double);
CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
uint16_t, double);
CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
uint8_t, double);
CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
uint64_t, double);
CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
uint32_t, double);
CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
uint16_t, double);
CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
uint8_t, double);
// Float matrices with all combinations of overhead storage.
CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
uint64_t, float);
CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
uint32_t, float);
CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
uint16_t, float);
CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
uint8_t, float);
CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
uint64_t, float);
CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
uint32_t, float);
CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
uint16_t, float);
CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
uint8_t, float);
CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
uint64_t, float);
CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
uint32_t, float);
CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
uint16_t, float);
CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
uint8_t, float);
CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
uint64_t, float);
CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
uint32_t, float);
CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
uint16_t, float);
CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
uint8_t, float);
// Integral matrices with same overhead storage.
CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t);
CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t);
CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t);
CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t);
CASE(kU32, kU32, kI32, uint32_t, uint32_t, int32_t);
CASE(kU32, kU32, kI16, uint32_t, uint32_t, int16_t);
CASE(kU32, kU32, kI8, uint32_t, uint32_t, int8_t);
CASE(kU16, kU16, kI32, uint16_t, uint16_t, int32_t);
CASE(kU16, kU16, kI16, uint16_t, uint16_t, int16_t);
CASE(kU16, kU16, kI8, uint16_t, uint16_t, int8_t);
CASE(kU8, kU8, kI32, uint8_t, uint8_t, int32_t);
CASE(kU8, kU8, kI16, uint8_t, uint8_t, int16_t);
CASE(kU8, kU8, kI8, uint8_t, uint8_t, int8_t);
// Integral matrices with both overheads of the same type.
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
// Unsupported case (add above if needed).
fputs("unsupported combination of types\n", stderr);
@ -830,7 +844,7 @@ void delSparseTensor(void *tensor) {
void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
double *values, uint64_t *indices) {
// Setup all-dims compressed and default ordering.
std::vector<uint8_t> sparse(rank, SparseTensorStorageBase::kCompressed);
std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
std::vector<uint64_t> perm(rank);
std::iota(perm.begin(), perm.end(), 0);
// Convert external format to internal COO.

View File

@ -1707,7 +1707,10 @@ cc_library(
cc_library(
name = "SparseTensorTransforms",
srcs = glob(["lib/Dialect/SparseTensor/Transforms/*.cpp"]),
hdrs = ["include/mlir/Dialect/SparseTensor/Transforms/Passes.h"],
hdrs = [
"include/mlir/Dialect/SparseTensor/Transforms/Passes.h",
"include/mlir/ExecutionEngine/SparseTensorUtils.h",
],
includes = ["include"],
deps = [
":Affine",
@ -5391,7 +5394,10 @@ cc_library(
"lib/ExecutionEngine/CRunnerUtils.cpp",
"lib/ExecutionEngine/SparseTensorUtils.cpp",
],
hdrs = ["include/mlir/ExecutionEngine/CRunnerUtils.h"],
hdrs = [
"include/mlir/ExecutionEngine/CRunnerUtils.h",
"include/mlir/ExecutionEngine/SparseTensorUtils.h",
],
includes = ["include"],
)