[mlir] Add a new builtin DenseResourceElementsAttr

This attributes is intended cover the current set of use cases that abuse
DenseElementsAttr, e.g. when the data is large. Using resources for large
data is one of the major reasons why they were added; e.g. they can be
deallocated mid-compilation, they support a wide variety of data origins
(e.g, heap allocated, mmap'd, etc.), they can support mutation, etc.

I considered at length not having a builtin variant of this, and instead
having multiple versions of this attribute for dialects that are interested,
but they all boiled down to the exact same attribute definition. Given the
generality of this attribute, it feels more aligned to keep it next to DenseArrayAttr
(given that DenseArrayAttr covers the "small" case, and DenseResourcesElementsAttr
covers the "large" case). The underlying infra used to build this attribute is
general, and having a builtin attribute doesn't preclude users from defining
their own when it makes sense (they can even share a blob manager with the
builtin dialect to avoid data duplication).

Differential Revision: https://reviews.llvm.org/D130022
This commit is contained in:
River Riddle 2022-07-19 18:22:55 -07:00
parent 5f58e14b36
commit 995ab92964
17 changed files with 549 additions and 17 deletions

View File

@ -17,8 +17,12 @@
namespace mlir {
class AffineMap;
class AsmResourceBlob;
class BoolAttr;
class BuiltinDialect;
class DenseIntElementsAttr;
template <typename T>
struct DialectResourceBlobHandle;
class FlatSymbolRefAttr;
class FunctionType;
class IntegerSet;
@ -729,6 +733,13 @@ public:
return denseAttr && denseAttr.isSplat();
}
};
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
//===----------------------------------------------------------------------===//
using DenseResourceElementsHandle = DialectResourceBlobHandle<BuiltinDialect>;
} // namespace mlir
//===----------------------------------------------------------------------===//
@ -743,6 +754,9 @@ public:
//===----------------------------------------------------------------------===//
namespace mlir {
//===----------------------------------------------------------------------===//
// DenseArrayAttr
namespace detail {
/// Base class for DenseArrayAttr that is instantiated and specialized for each
/// supported element type below.
@ -795,6 +809,71 @@ using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
namespace detail {
/// Base class for DenseResourceElementsAttr that is instantiated and
/// specialized for each supported element type below.
template <typename T>
class DenseResourceElementsAttrBase : public DenseResourceElementsAttr {
public:
using DenseResourceElementsAttr::DenseResourceElementsAttr;
/// A builder that inserts a new resource using the provided blob. The handle
/// of the inserted blob is used when building the attribute. The provided
/// `blobName` is used as a hint for the key of the new handle for the `blob`
/// resource, but may be changed if necessary to ensure uniqueness during
/// insertion.
static DenseResourceElementsAttrBase<T>
get(ShapedType type, StringRef blobName, AsmResourceBlob blob);
/// Return the data of this attribute as an ArrayRef<T> if it is present,
/// returns None otherwise.
Optional<ArrayRef<T>> tryGetAsArrayRef() const;
/// Support for isa<>/cast<>.
static bool classof(Attribute attr);
};
extern template class DenseResourceElementsAttrBase<bool>;
extern template class DenseResourceElementsAttrBase<int8_t>;
extern template class DenseResourceElementsAttrBase<int16_t>;
extern template class DenseResourceElementsAttrBase<int32_t>;
extern template class DenseResourceElementsAttrBase<int64_t>;
extern template class DenseResourceElementsAttrBase<uint8_t>;
extern template class DenseResourceElementsAttrBase<uint16_t>;
extern template class DenseResourceElementsAttrBase<uint32_t>;
extern template class DenseResourceElementsAttrBase<uint64_t>;
extern template class DenseResourceElementsAttrBase<float>;
extern template class DenseResourceElementsAttrBase<double>;
} // namespace detail
// Public names for all the supported DenseResourceElementsAttr.
using DenseBoolResourceElementsAttr =
detail::DenseResourceElementsAttrBase<bool>;
using DenseI8ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<int8_t>;
using DenseI16ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<int16_t>;
using DenseI32ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<int32_t>;
using DenseI64ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<int64_t>;
using DenseUI8ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<uint8_t>;
using DenseUI16ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<uint16_t>;
using DenseUI32ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<uint32_t>;
using DenseUI64ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<uint64_t>;
using DenseF32ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<float>;
using DenseF64ResourceElementsAttr =
detail::DenseResourceElementsAttrBase<double>;
//===----------------------------------------------------------------------===//
// BoolAttr
//===----------------------------------------------------------------------===//

View File

@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SubElementInterfaces.td"
// TODO: Currently the attributes defined in this file are prefixed with
@ -424,6 +425,65 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
//===----------------------------------------------------------------------===//
def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
ElementsAttrInterface, TypedAttrInterface
]> {
let summary = "An Attribute containing a dense multi-dimensional array "
"backed by a resource";
let description = [{
Syntax:
```
dense-resource-elements-attribute ::=
`dense_resource` `<` resource-handle `>` `:` shaped-type
```
A dense resource elements attribute is an elements attribute backed by a
handle to a builtin dialect resource containing a densely packed array of
values. This class provides the low-level attribute, which should only be
interacted with in very generic terms, actual access to the underlying
resource data is intended to be managed through one of the subclasses, such
as; `DenseBoolResourceElementsAttr`, `DenseUI64ResourceElementsAttr`,
`DenseI32ResourceElementsAttr`, `DenseF32ResourceElementsAttr`,
`DenseF64ResourceElementsAttr`, etc.
Examples:
```mlir
// A tensor referencing a builtin dialect resource, `resource_1`, with two
// unsigned i32 elements.
dense_resource<resource_1> : tensor<2xui32>
```
}];
let parameters = (ins
AttributeSelfTypeParameter<"", "ShapedType">:$type,
ResourceHandleParameter<"DenseResourceElementsHandle">:$rawHandle
);
let builders = [
AttrBuilderWithInferredContext<(ins
"ShapedType":$type, "DenseResourceElementsHandle":$handle
)>
];
let extraClassDeclaration = [{
protected:
/// A builder that inserts a new resource into the builtin dialect's blob
/// manager using the provided blob. The handle of the inserted blob is used
/// when building the attribute. The provided `blobName` is used as a hint
/// for the key of the new handle for the `blob` resource, but may be
/// changed if necessary to ensure uniqueness during insertion.
static DenseResourceElementsAttr get(
ShapedType type, StringRef blobName, AsmResourceBlob blob
);
public:
}];
let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
// DictionaryAttr
//===----------------------------------------------------------------------===//

View File

@ -1023,8 +1023,17 @@ public:
template <typename ResourceT>
FailureOr<ResourceT> parseResourceHandle() {
SMLoc handleLoc = getCurrentLocation();
FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(
getContext()->getOrLoadDialect<typename ResourceT::Dialect>());
// Try to load the dialect that owns the handle.
auto *dialect =
getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
if (!dialect) {
return emitError(handleLoc)
<< "dialect '" << ResourceT::Dialect::getDialectNamespace()
<< "' is unknown";
}
FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
if (failed(handle))
return failure();
if (auto *result = dyn_cast<ResourceT>(&*handle))

View File

@ -460,7 +460,7 @@ public:
/// Parse a handle to a resource within the assembly format.
FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect *dialect) override {
const auto *interface = dyn_cast_or_null<OpAsmDialectInterface>(dialect);
const auto *interface = dyn_cast<OpAsmDialectInterface>(dialect);
if (!interface) {
return parser.emitError() << "dialect '" << dialect->getNamespace()
<< "' does not expect resource handles";

View File

@ -15,9 +15,10 @@
#include "AsmParserImpl.h"
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
@ -97,6 +98,10 @@ Attribute Parser::parseAttribute(Type type) {
case Token::kw_dense:
return parseDenseElementsAttr(type);
// Parse a dense resource elements attribute.
case Token::kw_dense_resource:
return parseDenseResourceElementsAttr(type);
// Parse a dictionary attribute.
case Token::l_brace: {
NamedAttrList elements;
@ -241,6 +246,7 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
case Token::kw_affine_map:
case Token::kw_affine_set:
case Token::kw_dense:
case Token::kw_dense_resource:
case Token::kw_false:
case Token::kw_loc:
case Token::kw_opaque:
@ -928,6 +934,39 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
return literalParser.getAttr(loc, type);
}
Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
auto loc = getToken().getLoc();
consumeToken(Token::kw_dense_resource);
if (parseToken(Token::less, "expected '<' after 'dense_resource'"))
return nullptr;
// Parse the resource handle.
FailureOr<AsmDialectResourceHandle> rawHandle =
parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
if (failed(rawHandle) || parseToken(Token::greater, "expected '>'"))
return nullptr;
auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
if (!handle)
return emitError(loc, "invalid `dense_resource` handle type"), nullptr;
// Parse the type of the attribute if the user didn't provide one.
SMLoc typeLoc = loc;
if (!attrType) {
typeLoc = getToken().getLoc();
if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType()))
return nullptr;
}
ShapedType shapedType = attrType.dyn_cast<ShapedType>();
if (!shapedType) {
emitError(typeLoc, "`dense_resource` expected a shaped type");
return nullptr;
}
return DenseResourceElementsAttr::get(shapedType, *handle);
}
/// Parse an opaque elements attribute.
Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
SMLoc loc = getToken().getLoc();

View File

@ -340,6 +340,17 @@ Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
return entry.second;
}
FailureOr<AsmDialectResourceHandle>
Parser::parseResourceHandle(Dialect *dialect) {
const auto *interface = dyn_cast<OpAsmDialectInterface>(dialect);
if (!interface) {
return emitError() << "dialect '" << dialect->getNamespace()
<< "' does not expect resource handles";
}
StringRef resourceName;
return parseResourceHandle(interface, resourceName);
}
//===----------------------------------------------------------------------===//
// Code Completion

View File

@ -160,6 +160,7 @@ public:
/// Parse a handle to a dialect resource within the assembly format.
FailureOr<AsmDialectResourceHandle>
parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name);
FailureOr<AsmDialectResourceHandle> parseResourceHandle(Dialect *dialect);
//===--------------------------------------------------------------------===//
// Type Parsing
@ -272,6 +273,9 @@ public:
Attribute parseDenseElementsAttr(Type attrType);
ShapedType parseElementsLiteralType(Type type);
/// Parse a dense resource elements attribute.
Attribute parseDenseResourceElementsAttr(Type attrType);
/// Parse a DenseArrayAttr.
Attribute parseDenseArrayAttr();

View File

@ -87,6 +87,7 @@ TOK_KEYWORD(bf16)
TOK_KEYWORD(ceildiv)
TOK_KEYWORD(complex)
TOK_KEYWORD(dense)
TOK_KEYWORD(dense_resource)
TOK_KEYWORD(f16)
TOK_KEYWORD(f32)
TOK_KEYWORD(f64)

View File

@ -20,6 +20,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpImplementation.h"
@ -1896,6 +1897,10 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
os << " ";
denseArrayAttr.printWithoutBraces(os);
os << "]";
} else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
os << "dense_resource<";
printResourceHandle(resourceAttr.getRawHandle());
os << ">";
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
printLocation(locAttr);
} else {

View File

@ -11,6 +11,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
@ -36,11 +37,10 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//
void BuiltinDialect::registerAttributes() {
addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
DenseIntOrFPElementsAttr, DenseStringElementsAttr,
DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/IR/BuiltinAttributes.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
@ -1576,6 +1576,130 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
return false;
}
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
//===----------------------------------------------------------------------===//
DenseResourceElementsAttr
DenseResourceElementsAttr::get(ShapedType type,
DenseResourceElementsHandle handle) {
return Base::get(type.getContext(), type, handle);
}
DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
StringRef blobName,
AsmResourceBlob blob) {
// Extract the builtin dialect resource manager from context and construct a
// handle by inserting a new resource using the provided blob.
auto &manager =
DenseResourceElementsHandle::getManagerInterface(type.getContext());
return get(type, manager.insert(blobName, std::move(blob)));
}
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttrBase
namespace {
/// Instantiations of this class provide utilities for interacting with native
/// data types in the context of DenseResourceElementsAttr.
template <typename T>
struct DenseResourceAttrUtil;
template <size_t width, bool isSigned>
struct DenseResourceElementsAttrIntUtil {
static bool checkElementType(Type eltType) {
IntegerType type = eltType.dyn_cast<IntegerType>();
if (!type || type.getWidth() != width)
return false;
return isSigned ? !type.isUnsigned() : !type.isSigned();
}
};
template <>
struct DenseResourceAttrUtil<bool> {
static bool checkElementType(Type eltType) {
return eltType.isSignlessInteger(1);
}
};
template <>
struct DenseResourceAttrUtil<int8_t>
: public DenseResourceElementsAttrIntUtil<8, true> {};
template <>
struct DenseResourceAttrUtil<uint8_t>
: public DenseResourceElementsAttrIntUtil<8, false> {};
template <>
struct DenseResourceAttrUtil<int16_t>
: public DenseResourceElementsAttrIntUtil<16, true> {};
template <>
struct DenseResourceAttrUtil<uint16_t>
: public DenseResourceElementsAttrIntUtil<16, false> {};
template <>
struct DenseResourceAttrUtil<int32_t>
: public DenseResourceElementsAttrIntUtil<32, true> {};
template <>
struct DenseResourceAttrUtil<uint32_t>
: public DenseResourceElementsAttrIntUtil<32, false> {};
template <>
struct DenseResourceAttrUtil<int64_t>
: public DenseResourceElementsAttrIntUtil<64, true> {};
template <>
struct DenseResourceAttrUtil<uint64_t>
: public DenseResourceElementsAttrIntUtil<64, false> {};
template <>
struct DenseResourceAttrUtil<float> {
static bool checkElementType(Type eltType) { return eltType.isF32(); }
};
template <>
struct DenseResourceAttrUtil<double> {
static bool checkElementType(Type eltType) { return eltType.isF64(); }
};
} // namespace
template <typename T>
DenseResourceElementsAttrBase<T>
DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
AsmResourceBlob blob) {
// Check that the blob is in the form we were expecting.
assert(blob.getDataAlignment() == alignof(T) &&
"alignment mismatch between expected alignment and blob alignment");
assert(((blob.getData().size() % sizeof(T)) == 0) &&
"size mismatch between expected element width and blob size");
assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&
"invalid shape element type for provided type `T`");
return DenseResourceElementsAttr::get(type, blobName, std::move(blob))
.template cast<DenseResourceElementsAttrBase<T>>();
}
template <typename T>
Optional<ArrayRef<T>>
DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
return blob->template getDataAs<T>();
return llvm::None;
}
template <typename T>
bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>();
return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
resourceAttr.getElementType());
}
namespace mlir {
namespace detail {
// Explicit instantiation for all the supported DenseResourceElementsAttr.
template class DenseResourceElementsAttrBase<bool>;
template class DenseResourceElementsAttrBase<int8_t>;
template class DenseResourceElementsAttrBase<int16_t>;
template class DenseResourceElementsAttrBase<int32_t>;
template class DenseResourceElementsAttrBase<int64_t>;
template class DenseResourceElementsAttrBase<uint8_t>;
template class DenseResourceElementsAttrBase<uint16_t>;
template class DenseResourceElementsAttrBase<uint32_t>;
template class DenseResourceElementsAttrBase<uint64_t>;
template class DenseResourceElementsAttrBase<float>;
template class DenseResourceElementsAttrBase<double>;
} // namespace detail
} // namespace mlir
//===----------------------------------------------------------------------===//
// OpaqueElementsAttr
//===----------------------------------------------------------------------===//

View File

@ -16,6 +16,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
@ -23,14 +24,27 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
// Builtin Dialect
// TableGen'erated dialect
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinDialect.cpp.inc"
//===----------------------------------------------------------------------===//
// BuiltinBlobManagerInterface
//===----------------------------------------------------------------------===//
using BuiltinBlobManagerInterface =
ResourceBlobManagerDialectInterfaceBase<DenseResourceElementsHandle>;
//===----------------------------------------------------------------------===//
// BuiltinOpAsmDialectInterface
//===----------------------------------------------------------------------===//
namespace {
struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
BuiltinOpAsmDialectInterface(Dialect *dialect,
BuiltinBlobManagerInterface &mgr)
: OpAsmDialectInterface(dialect), blobManager(mgr) {}
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (attr.isa<AffineMapAttr>()) {
@ -57,6 +71,38 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
}
return AliasResult::NoAlias;
}
//===------------------------------------------------------------------===//
// Resources
//===------------------------------------------------------------------===//
std::string
getResourceKey(const AsmDialectResourceHandle &handle) const override {
return cast<DenseResourceElementsHandle>(handle).getKey().str();
}
FailureOr<AsmDialectResourceHandle>
declareResource(StringRef key) const final {
return blobManager.insert(key);
}
LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
if (failed(blob))
return failure();
// Update the blob for this entry.
blobManager.update(entry.getKey(), std::move(*blob));
return success();
}
void
buildResources(Operation *op,
const SetVector<AsmDialectResourceHandle> &referencedResources,
AsmResourceBuilder &provider) const final {
blobManager.buildResources(provider, referencedResources.getArrayRef());
}
private:
/// The blob manager for the dialect.
BuiltinBlobManagerInterface &blobManager;
};
} // namespace
@ -68,7 +114,9 @@ void BuiltinDialect::initialize() {
#define GET_OP_LIST
#include "mlir/IR/BuiltinOps.cpp.inc"
>();
addInterfaces<BuiltinOpAsmDialectInterface>();
auto &blobInterface = addInterface<BuiltinBlobManagerInterface>();
addInterface<BuiltinOpAsmDialectInterface>(blobInterface);
}
//===----------------------------------------------------------------------===//

