[mlir][sparse] Factoring magic numbers into a header
Addresses https://bugs.llvm.org/show_bug.cgi?id=52303 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D112962
This commit is contained in:
parent
f57d0e2726
commit
845561ec9d
|
@ -0,0 +1,55 @@
|
|||
//===- SparseTensorUtils.h - Enums shared with the runtime ------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header file defines several enums shared between
|
||||
// Transforms/SparseTensorConversion.cpp and ExecutionEngine/SparseUtils.cpp
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
|
||||
#define MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
|
||||
|
||||
#include <cinttypes>
|
||||
|
||||
extern "C" {
|
||||
|
||||
/// Encoding of the elemental type, for "overloading" @newSparseTensor.
|
||||
enum class OverheadType : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
|
||||
|
||||
/// Encoding of the elemental type, for "overloading" @newSparseTensor.
|
||||
enum class PrimaryType : uint32_t {
|
||||
kF64 = 1,
|
||||
kF32 = 2,
|
||||
kI64 = 3,
|
||||
kI32 = 4,
|
||||
kI16 = 5,
|
||||
kI8 = 6
|
||||
};
|
||||
|
||||
/// The actions performed by @newSparseTensor.
|
||||
enum class Action : uint32_t {
|
||||
kEmpty = 0,
|
||||
kFromFile = 1,
|
||||
kFromCOO = 2,
|
||||
kEmptyCOO = 3,
|
||||
kToCOO = 4,
|
||||
kToIterator = 5
|
||||
};
|
||||
|
||||
/// This enum mimics `SparseTensorEncodingAttr::DimLevelType` for
|
||||
/// breaking dependency cycles. `SparseTensorEncodingAttr::DimLevelType`
|
||||
/// is the source of truth and this enum should be kept consistent with it.
|
||||
enum class DimLevelType : uint8_t {
|
||||
kDense = 0,
|
||||
kCompressed = 1,
|
||||
kSingleton = 2
|
||||
};
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#endif // MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/ExecutionEngine/SparseTensorUtils.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -29,69 +30,10 @@ using namespace mlir::sparse_tensor;
|
|||
|
||||
namespace {
|
||||
|
||||
/// New tensor storage action. Keep these values consistent with
|
||||
/// the sparse runtime support library.
|
||||
enum Action : uint32_t {
|
||||
kEmpty = 0,
|
||||
kFromFile = 1,
|
||||
kFromCOO = 2,
|
||||
kEmptyCOO = 3,
|
||||
kToCOO = 4,
|
||||
kToIter = 5
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns internal type encoding for primary storage. Keep these
|
||||
/// values consistent with the sparse runtime support library.
|
||||
static uint32_t getPrimaryTypeEncoding(Type tp) {
|
||||
if (tp.isF64())
|
||||
return 1;
|
||||
if (tp.isF32())
|
||||
return 2;
|
||||
if (tp.isInteger(64))
|
||||
return 3;
|
||||
if (tp.isInteger(32))
|
||||
return 4;
|
||||
if (tp.isInteger(16))
|
||||
return 5;
|
||||
if (tp.isInteger(8))
|
||||
return 6;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Returns internal type encoding for overhead storage. Keep these
|
||||
/// values consistent with the sparse runtime support library.
|
||||
static uint32_t getOverheadTypeEncoding(unsigned width) {
|
||||
switch (width) {
|
||||
default:
|
||||
return 1;
|
||||
case 32:
|
||||
return 2;
|
||||
case 16:
|
||||
return 3;
|
||||
case 8:
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns internal dimension level type encoding. Keep these
|
||||
/// values consistent with the sparse runtime support library.
|
||||
static uint32_t
|
||||
getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
|
||||
switch (dlt) {
|
||||
case SparseTensorEncodingAttr::DimLevelType::Dense:
|
||||
return 0;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Compressed:
|
||||
return 1;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Singleton:
|
||||
return 2;
|
||||
}
|
||||
llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
|
||||
}
|
||||
|
||||
/// Generates a constant zero of the given type.
|
||||
inline static Value constantZero(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type t) {
|
||||
|
@ -116,6 +58,75 @@ inline static Value constantI8(ConversionPatternRewriter &rewriter,
|
|||
return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
|
||||
}
|
||||
|
||||
/// Generates a constant of the given `Action`.
|
||||
static Value constantAction(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Action action) {
|
||||
return constantI32(rewriter, loc, static_cast<uint32_t>(action));
|
||||
}
|
||||
|
||||
/// Generates a constant of the internal type encoding for overhead storage.
|
||||
static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
|
||||
Location loc, unsigned width) {
|
||||
OverheadType sec;
|
||||
switch (width) {
|
||||
default:
|
||||
sec = OverheadType::kU64;
|
||||
break;
|
||||
case 32:
|
||||
sec = OverheadType::kU32;
|
||||
break;
|
||||
case 16:
|
||||
sec = OverheadType::kU16;
|
||||
break;
|
||||
case 8:
|
||||
sec = OverheadType::kU8;
|
||||
break;
|
||||
}
|
||||
return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
|
||||
}
|
||||
|
||||
/// Generates a constant of the internal type encoding for primary storage.
|
||||
static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type tp) {
|
||||
PrimaryType primary;
|
||||
if (tp.isF64())
|
||||
primary = PrimaryType::kF64;
|
||||
else if (tp.isF32())
|
||||
primary = PrimaryType::kF32;
|
||||
else if (tp.isInteger(64))
|
||||
primary = PrimaryType::kI64;
|
||||
else if (tp.isInteger(32))
|
||||
primary = PrimaryType::kI32;
|
||||
else if (tp.isInteger(16))
|
||||
primary = PrimaryType::kI16;
|
||||
else if (tp.isInteger(8))
|
||||
primary = PrimaryType::kI8;
|
||||
else
|
||||
llvm_unreachable("Unknown element type");
|
||||
return constantI32(rewriter, loc, static_cast<uint32_t>(primary));
|
||||
}
|
||||
|
||||
/// Generates a constant of the internal dimension level type encoding.
|
||||
static Value
|
||||
constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
|
||||
SparseTensorEncodingAttr::DimLevelType dlt) {
|
||||
DimLevelType dlt2;
|
||||
switch (dlt) {
|
||||
case SparseTensorEncodingAttr::DimLevelType::Dense:
|
||||
dlt2 = DimLevelType::kDense;
|
||||
break;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Compressed:
|
||||
dlt2 = DimLevelType::kCompressed;
|
||||
break;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Singleton:
|
||||
dlt2 = DimLevelType::kSingleton;
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
|
||||
}
|
||||
return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
|
||||
}
|
||||
|
||||
/// Returns a function reference (first hit also inserts into module). Sets
|
||||
/// the "_emit_c_interface" on the function declaration when requested,
|
||||
/// so that LLVM lowering generates a wrapper function that takes care
|
||||
|
@ -238,7 +249,7 @@ static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
|
|||
/// computation.
|
||||
static void newParams(ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value, 8> ¶ms, Operation *op,
|
||||
SparseTensorEncodingAttr &enc, uint32_t action,
|
||||
SparseTensorEncodingAttr &enc, Action action,
|
||||
ValueRange szs, Value ptr = Value()) {
|
||||
Location loc = op->getLoc();
|
||||
ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
|
||||
|
@ -246,7 +257,7 @@ static void newParams(ConversionPatternRewriter &rewriter,
|
|||
// Sparsity annotations.
|
||||
SmallVector<Value, 4> attrs;
|
||||
for (unsigned i = 0; i < sz; i++)
|
||||
attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
|
||||
attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i]));
|
||||
params.push_back(genBuffer(rewriter, loc, attrs));
|
||||
// Dimension sizes array of the enveloping tensor. Useful for either
|
||||
// verification of external data, or for construction of internal data.
|
||||
|
@ -268,18 +279,17 @@ static void newParams(ConversionPatternRewriter &rewriter,
|
|||
params.push_back(genBuffer(rewriter, loc, rev));
|
||||
// Secondary and primary types encoding.
|
||||
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
|
||||
uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
|
||||
uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
|
||||
uint32_t primary = getPrimaryTypeEncoding(resType.getElementType());
|
||||
assert(primary);
|
||||
params.push_back(constantI32(rewriter, loc, secPtr));
|
||||
params.push_back(constantI32(rewriter, loc, secInd));
|
||||
params.push_back(constantI32(rewriter, loc, primary));
|
||||
params.push_back(
|
||||
constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()));
|
||||
params.push_back(
|
||||
constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()));
|
||||
params.push_back(
|
||||
constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
|
||||
// User action and pointer.
|
||||
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
|
||||
if (!ptr)
|
||||
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
|
||||
params.push_back(constantI32(rewriter, loc, action));
|
||||
params.push_back(constantAction(rewriter, loc, action));
|
||||
params.push_back(ptr);
|
||||
}
|
||||
|
||||
|
@ -530,7 +540,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
|
|||
SmallVector<Value, 8> params;
|
||||
sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
|
||||
Value ptr = adaptor.getOperands()[0];
|
||||
newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
|
||||
newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr);
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
|
||||
return success();
|
||||
}
|
||||
|
@ -549,7 +559,7 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
|
|||
// Generate the call to construct empty tensor. The sizes are
|
||||
// explicitly defined by the arguments to the init operator.
|
||||
SmallVector<Value, 8> params;
|
||||
newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
|
||||
newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands());
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
|
||||
return success();
|
||||
}
|
||||
|
@ -588,13 +598,13 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
|||
auto enc = SparseTensorEncodingAttr::get(
|
||||
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
|
||||
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
|
||||
newParams(rewriter, params, op, enc, kToCOO, sizes, src);
|
||||
newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
|
||||
Value coo = genNewCall(rewriter, op, params);
|
||||
params[3] = constantI32(
|
||||
rewriter, loc, getOverheadTypeEncoding(encDst.getPointerBitWidth()));
|
||||
params[4] = constantI32(
|
||||
rewriter, loc, getOverheadTypeEncoding(encDst.getIndexBitWidth()));
|
||||
params[6] = constantI32(rewriter, loc, kFromCOO);
|
||||
params[3] = constantOverheadTypeEncoding(rewriter, loc,
|
||||
encDst.getPointerBitWidth());
|
||||
params[4] = constantOverheadTypeEncoding(rewriter, loc,
|
||||
encDst.getIndexBitWidth());
|
||||
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
|
||||
params[7] = coo;
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
|
||||
return success();
|
||||
|
@ -613,7 +623,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
|||
Type elemTp = dstTensorTp.getElementType();
|
||||
// Fabricate a no-permutation encoding for newParams().
|
||||
// The pointer/index types must be those of `src`.
|
||||
// The dimLevelTypes aren't actually used by kToIter.
|
||||
// The dimLevelTypes aren't actually used by Action::kToIterator.
|
||||
encDst = SparseTensorEncodingAttr::get(
|
||||
op->getContext(),
|
||||
SmallVector<SparseTensorEncodingAttr::DimLevelType>(
|
||||
|
@ -622,7 +632,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
|||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
|
||||
newParams(rewriter, params, op, encDst, kToIter, sizes, src);
|
||||
newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src);
|
||||
Value iter = genNewCall(rewriter, op, params);
|
||||
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
|
||||
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
|
||||
|
@ -677,7 +687,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
|||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
sizesFromSrc(rewriter, sizes, loc, src);
|
||||
newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
|
||||
newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes);
|
||||
Value ptr = genNewCall(rewriter, op, params);
|
||||
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
|
||||
Value perm = params[2];
|
||||
|
@ -718,7 +728,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
|||
return {};
|
||||
});
|
||||
// Final call to construct sparse tensor storage.
|
||||
params[6] = constantI32(rewriter, loc, kFromCOO);
|
||||
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
|
||||
params[7] = ptr;
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
|
||||
return success();
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/ExecutionEngine/SparseTensorUtils.h"
|
||||
#include "mlir/ExecutionEngine/CRunnerUtils.h"
|
||||
|
||||
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
|
||||
|
@ -162,8 +163,6 @@ private:
|
|||
/// function overloading to implement "partial" method specialization.
|
||||
class SparseTensorStorageBase {
|
||||
public:
|
||||
enum DimLevelType : uint8_t { kDense = 0, kCompressed = 1, kSingleton = 2 };
|
||||
|
||||
virtual uint64_t getDimSize(uint64_t) = 0;
|
||||
|
||||
// Overhead storage.
|
||||
|
@ -206,7 +205,7 @@ public:
|
|||
/// permutation, and per-dimension dense/sparse annotations, using
|
||||
/// the coordinate scheme tensor for the initial contents if provided.
|
||||
SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
|
||||
const uint8_t *sparsity, SparseTensorCOO<V> *tensor)
|
||||
const DimLevelType *sparsity, SparseTensorCOO<V> *tensor)
|
||||
: sizes(szs), rev(getRank()), pointers(getRank()), indices(getRank()) {
|
||||
uint64_t rank = getRank();
|
||||
// Store "reverse" permutation.
|
||||
|
@ -216,17 +215,18 @@ public:
|
|||
// TODO: needs fine-tuning based on sparsity
|
||||
for (uint64_t r = 0, s = 1; r < rank; r++) {
|
||||
s *= sizes[r];
|
||||
if (sparsity[r] == kCompressed) {
|
||||
if (sparsity[r] == DimLevelType::kCompressed) {
|
||||
pointers[r].reserve(s + 1);
|
||||
indices[r].reserve(s);
|
||||
s = 1;
|
||||
} else {
|
||||
assert(sparsity[r] == kDense && "singleton not yet supported");
|
||||
assert(sparsity[r] == DimLevelType::kDense &&
|
||||
"singleton not yet supported");
|
||||
}
|
||||
}
|
||||
// Prepare sparse pointer structures for all dimensions.
|
||||
for (uint64_t r = 0; r < rank; r++)
|
||||
if (sparsity[r] == kCompressed)
|
||||
if (sparsity[r] == DimLevelType::kCompressed)
|
||||
pointers[r].push_back(0);
|
||||
// Then assign contents from coordinate scheme tensor if provided.
|
||||
if (tensor) {
|
||||
|
@ -288,7 +288,7 @@ public:
|
|||
/// permutation as is desired for the new sparse tensor storage.
|
||||
static SparseTensorStorage<P, I, V> *
|
||||
newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
|
||||
const uint8_t *sparsity, SparseTensorCOO<V> *tensor) {
|
||||
const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
|
||||
SparseTensorStorage<P, I, V> *n = nullptr;
|
||||
if (tensor) {
|
||||
assert(tensor->getRank() == rank);
|
||||
|
@ -311,8 +311,8 @@ private:
|
|||
/// Initializes sparse tensor storage scheme from a memory-resident sparse
|
||||
/// tensor in coordinate scheme. This method prepares the pointers and
|
||||
/// indices arrays under the given per-dimension dense/sparse annotations.
|
||||
void fromCOO(SparseTensorCOO<V> *tensor, const uint8_t *sparsity, uint64_t lo,
|
||||
uint64_t hi, uint64_t d) {
|
||||
void fromCOO(SparseTensorCOO<V> *tensor, const DimLevelType *sparsity,
|
||||
uint64_t lo, uint64_t hi, uint64_t d) {
|
||||
const std::vector<Element<V>> &elements = tensor->getElements();
|
||||
// Once dimensions are exhausted, insert the numerical values.
|
||||
if (d == getRank()) {
|
||||
|
@ -331,7 +331,7 @@ private:
|
|||
while (seg < hi && elements[seg].indices[d] == idx)
|
||||
seg++;
|
||||
// Handle segment in interval for sparse or dense dimension.
|
||||
if (sparsity[d] == kCompressed) {
|
||||
if (sparsity[d] == DimLevelType::kCompressed) {
|
||||
indices[d].push_back(idx);
|
||||
} else {
|
||||
// For dense storage we must fill in all the zero values between
|
||||
|
@ -346,7 +346,7 @@ private:
|
|||
lo = seg;
|
||||
}
|
||||
// Finalize the sparse pointer structure at this dimension.
|
||||
if (sparsity[d] == kCompressed) {
|
||||
if (sparsity[d] == DimLevelType::kCompressed) {
|
||||
pointers[d].push_back(indices[d].size());
|
||||
} else {
|
||||
// For dense storage we must fill in all the zero values after
|
||||
|
@ -543,53 +543,35 @@ typedef uint64_t index_t;
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
enum OverheadTypeEnum : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
|
||||
|
||||
enum PrimaryTypeEnum : uint32_t {
|
||||
kF64 = 1,
|
||||
kF32 = 2,
|
||||
kI64 = 3,
|
||||
kI32 = 4,
|
||||
kI16 = 5,
|
||||
kI8 = 6
|
||||
};
|
||||
|
||||
enum Action : uint32_t {
|
||||
kEmpty = 0,
|
||||
kFromFile = 1,
|
||||
kFromCOO = 2,
|
||||
kEmptyCOO = 3,
|
||||
kToCOO = 4,
|
||||
kToIter = 5
|
||||
};
|
||||
|
||||
#define CASE(p, i, v, P, I, V) \
|
||||
if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \
|
||||
SparseTensorCOO<V> *tensor = nullptr; \
|
||||
if (action <= kFromCOO) { \
|
||||
if (action == kFromFile) { \
|
||||
if (action <= Action::kFromCOO) { \
|
||||
if (action == Action::kFromFile) { \
|
||||
char *filename = static_cast<char *>(ptr); \
|
||||
tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm); \
|
||||
} else if (action == kFromCOO) { \
|
||||
} else if (action == Action::kFromCOO) { \
|
||||
tensor = static_cast<SparseTensorCOO<V> *>(ptr); \
|
||||
} else { \
|
||||
assert(action == kEmpty); \
|
||||
assert(action == Action::kEmpty); \
|
||||
} \
|
||||
return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm, \
|
||||
sparsity, tensor); \
|
||||
} else if (action == kEmptyCOO) { \
|
||||
} else if (action == Action::kEmptyCOO) { \
|
||||
return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm); \
|
||||
} else { \
|
||||
tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \
|
||||
if (action == kToIter) { \
|
||||
if (action == Action::kToIterator) { \
|
||||
tensor->startIterator(); \
|
||||
} else { \
|
||||
assert(action == kToCOO); \
|
||||
assert(action == Action::kToCOO); \
|
||||
} \
|
||||
return tensor; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
|
||||
|
||||
#define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \
|
||||
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \
|
||||
assert(ref); \
|
||||
|
@ -656,78 +638,110 @@ enum Action : uint32_t {
|
|||
/// Constructs a new sparse tensor. This is the "swiss army knife"
|
||||
/// method for materializing sparse tensors into the computation.
|
||||
///
|
||||
/// action:
|
||||
/// Action:
|
||||
/// kEmpty = returns empty storage to fill later
|
||||
/// kFromFile = returns storage, where ptr contains filename to read
|
||||
/// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
|
||||
/// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
|
||||
/// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
|
||||
/// kToIter = returns iterator from storage in ptr (call getNext() to use)
|
||||
/// kToIterator = returns iterator from storage in ptr (call getNext() to use)
|
||||
void *
|
||||
_mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
|
||||
_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
|
||||
StridedMemRefType<index_t, 1> *sref,
|
||||
StridedMemRefType<index_t, 1> *pref,
|
||||
uint32_t ptrTp, uint32_t indTp, uint32_t valTp,
|
||||
uint32_t action, void *ptr) {
|
||||
OverheadType ptrTp, OverheadType indTp,
|
||||
PrimaryType valTp, Action action, void *ptr) {
|
||||
assert(aref && sref && pref);
|
||||
assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
|
||||
pref->strides[0] == 1);
|
||||
assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
|
||||
const uint8_t *sparsity = aref->data + aref->offset;
|
||||
const DimLevelType *sparsity = aref->data + aref->offset;
|
||||
const index_t *sizes = sref->data + sref->offset;
|
||||
const index_t *perm = pref->data + pref->offset;
|
||||
uint64_t rank = aref->sizes[0];
|
||||
|
||||
// Double matrices with all combinations of overhead storage.
|
||||
CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
|
||||
CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
|
||||
CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
|
||||
CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
|
||||
CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
|
||||
CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
|
||||
CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
|
||||
CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
|
||||
CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
|
||||
CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
|
||||
CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
|
||||
CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
|
||||
CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
|
||||
CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
|
||||
CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
|
||||
CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
|
||||
CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
|
||||
uint64_t, double);
|
||||
CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
|
||||
uint32_t, double);
|
||||
CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
|
||||
uint16_t, double);
|
||||
CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
|
||||
uint8_t, double);
|
||||
CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
|
||||
uint64_t, double);
|
||||
CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
|
||||
uint32_t, double);
|
||||
CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
|
||||
uint16_t, double);
|
||||
CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
|
||||
uint8_t, double);
|
||||
CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
|
||||
uint64_t, double);
|
||||
CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
|
||||
uint32_t, double);
|
||||
CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
|
||||
uint16_t, double);
|
||||
CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
|
||||
uint8_t, double);
|
||||
CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
|
||||
uint64_t, double);
|
||||
CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
|
||||
uint32_t, double);
|
||||
CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
|
||||
uint16_t, double);
|
||||
CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
|
||||
uint8_t, double);
|
||||
|
||||
// Float matrices with all combinations of overhead storage.
|
||||
CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
|
||||
CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
|
||||
CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
|
||||
CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
|
||||
CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
|
||||
CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
|
||||
CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
|
||||
CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
|
||||
CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
|
||||
CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
|
||||
CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
|
||||
CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
|
||||
CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
|
||||
CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
|
||||
CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
|
||||
CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
|
||||
CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
|
||||
uint64_t, float);
|
||||
CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
|
||||
uint32_t, float);
|
||||
CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
|
||||
uint16_t, float);
|
||||
CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
|
||||
uint8_t, float);
|
||||
CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
|
||||
uint64_t, float);
|
||||
CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
|
||||
uint32_t, float);
|
||||
CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
|
||||
uint16_t, float);
|
||||
CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
|
||||
uint8_t, float);
|
||||
CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
|
||||
uint64_t, float);
|
||||
CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
|
||||
uint32_t, float);
|
||||
CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
|
||||
uint16_t, float);
|
||||
CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
|
||||
uint8_t, float);
|
||||
CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
|
||||
uint64_t, float);
|
||||
CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
|
||||
uint32_t, float);
|
||||
CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
|
||||
uint16_t, float);
|
||||
CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
|
||||
uint8_t, float);
|
||||
|
||||
// Integral matrices with same overhead storage.
|
||||
CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t);
|
||||
CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t);
|
||||
CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t);
|
||||
CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t);
|
||||
CASE(kU32, kU32, kI32, uint32_t, uint32_t, int32_t);
|
||||
CASE(kU32, kU32, kI16, uint32_t, uint32_t, int16_t);
|
||||
CASE(kU32, kU32, kI8, uint32_t, uint32_t, int8_t);
|
||||
CASE(kU16, kU16, kI32, uint16_t, uint16_t, int32_t);
|
||||
CASE(kU16, kU16, kI16, uint16_t, uint16_t, int16_t);
|
||||
CASE(kU16, kU16, kI8, uint16_t, uint16_t, int8_t);
|
||||
CASE(kU8, kU8, kI32, uint8_t, uint8_t, int32_t);
|
||||
CASE(kU8, kU8, kI16, uint8_t, uint8_t, int16_t);
|
||||
CASE(kU8, kU8, kI8, uint8_t, uint8_t, int8_t);
|
||||
// Integral matrices with both overheads of the same type.
|
||||
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
|
||||
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
|
||||
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
|
||||
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
|
||||
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
|
||||
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
|
||||
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
|
||||
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
|
||||
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
|
||||
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
|
||||
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
|
||||
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
|
||||
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
|
||||
|
||||
// Unsupported case (add above if needed).
|
||||
fputs("unsupported combination of types\n", stderr);
|
||||
|
@ -830,7 +844,7 @@ void delSparseTensor(void *tensor) {
|
|||
void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
|
||||
double *values, uint64_t *indices) {
|
||||
// Setup all-dims compressed and default ordering.
|
||||
std::vector<uint8_t> sparse(rank, SparseTensorStorageBase::kCompressed);
|
||||
std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
|
||||
std::vector<uint64_t> perm(rank);
|
||||
std::iota(perm.begin(), perm.end(), 0);
|
||||
// Convert external format to internal COO.
|
||||
|
|
|
@ -1707,7 +1707,10 @@ cc_library(
|
|||
cc_library(
|
||||
name = "SparseTensorTransforms",
|
||||
srcs = glob(["lib/Dialect/SparseTensor/Transforms/*.cpp"]),
|
||||
hdrs = ["include/mlir/Dialect/SparseTensor/Transforms/Passes.h"],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/SparseTensor/Transforms/Passes.h",
|
||||
"include/mlir/ExecutionEngine/SparseTensorUtils.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":Affine",
|
||||
|
@ -5391,7 +5394,10 @@ cc_library(
|
|||
"lib/ExecutionEngine/CRunnerUtils.cpp",
|
||||
"lib/ExecutionEngine/SparseTensorUtils.cpp",
|
||||
],
|
||||
hdrs = ["include/mlir/ExecutionEngine/CRunnerUtils.h"],
|
||||
hdrs = [
|
||||
"include/mlir/ExecutionEngine/CRunnerUtils.h",
|
||||
"include/mlir/ExecutionEngine/SparseTensorUtils.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue