[mlir] Remove types from attributes
This patch removes the `type` field from `Attribute` along with the `Attribute::getType` accessor. Going forward, this means that attributes in MLIR will no longer have types as a first-class concept. This patch lays the groundwork to incrementally remove or refactor code that relies on generic attributes being typed. The immediate impact will be on attributes that rely on `Attribute` containing a type, such as `IntegerAttr`, `DenseElementsAttr`, and `ml_program::ExternAttr`, which will now need to define a type parameter on their storage classes. This will save memory as all other attribute kinds will no longer contain a type. Moreover, it will not be possible to generically query the type of an attribute directly. This patch provides an attribute interface `TypedAttr` that implements only one method, `getType`, which can be used to generically query the types of attributes that implement the interface. This interface can be used to retain the concept of a "typed attribute". The ODS-generated accessor for a `type` parameter automatically implements this method. Next steps will be to refactor the assembly formats of certain operations that rely on `parseAttribute(type)` and `printAttributeWithoutType` to remove special handling of type elision until `type` can be removed from the dialect parsing hook entirely; and incrementally remove uses of `TypedAttr`. Reviewed By: lattner, rriddle, jpienaar Differential Revision: https://reviews.llvm.org/D130092
This commit is contained in:
parent
af1328ef45
commit
e179532284
|
@ -126,9 +126,9 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
|
||||||
An integer attribute is a literal attribute that represents an integral
|
An integer attribute is a literal attribute that represents an integral
|
||||||
value of the specified integer type.
|
value of the specified integer type.
|
||||||
}];
|
}];
|
||||||
/// Here we've defined two parameters, one is the `self` type of the attribute
|
/// Here we've defined two parameters, one is a "self" type parameter, and the
|
||||||
/// (i.e. the type of the Attribute itself), and the other is the integer value
|
/// other is the integer value of the attribute. The self type parameter is
|
||||||
/// of the attribute.
|
/// specially handled by the assembly format.
|
||||||
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);
|
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);
|
||||||
|
|
||||||
/// Here we've defined a custom builder for the type, that removes the need to pass
|
/// Here we've defined a custom builder for the type, that removes the need to pass
|
||||||
|
@ -146,6 +146,8 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
|
||||||
///
|
///
|
||||||
/// #my.int<50> : !my.int<32> // a 32-bit integer of value 50.
|
/// #my.int<50> : !my.int<32> // a 32-bit integer of value 50.
|
||||||
///
|
///
|
||||||
|
/// Note that the self type parameter is not included in the assembly format.
|
||||||
|
/// Its value is derived from the optional trailing type on all attributes.
|
||||||
let assemblyFormat = "`<` $value `>`";
|
let assemblyFormat = "`<` $value `>`";
|
||||||
|
|
||||||
/// Indicate that our attribute will add additional verification to the parameters.
|
/// Indicate that our attribute will add additional verification to the parameters.
|
||||||
|
@ -271,9 +273,8 @@ MLIR includes several specialized classes for common situations:
|
||||||
- `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays of
|
- `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays of
|
||||||
objects which self-allocate as per the last specialization.
|
objects which self-allocate as per the last specialization.
|
||||||
|
|
||||||
- `AttributeSelfTypeParameter` is a special AttrParameter that corresponds to
|
- `AttributeSelfTypeParameter` is a special `AttrParameter` that represents
|
||||||
the `Type` of the attribute. Only one parameter of the attribute may be of
|
parameters derived from the optional trailing type on attributes.
|
||||||
this parameter type.
|
|
||||||
|
|
||||||
### Traits
|
### Traits
|
||||||
|
|
||||||
|
@ -702,6 +703,54 @@ available through `$_ctxt`. E.g.
|
||||||
DefaultValuedParameter<"IntegerType", "IntegerType::get($_ctxt, 32)">
|
DefaultValuedParameter<"IntegerType", "IntegerType::get($_ctxt, 32)">
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The value of parameters that appear __before__ the default-valued parameter in
|
||||||
|
the parameter declaration list are available as substitutions. E.g.
|
||||||
|
|
||||||
|
```tablegen
|
||||||
|
let parameters = (ins
|
||||||
|
"IntegerAttr":$value,
|
||||||
|
DefaultValuedParameter<"Type", "$value.getType()">:$type
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
###### Attribute Self Type Parameter
|
||||||
|
|
||||||
|
An attribute optionally has a trailing type after the assembly format of the
|
||||||
|
attribute value itself. MLIR parses over the attribute value and optionally
|
||||||
|
parses a colon-type before passing the `Type` into the dialect parser hook.
|
||||||
|
|
||||||
|
```
|
||||||
|
dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
|
||||||
|
(`:` type)?
|
||||||
|
| `#` alias-name pretty-dialect-sym-body? (`:` type)?
|
||||||
|
```
|
||||||
|
|
||||||
|
`AttributeSelfTypeParameter` is an attribute parameter specially handled by the
|
||||||
|
assembly format generator. Only one such parameter can be specified, and its
|
||||||
|
value is derived from the trailing type. This parameter's default value is
|
||||||
|
`NoneType::get($_ctxt)`.
|
||||||
|
|
||||||
|
In order for the type to be printed by
|
||||||
|
MLIR, however, the attribute must implement `TypedAttrInterface`. For example,
|
||||||
|
|
||||||
|
```tablegen
|
||||||
|
// This attribute has only a self type parameter.
|
||||||
|
def MyExternAttr : AttrDef<MyDialect, "MyExtern", [TypedAttrInterface]> {
|
||||||
|
let parameters = (AttributeSelfTypeParameter<"">:$type);
|
||||||
|
let mnemonic = "extern";
|
||||||
|
let assemblyFormat = "";
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This attribute can look like:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
#my_dialect.extern // none
|
||||||
|
#my_dialect.extern : i32
|
||||||
|
#my_dialect.extern : tensor<4xi32>
|
||||||
|
#my_dialect.extern : !my_dialect.my_type
|
||||||
|
```
|
||||||
|
|
||||||
##### Assembly Format Directives
|
##### Assembly Format Directives
|
||||||
|
|
||||||
Attribute and type assembly formats have the following directives:
|
Attribute and type assembly formats have the following directives:
|
||||||
|
|
|
@ -15,6 +15,7 @@ include "mlir/Interfaces/InferIntRangeInterface.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/VectorInterfaces.td"
|
include "mlir/Interfaces/VectorInterfaces.td"
|
||||||
|
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||||
include "mlir/IR/OpAsmInterface.td"
|
include "mlir/IR/OpAsmInterface.td"
|
||||||
|
|
||||||
// Base class for Arithmetic dialect ops. Ops in this dialect have no side
|
// Base class for Arithmetic dialect ops. Ops in this dialect have no side
|
||||||
|
@ -147,7 +148,7 @@ def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyAttr:$value);
|
let arguments = (ins TypedAttrInterface:$value);
|
||||||
// TODO: Disallow arith.constant to return anything other than a signless
|
// TODO: Disallow arith.constant to return anything other than a signless
|
||||||
// integer or float like. Downstream users of Arithmetic should only be
|
// integer or float like. Downstream users of Arithmetic should only be
|
||||||
// working with signless integers, floats, or vectors/tensors thereof.
|
// working with signless integers, floats, or vectors/tensors thereof.
|
||||||
|
|
|
@ -32,12 +32,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
|
||||||
assert(operands.size() == 2 && "binary op takes two operands");
|
assert(operands.size() == 2 && "binary op takes two operands");
|
||||||
if (!operands[0] || !operands[1])
|
if (!operands[0] || !operands[1])
|
||||||
return {};
|
return {};
|
||||||
if (operands[0].getType() != operands[1].getType())
|
|
||||||
return {};
|
|
||||||
|
|
||||||
if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
|
if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
|
||||||
auto lhs = operands[0].cast<AttrElementT>();
|
auto lhs = operands[0].cast<AttrElementT>();
|
||||||
auto rhs = operands[1].cast<AttrElementT>();
|
auto rhs = operands[1].cast<AttrElementT>();
|
||||||
|
if (lhs.getType() != rhs.getType())
|
||||||
|
return {};
|
||||||
|
|
||||||
auto calRes = calculate(lhs.getValue(), rhs.getValue());
|
auto calRes = calculate(lhs.getValue(), rhs.getValue());
|
||||||
|
|
||||||
|
@ -53,6 +53,8 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
|
||||||
// just fold based on the splat value.
|
// just fold based on the splat value.
|
||||||
auto lhs = operands[0].cast<SplatElementsAttr>();
|
auto lhs = operands[0].cast<SplatElementsAttr>();
|
||||||
auto rhs = operands[1].cast<SplatElementsAttr>();
|
auto rhs = operands[1].cast<SplatElementsAttr>();
|
||||||
|
if (lhs.getType() != rhs.getType())
|
||||||
|
return {};
|
||||||
|
|
||||||
auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
|
auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
|
||||||
rhs.getSplatValue<ElementValueT>());
|
rhs.getSplatValue<ElementValueT>());
|
||||||
|
@ -66,6 +68,8 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
|
||||||
// expanding the values.
|
// expanding the values.
|
||||||
auto lhs = operands[0].cast<ElementsAttr>();
|
auto lhs = operands[0].cast<ElementsAttr>();
|
||||||
auto rhs = operands[1].cast<ElementsAttr>();
|
auto rhs = operands[1].cast<ElementsAttr>();
|
||||||
|
if (lhs.getType() != rhs.getType())
|
||||||
|
return {};
|
||||||
|
|
||||||
auto lhsIt = lhs.value_begin<ElementValueT>();
|
auto lhsIt = lhs.value_begin<ElementValueT>();
|
||||||
auto rhsIt = rhs.value_begin<ElementValueT>();
|
auto rhsIt = rhs.value_begin<ElementValueT>();
|
||||||
|
|
|
@ -10,18 +10,21 @@
|
||||||
#define COMPLEX_ATTRIBUTE
|
#define COMPLEX_ATTRIBUTE
|
||||||
|
|
||||||
include "mlir/IR/AttrTypeBase.td"
|
include "mlir/IR/AttrTypeBase.td"
|
||||||
|
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||||
include "mlir/Dialect/Complex/IR/ComplexBase.td"
|
include "mlir/Dialect/Complex/IR/ComplexBase.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Complex Attributes.
|
// Complex Attributes.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class Complex_Attr<string attrName, string attrMnemonic>
|
class Complex_Attr<string attrName, string attrMnemonic,
|
||||||
: AttrDef<Complex_Dialect, attrName> {
|
list<Trait> traits = []>
|
||||||
|
: AttrDef<Complex_Dialect, attrName, traits> {
|
||||||
let mnemonic = attrMnemonic;
|
let mnemonic = attrMnemonic;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Complex_NumberAttr : Complex_Attr<"Number", "number"> {
|
def Complex_NumberAttr : Complex_Attr<"Number", "number",
|
||||||
|
[TypedAttrInterface]> {
|
||||||
let summary = "A complex number attribute";
|
let summary = "A complex number attribute";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -139,7 +139,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyAttr:$value);
|
let arguments = (ins TypedAttrInterface:$value);
|
||||||
let results = (outs AnyType);
|
let results = (outs AnyType);
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
@ -212,7 +212,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyAttr:$value);
|
let arguments = (ins TypedAttrInterface:$value);
|
||||||
let results = (outs AnyType);
|
let results = (outs AnyType);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
|
|
@ -14,18 +14,19 @@
|
||||||
#define MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES
|
#define MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES
|
||||||
|
|
||||||
include "mlir/IR/AttrTypeBase.td"
|
include "mlir/IR/AttrTypeBase.td"
|
||||||
|
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||||
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
|
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// EmitC attribute definitions
|
// EmitC attribute definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class EmitC_Attr<string name, string attrMnemonic>
|
class EmitC_Attr<string name, string attrMnemonic, list<Trait> traits = []>
|
||||||
: AttrDef<EmitC_Dialect, name> {
|
: AttrDef<EmitC_Dialect, name, traits> {
|
||||||
let mnemonic = attrMnemonic;
|
let mnemonic = attrMnemonic;
|
||||||
}
|
}
|
||||||
|
|
||||||
def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
|
def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque", [TypedAttrInterface]> {
|
||||||
let summary = "An opaque attribute";
|
let summary = "An opaque attribute";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -40,8 +41,9 @@ def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let parameters = (ins StringRefParameter<"the opaque value">:$value);
|
let parameters = (ins "Type":$type,
|
||||||
|
StringRefParameter<"the opaque value">:$value);
|
||||||
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_
|
#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_
|
||||||
|
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Tablegen Attribute Declarations
|
// Tablegen Attribute Declarations
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#define MLPROGRAM_ATTRIBUTES
|
#define MLPROGRAM_ATTRIBUTES
|
||||||
|
|
||||||
include "mlir/IR/AttrTypeBase.td"
|
include "mlir/IR/AttrTypeBase.td"
|
||||||
|
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||||
include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
|
include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
|
||||||
|
|
||||||
// Base class for MLProgram dialect attributes.
|
// Base class for MLProgram dialect attributes.
|
||||||
|
@ -22,7 +23,7 @@ class MLProgram_Attr<string name, list<Trait> traits = []>
|
||||||
// ExternAttr
|
// ExternAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def MLProgram_ExternAttr : MLProgram_Attr<"Extern"> {
|
def MLProgram_ExternAttr : MLProgram_Attr<"Extern", [TypedAttrInterface]> {
|
||||||
let summary = "Value used for a global signalling external resolution";
|
let summary = "Value used for a global signalling external resolution";
|
||||||
let description = [{
|
let description = [{
|
||||||
When used as the value for a GlobalOp, this indicates that the actual
|
When used as the value for a GlobalOp, this indicates that the actual
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#define MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS
|
#define MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS
|
||||||
|
|
||||||
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
|
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
|
||||||
|
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||||
include "mlir/IR/FunctionInterfaces.td"
|
include "mlir/IR/FunctionInterfaces.td"
|
||||||
include "mlir/IR/OpAsmInterface.td"
|
include "mlir/IR/OpAsmInterface.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
@ -600,7 +601,7 @@ def SPV_SpecConstantOp : SPV_Op<"SpecConstant", [InModuleScope, Symbol]> {
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
StrAttr:$sym_name,
|
StrAttr:$sym_name,
|
||||||
AnyAttr:$default_value
|
TypedAttrInterface:$default_value
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs);
|
let results = (outs);
|
||||||
|
|
|
@ -257,14 +257,6 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
|
||||||
let convertFromStorage = "$_self.cast<" # dialect.cppNamespace #
|
let convertFromStorage = "$_self.cast<" # dialect.cppNamespace #
|
||||||
"::" # cppClassName # ">()";
|
"::" # cppClassName # ">()";
|
||||||
|
|
||||||
// A code block used to build the value 'Type' of an Attribute when
|
|
||||||
// initializing its storage instance. This field is optional, and if not
|
|
||||||
// present the attribute will have its value type set to `NoneType`. This code
|
|
||||||
// block may reference any of the attributes parameters via
|
|
||||||
// `$_<parameter-name`. If one of the parameters of the attribute is of type
|
|
||||||
// `AttributeSelfTypeParameter`, this field is ignored.
|
|
||||||
code typeBuilder = ?;
|
|
||||||
|
|
||||||
// The predicate for when this def is used as a constraint.
|
// The predicate for when this def is used as a constraint.
|
||||||
let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
|
let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
|
||||||
"::" # cppClassName # ">()">;
|
"::" # cppClassName # ">()">;
|
||||||
|
@ -334,7 +326,7 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
|
||||||
// which by default is the C++ equality operator. The current MLIR context is
|
// which by default is the C++ equality operator. The current MLIR context is
|
||||||
// made available through `$_ctxt`, e.g., for constructing default values for
|
// made available through `$_ctxt`, e.g., for constructing default values for
|
||||||
// attributes and types.
|
// attributes and types.
|
||||||
string defaultValue = ?;
|
string defaultValue = "";
|
||||||
}
|
}
|
||||||
class AttrParameter<string type, string desc, string accessorType = "">
|
class AttrParameter<string type, string desc, string accessorType = "">
|
||||||
: AttrOrTypeParameter<type, desc, accessorType>;
|
: AttrOrTypeParameter<type, desc, accessorType>;
|
||||||
|
@ -392,11 +384,21 @@ class DefaultValuedParameter<string type, string value, string desc = ""> :
|
||||||
let defaultValue = value;
|
let defaultValue = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a special parameter used for AttrDefs that represents a `mlir::Type`
|
// This is a special attribute parameter that represents the "self" type of the
|
||||||
// that is also used as the value `Type` of the attribute. Only one parameter
|
// attribute. It is specially handled by the assembly format generator to derive
|
||||||
// of the attribute may be of this type.
|
// its value from the optional trailing type after each attribute.
|
||||||
|
//
|
||||||
|
// By default, the self type parameter is optional and has a default value of
|
||||||
|
// `none`. If a derived type other than `::mlir::Type` is specified, the
|
||||||
|
// parameter loses its default value unless another one is specified by
|
||||||
|
// `typeBuilder`.
|
||||||
class AttributeSelfTypeParameter<string desc,
|
class AttributeSelfTypeParameter<string desc,
|
||||||
string derivedType = "::mlir::Type"> :
|
string derivedType = "::mlir::Type",
|
||||||
AttrOrTypeParameter<derivedType, desc> {}
|
string typeBuilder = ""> :
|
||||||
|
AttrOrTypeParameter<derivedType, desc> {
|
||||||
|
let defaultValue = !if(!and(!empty(typeBuilder),
|
||||||
|
!eq(derivedType, "::mlir::Type")),
|
||||||
|
"::mlir::NoneType::get($_ctxt)", typeBuilder);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // ATTRTYPEBASE_TD
|
#endif // ATTRTYPEBASE_TD
|
||||||
|
|
|
@ -129,9 +129,6 @@ class alignas(8) AttributeStorage : public StorageUniquer::BaseStorage {
|
||||||
friend StorageUniquer;
|
friend StorageUniquer;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// Get the type of this attribute.
|
|
||||||
Type getType() const { return type; }
|
|
||||||
|
|
||||||
/// Return the abstract descriptor for this attribute.
|
/// Return the abstract descriptor for this attribute.
|
||||||
const AbstractAttribute &getAbstractAttribute() const {
|
const AbstractAttribute &getAbstractAttribute() const {
|
||||||
assert(abstractAttribute && "Malformed attribute storage object.");
|
assert(abstractAttribute && "Malformed attribute storage object.");
|
||||||
|
@ -139,15 +136,6 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// Construct a new attribute storage instance with the given type.
|
|
||||||
/// Note: All attributes require a valid type. If no type is provided here,
|
|
||||||
/// the type of the attribute will automatically default to NoneType
|
|
||||||
/// upon initialization in the uniquer.
|
|
||||||
AttributeStorage(Type type = nullptr) : type(type) {}
|
|
||||||
|
|
||||||
/// Set the type of this attribute.
|
|
||||||
void setType(Type newType) { type = newType; }
|
|
||||||
|
|
||||||
/// Set the abstract attribute for this storage instance. This is used by the
|
/// Set the abstract attribute for this storage instance. This is used by the
|
||||||
/// AttributeUniquer when initializing a newly constructed storage object.
|
/// AttributeUniquer when initializing a newly constructed storage object.
|
||||||
void initializeAbstractAttribute(const AbstractAttribute &abstractAttr) {
|
void initializeAbstractAttribute(const AbstractAttribute &abstractAttr) {
|
||||||
|
@ -159,9 +147,6 @@ protected:
|
||||||
void initialize(MLIRContext *context) {}
|
void initialize(MLIRContext *context) {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// The type of the attribute value.
|
|
||||||
Type type;
|
|
||||||
|
|
||||||
/// The abstract descriptor for this attribute.
|
/// The abstract descriptor for this attribute.
|
||||||
const AbstractAttribute *abstractAttribute = nullptr;
|
const AbstractAttribute *abstractAttribute = nullptr;
|
||||||
};
|
};
|
||||||
|
|
|
@ -66,9 +66,6 @@ public:
|
||||||
/// to support dynamic type casting.
|
/// to support dynamic type casting.
|
||||||
TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
|
TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
|
||||||
|
|
||||||
/// Return the type of this attribute.
|
|
||||||
Type getType() const { return impl->getType(); }
|
|
||||||
|
|
||||||
/// Return the context this attribute belongs to.
|
/// Return the context this attribute belongs to.
|
||||||
MLIRContext *getContext() const;
|
MLIRContext *getContext() const;
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
|
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "llvm/ADT/Any.h"
|
#include "llvm/ADT/Any.h"
|
||||||
|
@ -18,7 +19,6 @@
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class ShapedType;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ElementsAttr
|
// ElementsAttr
|
||||||
|
@ -237,10 +237,10 @@ class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
|
||||||
public:
|
public:
|
||||||
using reference = typename IteratorT::reference;
|
using reference = typename IteratorT::reference;
|
||||||
|
|
||||||
ElementsAttrRange(Type shapeType,
|
ElementsAttrRange(ShapedType shapeType,
|
||||||
const llvm::iterator_range<IteratorT> &range)
|
const llvm::iterator_range<IteratorT> &range)
|
||||||
: llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
|
: llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
|
||||||
ElementsAttrRange(Type shapeType, IteratorT beginIt, IteratorT endIt)
|
ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
|
||||||
: ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
|
: ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
|
||||||
|
|
||||||
/// Return the value at the given index.
|
/// Return the value at the given index.
|
||||||
|
@ -254,7 +254,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// The shaped type of the parent ElementsAttr.
|
/// The shaped type of the parent ElementsAttr.
|
||||||
Type shapeType;
|
ShapedType shapeType;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
|
@ -154,7 +154,10 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||||
}], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{
|
}], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{
|
||||||
// By default, only check for a single element splat.
|
// By default, only check for a single element splat.
|
||||||
return $_attr.getNumElements() == 1;
|
return $_attr.getNumElements() == 1;
|
||||||
}]>
|
}]>,
|
||||||
|
InterfaceMethod<[{
|
||||||
|
Returns the shaped type of the elements attribute.
|
||||||
|
}], "::mlir::ShapedType", "getType">
|
||||||
];
|
];
|
||||||
|
|
||||||
string ElementsAttrInterfaceAccessors = [{
|
string ElementsAttrInterfaceAccessors = [{
|
||||||
|
@ -280,7 +283,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||||
auto getValues() const {
|
auto getValues() const {
|
||||||
auto beginIt = $_attr.template value_begin<T>();
|
auto beginIt = $_attr.template value_begin<T>();
|
||||||
return detail::ElementsAttrRange<decltype(beginIt)>(
|
return detail::ElementsAttrRange<decltype(beginIt)>(
|
||||||
Attribute($_attr).getType(), beginIt, std::next(beginIt, size()));
|
$_attr.getType(), beginIt, std::next(beginIt, size()));
|
||||||
}
|
}
|
||||||
}] # ElementsAttrInterfaceAccessors;
|
}] # ElementsAttrInterfaceAccessors;
|
||||||
|
|
||||||
|
@ -294,19 +297,17 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||||
// Accessors
|
// Accessors
|
||||||
//===------------------------------------------------------------------===//
|
//===------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Return the type of this attribute.
|
|
||||||
ShapedType getType() const;
|
|
||||||
|
|
||||||
/// Return the element type of this ElementsAttr.
|
/// Return the element type of this ElementsAttr.
|
||||||
Type getElementType() const { return getElementType(*this); }
|
Type getElementType() const { return getElementType(*this); }
|
||||||
static Type getElementType(Attribute elementsAttr);
|
static Type getElementType(ElementsAttr elementsAttr);
|
||||||
|
|
||||||
/// Return if the given 'index' refers to a valid element in this attribute.
|
/// Return if the given 'index' refers to a valid element in this attribute.
|
||||||
bool isValidIndex(ArrayRef<uint64_t> index) const {
|
bool isValidIndex(ArrayRef<uint64_t> index) const {
|
||||||
return isValidIndex(*this, index);
|
return isValidIndex(*this, index);
|
||||||
}
|
}
|
||||||
static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index);
|
static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index);
|
||||||
static bool isValidIndex(Attribute elementsAttr, ArrayRef<uint64_t> index);
|
static bool isValidIndex(ElementsAttr elementsAttr,
|
||||||
|
ArrayRef<uint64_t> index);
|
||||||
|
|
||||||
/// Return the 1 dimensional flattened row-major index from the given
|
/// Return the 1 dimensional flattened row-major index from the given
|
||||||
/// multi-dimensional index.
|
/// multi-dimensional index.
|
||||||
|
@ -315,14 +316,14 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||||
}
|
}
|
||||||
static uint64_t getFlattenedIndex(Type type,
|
static uint64_t getFlattenedIndex(Type type,
|
||||||
ArrayRef<uint64_t> index);
|
ArrayRef<uint64_t> index);
|
||||||
static uint64_t getFlattenedIndex(Attribute elementsAttr,
|
static uint64_t getFlattenedIndex(ElementsAttr elementsAttr,
|
||||||
ArrayRef<uint64_t> index) {
|
ArrayRef<uint64_t> index) {
|
||||||
return getFlattenedIndex(elementsAttr.getType(), index);
|
return getFlattenedIndex(elementsAttr.getType(), index);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of elements held by this attribute.
|
/// Returns the number of elements held by this attribute.
|
||||||
int64_t getNumElements() const { return getNumElements(*this); }
|
int64_t getNumElements() const { return getNumElements(*this); }
|
||||||
static int64_t getNumElements(Attribute elementsAttr);
|
static int64_t getNumElements(ElementsAttr elementsAttr);
|
||||||
|
|
||||||
//===------------------------------------------------------------------===//
|
//===------------------------------------------------------------------===//
|
||||||
// Value Iteration
|
// Value Iteration
|
||||||
|
@ -349,7 +350,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||||
/// Return the elements of this attribute as a value of type 'T'.
|
/// Return the elements of this attribute as a value of type 'T'.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
DefaultValueCheckT<T, iterator_range<T>> getValues() const {
|
DefaultValueCheckT<T, iterator_range<T>> getValues() const {
|
||||||
return {Attribute::getType(), value_begin<T>(), value_end<T>()};
|
return {getType(), value_begin<T>(), value_end<T>()};
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
DefaultValueCheckT<T, iterator<T>> value_begin() const;
|
DefaultValueCheckT<T, iterator<T>> value_begin() const;
|
||||||
|
@ -369,8 +370,8 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||||
template <typename T, typename = DerivedAttrValueCheckT<T>>
|
template <typename T, typename = DerivedAttrValueCheckT<T>>
|
||||||
DerivedAttrValueIteratorRange<T> getValues() const {
|
DerivedAttrValueIteratorRange<T> getValues() const {
|
||||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||||
return {Attribute::getType(), llvm::map_range(getValues<Attribute>(),
|
return {getType(), llvm::map_range(getValues<Attribute>(),
|
||||||
static_cast<T (*)(Attribute)>(castFn))};
|
static_cast<T (*)(Attribute)>(castFn))};
|
||||||
}
|
}
|
||||||
template <typename T, typename = DerivedAttrValueCheckT<T>>
|
template <typename T, typename = DerivedAttrValueCheckT<T>>
|
||||||
DerivedAttrValueIterator<T> value_begin() const {
|
DerivedAttrValueIterator<T> value_begin() const {
|
||||||
|
@ -388,10 +389,8 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||||
/// return the iterable range. Otherwise, return llvm::None.
|
/// return the iterable range. Otherwise, return llvm::None.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const {
|
DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const {
|
||||||
if (Optional<iterator<T>> beginIt = try_value_begin<T>()) {
|
if (Optional<iterator<T>> beginIt = try_value_begin<T>())
|
||||||
return iterator_range<T>(Attribute::getType(), *beginIt,
|
return iterator_range<T>(getType(), *beginIt, value_end<T>());
|
||||||
value_end<T>());
|
|
||||||
}
|
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -407,7 +406,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||||
|
|
||||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||||
return DerivedAttrValueIteratorRange<T>(
|
return DerivedAttrValueIteratorRange<T>(
|
||||||
Attribute::getType(),
|
getType(),
|
||||||
llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
|
llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -468,4 +467,23 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TypedAttrInterface
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def TypedAttrInterface : AttrInterface<"TypedAttr"> {
|
||||||
|
let cppNamespace = "::mlir";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
This interface is used for attributes that have a type. The type of an
|
||||||
|
attribute is understood to represent the type of the data contained in the
|
||||||
|
attribute and is often used as the type of a value with this data.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [InterfaceMethod<
|
||||||
|
"Get the attribute's type",
|
||||||
|
"::mlir::Type", "getType"
|
||||||
|
>];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
|
#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
|
||||||
|
|
|
@ -25,7 +25,6 @@ class IntegerSet;
|
||||||
class IntegerType;
|
class IntegerType;
|
||||||
class Location;
|
class Location;
|
||||||
class Operation;
|
class Operation;
|
||||||
class ShapedType;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Elements Attributes
|
// Elements Attributes
|
||||||
|
@ -402,7 +401,7 @@ public:
|
||||||
std::numeric_limits<T>::is_signed));
|
std::numeric_limits<T>::is_signed));
|
||||||
const char *rawData = getRawData().data();
|
const char *rawData = getRawData().data();
|
||||||
bool splat = isSplat();
|
bool splat = isSplat();
|
||||||
return {Attribute::getType(), ElementIterator<T>(rawData, splat, 0),
|
return {getType(), ElementIterator<T>(rawData, splat, 0),
|
||||||
ElementIterator<T>(rawData, splat, getNumElements())};
|
ElementIterator<T>(rawData, splat, getNumElements())};
|
||||||
}
|
}
|
||||||
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
|
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
|
||||||
|
@ -431,7 +430,7 @@ public:
|
||||||
std::numeric_limits<ElementT>::is_signed));
|
std::numeric_limits<ElementT>::is_signed));
|
||||||
const char *rawData = getRawData().data();
|
const char *rawData = getRawData().data();
|
||||||
bool splat = isSplat();
|
bool splat = isSplat();
|
||||||
return {Attribute::getType(), ElementIterator<T>(rawData, splat, 0),
|
return {getType(), ElementIterator<T>(rawData, splat, 0),
|
||||||
ElementIterator<T>(rawData, splat, getNumElements())};
|
ElementIterator<T>(rawData, splat, getNumElements())};
|
||||||
}
|
}
|
||||||
template <typename T, typename ElementT = typename T::value_type,
|
template <typename T, typename ElementT = typename T::value_type,
|
||||||
|
@ -458,7 +457,7 @@ public:
|
||||||
auto stringRefs = getRawStringData();
|
auto stringRefs = getRawStringData();
|
||||||
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
|
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
|
||||||
bool splat = isSplat();
|
bool splat = isSplat();
|
||||||
return {Attribute::getType(), ElementIterator<StringRef>(ptr, splat, 0),
|
return {getType(), ElementIterator<StringRef>(ptr, splat, 0),
|
||||||
ElementIterator<StringRef>(ptr, splat, getNumElements())};
|
ElementIterator<StringRef>(ptr, splat, getNumElements())};
|
||||||
}
|
}
|
||||||
template <typename T, typename = StringRefValueTemplateCheckT<T>>
|
template <typename T, typename = StringRefValueTemplateCheckT<T>>
|
||||||
|
@ -478,8 +477,7 @@ public:
|
||||||
typename std::enable_if<std::is_same<T, Attribute>::value>::type;
|
typename std::enable_if<std::is_same<T, Attribute>::value>::type;
|
||||||
template <typename T, typename = AttributeValueTemplateCheckT<T>>
|
template <typename T, typename = AttributeValueTemplateCheckT<T>>
|
||||||
iterator_range_impl<AttributeElementIterator> getValues() const {
|
iterator_range_impl<AttributeElementIterator> getValues() const {
|
||||||
return {Attribute::getType(), value_begin<Attribute>(),
|
return {getType(), value_begin<Attribute>(), value_end<Attribute>()};
|
||||||
value_end<Attribute>()};
|
|
||||||
}
|
}
|
||||||
template <typename T, typename = AttributeValueTemplateCheckT<T>>
|
template <typename T, typename = AttributeValueTemplateCheckT<T>>
|
||||||
AttributeElementIterator value_begin() const {
|
AttributeElementIterator value_begin() const {
|
||||||
|
@ -510,7 +508,7 @@ public:
|
||||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||||
iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
|
iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
|
||||||
using DerivedIterT = DerivedAttributeElementIterator<T>;
|
using DerivedIterT = DerivedAttributeElementIterator<T>;
|
||||||
return {Attribute::getType(), DerivedIterT(value_begin<Attribute>()),
|
return {getType(), DerivedIterT(value_begin<Attribute>()),
|
||||||
DerivedIterT(value_end<Attribute>())};
|
DerivedIterT(value_end<Attribute>())};
|
||||||
}
|
}
|
||||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||||
|
@ -530,7 +528,7 @@ public:
|
||||||
template <typename T, typename = BoolValueTemplateCheckT<T>>
|
template <typename T, typename = BoolValueTemplateCheckT<T>>
|
||||||
iterator_range_impl<BoolElementIterator> getValues() const {
|
iterator_range_impl<BoolElementIterator> getValues() const {
|
||||||
assert(isValidBool() && "bool is not the value of this elements attribute");
|
assert(isValidBool() && "bool is not the value of this elements attribute");
|
||||||
return {Attribute::getType(), BoolElementIterator(*this, 0),
|
return {getType(), BoolElementIterator(*this, 0),
|
||||||
BoolElementIterator(*this, getNumElements())};
|
BoolElementIterator(*this, getNumElements())};
|
||||||
}
|
}
|
||||||
template <typename T, typename = BoolValueTemplateCheckT<T>>
|
template <typename T, typename = BoolValueTemplateCheckT<T>>
|
||||||
|
@ -552,7 +550,7 @@ public:
|
||||||
template <typename T, typename = APIntValueTemplateCheckT<T>>
|
template <typename T, typename = APIntValueTemplateCheckT<T>>
|
||||||
iterator_range_impl<IntElementIterator> getValues() const {
|
iterator_range_impl<IntElementIterator> getValues() const {
|
||||||
assert(getElementType().isIntOrIndex() && "expected integral type");
|
assert(getElementType().isIntOrIndex() && "expected integral type");
|
||||||
return {Attribute::getType(), raw_int_begin(), raw_int_end()};
|
return {getType(), raw_int_begin(), raw_int_end()};
|
||||||
}
|
}
|
||||||
template <typename T, typename = APIntValueTemplateCheckT<T>>
|
template <typename T, typename = APIntValueTemplateCheckT<T>>
|
||||||
IntElementIterator value_begin() const {
|
IntElementIterator value_begin() const {
|
||||||
|
@ -991,8 +989,6 @@ inline bool operator==(StringRef lhs, StringAttr rhs) {
|
||||||
}
|
}
|
||||||
inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); }
|
inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); }
|
||||||
|
|
||||||
inline Type StringAttr::getType() const { return Attribute::getType(); }
|
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -64,7 +64,6 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
|
||||||
AffineMap getAffineMap() const { return getValue(); }
|
AffineMap getAffineMap() const { return getValue(); }
|
||||||
}];
|
}];
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
let typeBuilder = "IndexType::get($_value.getContext())";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -140,11 +139,11 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DenseIntOrFPElementsAttr
|
// DenseArrayBaseAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Builtin_DenseArrayBase : Builtin_Attr<
|
def Builtin_DenseArrayBase : Builtin_Attr<
|
||||||
"DenseArrayBase", [ElementsAttrInterface]> {
|
"DenseArrayBase", [ElementsAttrInterface, TypedAttrInterface]> {
|
||||||
let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
|
let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
|
||||||
let description = [{
|
let description = [{
|
||||||
A dense array attribute is an attribute that represents a dense array of
|
A dense array attribute is an attribute that represents a dense array of
|
||||||
|
@ -197,8 +196,12 @@ def Builtin_DenseArrayBase : Builtin_Attr<
|
||||||
const float *value_begin_impl(OverloadToken<float>) const;
|
const float *value_begin_impl(OverloadToken<float>) const;
|
||||||
const double *value_begin_impl(OverloadToken<double>) const;
|
const double *value_begin_impl(OverloadToken<double>) const;
|
||||||
|
|
||||||
/// Methods to support type inquiry through isa, cast, and dyn_cast.
|
/// Returns the shaped type, containing the number of elements in the array
|
||||||
|
/// and the array element type.
|
||||||
|
ShapedType getType() const;
|
||||||
|
/// Returns the element type.
|
||||||
EltType getElementType() const;
|
EltType getElementType() const;
|
||||||
|
|
||||||
/// Printer for the short form: will dispatch to the appropriate subclass.
|
/// Printer for the short form: will dispatch to the appropriate subclass.
|
||||||
void print(AsmPrinter &printer) const;
|
void print(AsmPrinter &printer) const;
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os) const;
|
||||||
|
@ -216,7 +219,8 @@ def Builtin_DenseArrayBase : Builtin_Attr<
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
|
def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
|
||||||
"DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr"
|
"DenseIntOrFPElements", [ElementsAttrInterface, TypedAttrInterface],
|
||||||
|
"DenseElementsAttr"
|
||||||
> {
|
> {
|
||||||
let summary = "An Attribute containing a dense multi-dimensional array of "
|
let summary = "An Attribute containing a dense multi-dimensional array of "
|
||||||
"integer or floating-point values";
|
"integer or floating-point values";
|
||||||
|
@ -355,7 +359,8 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Builtin_DenseStringElementsAttr : Builtin_Attr<
|
def Builtin_DenseStringElementsAttr : Builtin_Attr<
|
||||||
"DenseStringElements", [ElementsAttrInterface], "DenseElementsAttr"
|
"DenseStringElements", [ElementsAttrInterface, TypedAttrInterface],
|
||||||
|
"DenseElementsAttr"
|
||||||
> {
|
> {
|
||||||
let summary = "An Attribute containing a dense multi-dimensional array of "
|
let summary = "An Attribute containing a dense multi-dimensional array of "
|
||||||
"strings";
|
"strings";
|
||||||
|
@ -523,7 +528,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
|
||||||
// FloatAttr
|
// FloatAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Builtin_FloatAttr : Builtin_Attr<"Float"> {
|
def Builtin_FloatAttr : Builtin_Attr<"Float", [TypedAttrInterface]> {
|
||||||
let summary = "An Attribute containing a floating-point value";
|
let summary = "An Attribute containing a floating-point value";
|
||||||
let description = [{
|
let description = [{
|
||||||
Syntax:
|
Syntax:
|
||||||
|
@ -586,7 +591,7 @@ def Builtin_FloatAttr : Builtin_Attr<"Float"> {
|
||||||
// IntegerAttr
|
// IntegerAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Builtin_IntegerAttr : Builtin_Attr<"Integer"> {
|
def Builtin_IntegerAttr : Builtin_Attr<"Integer", [TypedAttrInterface]> {
|
||||||
let summary = "An Attribute containing a integer value";
|
let summary = "An Attribute containing a integer value";
|
||||||
let description = [{
|
let description = [{
|
||||||
Syntax:
|
Syntax:
|
||||||
|
@ -703,7 +708,7 @@ def Builtin_IntegerSetAttr : Builtin_Attr<"IntegerSet"> {
|
||||||
// OpaqueAttr
|
// OpaqueAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Builtin_OpaqueAttr : Builtin_Attr<"Opaque"> {
|
def Builtin_OpaqueAttr : Builtin_Attr<"Opaque", [TypedAttrInterface]> {
|
||||||
let summary = "An opaque representation of another Attribute";
|
let summary = "An opaque representation of another Attribute";
|
||||||
let description = [{
|
let description = [{
|
||||||
Syntax:
|
Syntax:
|
||||||
|
@ -741,7 +746,7 @@ def Builtin_OpaqueAttr : Builtin_Attr<"Opaque"> {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Builtin_OpaqueElementsAttr : Builtin_Attr<
|
def Builtin_OpaqueElementsAttr : Builtin_Attr<
|
||||||
"OpaqueElements", [ElementsAttrInterface]
|
"OpaqueElements", [ElementsAttrInterface, TypedAttrInterface]
|
||||||
> {
|
> {
|
||||||
let summary = "An opaque representation of a multi-dimensional array";
|
let summary = "An opaque representation of a multi-dimensional array";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -803,7 +808,7 @@ def Builtin_OpaqueElementsAttr : Builtin_Attr<
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Builtin_SparseElementsAttr : Builtin_Attr<
|
def Builtin_SparseElementsAttr : Builtin_Attr<
|
||||||
"SparseElements", [ElementsAttrInterface]
|
"SparseElements", [ElementsAttrInterface, TypedAttrInterface]
|
||||||
> {
|
> {
|
||||||
let summary = "An opaque representation of a multi-dimensional array";
|
let summary = "An opaque representation of a multi-dimensional array";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -966,7 +971,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
|
||||||
// StringAttr
|
// StringAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Builtin_StringAttr : Builtin_Attr<"String"> {
|
def Builtin_StringAttr : Builtin_Attr<"String", [TypedAttrInterface]> {
|
||||||
let summary = "An Attribute containing a string";
|
let summary = "An Attribute containing a string";
|
||||||
let description = [{
|
let description = [{
|
||||||
Syntax:
|
Syntax:
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
//===- BuiltinTypeInterfaces.h - Builtin Type Interfaces --------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_IR_BUILTINTYPEINTERFACES_H
|
||||||
|
#define MLIR_IR_BUILTINTYPEINTERFACES_H
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
|
||||||
|
|
||||||
|
#endif // MLIR_IR_BUILTINTYPEINTERFACES_H
|
|
@ -9,8 +9,9 @@
|
||||||
#ifndef MLIR_IR_BUILTINTYPES_H
|
#ifndef MLIR_IR_BUILTINTYPES_H
|
||||||
#define MLIR_IR_BUILTINTYPES_H
|
#define MLIR_IR_BUILTINTYPES_H
|
||||||
|
|
||||||
#include "BuiltinAttributeInterfaces.h"
|
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||||
#include "SubElementInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
|
#include "mlir/IR/SubElementInterfaces.h"
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
class BitVector;
|
class BitVector;
|
||||||
|
@ -21,8 +22,6 @@ struct fltSemantics;
|
||||||
// Tablegen Interface Declarations
|
// Tablegen Interface Declarations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class AffineExpr;
|
class AffineExpr;
|
||||||
class AffineMap;
|
class AffineMap;
|
||||||
|
|
|
@ -215,8 +215,9 @@ static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases,
|
||||||
/// Parse an extended attribute.
|
/// Parse an extended attribute.
|
||||||
///
|
///
|
||||||
/// extended-attribute ::= (dialect-attribute | attribute-alias)
|
/// extended-attribute ::= (dialect-attribute | attribute-alias)
|
||||||
/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
|
/// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
|
||||||
/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
|
/// (`:` type)?
|
||||||
|
/// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
|
||||||
/// attribute-alias ::= `#` alias-name
|
/// attribute-alias ::= `#` alias-name
|
||||||
///
|
///
|
||||||
Attribute Parser::parseExtendedAttr(Type type) {
|
Attribute Parser::parseExtendedAttr(Type type) {
|
||||||
|
@ -250,9 +251,10 @@ Attribute Parser::parseExtendedAttr(Type type) {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Ensure that the attribute has the same type as requested.
|
// Ensure that the attribute has the same type as requested.
|
||||||
if (attr && type && attr.getType() != type) {
|
auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
|
||||||
|
if (type && typedAttr && typedAttr.getType() != type) {
|
||||||
emitError("attribute type different than expected: expected ")
|
emitError("attribute type different than expected: expected ")
|
||||||
<< type << ", but got " << attr.getType();
|
<< type << ", but got " << typedAttr.getType();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return attr;
|
return attr;
|
||||||
|
|
|
@ -753,7 +753,10 @@ MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType mlirAttributeGetType(MlirAttribute attribute) {
|
MlirType mlirAttributeGetType(MlirAttribute attribute) {
|
||||||
return wrap(unwrap(attribute).getType());
|
Attribute attr = unwrap(attribute);
|
||||||
|
if (auto typedAttr = attr.dyn_cast<TypedAttr>())
|
||||||
|
return wrap(typedAttr.getType());
|
||||||
|
return wrap(NoneType::get(attr.getContext()));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
|
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
|
||||||
|
|
|
@ -395,8 +395,8 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Attribute cstAttr = constOp.getValue();
|
Attribute cstAttr = constOp.getValue();
|
||||||
if (cstAttr.getType().isa<ShapedType>())
|
if (auto elementsAttr = cstAttr.dyn_cast<DenseElementsAttr>())
|
||||||
cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
|
cstAttr = elementsAttr.getSplatValue<Attribute>();
|
||||||
|
|
||||||
Type dstType = getTypeConverter()->convertType(srcType);
|
Type dstType = getTypeConverter()->convertType(srcType);
|
||||||
if (!dstType)
|
if (!dstType)
|
||||||
|
|
|
@ -698,8 +698,8 @@ static void convertConstantOp(arith::ConstantOp op,
|
||||||
llvm::DenseMap<Value, Value> &valueMapping) {
|
llvm::DenseMap<Value, Value> &valueMapping) {
|
||||||
assert(constantSupportsMMAMatrixType(op));
|
assert(constantSupportsMMAMatrixType(op));
|
||||||
OpBuilder b(op);
|
OpBuilder b(op);
|
||||||
Attribute splat =
|
auto splat =
|
||||||
op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>();
|
op.getValue().cast<SplatElementsAttr>().getSplatValue<TypedAttr>();
|
||||||
auto scalarConstant =
|
auto scalarConstant =
|
||||||
b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
|
b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
|
||||||
const char *fragType = inferFragType(op);
|
const char *fragType = inferFragType(op);
|
||||||
|
|
|
@ -128,7 +128,8 @@ LogicalResult arith::ConstantOp::verify() {
|
||||||
|
|
||||||
bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
|
bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
|
||||||
// The value's type must be the same as the provided type.
|
// The value's type must be the same as the provided type.
|
||||||
if (value.getType() != type)
|
auto typedAttr = value.dyn_cast<TypedAttr>();
|
||||||
|
if (!typedAttr || typedAttr.getType() != type)
|
||||||
return false;
|
return false;
|
||||||
// Integer values must be signless.
|
// Integer values must be signless.
|
||||||
if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
|
if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
|
||||||
|
|
|
@ -30,11 +30,13 @@ void ConstantOp::getAsmResultNames(
|
||||||
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
|
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
|
||||||
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
|
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
|
||||||
auto complexTy = type.dyn_cast<ComplexType>();
|
auto complexTy = type.dyn_cast<ComplexType>();
|
||||||
if (!complexTy)
|
if (!complexTy || arrAttr.size() != 2)
|
||||||
return false;
|
return false;
|
||||||
auto complexEltTy = complexTy.getElementType();
|
auto complexEltTy = complexTy.getElementType();
|
||||||
return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
|
auto re = arrAttr[0].dyn_cast<FloatAttr>();
|
||||||
arrAttr[1].getType() == complexEltTy;
|
auto im = arrAttr[1].dyn_cast<FloatAttr>();
|
||||||
|
return re && im && re.getType() == complexEltTy &&
|
||||||
|
im.getType() == complexEltTy;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -48,11 +50,14 @@ LogicalResult ConstantOp::verify() {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto complexEltTy = getType().getElementType();
|
auto complexEltTy = getType().getElementType();
|
||||||
if (complexEltTy != arrayAttr[0].getType() ||
|
auto re = arrayAttr[0].dyn_cast<FloatAttr>();
|
||||||
complexEltTy != arrayAttr[1].getType()) {
|
auto im = arrayAttr[1].dyn_cast<FloatAttr>();
|
||||||
|
if (!re || !im)
|
||||||
|
return emitOpError("requires attribute's elements to be float attributes");
|
||||||
|
if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
|
||||||
return emitOpError()
|
return emitOpError()
|
||||||
<< "requires attribute's element types (" << arrayAttr[0].getType()
|
<< "requires attribute's element types (" << re.getType() << ", "
|
||||||
<< ", " << arrayAttr[1].getType()
|
<< im.getType()
|
||||||
<< ") to match the element type of the op's return type ("
|
<< ") to match the element type of the op's return type ("
|
||||||
<< complexEltTy << ")";
|
<< complexEltTy << ")";
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,15 +86,17 @@ LogicalResult emitc::CallOp::verify() {
|
||||||
|
|
||||||
if (Optional<ArrayAttr> argsAttr = getArgs()) {
|
if (Optional<ArrayAttr> argsAttr = getArgs()) {
|
||||||
for (Attribute arg : *argsAttr) {
|
for (Attribute arg : *argsAttr) {
|
||||||
if (arg.getType().isa<IndexType>()) {
|
auto intAttr = arg.dyn_cast<IntegerAttr>();
|
||||||
int64_t index = arg.cast<IntegerAttr>().getInt();
|
if (intAttr && intAttr.getType().isa<IndexType>()) {
|
||||||
|
int64_t index = intAttr.getInt();
|
||||||
// Args with elements of type index must be in range
|
// Args with elements of type index must be in range
|
||||||
// [0..operands.size).
|
// [0..operands.size).
|
||||||
if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
|
if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
|
||||||
return emitOpError("index argument is out of range");
|
return emitOpError("index argument is out of range");
|
||||||
|
|
||||||
// Args with elements of type ArrayAttr must have a type.
|
// Args with elements of type ArrayAttr must have a type.
|
||||||
} else if (arg.isa<ArrayAttr>() && arg.getType().isa<NoneType>()) {
|
} else if (arg.isa<ArrayAttr>() /*&& arg.getType().isa<NoneType>()*/) {
|
||||||
|
// FIXME: Array attributes never have types
|
||||||
return emitOpError("array argument has no type");
|
return emitOpError("array argument has no type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -102,8 +104,7 @@ LogicalResult emitc::CallOp::verify() {
|
||||||
|
|
||||||
if (Optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
|
if (Optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
|
||||||
for (Attribute tArg : *templateArgsAttr) {
|
for (Attribute tArg : *templateArgsAttr) {
|
||||||
if (!tArg.isa<TypeAttr>() && !tArg.isa<IntegerAttr>() &&
|
if (!tArg.isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>())
|
||||||
!tArg.isa<FloatAttr>() && !tArg.isa<emitc::OpaqueAttr>())
|
|
||||||
return emitOpError("template argument has invalid type");
|
return emitOpError("template argument has invalid type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -117,7 +118,7 @@ LogicalResult emitc::CallOp::verify() {
|
||||||
|
|
||||||
/// The constant op requires that the attribute's type matches the return type.
|
/// The constant op requires that the attribute's type matches the return type.
|
||||||
LogicalResult emitc::ConstantOp::verify() {
|
LogicalResult emitc::ConstantOp::verify() {
|
||||||
Attribute value = getValueAttr();
|
TypedAttr value = getValueAttr();
|
||||||
Type type = getType();
|
Type type = getType();
|
||||||
if (!value.getType().isa<NoneType>() && type != value.getType())
|
if (!value.getType().isa<NoneType>() && type != value.getType())
|
||||||
return emitOpError() << "requires attribute's type (" << value.getType()
|
return emitOpError() << "requires attribute's type (" << value.getType()
|
||||||
|
@ -171,7 +172,7 @@ ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
|
||||||
/// The variable op requires that the attribute's type matches the return type.
|
/// The variable op requires that the attribute's type matches the return type.
|
||||||
LogicalResult emitc::VariableOp::verify() {
|
LogicalResult emitc::VariableOp::verify() {
|
||||||
Attribute value = getValueAttr();
|
TypedAttr value = getValueAttr();
|
||||||
Type type = getType();
|
Type type = getType();
|
||||||
if (!value.getType().isa<NoneType>() && type != value.getType())
|
if (!value.getType().isa<NoneType>() && type != value.getType())
|
||||||
return emitOpError() << "requires attribute's type (" << value.getType()
|
return emitOpError() << "requires attribute's type (" << value.getType()
|
||||||
|
@ -204,7 +205,9 @@ Attribute emitc::OpaqueAttr::parse(AsmParser &parser, Type type) {
|
||||||
}
|
}
|
||||||
if (parser.parseGreater())
|
if (parser.parseGreater())
|
||||||
return Attribute();
|
return Attribute();
|
||||||
return get(parser.getContext(), value);
|
|
||||||
|
return get(parser.getContext(),
|
||||||
|
type ? type : NoneType::get(parser.getContext()), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void emitc::OpaqueAttr::print(AsmPrinter &printer) const {
|
void emitc::OpaqueAttr::print(AsmPrinter &printer) const {
|
||||||
|
|
|
@ -2409,11 +2409,16 @@ LogicalResult LLVM::ConstantOp::verify() {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
|
auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
|
||||||
if (!arrayAttr || arrayAttr.size() != 2 ||
|
if (!arrayAttr || arrayAttr.size() != 2) {
|
||||||
arrayAttr[0].getType() != arrayAttr[1].getType()) {
|
|
||||||
return emitOpError() << "expected array attribute with two elements, "
|
return emitOpError() << "expected array attribute with two elements, "
|
||||||
"representing a complex constant";
|
"representing a complex constant";
|
||||||
}
|
}
|
||||||
|
auto re = arrayAttr[0].dyn_cast<TypedAttr>();
|
||||||
|
auto im = arrayAttr[1].dyn_cast<TypedAttr>();
|
||||||
|
if (!re || !im || re.getType() != im.getType()) {
|
||||||
|
return emitOpError()
|
||||||
|
<< "expected array attribute with two elements of the same type";
|
||||||
|
}
|
||||||
|
|
||||||
Type elementType = structType.getBody()[0];
|
Type elementType = structType.getBody()[0];
|
||||||
if (!elementType
|
if (!elementType
|
||||||
|
|
|
@ -400,8 +400,10 @@ public:
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
Location loc = builder.getUnknownLoc();
|
Location loc = builder.getUnknownLoc();
|
||||||
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
||||||
return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
|
Type type = NoneType::get(builder.getContext());
|
||||||
valueAttr);
|
if (auto typedAttr = valueAttr.dyn_cast<TypedAttr>())
|
||||||
|
type = typedAttr.getType();
|
||||||
|
return builder.create<arith::ConstantOp>(loc, type, valueAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value index(int64_t dim) {
|
Value index(int64_t dim) {
|
||||||
|
|
|
@ -530,7 +530,11 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
|
||||||
SmallVector<Attribute> paddingValues;
|
SmallVector<Attribute> paddingValues;
|
||||||
for (auto const &it :
|
for (auto const &it :
|
||||||
llvm::zip(getPaddingValues(), target->getOperandTypes())) {
|
llvm::zip(getPaddingValues(), target->getOperandTypes())) {
|
||||||
Attribute attr = std::get<0>(it);
|
auto attr = std::get<0>(it).dyn_cast<TypedAttr>();
|
||||||
|
if (!attr) {
|
||||||
|
emitOpError("expects padding values to be typed attributes");
|
||||||
|
return DiagnosedSilenceableFailure::definiteFailure();
|
||||||
|
}
|
||||||
Type elementType = getElementTypeOrSelf(std::get<1>(it));
|
Type elementType = getElementTypeOrSelf(std::get<1>(it));
|
||||||
// Try to parse string attributes to obtain an attribute of element type.
|
// Try to parse string attributes to obtain an attribute of element type.
|
||||||
if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
|
if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
|
||||||
|
|
|
@ -1509,14 +1509,14 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
||||||
Operation *def = opOperand->get().getDefiningOp();
|
Operation *def = opOperand->get().getDefiningOp();
|
||||||
Attribute constantAttr;
|
TypedAttr constantAttr;
|
||||||
auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
|
auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
|
||||||
{
|
{
|
||||||
DenseElementsAttr splatAttr;
|
DenseElementsAttr splatAttr;
|
||||||
if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
|
if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
|
||||||
splatAttr.isSplat() &&
|
splatAttr.isSplat() &&
|
||||||
splatAttr.getType().getElementType().isIntOrFloat()) {
|
splatAttr.getType().getElementType().isIntOrFloat()) {
|
||||||
constantAttr = splatAttr.getSplatValue<Attribute>();
|
constantAttr = splatAttr.getSplatValue<TypedAttr>();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,8 +198,11 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
|
||||||
if (opOperand->getOperandNumber() >= paddingValues.size())
|
if (opOperand->getOperandNumber() >= paddingValues.size())
|
||||||
return failure();
|
return failure();
|
||||||
Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
|
Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
|
||||||
Value paddingValue = b.create<arith::ConstantOp>(
|
Type paddingType = b.getType<NoneType>();
|
||||||
opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
|
if (auto typedAttr = paddingAttr.dyn_cast<TypedAttr>())
|
||||||
|
paddingType = typedAttr.getType();
|
||||||
|
Value paddingValue =
|
||||||
|
b.create<arith::ConstantOp>(opToPad.getLoc(), paddingType, paddingAttr);
|
||||||
|
|
||||||
// Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
|
// Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
|
||||||
OpOperand *currOpOperand = opOperand;
|
OpOperand *currOpOperand = opOperand;
|
||||||
|
|
|
@ -1309,8 +1309,8 @@ LogicalResult GlobalOp::verify() {
|
||||||
|
|
||||||
// Check that the type of the initial value is compatible with the type of
|
// Check that the type of the initial value is compatible with the type of
|
||||||
// the global variable.
|
// the global variable.
|
||||||
if (initValue.isa<ElementsAttr>()) {
|
if (auto elementsAttr = initValue.dyn_cast<ElementsAttr>()) {
|
||||||
Type initType = initValue.getType();
|
Type initType = elementsAttr.getType();
|
||||||
Type tensorType = getTensorTypeFromMemRefType(memrefType);
|
Type tensorType = getTensorTypeFromMemRefType(memrefType);
|
||||||
if (initType != tensorType)
|
if (initType != tensorType)
|
||||||
return emitOpError("initial value expected to be of type ")
|
return emitOpError("initial value expected to be of type ")
|
||||||
|
|
|
@ -28,20 +28,15 @@ using namespace mlir;
|
||||||
|
|
||||||
/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
|
/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
|
||||||
/// or splat vector bool constant.
|
/// or splat vector bool constant.
|
||||||
static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
|
static Optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
|
||||||
if (!boolAttr)
|
if (!attr)
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
|
||||||
auto type = boolAttr.getType();
|
if (auto boolAttr = attr.dyn_cast<BoolAttr>())
|
||||||
if (type.isInteger(1)) {
|
return boolAttr.getValue();
|
||||||
auto attr = boolAttr.cast<BoolAttr>();
|
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>())
|
||||||
return attr.getValue();
|
if (splatAttr.getElementType().isInteger(1))
|
||||||
}
|
return splatAttr.getSplatValue<bool>();
|
||||||
if (auto vecType = type.cast<VectorType>()) {
|
|
||||||
if (vecType.getElementType().isInteger(1))
|
|
||||||
if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
|
|
||||||
return attr.getSplatValue<bool>();
|
|
||||||
}
|
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1803,7 +1803,9 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
|
||||||
if (parser.parseAttribute(value, kValueAttrName, state.attributes))
|
if (parser.parseAttribute(value, kValueAttrName, state.attributes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Type type = value.getType();
|
Type type = NoneType::get(parser.getContext());
|
||||||
|
if (auto typedAttr = value.dyn_cast<TypedAttr>())
|
||||||
|
type = typedAttr.getType();
|
||||||
if (type.isa<NoneType, TensorType>()) {
|
if (type.isa<NoneType, TensorType>()) {
|
||||||
if (parser.parseColonType(type))
|
if (parser.parseColonType(type))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -1820,15 +1822,15 @@ void spirv::ConstantOp::print(OpAsmPrinter &printer) {
|
||||||
|
|
||||||
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
|
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
|
||||||
Type opType) {
|
Type opType) {
|
||||||
auto valueType = value.getType();
|
|
||||||
|
|
||||||
if (value.isa<IntegerAttr, FloatAttr>()) {
|
if (value.isa<IntegerAttr, FloatAttr>()) {
|
||||||
|
auto valueType = value.cast<TypedAttr>().getType();
|
||||||
if (valueType != opType)
|
if (valueType != opType)
|
||||||
return op.emitOpError("result type (")
|
return op.emitOpError("result type (")
|
||||||
<< opType << ") does not match value type (" << valueType << ")";
|
<< opType << ") does not match value type (" << valueType << ")";
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
|
if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
|
||||||
|
auto valueType = value.cast<TypedAttr>().getType();
|
||||||
if (valueType == opType)
|
if (valueType == opType)
|
||||||
return success();
|
return success();
|
||||||
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
||||||
|
@ -1873,7 +1875,7 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
return op.emitOpError("cannot have value of type ") << valueType;
|
return op.emitOpError("cannot have attribute: ") << value;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult spirv::ConstantOp::verify() {
|
LogicalResult spirv::ConstantOp::verify() {
|
||||||
|
|
|
@ -1737,7 +1737,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (!operands[0])
|
if (!operands[0])
|
||||||
return {};
|
return {};
|
||||||
auto vectorType = getVectorType();
|
auto vectorType = getVectorType();
|
||||||
if (operands[0].getType().isIntOrIndexOrFloat())
|
if (operands[0].isa<IntegerAttr, FloatAttr>())
|
||||||
return DenseElementsAttr::get(vectorType, operands[0]);
|
return DenseElementsAttr::get(vectorType, operands[0]);
|
||||||
if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
|
if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
|
||||||
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
|
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
|
||||||
|
@ -1855,7 +1855,7 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (!lhs || !rhs)
|
if (!lhs || !rhs)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto lhsType = lhs.getType().cast<VectorType>();
|
auto lhsType = lhs.cast<DenseElementsAttr>().getType().cast<VectorType>();
|
||||||
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
|
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
|
||||||
// manipulation.
|
// manipulation.
|
||||||
if (lhsType.getRank() != 1)
|
if (lhsType.getRank() != 1)
|
||||||
|
|
|
@ -1752,7 +1752,6 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
||||||
if (succeeded(printAlias(attr)))
|
if (succeeded(printAlias(attr)))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto attrType = attr.getType();
|
|
||||||
if (!isa<BuiltinDialect>(attr.getDialect())) {
|
if (!isa<BuiltinDialect>(attr.getDialect())) {
|
||||||
printDialectAttribute(attr);
|
printDialectAttribute(attr);
|
||||||
} else if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
|
} else if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
|
||||||
|
@ -1768,7 +1767,8 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
||||||
os << '}';
|
os << '}';
|
||||||
|
|
||||||
} else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
|
} else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
|
||||||
if (attrType.isSignlessInteger(1)) {
|
Type intType = intAttr.getType();
|
||||||
|
if (intType.isSignlessInteger(1)) {
|
||||||
os << (intAttr.getValue().getBoolValue() ? "true" : "false");
|
os << (intAttr.getValue().getBoolValue() ? "true" : "false");
|
||||||
|
|
||||||
// Boolean integer attributes always elides the type.
|
// Boolean integer attributes always elides the type.
|
||||||
|
@ -1779,18 +1779,18 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
||||||
// signless 1-bit values. Indexes, signed values, and multi-bit signless
|
// signless 1-bit values. Indexes, signed values, and multi-bit signless
|
||||||
// values print as signed.
|
// values print as signed.
|
||||||
bool isUnsigned =
|
bool isUnsigned =
|
||||||
attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
|
intType.isUnsignedInteger() || intType.isSignlessInteger(1);
|
||||||
intAttr.getValue().print(os, !isUnsigned);
|
intAttr.getValue().print(os, !isUnsigned);
|
||||||
|
|
||||||
// IntegerAttr elides the type if I64.
|
// IntegerAttr elides the type if I64.
|
||||||
if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
|
if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
} else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
|
} else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
|
||||||
printFloatValue(floatAttr.getValue(), os);
|
printFloatValue(floatAttr.getValue(), os);
|
||||||
|
|
||||||
// FloatAttr elides the type if F64.
|
// FloatAttr elides the type if F64.
|
||||||
if (typeElision == AttrTypeElision::May && attrType.isF64())
|
if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
} else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
|
} else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
|
||||||
|
@ -1892,7 +1892,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
||||||
os << "[:f64";
|
os << "[:f64";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (denseArrayAttr.getType().cast<ShapedType>().getRank())
|
if (denseArrayAttr.getType().getRank())
|
||||||
os << " ";
|
os << " ";
|
||||||
denseArrayAttr.printWithoutBraces(os);
|
denseArrayAttr.printWithoutBraces(os);
|
||||||
os << "]";
|
os << "]";
|
||||||
|
@ -1902,9 +1902,14 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
||||||
llvm::report_fatal_error("Unknown builtin attribute");
|
llvm::report_fatal_error("Unknown builtin attribute");
|
||||||
}
|
}
|
||||||
// Don't print the type if we must elide it, or if it is a None type.
|
// Don't print the type if we must elide it, or if it is a None type.
|
||||||
if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
|
if (typeElision != AttrTypeElision::Must) {
|
||||||
os << " : ";
|
if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
|
||||||
printType(attrType);
|
Type attrType = typedAttr.getType();
|
||||||
|
if (!attrType.isa<NoneType>()) {
|
||||||
|
os << " : ";
|
||||||
|
printType(attrType);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,9 +43,10 @@ inline size_t getDenseElementBitWidth(Type eltType) {
|
||||||
/// An attribute representing a reference to a dense vector or tensor object.
|
/// An attribute representing a reference to a dense vector or tensor object.
|
||||||
struct DenseElementsAttributeStorage : public AttributeStorage {
|
struct DenseElementsAttributeStorage : public AttributeStorage {
|
||||||
public:
|
public:
|
||||||
DenseElementsAttributeStorage(ShapedType ty, bool isSplat)
|
DenseElementsAttributeStorage(ShapedType type, bool isSplat)
|
||||||
: AttributeStorage(ty), isSplat(isSplat) {}
|
: type(type), isSplat(isSplat) {}
|
||||||
|
|
||||||
|
ShapedType type;
|
||||||
bool isSplat;
|
bool isSplat;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -75,7 +76,7 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
|
||||||
|
|
||||||
/// Compare this storage instance with the provided key.
|
/// Compare this storage instance with the provided key.
|
||||||
bool operator==(const KeyTy &key) const {
|
bool operator==(const KeyTy &key) const {
|
||||||
if (key.type != getType())
|
if (key.type != type)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// For boolean splats we need to explicitly check that the first bit is the
|
// For boolean splats we need to explicitly check that the first bit is the
|
||||||
|
@ -228,7 +229,7 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
|
||||||
|
|
||||||
/// Compare this storage instance with the provided key.
|
/// Compare this storage instance with the provided key.
|
||||||
bool operator==(const KeyTy &key) const {
|
bool operator==(const KeyTy &key) const {
|
||||||
if (key.type != getType())
|
if (key.type != type)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Otherwise, we can default to just checking the data. StringRefs compare
|
// Otherwise, we can default to just checking the data. StringRefs compare
|
||||||
|
@ -324,12 +325,12 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
|
||||||
|
|
||||||
struct StringAttrStorage : public AttributeStorage {
|
struct StringAttrStorage : public AttributeStorage {
|
||||||
StringAttrStorage(StringRef value, Type type)
|
StringAttrStorage(StringRef value, Type type)
|
||||||
: AttributeStorage(type), value(value), referencedDialect(nullptr) {}
|
: type(type), value(value), referencedDialect(nullptr) {}
|
||||||
|
|
||||||
/// The hash key is a tuple of the parameter types.
|
/// The hash key is a tuple of the parameter types.
|
||||||
using KeyTy = std::pair<StringRef, Type>;
|
using KeyTy = std::pair<StringRef, Type>;
|
||||||
bool operator==(const KeyTy &key) const {
|
bool operator==(const KeyTy &key) const {
|
||||||
return value == key.first && getType() == key.second;
|
return value == key.first && type == key.second;
|
||||||
}
|
}
|
||||||
static ::llvm::hash_code hashKey(const KeyTy &key) {
|
static ::llvm::hash_code hashKey(const KeyTy &key) {
|
||||||
return DenseMapInfo<KeyTy>::getHashValue(key);
|
return DenseMapInfo<KeyTy>::getHashValue(key);
|
||||||
|
@ -346,6 +347,8 @@ struct StringAttrStorage : public AttributeStorage {
|
||||||
/// Initialize the storage given an MLIRContext.
|
/// Initialize the storage given an MLIRContext.
|
||||||
void initialize(MLIRContext *context);
|
void initialize(MLIRContext *context);
|
||||||
|
|
||||||
|
/// The type of the string.
|
||||||
|
Type type;
|
||||||
/// The raw string value.
|
/// The raw string value.
|
||||||
StringRef value;
|
StringRef value;
|
||||||
/// If the string value contains a dialect namespace prefix (e.g.
|
/// If the string value contains a dialect namespace prefix (e.g.
|
||||||
|
|
|
@ -24,16 +24,12 @@ using namespace mlir::detail;
|
||||||
// ElementsAttr
|
// ElementsAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ShapedType ElementsAttr::getType() const {
|
Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
|
||||||
return Attribute::getType().cast<ShapedType>();
|
return elementsAttr.getType().getElementType();
|
||||||
}
|
}
|
||||||
|
|
||||||
Type ElementsAttr::getElementType(Attribute elementsAttr) {
|
int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
|
||||||
return elementsAttr.getType().cast<ShapedType>().getElementType();
|
return elementsAttr.getType().getNumElements();
|
||||||
}
|
|
||||||
|
|
||||||
int64_t ElementsAttr::getNumElements(Attribute elementsAttr) {
|
|
||||||
return elementsAttr.getType().cast<ShapedType>().getNumElements();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
|
bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
|
||||||
|
@ -51,9 +47,9 @@ bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
|
||||||
return 0 <= dim && dim < shape[i];
|
return 0 <= dim && dim < shape[i];
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
bool ElementsAttr::isValidIndex(Attribute elementsAttr,
|
bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
|
||||||
ArrayRef<uint64_t> index) {
|
ArrayRef<uint64_t> index) {
|
||||||
return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index);
|
return isValidIndex(elementsAttr.getType(), index);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
|
uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
|
||||||
|
|
|
@ -261,6 +261,8 @@ StringAttr StringAttr::get(const Twine &twine, Type type) {
|
||||||
|
|
||||||
StringRef StringAttr::getValue() const { return getImpl()->value; }
|
StringRef StringAttr::getValue() const { return getImpl()->value; }
|
||||||
|
|
||||||
|
Type StringAttr::getType() const { return getImpl()->type; }
|
||||||
|
|
||||||
Dialect *StringAttr::getReferencedDialect() const {
|
Dialect *StringAttr::getReferencedDialect() const {
|
||||||
return getImpl()->referencedDialect;
|
return getImpl()->referencedDialect;
|
||||||
}
|
}
|
||||||
|
@ -688,29 +690,28 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
|
||||||
/// Custom storage to ensure proper memory alignment for the allocation of
|
/// Custom storage to ensure proper memory alignment for the allocation of
|
||||||
/// DenseArray of any element type.
|
/// DenseArray of any element type.
|
||||||
struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
|
struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
|
||||||
using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
|
using KeyTy =
|
||||||
::llvm::ArrayRef<char>>;
|
std::tuple<ShapedType, DenseArrayBaseAttr::EltType, ArrayRef<char>>;
|
||||||
DenseArrayBaseAttrStorage(ShapedType type,
|
DenseArrayBaseAttrStorage(ShapedType type,
|
||||||
DenseArrayBaseAttr::EltType eltType,
|
DenseArrayBaseAttr::EltType eltType,
|
||||||
::llvm::ArrayRef<char> elements)
|
ArrayRef<char> elements)
|
||||||
: AttributeStorage(type), eltType(eltType), elements(elements) {}
|
: type(type), eltType(eltType), elements(elements) {}
|
||||||
|
|
||||||
bool operator==(const KeyTy &tblgenKey) const {
|
bool operator==(const KeyTy &key) const {
|
||||||
return (getType() == std::get<0>(tblgenKey)) &&
|
return (type == std::get<0>(key)) && (eltType == std::get<1>(key)) &&
|
||||||
(eltType == std::get<1>(tblgenKey)) &&
|
(elements == std::get<2>(key));
|
||||||
(elements == std::get<2>(tblgenKey));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {
|
static llvm::hash_code hashKey(const KeyTy &key) {
|
||||||
return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
|
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
|
||||||
std::get<2>(tblgenKey));
|
std::get<2>(key));
|
||||||
}
|
}
|
||||||
|
|
||||||
static DenseArrayBaseAttrStorage *
|
static DenseArrayBaseAttrStorage *
|
||||||
construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) {
|
construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
|
||||||
auto type = std::get<0>(tblgenKey);
|
auto type = std::get<0>(key);
|
||||||
auto eltType = std::get<1>(tblgenKey);
|
auto eltType = std::get<1>(key);
|
||||||
auto elements = std::get<2>(tblgenKey);
|
auto elements = std::get<2>(key);
|
||||||
if (!elements.empty()) {
|
if (!elements.empty()) {
|
||||||
char *alloc = static_cast<char *>(
|
char *alloc = static_cast<char *>(
|
||||||
allocator.allocate(elements.size(), alignof(uint64_t)));
|
allocator.allocate(elements.size(), alignof(uint64_t)));
|
||||||
|
@ -721,14 +722,17 @@ struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
|
||||||
DenseArrayBaseAttrStorage(type, eltType, elements);
|
DenseArrayBaseAttrStorage(type, eltType, elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShapedType type;
|
||||||
DenseArrayBaseAttr::EltType eltType;
|
DenseArrayBaseAttr::EltType eltType;
|
||||||
::llvm::ArrayRef<char> elements;
|
ArrayRef<char> elements;
|
||||||
};
|
};
|
||||||
|
|
||||||
DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
|
DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
|
||||||
return getImpl()->eltType;
|
return getImpl()->eltType;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShapedType DenseArrayBaseAttr::getType() const { return getImpl()->type; }
|
||||||
|
|
||||||
const int8_t *
|
const int8_t *
|
||||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
|
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
|
||||||
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
|
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
|
||||||
|
@ -974,8 +978,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||||
|
|
||||||
// If the element type is not based on int/float/index, assume it is a string
|
// If the element type is not based on int/float/index, assume it is a string
|
||||||
// type.
|
// type.
|
||||||
auto eltType = type.getElementType();
|
Type eltType = type.getElementType();
|
||||||
if (!type.getElementType().isIntOrIndexOrFloat()) {
|
if (!eltType.isIntOrIndexOrFloat()) {
|
||||||
SmallVector<StringRef, 8> stringValues;
|
SmallVector<StringRef, 8> stringValues;
|
||||||
stringValues.reserve(values.size());
|
stringValues.reserve(values.size());
|
||||||
for (Attribute attr : values) {
|
for (Attribute attr : values) {
|
||||||
|
@ -995,14 +999,16 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||||
llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
|
llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
|
||||||
APInt intVal;
|
APInt intVal;
|
||||||
for (unsigned i = 0, e = values.size(); i < e; ++i) {
|
for (unsigned i = 0, e = values.size(); i < e; ++i) {
|
||||||
assert(eltType == values[i].getType() &&
|
if (auto floatAttr = values[i].dyn_cast<FloatAttr>()) {
|
||||||
"expected attribute value to have element type");
|
assert(floatAttr.getType() == eltType &&
|
||||||
if (eltType.isa<FloatType>())
|
"expected float attribute type to equal element type");
|
||||||
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
|
intVal = floatAttr.getValue().bitcastToAPInt();
|
||||||
else if (eltType.isa<IntegerType, IndexType>())
|
} else {
|
||||||
intVal = values[i].cast<IntegerAttr>().getValue();
|
auto intAttr = values[i].cast<IntegerAttr>();
|
||||||
else
|
assert(intAttr.getType() == eltType &&
|
||||||
llvm_unreachable("unexpected element type");
|
"expected integer attribute type to equal element type");
|
||||||
|
intVal = intAttr.getValue();
|
||||||
|
}
|
||||||
|
|
||||||
assert(intVal.getBitWidth() == bitWidth &&
|
assert(intVal.getBitWidth() == bitWidth &&
|
||||||
"expected value to have same bitwidth as element type");
|
"expected value to have same bitwidth as element type");
|
||||||
|
@ -1010,7 +1016,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle the special encoding of splat of bool.
|
// Handle the special encoding of splat of bool.
|
||||||
if (values.size() == 1 && values[0].getType().isInteger(1))
|
if (values.size() == 1 && eltType.isInteger(1))
|
||||||
data[0] = data[0] ? -1 : 0;
|
data[0] = data[0] ? -1 : 0;
|
||||||
|
|
||||||
return DenseIntOrFPElementsAttr::getRaw(type, data);
|
return DenseIntOrFPElementsAttr::getRaw(type, data);
|
||||||
|
@ -1326,7 +1332,7 @@ DenseElementsAttr DenseElementsAttr::mapValues(
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapedType DenseElementsAttr::getType() const {
|
ShapedType DenseElementsAttr::getType() const {
|
||||||
return Attribute::getType().cast<ShapedType>();
|
return static_cast<const DenseElementsAttributeStorage *>(impl)->type;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type DenseElementsAttr::getElementType() const {
|
Type DenseElementsAttr::getElementType() const {
|
||||||
|
@ -1546,8 +1552,9 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
|
||||||
|
|
||||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||||
bool DenseFPElementsAttr::classof(Attribute attr) {
|
bool DenseFPElementsAttr::classof(Attribute attr) {
|
||||||
return attr.isa<DenseElementsAttr>() &&
|
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
|
||||||
attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
return denseAttr.getType().getElementType().isa<FloatType>();
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1564,8 +1571,9 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
|
||||||
|
|
||||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||||
bool DenseIntElementsAttr::classof(Attribute attr) {
|
bool DenseIntElementsAttr::classof(Attribute attr) {
|
||||||
return attr.isa<DenseElementsAttr>() &&
|
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
|
||||||
attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
|
return denseAttr.getType().getElementType().isIntOrIndex();
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -896,10 +896,6 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
|
||||||
MLIRContext *ctx,
|
MLIRContext *ctx,
|
||||||
TypeID attrID) {
|
TypeID attrID) {
|
||||||
storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
|
storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
|
||||||
|
|
||||||
// If the attribute did not provide a type, then default to NoneType.
|
|
||||||
if (!storage->getType())
|
|
||||||
storage->setType(NoneType::get(ctx));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
|
BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
|
||||||
|
|
|
@ -32,7 +32,9 @@ Type mlir::getElementTypeOrSelf(Value val) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Type mlir::getElementTypeOrSelf(Attribute attr) {
|
Type mlir::getElementTypeOrSelf(Attribute attr) {
|
||||||
return getElementTypeOrSelf(attr.getType());
|
if (auto typedAttr = attr.dyn_cast<TypedAttr>())
|
||||||
|
return getElementTypeOrSelf(typedAttr.getType());
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
|
SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
|
||||||
|
|
|
@ -1652,7 +1652,9 @@ void ByteCodeExecutor::executeGetAttributeType() {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
|
LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
|
||||||
unsigned memIndex = read();
|
unsigned memIndex = read();
|
||||||
Attribute attr = read<Attribute>();
|
Attribute attr = read<Attribute>();
|
||||||
Type type = attr ? attr.getType() : Type();
|
Type type;
|
||||||
|
if (auto typedAttr = attr.dyn_cast<TypedAttr>())
|
||||||
|
type = typedAttr.getType();
|
||||||
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
|
LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
|
||||||
<< " * Result: " << type << "\n");
|
<< " * Result: " << type << "\n");
|
||||||
|
|
|
@ -283,7 +283,8 @@ bool AttrOrTypeParameter::isOptional() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
|
Optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
|
||||||
return getDefValue<llvm::StringInit>("defaultValue");
|
Optional<StringRef> result = getDefValue<llvm::StringInit>("defaultValue");
|
||||||
|
return result && !result->empty() ? result : llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
|
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
|
||||||
|
|
|
@ -823,7 +823,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
|
||||||
if (auto type = attr.dyn_cast<TypeAttr>())
|
if (auto type = attr.dyn_cast<TypeAttr>())
|
||||||
return emitType(loc, type.getValue());
|
return emitType(loc, type.getValue());
|
||||||
|
|
||||||
return emitError(loc, "cannot emit attribute of type ") << attr.getType();
|
return emitError(loc, "cannot emit attribute: ") << attr;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult CppEmitter::emitOperands(Operation &op) {
|
LogicalResult CppEmitter::emitOperands(Operation &op) {
|
||||||
|
|
|
@ -769,7 +769,8 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
|
||||||
|
|
||||||
// Process the type for this bool literal
|
// Process the type for this bool literal
|
||||||
uint32_t typeID = 0;
|
uint32_t typeID = 0;
|
||||||
if (failed(processType(loc, boolAttr.getType(), typeID))) {
|
if (failed(
|
||||||
|
processType(loc, boolAttr.cast<IntegerAttr>().getType(), typeID))) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -332,7 +332,7 @@ llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
|
llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
|
||||||
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
|
// expected-error @+1 {{expected array attribute with two elements of the same type}}
|
||||||
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
|
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
|
||||||
llvm.return %0 : !llvm.struct<(f64, f64)>
|
llvm.return %0 : !llvm.struct<(f64, f64)>
|
||||||
}
|
}
|
||||||
|
@ -547,7 +547,7 @@ func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
|
||||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||||
// expected-error@+1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
|
// expected-error@+1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
|
||||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
|
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
|
||||||
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||||
}
|
}
|
||||||
|
@ -571,7 +571,7 @@ func.func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||||
// expected-error@+1 {{op requires attribute 'layoutA'}}
|
// expected-error@+1 {{op requires attribute 'layoutA'}}
|
||||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
|
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
|
||||||
{shape = #nvvm.shape<m = 8, n = 8, k = 4>}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
{shape = #nvvm.shape<m = 8, n = 8, k = 4>}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||||
}
|
}
|
||||||
|
@ -594,7 +594,7 @@ func.func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
|
||||||
// expected-error@+1 {{op requires b1Op attribute}}
|
// expected-error@+1 {{op requires b1Op attribute}}
|
||||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||||
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
|
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
|
||||||
shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,7 @@ func.func @const() -> () {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func.func @unaccepted_std_attr() -> () {
|
func.func @unaccepted_std_attr() -> () {
|
||||||
// expected-error @+1 {{cannot have value of type 'none'}}
|
// expected-error @+1 {{cannot have attribute: unit}}
|
||||||
%0 = spv.Constant unit : none
|
%0 = spv.Constant unit : none
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
// CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000"
|
// CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000"
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
module attributes { test.blob_ref = #test.e1di64_elements<blob1> } {}
|
module attributes { test.blob_ref = #test.e1di64_elements<blob1> : tensor<*xi1>} {}
|
||||||
|
|
||||||
{-#
|
{-#
|
||||||
dialect_resources: {
|
dialect_resources: {
|
||||||
|
|
|
@ -53,18 +53,22 @@ def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// An attribute testing AttributeSelfTypeParameter.
|
// An attribute testing AttributeSelfTypeParameter.
|
||||||
def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> {
|
def AttrWithSelfTypeParam
|
||||||
|
: Test_Attr<"AttrWithSelfTypeParam", [TypedAttrInterface]> {
|
||||||
let mnemonic = "attr_with_self_type_param";
|
let mnemonic = "attr_with_self_type_param";
|
||||||
let parameters = (ins AttributeSelfTypeParameter<"">:$type);
|
let parameters = (ins AttributeSelfTypeParameter<"">:$type);
|
||||||
let assemblyFormat = "";
|
let assemblyFormat = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
// An attribute testing AttributeSelfTypeParameter.
|
// An attribute testing AttributeSelfTypeParameter.
|
||||||
def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> {
|
def AttrWithTypeBuilder
|
||||||
|
: Test_Attr<"AttrWithTypeBuilder", [TypedAttrInterface]> {
|
||||||
let mnemonic = "attr_with_type_builder";
|
let mnemonic = "attr_with_type_builder";
|
||||||
let parameters = (ins "::mlir::IntegerAttr":$attr);
|
let parameters = (ins
|
||||||
let typeBuilder = "$_attr.getType()";
|
"::mlir::IntegerAttr":$attr,
|
||||||
let hasCustomAssemblyFormat = 1;
|
AttributeSelfTypeParameter<"", "mlir::Type", "$attr.getType()">:$type
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$attr";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
|
def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
|
||||||
|
@ -76,7 +80,7 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
|
||||||
|
|
||||||
// Test support for ElementsAttrInterface.
|
// Test support for ElementsAttrInterface.
|
||||||
def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
|
def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
|
||||||
ElementsAttrInterface
|
ElementsAttrInterface, TypedAttrInterface
|
||||||
]> {
|
]> {
|
||||||
let mnemonic = "i64_elements";
|
let mnemonic = "i64_elements";
|
||||||
let parameters = (ins
|
let parameters = (ins
|
||||||
|
@ -215,7 +219,7 @@ def TestAttrWithTypeParam : Test_Attr<"TestAttrWithTypeParam"> {
|
||||||
|
|
||||||
// Test self type parameter with assembly format.
|
// Test self type parameter with assembly format.
|
||||||
def TestAttrSelfTypeParameterFormat
|
def TestAttrSelfTypeParameterFormat
|
||||||
: Test_Attr<"TestAttrSelfTypeParameterFormat"> {
|
: Test_Attr<"TestAttrSelfTypeParameterFormat", [TypedAttrInterface]> {
|
||||||
let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type);
|
let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type);
|
||||||
|
|
||||||
let mnemonic = "attr_self_type_format";
|
let mnemonic = "attr_self_type_format";
|
||||||
|
@ -237,7 +241,7 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
|
||||||
|
|
||||||
// Test simple extern 1D vector using ElementsAttrInterface.
|
// Test simple extern 1D vector using ElementsAttrInterface.
|
||||||
def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
|
def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
|
||||||
ElementsAttrInterface
|
ElementsAttrInterface, TypedAttrInterface
|
||||||
]> {
|
]> {
|
||||||
let mnemonic = "e1di64_elements";
|
let mnemonic = "e1di64_elements";
|
||||||
let parameters = (ins
|
let parameters = (ins
|
||||||
|
|
|
@ -27,21 +27,6 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace test;
|
using namespace test;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// AttrWithTypeBuilderAttr
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) {
|
|
||||||
IntegerAttr element;
|
|
||||||
if (parser.parseAttribute(element))
|
|
||||||
return Attribute();
|
|
||||||
return get(parser.getContext(), element);
|
|
||||||
}
|
|
||||||
|
|
||||||
void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const {
|
|
||||||
printer << " " << getAttr();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// CompoundAAttr
|
// CompoundAAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -114,10 +99,11 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult TestAttrWithFormatAttr::verify(
|
LogicalResult
|
||||||
function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
|
TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||||
IntegerAttr three, ArrayRef<int> four,
|
int64_t one, std::string two, IntegerAttr three,
|
||||||
ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrWithTypeBuilderAttr) {
|
ArrayRef<int> four,
|
||||||
|
ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) {
|
||||||
if (four.size() != static_cast<unsigned>(one))
|
if (four.size() != static_cast<unsigned>(one))
|
||||||
return emitError() << "expected 'one' to equal 'four.size()'";
|
return emitError() << "expected 'one' to equal 'four.size()'";
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -554,7 +554,7 @@ def OperandsHaveSameType :
|
||||||
def ResultHasSameTypeAsAttr :
|
def ResultHasSameTypeAsAttr :
|
||||||
TEST_Op<"result_has_same_type_as_attr",
|
TEST_Op<"result_has_same_type_as_attr",
|
||||||
[AllTypesMatch<["attr", "result"]>]> {
|
[AllTypesMatch<["attr", "result"]>]> {
|
||||||
let arguments = (ins AnyAttr:$attr);
|
let arguments = (ins TypedAttrInterface:$attr);
|
||||||
let results = (outs AnyType:$result);
|
let results = (outs AnyType:$result);
|
||||||
let assemblyFormat = "$attr `->` type($result) attr-dict";
|
let assemblyFormat = "$attr `->` type($result) attr-dict";
|
||||||
}
|
}
|
||||||
|
@ -2310,7 +2310,7 @@ def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
|
||||||
def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
|
def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
|
||||||
AllTypesMatch<["value1", "value2", "result"]>
|
AllTypesMatch<["value1", "value2", "result"]>
|
||||||
]> {
|
]> {
|
||||||
let arguments = (ins AnyAttr:$value1, AnyType:$value2);
|
let arguments = (ins TypedAttrInterface:$value1, AnyType:$value2);
|
||||||
let results = (outs AnyType:$result);
|
let results = (outs AnyType:$result);
|
||||||
let assemblyFormat = "attr-dict $value1 `,` $value2";
|
let assemblyFormat = "attr-dict $value1 `,` $value2";
|
||||||
}
|
}
|
||||||
|
@ -2338,7 +2338,7 @@ def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [
|
||||||
def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
|
def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
|
||||||
TypesMatchWith<"result type matches constant", "value", "result", "$_self">
|
TypesMatchWith<"result type matches constant", "value", "result", "$_self">
|
||||||
]> {
|
]> {
|
||||||
let arguments = (ins AnyAttr:$value);
|
let arguments = (ins TypedAttrInterface:$value);
|
||||||
let results = (outs AnyType:$result);
|
let results = (outs AnyType:$result);
|
||||||
let assemblyFormat = "attr-dict $value";
|
let assemblyFormat = "attr-dict $value";
|
||||||
}
|
}
|
||||||
|
|
|
@ -164,9 +164,13 @@ def AttrC : TestAttr<"TestF"> {
|
||||||
|
|
||||||
/// Test attribute with self type parameter
|
/// Test attribute with self type parameter
|
||||||
|
|
||||||
// ATTR: TestGAttr::parse
|
// ATTR-LABEL: Attribute TestGAttr::parse
|
||||||
// ATTR: return TestGAttr::get
|
// ATTR: if (odsType)
|
||||||
// ATTR: odsType
|
// ATTR: if (auto reqType = odsType.dyn_cast<::mlir::Type>())
|
||||||
|
// ATTR: _result_type = reqType
|
||||||
|
// ATTR: TestGAttr::get
|
||||||
|
// ATTR-NEXT: *_result_a
|
||||||
|
// ATTR-NEXT: _result_type.value_or(::mlir::NoneType::get(
|
||||||
def AttrD : TestAttr<"TestG"> {
|
def AttrD : TestAttr<"TestG"> {
|
||||||
let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type);
|
let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type);
|
||||||
let mnemonic = "attr_d";
|
let mnemonic = "attr_d";
|
||||||
|
|
|
@ -77,11 +77,12 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
|
||||||
// DECL: int getWidthOfSomething() const;
|
// DECL: int getWidthOfSomething() const;
|
||||||
// DECL: ::test::SimpleTypeA getExampleTdType() const;
|
// DECL: ::test::SimpleTypeA getExampleTdType() const;
|
||||||
// DECL: ::llvm::APFloat getApFloat() const;
|
// DECL: ::llvm::APFloat getApFloat() const;
|
||||||
|
// DECL: ::mlir::Type getInner() const;
|
||||||
|
|
||||||
// Check that AttributeSelfTypeParameter is handled properly.
|
// Check that AttributeSelfTypeParameter is handled properly.
|
||||||
// DEF-LABEL: struct CompoundAAttrStorage
|
// DEF-LABEL: struct CompoundAAttrStorage
|
||||||
// DEF: CompoundAAttrStorage(
|
// DEF: CompoundAAttrStorage(
|
||||||
// DEF-SAME: : ::mlir::AttributeStorage(inner),
|
// DEF-SAME: inner(inner)
|
||||||
|
|
||||||
// DEF: bool operator==(const KeyTy &tblgenKey) const {
|
// DEF: bool operator==(const KeyTy &tblgenKey) const {
|
||||||
// DEF-NEXT: return
|
// DEF-NEXT: return
|
||||||
|
@ -89,14 +90,14 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
|
||||||
// DEF-SAME: (exampleTdType == std::get<1>(tblgenKey)) &&
|
// DEF-SAME: (exampleTdType == std::get<1>(tblgenKey)) &&
|
||||||
// DEF-SAME: (apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))) &&
|
// DEF-SAME: (apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))) &&
|
||||||
// DEF-SAME: (dims == std::get<3>(tblgenKey)) &&
|
// DEF-SAME: (dims == std::get<3>(tblgenKey)) &&
|
||||||
// DEF-SAME: (getType() == std::get<4>(tblgenKey));
|
// DEF-SAME: (inner == std::get<4>(tblgenKey));
|
||||||
|
|
||||||
// DEF: static CompoundAAttrStorage *construct
|
// DEF: static CompoundAAttrStorage *construct
|
||||||
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
|
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
|
||||||
// DEF-SAME: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
|
// DEF-SAME: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
|
||||||
|
|
||||||
// DEF: ::mlir::Type CompoundAAttr::getInner() const {
|
// DEF: ::mlir::Type CompoundAAttr::getInner() const {
|
||||||
// DEF-NEXT: return getImpl()->getType().cast<::mlir::Type>();
|
// DEF-NEXT: return getImpl()->inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
def C_IndexAttr : TestAttr<"Index"> {
|
def C_IndexAttr : TestAttr<"Index"> {
|
||||||
|
@ -127,18 +128,6 @@ def D_SingleParameterAttr : TestAttr<"SingleParameter"> {
|
||||||
// DECL-SAME: detail::SingleParameterAttrStorage
|
// DECL-SAME: detail::SingleParameterAttrStorage
|
||||||
}
|
}
|
||||||
|
|
||||||
// An attribute testing AttributeSelfTypeParameter.
|
|
||||||
def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> {
|
|
||||||
let mnemonic = "attr_with_type_builder";
|
|
||||||
let parameters = (ins "::mlir::IntegerAttr":$attr);
|
|
||||||
let typeBuilder = "$_attr.getType()";
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// DEF-LABEL: struct AttrWithTypeBuilderAttrStorage
|
|
||||||
// DEF: AttrWithTypeBuilderAttrStorage(::mlir::IntegerAttr attr)
|
|
||||||
// DEF-SAME: : ::mlir::AttributeStorage(attr.getType()), attr(attr)
|
|
||||||
|
|
||||||
def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
|
def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
|
||||||
let parameters = (ins AttrParameter<"std::string", "", "StringRef">:$param);
|
let parameters = (ins AttrParameter<"std::string", "", "StringRef">:$param);
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,7 @@ def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
||||||
|
|
||||||
// CHECK-LABEL: OpE definitions
|
// CHECK-LABEL: OpE definitions
|
||||||
// CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
|
// CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
|
||||||
// CHECK: odsState.addTypes({attr.getValue().getType()});
|
// CHECK: odsState.addTypes({attr.getValue().cast<::mlir::TypedAttr>().getType()});
|
||||||
|
|
||||||
def OpF : NS_Op<"one_variadic_result_op", []> {
|
def OpF : NS_Op<"one_variadic_result_op", []> {
|
||||||
let results = (outs Variadic<I32>:$x);
|
let results = (outs Variadic<I32>:$x);
|
||||||
|
@ -155,5 +155,5 @@ def OpL3 : NS_Op<"op_with_all_types_constraint",
|
||||||
|
|
||||||
// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
|
// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
|
||||||
// CHECK-NOT: }
|
// CHECK-NOT: }
|
||||||
// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType();
|
// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").cast<::mlir::TypedAttr>().getType();
|
||||||
// CHECK: inferredReturnTypes[0] = odsInferredType0;
|
// CHECK: inferredReturnTypes[0] = odsInferredType0;
|
||||||
|
|
|
@ -295,11 +295,7 @@ void DefGen::emitAccessors() {
|
||||||
// class. Otherwise, let the user define the exact accessor definition.
|
// class. Otherwise, let the user define the exact accessor definition.
|
||||||
if (!def.genStorageClass())
|
if (!def.genStorageClass())
|
||||||
continue;
|
continue;
|
||||||
auto scope = m->body().indent().scope("return getImpl()->", ";");
|
m->body().indent() << "return getImpl()->" << param.getName() << ";";
|
||||||
if (isa<AttributeSelfTypeParameter>(param))
|
|
||||||
m->body() << formatv("getType().cast<{0}>()", param.getCppType());
|
|
||||||
else
|
|
||||||
m->body() << param.getName();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -450,37 +446,8 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
|
||||||
void DefGen::emitStorageConstructor() {
|
void DefGen::emitStorageConstructor() {
|
||||||
Constructor *ctor =
|
Constructor *ctor =
|
||||||
storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
|
storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
|
||||||
if (auto *attrDef = dyn_cast<AttrDef>(&def)) {
|
for (auto ¶m : params)
|
||||||
// For attributes, a parameter marked with AttributeSelfTypeParameter is
|
ctor->addMemberInitializer(param.getName(), param.getName());
|
||||||
// the type initializer that must be passed to the parent constructor.
|
|
||||||
const auto isSelfType = [](const AttrOrTypeParameter ¶m) {
|
|
||||||
return isa<AttributeSelfTypeParameter>(param);
|
|
||||||
};
|
|
||||||
auto *selfTypeParam = llvm::find_if(params, isSelfType);
|
|
||||||
if (std::count_if(selfTypeParam, params.end(), isSelfType) > 1) {
|
|
||||||
PrintFatalError(def.getLoc(),
|
|
||||||
"Only one attribute parameter can be marked as "
|
|
||||||
"AttributeSelfTypeParameter");
|
|
||||||
}
|
|
||||||
// Alternatively, if a type builder was specified, use that instead.
|
|
||||||
std::string attrStorageInit =
|
|
||||||
selfTypeParam == params.end() ? "" : selfTypeParam->getName().str();
|
|
||||||
if (attrDef->getTypeBuilder()) {
|
|
||||||
FmtContext ctx;
|
|
||||||
for (auto ¶m : params)
|
|
||||||
ctx.addSubst(strfmt("_{0}", param.getName()), param.getName());
|
|
||||||
attrStorageInit = tgfmt(*attrDef->getTypeBuilder(), &ctx);
|
|
||||||
}
|
|
||||||
ctor->addMemberInitializer("::mlir::AttributeStorage",
|
|
||||||
std::move(attrStorageInit));
|
|
||||||
// Initialize members that aren't the attribute's type.
|
|
||||||
for (auto ¶m : params)
|
|
||||||
if (selfTypeParam == params.end() || *selfTypeParam != param)
|
|
||||||
ctor->addMemberInitializer(param.getName(), param.getName());
|
|
||||||
} else {
|
|
||||||
for (auto ¶m : params)
|
|
||||||
ctor->addMemberInitializer(param.getName(), param.getName());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DefGen::emitKeyType() {
|
void DefGen::emitKeyType() {
|
||||||
|
@ -498,9 +465,7 @@ void DefGen::emitEquals() {
|
||||||
auto &body = eq->body().indent();
|
auto &body = eq->body().indent();
|
||||||
auto scope = body.scope("return (", ");");
|
auto scope = body.scope("return (", ");");
|
||||||
const auto eachFn = [&](auto it) {
|
const auto eachFn = [&](auto it) {
|
||||||
FmtContext ctx({{"_lhs", isa<AttributeSelfTypeParameter>(it.value())
|
FmtContext ctx({{"_lhs", it.value().getName()},
|
||||||
? "getType()"
|
|
||||||
: it.value().getName()},
|
|
||||||
{"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
|
{"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
|
||||||
body << tgfmt(it.value().getComparator(), &ctx);
|
body << tgfmt(it.value().getComparator(), &ctx);
|
||||||
};
|
};
|
||||||
|
@ -566,8 +531,7 @@ void DefGen::emitStorageClass() {
|
||||||
// Emit the storage class members as public, at the very end of the struct.
|
// Emit the storage class members as public, at the very end of the struct.
|
||||||
storageCls->finalize();
|
storageCls->finalize();
|
||||||
for (auto ¶m : params)
|
for (auto ¶m : params)
|
||||||
if (!isa<AttributeSelfTypeParameter>(param))
|
storageCls->declare<Field>(param.getCppType(), param.getName());
|
||||||
storageCls->declare<Field>(param.getCppType(), param.getName());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -246,6 +246,39 @@ private:
|
||||||
// ParserGen
|
// ParserGen
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Generate a special-case "parser" for an attribute's self type parameter. The
|
||||||
|
/// self type parameter has special handling in the assembly format in that it
|
||||||
|
/// is derived from the optional trailing colon type after the attribute.
|
||||||
|
static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx,
|
||||||
|
const AttributeSelfTypeParameter ¶m) {
|
||||||
|
// "Parser" for an attribute self type parameter that checks the
|
||||||
|
// optionally-parsed trailing colon type.
|
||||||
|
//
|
||||||
|
// $0: The C++ storage class of the type parameter.
|
||||||
|
// $1: The self type parameter name.
|
||||||
|
const char *const selfTypeParser = R"(
|
||||||
|
if ($_type) {
|
||||||
|
if (auto reqType = $_type.dyn_cast<$0>()) {
|
||||||
|
_result_$1 = reqType;
|
||||||
|
} else {
|
||||||
|
$_parser.emitError($_loc, "invalid kind of type specified");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
})";
|
||||||
|
|
||||||
|
// If the attribute self type parameter is required, emit code that emits an
|
||||||
|
// error if the trailing type was not parsed.
|
||||||
|
const char *const selfTypeRequired = R"( else {
|
||||||
|
$_parser.emitError($_loc, "expected a trailing type");
|
||||||
|
return {};
|
||||||
|
})";
|
||||||
|
|
||||||
|
os << tgfmt(selfTypeParser, &ctx, param.getCppStorageType(), param.getName());
|
||||||
|
if (!param.isOptional())
|
||||||
|
os << tgfmt(selfTypeRequired, &ctx);
|
||||||
|
os << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
void DefFormat::genParser(MethodBody &os) {
|
void DefFormat::genParser(MethodBody &os) {
|
||||||
FmtContext ctx;
|
FmtContext ctx;
|
||||||
ctx.addSubst("_parser", "odsParser");
|
ctx.addSubst("_parser", "odsParser");
|
||||||
|
@ -262,8 +295,6 @@ void DefFormat::genParser(MethodBody &os) {
|
||||||
// a loop (parsers return FailureOr anyways).
|
// a loop (parsers return FailureOr anyways).
|
||||||
ArrayRef<AttrOrTypeParameter> params = def.getParameters();
|
ArrayRef<AttrOrTypeParameter> params = def.getParameters();
|
||||||
for (const AttrOrTypeParameter ¶m : params) {
|
for (const AttrOrTypeParameter ¶m : params) {
|
||||||
if (isa<AttributeSelfTypeParameter>(param))
|
|
||||||
continue;
|
|
||||||
os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
|
os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
|
||||||
param.getCppStorageType(), param.getName());
|
param.getCppStorageType(), param.getName());
|
||||||
}
|
}
|
||||||
|
@ -281,7 +312,9 @@ void DefFormat::genParser(MethodBody &os) {
|
||||||
// Emit an assert for each mandatory parameter. Triggering an assert means
|
// Emit an assert for each mandatory parameter. Triggering an assert means
|
||||||
// the generated parser is incorrect (i.e. there is a bug in this code).
|
// the generated parser is incorrect (i.e. there is a bug in this code).
|
||||||
for (const AttrOrTypeParameter ¶m : params) {
|
for (const AttrOrTypeParameter ¶m : params) {
|
||||||
if (param.isOptional() || isa<AttributeSelfTypeParameter>(param))
|
if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(¶m))
|
||||||
|
genAttrSelfTypeParser(os, ctx, *selfTypeParam);
|
||||||
|
if (param.isOptional())
|
||||||
continue;
|
continue;
|
||||||
os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName());
|
os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName());
|
||||||
}
|
}
|
||||||
|
@ -306,11 +339,10 @@ void DefFormat::genParser(MethodBody &os) {
|
||||||
else
|
else
|
||||||
selfOs << param.getCppStorageType() << "()";
|
selfOs << param.getCppStorageType() << "()";
|
||||||
selfOs << "))";
|
selfOs << "))";
|
||||||
} else if (isa<AttributeSelfTypeParameter>(param)) {
|
|
||||||
selfOs << tgfmt("$_type", &ctx);
|
|
||||||
} else {
|
} else {
|
||||||
selfOs << formatv("(*_result_{0})", param.getName());
|
selfOs << formatv("(*_result_{0})", param.getName());
|
||||||
}
|
}
|
||||||
|
ctx.addSubst(param.getName(), selfOs.str());
|
||||||
os << param.getCppType() << "("
|
os << param.getCppType() << "("
|
||||||
<< tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
|
<< tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
|
||||||
<< ")";
|
<< ")";
|
||||||
|
|
|
@ -578,7 +578,8 @@ static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
|
||||||
// Populate substitutions for attributes.
|
// Populate substitutions for attributes.
|
||||||
auto &op = emitHelper.getOp();
|
auto &op = emitHelper.getOp();
|
||||||
for (const auto &namedAttr : op.getAttributes())
|
for (const auto &namedAttr : op.getAttributes())
|
||||||
ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str());
|
ctx.addSubst(namedAttr.name,
|
||||||
|
emitHelper.getOp().getGetterName(namedAttr.name) + "()");
|
||||||
|
|
||||||
// Populate substitutions for named operands.
|
// Populate substitutions for named operands.
|
||||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||||
|
@ -1756,7 +1757,7 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
|
||||||
if (namedAttr.attr.isTypeAttr()) {
|
if (namedAttr.attr.isTypeAttr()) {
|
||||||
resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()";
|
resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()";
|
||||||
} else {
|
} else {
|
||||||
resultType = "attr.getValue().getType()";
|
resultType = "attr.getValue().cast<::mlir::TypedAttr>().getType()";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Operands
|
// Operands
|
||||||
|
@ -2416,7 +2417,8 @@ void OpEmitter::genTypeInterfaceMethods() {
|
||||||
} else {
|
} else {
|
||||||
auto *attr =
|
auto *attr =
|
||||||
op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
|
op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
|
||||||
body << "attributes.get(\"" << attr->name << "\").getType()";
|
body << "attributes.get(\"" << attr->name
|
||||||
|
<< "\").cast<::mlir::TypedAttr>().getType()";
|
||||||
}
|
}
|
||||||
body << ";\n";
|
body << ";\n";
|
||||||
}
|
}
|
||||||
|
|
|
@ -237,16 +237,19 @@ TEST(SparseElementsAttrTest, GetZero) {
|
||||||
|
|
||||||
// Only index (0, 0) contains an element, others are supposed to return
|
// Only index (0, 0) contains an element, others are supposed to return
|
||||||
// the zero/empty value.
|
// the zero/empty value.
|
||||||
auto zeroIntValue = sparseInt.getValues<Attribute>()[{1, 1}];
|
auto zeroIntValue =
|
||||||
EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
|
sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>();
|
||||||
|
EXPECT_EQ(zeroIntValue.getInt(), 0);
|
||||||
EXPECT_TRUE(zeroIntValue.getType() == intTy);
|
EXPECT_TRUE(zeroIntValue.getType() == intTy);
|
||||||
|
|
||||||
auto zeroFloatValue = sparseFloat.getValues<Attribute>()[{1, 1}];
|
auto zeroFloatValue =
|
||||||
EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
|
sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>();
|
||||||
|
EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
|
||||||
EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
|
EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
|
||||||
|
|
||||||
auto zeroStringValue = sparseString.getValues<Attribute>()[{1, 1}];
|
auto zeroStringValue =
|
||||||
EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
|
sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>();
|
||||||
|
EXPECT_TRUE(zeroStringValue.getValue().empty());
|
||||||
EXPECT_TRUE(zeroStringValue.getType() == stringTy);
|
EXPECT_TRUE(zeroStringValue.getType() == stringTy);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue