[MLIR][SPIRV] Add intel joint matrix ops

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D131586
This commit is contained in:
Nirvedh Meshram 2022-08-04 22:23:27 +00:00
parent c63f2581f4
commit b8f62dc22a
17 changed files with 891 additions and 26 deletions

View File

@ -27,12 +27,12 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand1,
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand2
);
let results = (outs
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$result
);
let assemblyFormat = "operands attr-dict `:` type($result)";
}

View File

@ -64,6 +64,27 @@ def SPV_CooperativeMatrixPropertiesNVArrayAttr :
TypedArrayAttrBase<SPV_CooperativeMatrixPropertiesNVAttr,
"CooperativeMatrixPropertiesNV array attribute">;
// Description of the supported joint matrix operations. See
// https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
def SPV_JointMatrixPropertiesINTELAttr :
SPV_Attr<"JointMatrixPropertiesINTEL", "joint_matrix_props"> {
let parameters = (ins
"int":$m_size,
"int":$n_size,
"int":$k_size,
"mlir::Type":$a_type,
"mlir::Type":$b_type,
"mlir::Type":$c_type,
"mlir::Type":$result_type,
"mlir::spirv::ScopeAttr":$scope
);
let assemblyFormat = "`<` struct(params) `>`";
}
def SPV_JointMatrixPropertiesINTELArrayAttr :
TypedArrayAttrBase<SPV_JointMatrixPropertiesINTELAttr,
"JointMatrixPropertiesINTEL array attribute">;
// This attribute specifies the limits for various resources on the target
// architecture.
//

View File

@ -387,6 +387,7 @@ def SPV_INTEL_debug_module : I32EnumAttrCase<"SPV_INTEL_de
def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp_fast_math_mode", 4027>;
def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@ -443,7 +444,7 @@ def SPV_ExtensionAttr :
SPV_INTEL_usm_storage_classes, SPV_INTEL_io_pipes, SPV_INTEL_blocking_pipes,
SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix,
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@ -1390,6 +1391,12 @@ def SPV_C_ShaderStereoViewNV : I32EnumAttrCase<"ShaderS
];
}
def SPV_C_JointMatrixINTEL : I32EnumAttrCase<"JointMatrixINTEL", 6118> {
list<Availability> availability = [
Extension<[SPV_INTEL_joint_matrix]>
];
}
def SPV_CapabilityAttr :
SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16,
@ -1481,7 +1488,7 @@ def SPV_CapabilityAttr :
SPV_C_UniformTexelBufferArrayNonUniformIndexing,
SPV_C_StorageTexelBufferArrayNonUniformIndexing,
SPV_C_ShaderViewportIndexLayerEXT, SPV_C_ShaderViewportMaskNV,
SPV_C_ShaderStereoViewNV
SPV_C_ShaderStereoViewNV, SPV_C_JointMatrixINTEL
]>;
def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@ -3981,6 +3988,16 @@ def SPV_SamplerUseAttr: SPV_I32EnumAttr<
"image_sampler_use_info",
[SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
def SPV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 0>;
def SPV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 1>;
def SPV_ML_PackedA : I32EnumAttrCase<"PackedA", 2>;
def SPV_ML_PackedB : I32EnumAttrCase<"PackedB", 3>;
def SPV_MatrixLayoutAttr :
SPV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
SPV_ML_ColumnMajor, SPV_ML_RowMajor, SPV_ML_PackedA, SPV_ML_PackedB
]>;
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
@ -4013,6 +4030,8 @@ def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
def SPV_IsCooperativeMatrixType :
CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
def SPV_IsImageType : CPred<"$_self.isa<::mlir::spirv::ImageType>()">;
def SPV_IsJointMatrixType :
CPred<"$_self.isa<::mlir::spirv::JointMatrixINTELType>()">;
def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">;
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
@ -4043,6 +4062,8 @@ def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
"any SPIR-V cooperative matrix type">;
def SPV_AnyImage : DialectType<SPIRV_Dialect, SPV_IsImageType,
"any SPIR-V image type">;
def SPV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPV_IsJointMatrixType,
"any SPIR-V joint matrix type">;
def SPV_AnyMatrix : DialectType<SPIRV_Dialect, SPV_IsMatrixType,
"any SPIR-V matrix type">;
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
@ -4057,11 +4078,12 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
def SPV_Composite :
AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>;
SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix]>;
def SPV_Type : AnyTypeOf<[
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
SPV_AnyCooperativeMatrix, SPV_AnyMatrix, SPV_AnySampledImage
SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix,
SPV_AnySampledImage
]>;
def SPV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@ -4072,6 +4094,11 @@ class SPV_CoopMatrixOfType<list<Type> allowedTypes> :
"$_self.cast<::mlir::spirv::CooperativeMatrixNVType>().getElementType()",
"Cooperative Matrix">;
class SPV_JointMatrixOfType<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, SPV_IsJointMatrixType,
"$_self.cast<::mlir::spirv::JointMatrixINTELType>().getElementType()",
"Joint Matrix">;
class SPV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
@ -4079,6 +4106,14 @@ class SPV_ScalarOrVectorOrCoopMatrixOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
SPV_CoopMatrixOfType<[type]>]>;
class SPV_ScalarOrVectorOrJointMatrixOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
SPV_JointMatrixOfType<[type]>]>;
class SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
SPV_CoopMatrixOfType<[type]>, SPV_JointMatrixOfType<[type]> ]>;
def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
@ -4311,6 +4346,11 @@ def SPV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINT
def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
def SPV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
def SPV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
def SPV_OC_OpTypeJointMatrixINTEL : I32EnumAttrCase<"OpTypeJointMatrixINTEL", 6119>;
def SPV_OC_OpJointMatrixLoadINTEL : I32EnumAttrCase<"OpJointMatrixLoadINTEL", 6120>;
def SPV_OC_OpJointMatrixStoreINTEL : I32EnumAttrCase<"OpJointMatrixStoreINTEL", 6121>;
def SPV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>;
def SPV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>;
def SPV_OpcodeAttr :
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@ -4376,7 +4416,10 @@ def SPV_OpcodeAttr :
SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV,
SPV_OC_OpSubgroupBlockReadINTEL, SPV_OC_OpSubgroupBlockWriteINTEL,
SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT
SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT,
SPV_OC_OpTypeJointMatrixINTEL, SPV_OC_OpJointMatrixLoadINTEL,
SPV_OC_OpJointMatrixStoreINTEL, SPV_OC_OpJointMatrixMadINTEL,
SPV_OC_OpTypejointMatrixWorkItemLengthINTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

View File

@ -23,11 +23,11 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultShape])> {
let arguments = (ins
SPV_ScalarOrVectorOrCoopMatrixOf<operandType>:$operand
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<operandType>:$operand
);
let results = (outs
SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<resultType>:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)

View File

@ -0,0 +1,248 @@
//===- SPIRVJointMatrixOps.td - joint matmul ---------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the op definition spec of joint matrix multiply extension ops.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
#define MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
// -----
def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
[NoSideEffect]> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
Return number of components owned by the current work-item in
a joint matrix.
Result Type must be an 32-bit unsigned integer type scalar.
Type is a joint matrix type.
``` {.ebnf}
joint-matrix-length-op ::= ssa-id `=` `spv.JointMatrixWorkItemLengthINTEL
` : ` joint-matrix-type
```
For example:
```
%0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<Subgroup, i32, 8, 16>
```
}];
let assemblyFormat = "attr-dict `:` $type";
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_6>,
Extension<[SPV_INTEL_joint_matrix]>,
Capability<[SPV_C_JointMatrixINTEL]>
];
let arguments = (ins
TypeAttr:$type
);
let results = (outs
SPV_Int32:$result
);
let hasVerifier = 0;
}
// -----
def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
Load a matrix through a pointer.
Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
Pointer is the pointer to load through. It specifies start of memory region where
elements of the matrix are stored and arranged according to Layout.
Stride is the number of elements in memory between beginnings of successive rows,
columns (or words) in the result. It must be a scalar integer type.
Layout indicates how the values loaded from memory are arranged. It must be the
result of a constant instruction.
Scope is syncronization scope for operation on the matrix. It must be the result
of a constant instruction with scalar integer type.
If present, any Memory Operands must begin with a memory operand literal. If not
present, it is the same as specifying the memory operand None.
#### Example:
```mlir
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
{memory_access = #spv.memory_access<Volatile>} :
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
!spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
```
}];
let assemblyFormat = [{
$scope $layout operands attr-dict `:` `(` type(operands) `)` `->` type($result)
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_6>,
Extension<[SPV_INTEL_joint_matrix]>,
Capability<[SPV_C_JointMatrixINTEL]>
];
let arguments = (ins
SPV_ScopeAttr:$scope,
SPV_MatrixLayoutAttr:$layout,
SPV_AnyPtr:$pointer,
SPV_Integer:$stride,
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
OptionalAttr<I32Attr>:$alignment
);
let results = (outs
SPV_AnyJointMatrix:$result
);
}
// -----
def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
Multiply matrix A by matrix B and add matrix C to the result
of the multiplication: A*B+C. Here A is a M x K matrix, B is
a K x N matrix and C is a M x N matrix.
Behavior is undefined if sizes of operands do not meet the
conditions above. All operands and the Result Type must be
OpTypeJointMatrixINTEL.
A must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to M and Column Count
equals to K
B must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to K and Column Count
equals to N
C and Result Type must be a OpTypeJointMatrixINTEL with Row
Count equals to M and Column Count equals to N
Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
integer type.
#### Example:
```mlir
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
-> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
```
}];
let assemblyFormat = [{
$scope operands attr-dict`:` type($a) `,` type($b) `->` type($c)
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_6>,
Extension<[SPV_INTEL_joint_matrix]>,
Capability<[SPV_C_JointMatrixINTEL]>
];
let arguments = (ins
SPV_ScopeAttr:$scope,
SPV_AnyJointMatrix:$a,
SPV_AnyJointMatrix:$b,
SPV_AnyJointMatrix:$c
);
let results = (outs
SPV_AnyJointMatrix:$result
);
}
// -----
def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
Store a matrix through a pointer.
Pointer is the pointer to store through. It specifies
start of memory region where elements of the matrix must
be stored and arranged according to Layout.
Object is the matrix to store. It must be
OpTypeJointMatrixINTEL.
Stride is the number of elements in memory between beginnings
of successive rows, columns (or words) of the Object. It must
be a scalar integer type.
Layout indicates how the values stored to memory are arranged.
It must be the result of a constant instruction.
Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
integer type.
If present, any Memory Operands must begin with a memory operand
literal. If not present, it is the same as specifying the memory
operand None.
#### Example:
```mlir
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
!spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
```
}];
let assemblyFormat = [{
$scope $layout operands attr-dict `:` `(` type(operands) `)`
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_6>,
Extension<[SPV_INTEL_joint_matrix]>,
Capability<[SPV_C_JointMatrixINTEL]>
];
let arguments = (ins
SPV_ScopeAttr:$scope,
SPV_MatrixLayoutAttr:$layout,
SPV_AnyPtr:$pointer,
SPV_AnyJointMatrix:$object,
SPV_Integer:$stride,
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
OptionalAttr<I32Attr>:$alignment
);
let results = (outs);
}
// -----
#endif // MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS

View File

@ -30,6 +30,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCastOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"

View File

@ -29,6 +29,7 @@ namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct ImageTypeStorage;
struct JointMatrixTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
@ -420,6 +421,33 @@ public:
Optional<StorageClass> storage = llvm::None);
};
// SPIR-V joint matrix type
class JointMatrixINTELType
: public Type::TypeBase<JointMatrixINTELType, CompositeType,
detail::JointMatrixTypeStorage> {
public:
using Base::Base;
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
unsigned columns, MatrixLayout matrixLayout);
Type getElementType() const;
/// Return the scope of the joint matrix.
Scope getScope() const;
/// return the number of rows of the matrix.
unsigned getRows() const;
/// return the number of columns of the matrix.
unsigned getColumns() const;
/// return the layout of the matrix
MatrixLayout getMatrixLayout() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
};
// SPIR-V matrix type
class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
detail::MatrixTypeStorage> {

View File

@ -348,6 +348,39 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
}
// joint-matrix-type ::= `!spv.jointmatrix` `<`rows `x` columns `x` element-type
// `,` layout `,` scope`>`
static Type parseJointMatrixType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
return Type();
SmallVector<int64_t, 2> dims;
SMLoc countLoc = parser.getCurrentLocation();
if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
return Type();
if (dims.size() != 2) {
parser.emitError(countLoc, "expected rows and columns size");
return Type();
}
auto elementTy = parseAndVerifyType(dialect, parser);
if (!elementTy)
return Type();
MatrixLayout matrixLayout;
if (parser.parseComma() ||
parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
return Type();
Scope scope;
if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
return Type();
if (parser.parseGreater())
return Type();
return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
matrixLayout);
}
// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
@ -753,6 +786,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseArrayType(*this, parser);
if (keyword == "coopmatrix")
return parseCooperativeMatrixType(*this, parser);
if (keyword == "jointmatrix")
return parseJointMatrixType(*this, parser);
if (keyword == "image")
return parseImageType(*this, parser);
if (keyword == "ptr")
@ -859,6 +894,13 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
os << ">";
}
static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", "
<< stringifyMatrixLayout(type.getMatrixLayout());
os << ", " << stringifyScope(type.getScope()) << ">";
}
static void print(MatrixType type, DialectAsmPrinter &os) {
os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
os << ">";
@ -866,9 +908,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
ImageType, SampledImageType, StructType, MatrixType>(
[&](auto type) { print(type, os); })
.Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
PointerType, RuntimeArrayType, ImageType, SampledImageType,
StructType, MatrixType>([&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}

View File

@ -436,6 +436,13 @@ static LogicalResult verifyCastOp(Operation *op,
resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
}
if (auto jointMatrixType =
operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
operandType = jointMatrixType.getElementType();
resultType =
resultType.cast<spirv::JointMatrixINTELType>().getElementType();
}
auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
@ -1637,6 +1644,17 @@ LogicalResult spirv::CompositeConstructOp::verify() {
return success();
}
if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
if (constituents.size() != 1)
return emitOpError("has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
if (jointType.getElementType() != constituents.front().getType())
return emitOpError("operand type mismatch: expected operand type ")
<< jointType.getElementType() << ", but provided "
<< constituents.front().getType();
return success();
}
if (constituents.size() == cType.getNumElements()) {
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
if (constituents[index].getType() != cType.getElementType(index)) {
@ -3893,6 +3911,70 @@ LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
return verifyCoopMatrixMulAdd(*this);
}
static LogicalResult
verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
return op->emitError(
"Pointer must point to a scalar or vector type but provided ")
<< pointeeType;
spirv::StorageClass storage =
pointer.cast<spirv::PointerType>().getStorageClass();
if (storage != spirv::StorageClass::Workgroup &&
storage != spirv::StorageClass::CrossWorkgroup)
return op->emitError("Pointer storage class must be Workgroup or "
"CrossWorkgroup but provided ")
<< stringifyStorageClass(storage);
return success();
}
//===----------------------------------------------------------------------===//
// spv.JointMatrixLoadINTEL
//===----------------------------------------------------------------------===//
LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
result().getType());
}
//===----------------------------------------------------------------------===//
// spv.JointMatrixStoreINTEL
//===----------------------------------------------------------------------===//
LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
object().getType());
}
//===----------------------------------------------------------------------===//
// spv.JointMatrixMadINTEL
//===----------------------------------------------------------------------===//
static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
auto typeB = op.b().getType().cast<spirv::JointMatrixINTELType>();
auto typeC = op.c().getType().cast<spirv::JointMatrixINTELType>();
auto typeR = op.result().getType().cast<spirv::JointMatrixINTELType>();
if (typeA.getRows() != typeR.getRows() ||
typeA.getColumns() != typeB.getRows() ||
typeB.getColumns() != typeR.getColumns())
return op.emitOpError("matrix size must match");
if (typeR.getScope() != typeA.getScope() ||
typeR.getScope() != typeB.getScope() ||
typeR.getScope() != typeC.getScope())
return op.emitOpError("matrix scope must match");
if (typeA.getElementType() != typeB.getElementType() ||
typeR.getElementType() != typeC.getElementType())
return op.emitOpError("matrix element type must match");
return success();
}
LogicalResult spirv::JointMatrixMadINTELOp::verify() {
return verifyJointMatrixMad(*this);
}
//===----------------------------------------------------------------------===//
// spv.MatrixTimesScalar
//===----------------------------------------------------------------------===//
@ -4150,6 +4232,8 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
if (cType.isa<spirv::CooperativeMatrixNVType>())
return emitError("unsupported composite type ") << cType;
if (cType.isa<spirv::JointMatrixINTELType>())
return emitError("unsupported composite type ") << cType;
if (constituents.size() != cType.getNumElements())
return emitError("has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided "

View File

@ -89,9 +89,9 @@ Optional<int64_t> ArrayType::getSizeInBytes() {
bool CompositeType::classof(Type type) {
if (auto vectorType = type.dyn_cast<VectorType>())
return isValid(vectorType);
return type
.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
spirv::RuntimeArrayType, spirv::StructType>();
return type.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType,
spirv::JointMatrixINTELType, spirv::MatrixType,
spirv::RuntimeArrayType, spirv::StructType>();
}
bool CompositeType::isValid(VectorType type) {
@ -110,7 +110,8 @@ bool CompositeType::isValid(VectorType type) {
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
.Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
.Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
RuntimeArrayType, VectorType>(
[](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
@ -132,6 +133,10 @@ unsigned CompositeType::getNumElements() const {
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
}
if (isa<JointMatrixINTELType>()) {
llvm_unreachable(
"invalid to query number of elements of spirv::JointMatrix type");
}
if (isa<RuntimeArrayType>()) {
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
@ -140,15 +145,16 @@ unsigned CompositeType::getNumElements() const {
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
RuntimeArrayType>();
}
void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
.Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
StructType>(
.Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
MatrixType, RuntimeArrayType, StructType>(
[&](auto type) { type.getExtensions(extensions, storage); })
.Case<VectorType>([&](VectorType type) {
return type.getElementType().cast<ScalarType>().getExtensions(
@ -161,8 +167,8 @@ void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
.Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
StructType>(
.Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
MatrixType, RuntimeArrayType, StructType>(
[&](auto type) { type.getCapabilities(capabilities, storage); })
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
@ -255,6 +261,74 @@ void CooperativeMatrixNVType::getCapabilities(
capabilities.push_back(ref);
}
//===----------------------------------------------------------------------===//
// JointMatrixType
//===----------------------------------------------------------------------===//
struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<JointMatrixTypeStorage>())
JointMatrixTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
}
JointMatrixTypeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), rows(std::get<1>(key)),
columns(std::get<2>(key)), scope(std::get<4>(key)),
matrixLayout(std::get<3>(key)) {}
Type elementType;
unsigned rows;
unsigned columns;
Scope scope;
MatrixLayout matrixLayout;
};
JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope,
unsigned rows, unsigned columns,
MatrixLayout matrixLayout) {
return Base::get(elementType.getContext(), elementType, rows, columns,
matrixLayout, scope);
}
Type JointMatrixINTELType::getElementType() const {
return getImpl()->elementType;
}
Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
return getImpl()->matrixLayout;
}
void JointMatrixINTELType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
extensions.push_back(ref);
}
void JointMatrixINTELType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
static const Capability caps[] = {Capability::JointMatrixINTEL};
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
capabilities.push_back(ref);
}
//===----------------------------------------------------------------------===//
// ImageType
//===----------------------------------------------------------------------===//
@ -1172,6 +1246,7 @@ void MatrixType::getCapabilities(
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
PointerType, RuntimeArrayType, SampledImageType, StructType>();
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, JointMatrixINTELType,
MatrixType, PointerType, RuntimeArrayType, SampledImageType,
StructType>();
}

View File

@ -168,6 +168,8 @@ LogicalResult spirv::Deserializer::processInstruction(
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
return processTypeForwardPointer(operands);
case spirv::Opcode::OpTypeJointMatrixINTEL:
return processType(opcode, operands);
case spirv::Opcode::OpConstant:
return processConstant(operands, /*isSpec=*/false);
case spirv::Opcode::OpSpecConstant:

View File

@ -730,6 +730,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processCooperativeMatrixType(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
case spirv::Opcode::OpTypeJointMatrixINTEL:
return processJointMatrixType(operands);
case spirv::Opcode::OpTypeImage:
return processImageType(operands);
case spirv::Opcode::OpTypeSampledImage:
@ -888,6 +890,40 @@ spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult
spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
if (operands.size() != 6) {
return emitError(unknownLoc, "OpTypeJointMatrix must have element "
"type and row x column parameters");
}
Type elementTy = getType(operands[1]);
if (!elementTy) {
return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
<< operands[1];
}
auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
if (!scope) {
return emitError(unknownLoc,
"OpTypeJointMatrix references undefined scope <id> ")
<< operands[5];
}
auto matrixLayout =
spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
if (!matrixLayout) {
return emitError(unknownLoc,
"OpTypeJointMatrix references undefined scope <id> ")
<< operands[4];
}
unsigned rows = getConstantInt(operands[2]).getInt();
unsigned columns = getConstantInt(operands[3]).getInt();
typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
elementTy, scope.value(), rows, columns, matrixLayout.value());
return success();
}
LogicalResult
spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {

View File

@ -257,6 +257,8 @@ private:
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
LogicalResult processJointMatrixType(ArrayRef<uint32_t> operands);
LogicalResult processImageType(ArrayRef<uint32_t> operands);
LogicalResult processSampledImageType(ArrayRef<uint32_t> operands);

View File

@ -598,6 +598,27 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
if (auto jointMatrixType = type.dyn_cast<spirv::JointMatrixINTELType>()) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
elementTypeID, serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
auto getConstantOp = [&](uint32_t id) {
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
operands.push_back(elementTypeID);
operands.push_back(getConstantOp(jointMatrixType.getRows()));
operands.push_back(getConstantOp(jointMatrixType.getColumns()));
operands.push_back(getConstantOp(
static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
operands.push_back(
getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
return success();
}
if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,

View File

@ -0,0 +1,158 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: @joint_matrix_load
spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
spv.Return
}
// -----
// CHECK-LABEL: @joint_matrix_load_memaccess
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, CrossWorkgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_load_diff_ptr_type
spv.func @joint_matrix_load_diff_ptr_type(%ptr : !spv.ptr<vector<4xi32>, Workgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_store
spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>) "None" {
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
spv.Return
}
// CHECK-LABEL: @joint_matrix_store_memaccess
spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.Return
}
// CHECK-LABEL: @joint_matrix_length
spv.func @joint_matrix_length() -> i32 "None" {
// CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
%0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
spv.ReturnValue %0 : i32
}
// CHECK-LABEL: @joint_matrix_muladd
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, %b : !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_add
spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_sub
spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_sdiv
spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_udiv
spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_fadd
spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
%r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_fsub
spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
%r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_fdiv
spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
%r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
spv.Return
}
// -----
// CHECK-LABEL: @joint_matrix_access_chain
spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
%0 = spv.Constant 0: i32
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
%1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
spv.ReturnValue %1 : !spv.ptr<f32, Function>
}
// -----
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}
// -----
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}
// -----
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix scope must match}}
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}
// -----
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// expected-error @+1 {{matrix element type must match}}
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}
// -----
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, %stride : i32) "None" {
// expected-error @+1 {{Pointer must point to a scalar or vector type}}
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, i32)-> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// -----
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32) "None" {
// expected-error @+1 {{Pointer storage class must be Workgroup or CrossWorkgroup}}
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Function>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}

View File

@ -0,0 +1,102 @@
// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
spv.module Logical GLSL450 requires #spv.vce<v1.0, [JointMatrixINTEL], [SPV_INTEL_joint_matrix]> {
// CHECK-LABEL: @joint_matrix_load
spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_load_memaccess
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_store
spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>) "None" {
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
spv.Return
}
// CHECK-LABEL: @joint_matrix_store_memaccess
spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.Return
}
// CHECK-LABEL: @joint_matrix_length
spv.func @joint_matrix_length() -> i32 "None" {
// CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.ReturnValue %0 : i32
}
// CHECK-LABEL: @joint_matrix_muladd
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_add
spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_sub
spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_sdiv
spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_udiv
spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_fadd
spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
%r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_fsub
spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
%r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_fdiv
spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
%r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
spv.Return
}
// CHECK-LABEL: @joint_matrix_access_chain
spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
%0 = spv.Constant 0: i32
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
%1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
spv.ReturnValue %1 : !spv.ptr<f32, Function>
}
}

View File

@ -518,7 +518,8 @@ static void emitAttributeSerialization(const Attribute &attr,
os << tabs
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
attr.getAttrDefName() == "SPV_MemorySemanticsAttr" ||
attr.getAttrDefName() == "SPV_MatrixLayoutAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
@ -810,7 +811,8 @@ static void emitAttributeDeserialization(const Attribute &attr,
StringRef words, StringRef wordIndex,
raw_ostream &os) {
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
attr.getAttrDefName() == "SPV_MemorySemanticsAttr" ||
attr.getAttrDefName() == "SPV_MatrixLayoutAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.