View File

@ -57,7 +57,7 @@ auto DialectResourceBlobManager::insert(StringRef name,
Twine(nameCounter++).toVector(nameStorage);
// Try inserting with the new name.
if (BlobEntry *entry = tryInsertion(name))
if (BlobEntry *entry = tryInsertion(nameStorage))
return *entry;
nameStorage.resize(name.size() + 1);
} while (true);

View File

@ -712,8 +712,9 @@ public:
/// Signal a completion for an attribute.
void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
appendSimpleCompletions({"affine_set", "affine_map", "dense", "false",
"loc", "opaque", "sparse", "true", "unit"},
appendSimpleCompletions({"affine_set", "affine_map", "dense",
"dense_resource", "false", "loc", "opaque",
"sparse", "true", "unit"},
lsp::CompletionItemKind::Field,
/*sortText=*/"1");

View File

@ -0,0 +1,13 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -verify-diagnostics -split-input-file | FileCheck %s
// CHECK: attr = dense_resource<blob1> : tensor<3xi64>
"test.user_op"() {attr = dense_resource<blob1> : tensor<3xi64> } : () -> ()
{-#
dialect_resources: {
builtin: {
// CHECK: blob1: "0x08000000010000000000000002000000000000000300000000000000"
blob1: "0x08000000010000000000000002000000000000000300000000000000"
}
}
#-}

View File

@ -519,3 +519,23 @@ func.func @duplicate_dictionary_attr_key() {
"J// -----
" // expected-error {{expected}}
// -----
// expected-error@+1 {{expected '<' after 'dense_resource'}}
#attr = dense_resource>
// -----
// expected-error@+1 {{expected '>'}}
#attr = dense_resource<resource
// -----
// expected-error@+1 {{expected ':'}}
#attr = dense_resource<resource>
// -----
// expected-error@+1 {{`dense_resource` expected a shaped type}}
#attr = dense_resource<resource> : i32

View File

@ -59,10 +59,10 @@
// -----
// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'builtin'}}
// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'ml_program'}}
{-#
dialect_resources: {
builtin: {
ml_program: {
unknown_entry: "foo"
}
}

View File

@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "gtest/gtest.h"
@ -13,6 +15,10 @@
using namespace mlir;
using namespace mlir::detail;
//===----------------------------------------------------------------------===//
// DenseElementsAttr
//===----------------------------------------------------------------------===//
template <typename EltTy>
static void testSplat(Type eltType, const EltTy &splatElt) {
RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
@ -203,7 +209,119 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
}
} // namespace
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
//===----------------------------------------------------------------------===//
template <typename AttrT, typename T>
static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
Type elementType) {
auto type = RankedTensorType::get(data.size(), elementType);
auto attr =
AttrT::get(type, "resource", UnmanagedAsmResourceBlob::allocate(data));
// Check that we can access and iterate the data properly.
Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
EXPECT_TRUE(attrData.hasValue());
EXPECT_EQ(*attrData, data);
// Check that we cast to this attribute when possible.
Attribute genericAttr = attr;
EXPECT_TRUE(genericAttr.template isa<AttrT>());
}
template <typename AttrT, typename T>
static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
T data[] = {0, 1, 2};
checkNativeAccess<AttrT, T>(builder.getContext(), llvm::makeArrayRef(data),
builder.getIntegerType(intWidth));
}
namespace {
TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
MLIRContext context;
Builder builder(&context);
// Bool
bool boolData[] = {true, false, true};
checkNativeAccess<DenseBoolResourceElementsAttr>(
&context, llvm::makeArrayRef(boolData), builder.getI1Type());
// Unsigned integers
checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
// Signed integers
checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
// Float
float floatData[] = {0, 1, 2};
checkNativeAccess<DenseF32ResourceElementsAttr>(
&context, llvm::makeArrayRef(floatData), builder.getF32Type());
// Double
double doubleData[] = {0, 1, 2};
checkNativeAccess<DenseF64ResourceElementsAttr>(
&context, llvm::makeArrayRef(doubleData), builder.getF64Type());
}
TEST(DenseResourceElementsAttrTest, CheckNoCast) {
MLIRContext context;
Builder builder(&context);
// Create a i32 attribute.
ArrayRef<uint32_t> data;
auto type = RankedTensorType::get(data.size(), builder.getI32Type());
Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
type, "resource", UnmanagedAsmResourceBlob::allocate(data));
EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>());
EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>());
EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>());
}
TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
MLIRContext context;
Builder builder(&context);
// Create a bool attribute with data of the incorrect type.
ArrayRef<uint32_t> data;
auto type = RankedTensorType::get(data.size(), builder.getI32Type());
ASSERT_DEATH(
{
DenseBoolResourceElementsAttr::get(
type, "resource", UnmanagedAsmResourceBlob::allocate(data));
},
"alignment mismatch between expected alignment and blob alignment");
}
TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
MLIRContext context;
Builder builder(&context);
// Create a bool attribute with incorrect type.
ArrayRef<bool> data;
auto type = RankedTensorType::get(data.size(), builder.getI32Type());
ASSERT_DEATH(
{
DenseBoolResourceElementsAttr::get(
type, "resource", UnmanagedAsmResourceBlob::allocate(data));
},
"invalid shape element type for provided type `T`");
}
} // namespace
//===----------------------------------------------------------------------===//
// SparseElementsAttr
//===----------------------------------------------------------------------===//
namespace {
TEST(SparseElementsAttrTest, GetZero) {
MLIRContext context;
context.allowUnregisteredDialects();