[MLIR][SPIRV] Add intel joint matrix ops
Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D131586
This commit is contained in:
parent
c63f2581f4
commit
b8f62dc22a
|
@ -27,12 +27,12 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
|
|||
// In addition to normal types arithmetic instructions can support cooperative
|
||||
// matrix.
|
||||
let arguments = (ins
|
||||
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
|
||||
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
|
||||
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand1,
|
||||
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand2
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
|
||||
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$result
|
||||
);
|
||||
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
}
|
||||
|
|
|
@ -64,6 +64,27 @@ def SPV_CooperativeMatrixPropertiesNVArrayAttr :
|
|||
TypedArrayAttrBase<SPV_CooperativeMatrixPropertiesNVAttr,
|
||||
"CooperativeMatrixPropertiesNV array attribute">;
|
||||
|
||||
// Description of the supported joint matrix operations. See
|
||||
// https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
|
||||
def SPV_JointMatrixPropertiesINTELAttr :
|
||||
SPV_Attr<"JointMatrixPropertiesINTEL", "joint_matrix_props"> {
|
||||
let parameters = (ins
|
||||
"int":$m_size,
|
||||
"int":$n_size,
|
||||
"int":$k_size,
|
||||
"mlir::Type":$a_type,
|
||||
"mlir::Type":$b_type,
|
||||
"mlir::Type":$c_type,
|
||||
"mlir::Type":$result_type,
|
||||
"mlir::spirv::ScopeAttr":$scope
|
||||
);
|
||||
let assemblyFormat = "`<` struct(params) `>`";
|
||||
}
|
||||
|
||||
def SPV_JointMatrixPropertiesINTELArrayAttr :
|
||||
TypedArrayAttrBase<SPV_JointMatrixPropertiesINTELAttr,
|
||||
"JointMatrixPropertiesINTEL array attribute">;
|
||||
|
||||
// This attribute specifies the limits for various resources on the target
|
||||
// architecture.
|
||||
//
|
||||
|
|
|
@ -387,6 +387,7 @@ def SPV_INTEL_debug_module : I32EnumAttrCase<"SPV_INTEL_de
|
|||
def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp_fast_math_mode", 4027>;
|
||||
def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
|
||||
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
|
||||
def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>;
|
||||
|
||||
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
|
||||
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
|
||||
|
@ -443,7 +444,7 @@ def SPV_ExtensionAttr :
|
|||
SPV_INTEL_usm_storage_classes, SPV_INTEL_io_pipes, SPV_INTEL_blocking_pipes,
|
||||
SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
|
||||
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
|
||||
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
|
||||
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix,
|
||||
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
|
||||
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
|
||||
SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
|
||||
|
@ -1390,6 +1391,12 @@ def SPV_C_ShaderStereoViewNV : I32EnumAttrCase<"ShaderS
|
|||
];
|
||||
}
|
||||
|
||||
def SPV_C_JointMatrixINTEL : I32EnumAttrCase<"JointMatrixINTEL", 6118> {
|
||||
list<Availability> availability = [
|
||||
Extension<[SPV_INTEL_joint_matrix]>
|
||||
];
|
||||
}
|
||||
|
||||
def SPV_CapabilityAttr :
|
||||
SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
|
||||
SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16,
|
||||
|
@ -1481,7 +1488,7 @@ def SPV_CapabilityAttr :
|
|||
SPV_C_UniformTexelBufferArrayNonUniformIndexing,
|
||||
SPV_C_StorageTexelBufferArrayNonUniformIndexing,
|
||||
SPV_C_ShaderViewportIndexLayerEXT, SPV_C_ShaderViewportMaskNV,
|
||||
SPV_C_ShaderStereoViewNV
|
||||
SPV_C_ShaderStereoViewNV, SPV_C_JointMatrixINTEL
|
||||
]>;
|
||||
|
||||
def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
|
||||
|
@ -3981,6 +3988,16 @@ def SPV_SamplerUseAttr: SPV_I32EnumAttr<
|
|||
"image_sampler_use_info",
|
||||
[SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
|
||||
|
||||
def SPV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 0>;
|
||||
def SPV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 1>;
|
||||
def SPV_ML_PackedA : I32EnumAttrCase<"PackedA", 2>;
|
||||
def SPV_ML_PackedB : I32EnumAttrCase<"PackedB", 3>;
|
||||
|
||||
def SPV_MatrixLayoutAttr :
|
||||
SPV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
|
||||
SPV_ML_ColumnMajor, SPV_ML_RowMajor, SPV_ML_PackedA, SPV_ML_PackedB
|
||||
]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V attribute definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -4013,6 +4030,8 @@ def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
|
|||
def SPV_IsCooperativeMatrixType :
|
||||
CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
|
||||
def SPV_IsImageType : CPred<"$_self.isa<::mlir::spirv::ImageType>()">;
|
||||
def SPV_IsJointMatrixType :
|
||||
CPred<"$_self.isa<::mlir::spirv::JointMatrixINTELType>()">;
|
||||
def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">;
|
||||
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
|
||||
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
|
||||
|
@ -4043,6 +4062,8 @@ def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
|
|||
"any SPIR-V cooperative matrix type">;
|
||||
def SPV_AnyImage : DialectType<SPIRV_Dialect, SPV_IsImageType,
|
||||
"any SPIR-V image type">;
|
||||
def SPV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPV_IsJointMatrixType,
|
||||
"any SPIR-V joint matrix type">;
|
||||
def SPV_AnyMatrix : DialectType<SPIRV_Dialect, SPV_IsMatrixType,
|
||||
"any SPIR-V matrix type">;
|
||||
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
|
||||
|
@ -4057,11 +4078,12 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
|
|||
def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
|
||||
def SPV_Composite :
|
||||
AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
|
||||
SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>;
|
||||
SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix]>;
|
||||
def SPV_Type : AnyTypeOf<[
|
||||
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
|
||||
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
|
||||
SPV_AnyCooperativeMatrix, SPV_AnyMatrix, SPV_AnySampledImage
|
||||
SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix,
|
||||
SPV_AnySampledImage
|
||||
]>;
|
||||
|
||||
def SPV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
|
||||
|
@ -4072,6 +4094,11 @@ class SPV_CoopMatrixOfType<list<Type> allowedTypes> :
|
|||
"$_self.cast<::mlir::spirv::CooperativeMatrixNVType>().getElementType()",
|
||||
"Cooperative Matrix">;
|
||||
|
||||
class SPV_JointMatrixOfType<list<Type> allowedTypes> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, SPV_IsJointMatrixType,
|
||||
"$_self.cast<::mlir::spirv::JointMatrixINTELType>().getElementType()",
|
||||
"Joint Matrix">;
|
||||
|
||||
class SPV_ScalarOrVectorOf<Type type> :
|
||||
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
|
||||
|
||||
|
@ -4079,6 +4106,14 @@ class SPV_ScalarOrVectorOrCoopMatrixOf<Type type> :
|
|||
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
|
||||
SPV_CoopMatrixOfType<[type]>]>;
|
||||
|
||||
class SPV_ScalarOrVectorOrJointMatrixOf<Type type> :
|
||||
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
|
||||
SPV_JointMatrixOfType<[type]>]>;
|
||||
|
||||
class SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<Type type> :
|
||||
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
|
||||
SPV_CoopMatrixOfType<[type]>, SPV_JointMatrixOfType<[type]> ]>;
|
||||
|
||||
def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
|
||||
def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
|
||||
|
||||
|
@ -4311,6 +4346,11 @@ def SPV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINT
|
|||
def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
|
||||
def SPV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
|
||||
def SPV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
|
||||
def SPV_OC_OpTypeJointMatrixINTEL : I32EnumAttrCase<"OpTypeJointMatrixINTEL", 6119>;
|
||||
def SPV_OC_OpJointMatrixLoadINTEL : I32EnumAttrCase<"OpJointMatrixLoadINTEL", 6120>;
|
||||
def SPV_OC_OpJointMatrixStoreINTEL : I32EnumAttrCase<"OpJointMatrixStoreINTEL", 6121>;
|
||||
def SPV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>;
|
||||
def SPV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>;
|
||||
|
||||
def SPV_OpcodeAttr :
|
||||
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
|
||||
|
@ -4376,7 +4416,10 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
|
||||
SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV,
|
||||
SPV_OC_OpSubgroupBlockReadINTEL, SPV_OC_OpSubgroupBlockWriteINTEL,
|
||||
SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT
|
||||
SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT,
|
||||
SPV_OC_OpTypeJointMatrixINTEL, SPV_OC_OpJointMatrixLoadINTEL,
|
||||
SPV_OC_OpJointMatrixStoreINTEL, SPV_OC_OpJointMatrixMadINTEL,
|
||||
SPV_OC_OpTypejointMatrixWorkItemLengthINTEL
|
||||
]>;
|
||||
|
||||
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
|
||||
|
|
|
@ -23,11 +23,11 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
|
|||
!listconcat(traits,
|
||||
[NoSideEffect, SameOperandsAndResultShape])> {
|
||||
let arguments = (ins
|
||||
SPV_ScalarOrVectorOrCoopMatrixOf<operandType>:$operand
|
||||
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<operandType>:$operand
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
|
||||
SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<resultType>:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `to` type($result)
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
//===- SPIRVJointMatrixOps.td - joint matmul ---------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the op definition spec of joint matrix multiply extension ops.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
|
||||
#define MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
|
||||
[NoSideEffect]> {
|
||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||
|
||||
let description = [{
|
||||
Return number of components owned by the current work-item in
|
||||
a joint matrix.
|
||||
|
||||
Result Type must be an 32-bit unsigned integer type scalar.
|
||||
|
||||
Type is a joint matrix type.
|
||||
|
||||
``` {.ebnf}
|
||||
joint-matrix-length-op ::= ssa-id `=` `spv.JointMatrixWorkItemLengthINTEL
|
||||
` : ` joint-matrix-type
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
%0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<Subgroup, i32, 8, 16>
|
||||
```
|
||||
}];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` $type";
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_6>,
|
||||
Extension<[SPV_INTEL_joint_matrix]>,
|
||||
Capability<[SPV_C_JointMatrixINTEL]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
TypeAttr:$type
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_Int32:$result
|
||||
);
|
||||
let hasVerifier = 0;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
|
||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||
|
||||
let description = [{
|
||||
Load a matrix through a pointer.
|
||||
|
||||
Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
|
||||
|
||||
Pointer is the pointer to load through. It specifies start of memory region where
|
||||
elements of the matrix are stored and arranged according to Layout.
|
||||
|
||||
Stride is the number of elements in memory between beginnings of successive rows,
|
||||
columns (or words) in the result. It must be a scalar integer type.
|
||||
|
||||
Layout indicates how the values loaded from memory are arranged. It must be the
|
||||
result of a constant instruction.
|
||||
|
||||
Scope is syncronization scope for operation on the matrix. It must be the result
|
||||
of a constant instruction with scalar integer type.
|
||||
|
||||
If present, any Memory Operands must begin with a memory operand literal. If not
|
||||
present, it is the same as specifying the memory operand None.
|
||||
|
||||
#### Example:
|
||||
```mlir
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
|
||||
{memory_access = #spv.memory_access<Volatile>} :
|
||||
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
|
||||
!spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
|
||||
```
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$scope $layout operands attr-dict `:` `(` type(operands) `)` `->` type($result)
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_6>,
|
||||
Extension<[SPV_INTEL_joint_matrix]>,
|
||||
Capability<[SPV_C_JointMatrixINTEL]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_ScopeAttr:$scope,
|
||||
SPV_MatrixLayoutAttr:$layout,
|
||||
SPV_AnyPtr:$pointer,
|
||||
SPV_Integer:$stride,
|
||||
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
|
||||
OptionalAttr<I32Attr>:$alignment
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_AnyJointMatrix:$result
|
||||
);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
|
||||
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
|
||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||
|
||||
let description = [{
|
||||
Multiply matrix A by matrix B and add matrix C to the result
|
||||
of the multiplication: A*B+C. Here A is a M x K matrix, B is
|
||||
a K x N matrix and C is a M x N matrix.
|
||||
|
||||
Behavior is undefined if sizes of operands do not meet the
|
||||
conditions above. All operands and the Result Type must be
|
||||
OpTypeJointMatrixINTEL.
|
||||
|
||||
A must be a OpTypeJointMatrixINTEL whose Component Type is a
|
||||
signed numerical type, Row Count equals to M and Column Count
|
||||
equals to K
|
||||
|
||||
B must be a OpTypeJointMatrixINTEL whose Component Type is a
|
||||
signed numerical type, Row Count equals to K and Column Count
|
||||
equals to N
|
||||
|
||||
C and Result Type must be a OpTypeJointMatrixINTEL with Row
|
||||
Count equals to M and Column Count equals to N
|
||||
|
||||
Scope is syncronization scope for operation on the matrix.
|
||||
It must be the result of a constant instruction with scalar
|
||||
integer type.
|
||||
|
||||
#### Example:
|
||||
```mlir
|
||||
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
|
||||
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
|
||||
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
|
||||
-> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
```
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$scope operands attr-dict`:` type($a) `,` type($b) `->` type($c)
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_6>,
|
||||
Extension<[SPV_INTEL_joint_matrix]>,
|
||||
Capability<[SPV_C_JointMatrixINTEL]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_ScopeAttr:$scope,
|
||||
SPV_AnyJointMatrix:$a,
|
||||
SPV_AnyJointMatrix:$b,
|
||||
SPV_AnyJointMatrix:$c
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_AnyJointMatrix:$result
|
||||
);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
|
||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||
|
||||
let description = [{
|
||||
Store a matrix through a pointer.
|
||||
|
||||
Pointer is the pointer to store through. It specifies
|
||||
start of memory region where elements of the matrix must
|
||||
be stored and arranged according to Layout.
|
||||
|
||||
Object is the matrix to store. It must be
|
||||
OpTypeJointMatrixINTEL.
|
||||
|
||||
Stride is the number of elements in memory between beginnings
|
||||
of successive rows, columns (or words) of the Object. It must
|
||||
be a scalar integer type.
|
||||
|
||||
Layout indicates how the values stored to memory are arranged.
|
||||
It must be the result of a constant instruction.
|
||||
|
||||
Scope is syncronization scope for operation on the matrix.
|
||||
It must be the result of a constant instruction with scalar
|
||||
integer type.
|
||||
|
||||
If present, any Memory Operands must begin with a memory operand
|
||||
literal. If not present, it is the same as specifying the memory
|
||||
operand None.
|
||||
|
||||
#### Example:
|
||||
```mlir
|
||||
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
|
||||
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
|
||||
!spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
|
||||
```
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$scope $layout operands attr-dict `:` `(` type(operands) `)`
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_6>,
|
||||
Extension<[SPV_INTEL_joint_matrix]>,
|
||||
Capability<[SPV_C_JointMatrixINTEL]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_ScopeAttr:$scope,
|
||||
SPV_MatrixLayoutAttr:$layout,
|
||||
SPV_AnyPtr:$pointer,
|
||||
SPV_AnyJointMatrix:$object,
|
||||
SPV_Integer:$stride,
|
||||
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
|
||||
OptionalAttr<I32Attr>:$alignment
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#endif // MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
|
|
@ -30,6 +30,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCastOps.td"
|
|||
include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td"
|
||||
include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
|
||||
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
|
||||
include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
|
||||
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
|
||||
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
|
||||
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
|
||||
|
|
|
@ -29,6 +29,7 @@ namespace detail {
|
|||
struct ArrayTypeStorage;
|
||||
struct CooperativeMatrixTypeStorage;
|
||||
struct ImageTypeStorage;
|
||||
struct JointMatrixTypeStorage;
|
||||
struct MatrixTypeStorage;
|
||||
struct PointerTypeStorage;
|
||||
struct RuntimeArrayTypeStorage;
|
||||
|
@ -420,6 +421,33 @@ public:
|
|||
Optional<StorageClass> storage = llvm::None);
|
||||
};
|
||||
|
||||
// SPIR-V joint matrix type
|
||||
class JointMatrixINTELType
|
||||
: public Type::TypeBase<JointMatrixINTELType, CompositeType,
|
||||
detail::JointMatrixTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
|
||||
unsigned columns, MatrixLayout matrixLayout);
|
||||
Type getElementType() const;
|
||||
|
||||
/// Return the scope of the joint matrix.
|
||||
Scope getScope() const;
|
||||
/// return the number of rows of the matrix.
|
||||
unsigned getRows() const;
|
||||
/// return the number of columns of the matrix.
|
||||
unsigned getColumns() const;
|
||||
|
||||
/// return the layout of the matrix
|
||||
MatrixLayout getMatrixLayout() const;
|
||||
|
||||
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
Optional<StorageClass> storage = llvm::None);
|
||||
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||
Optional<StorageClass> storage = llvm::None);
|
||||
};
|
||||
|
||||
// SPIR-V matrix type
|
||||
class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
|
||||
detail::MatrixTypeStorage> {
|
||||
|
|
|
@ -348,6 +348,39 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
|
|||
return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
|
||||
}
|
||||
|
||||
// joint-matrix-type ::= `!spv.jointmatrix` `<`rows `x` columns `x` element-type
|
||||
// `,` layout `,` scope`>`
|
||||
static Type parseJointMatrixType(SPIRVDialect const &dialect,
|
||||
DialectAsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
SMLoc countLoc = parser.getCurrentLocation();
|
||||
if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
|
||||
return Type();
|
||||
|
||||
if (dims.size() != 2) {
|
||||
parser.emitError(countLoc, "expected rows and columns size");
|
||||
return Type();
|
||||
}
|
||||
|
||||
auto elementTy = parseAndVerifyType(dialect, parser);
|
||||
if (!elementTy)
|
||||
return Type();
|
||||
MatrixLayout matrixLayout;
|
||||
if (parser.parseComma() ||
|
||||
parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
|
||||
return Type();
|
||||
Scope scope;
|
||||
if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
|
||||
return Type();
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
|
||||
matrixLayout);
|
||||
}
|
||||
|
||||
// TODO: Reorder methods to be utilities first and parse*Type
|
||||
// methods in alphabetical order
|
||||
//
|
||||
|
@ -753,6 +786,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
|
|||
return parseArrayType(*this, parser);
|
||||
if (keyword == "coopmatrix")
|
||||
return parseCooperativeMatrixType(*this, parser);
|
||||
if (keyword == "jointmatrix")
|
||||
return parseJointMatrixType(*this, parser);
|
||||
if (keyword == "image")
|
||||
return parseImageType(*this, parser);
|
||||
if (keyword == "ptr")
|
||||
|
@ -859,6 +894,13 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
|
|||
os << ">";
|
||||
}
|
||||
|
||||
static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
|
||||
os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
|
||||
os << type.getElementType() << ", "
|
||||
<< stringifyMatrixLayout(type.getMatrixLayout());
|
||||
os << ", " << stringifyScope(type.getScope()) << ">";
|
||||
}
|
||||
|
||||
static void print(MatrixType type, DialectAsmPrinter &os) {
|
||||
os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
|
||||
os << ">";
|
||||
|
@ -866,9 +908,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
|
|||
|
||||
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
TypeSwitch<Type>(type)
|
||||
.Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
|
||||
ImageType, SampledImageType, StructType, MatrixType>(
|
||||
[&](auto type) { print(type, os); })
|
||||
.Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
|
||||
PointerType, RuntimeArrayType, ImageType, SampledImageType,
|
||||
StructType, MatrixType>([&](auto type) { print(type, os); })
|
||||
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
|
||||
}
|
||||
|
||||
|
|
|
@ -436,6 +436,13 @@ static LogicalResult verifyCastOp(Operation *op,
|
|||
resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
|
||||
}
|
||||
|
||||
if (auto jointMatrixType =
|
||||
operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
|
||||
operandType = jointMatrixType.getElementType();
|
||||
resultType =
|
||||
resultType.cast<spirv::JointMatrixINTELType>().getElementType();
|
||||
}
|
||||
|
||||
auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
|
||||
auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
|
||||
auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
|
||||
|
@ -1637,6 +1644,17 @@ LogicalResult spirv::CompositeConstructOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
|
||||
if (constituents.size() != 1)
|
||||
return emitOpError("has incorrect number of operands: expected ")
|
||||
<< "1, but provided " << constituents.size();
|
||||
if (jointType.getElementType() != constituents.front().getType())
|
||||
return emitOpError("operand type mismatch: expected operand type ")
|
||||
<< jointType.getElementType() << ", but provided "
|
||||
<< constituents.front().getType();
|
||||
return success();
|
||||
}
|
||||
|
||||
if (constituents.size() == cType.getNumElements()) {
|
||||
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
||||
if (constituents[index].getType() != cType.getElementType(index)) {
|
||||
|
@ -3893,6 +3911,70 @@ LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
|
|||
return verifyCoopMatrixMulAdd(*this);
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
|
||||
Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
|
||||
if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
|
||||
return op->emitError(
|
||||
"Pointer must point to a scalar or vector type but provided ")
|
||||
<< pointeeType;
|
||||
spirv::StorageClass storage =
|
||||
pointer.cast<spirv::PointerType>().getStorageClass();
|
||||
if (storage != spirv::StorageClass::Workgroup &&
|
||||
storage != spirv::StorageClass::CrossWorkgroup)
|
||||
return op->emitError("Pointer storage class must be Workgroup or "
|
||||
"CrossWorkgroup but provided ")
|
||||
<< stringifyStorageClass(storage);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.JointMatrixLoadINTEL
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
|
||||
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
|
||||
result().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.JointMatrixStoreINTEL
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
|
||||
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
|
||||
object().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.JointMatrixMadINTEL
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
|
||||
if (op.c().getType() != op.result().getType())
|
||||
return op.emitOpError("result and third operand must have the same type");
|
||||
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
|
||||
auto typeB = op.b().getType().cast<spirv::JointMatrixINTELType>();
|
||||
auto typeC = op.c().getType().cast<spirv::JointMatrixINTELType>();
|
||||
auto typeR = op.result().getType().cast<spirv::JointMatrixINTELType>();
|
||||
if (typeA.getRows() != typeR.getRows() ||
|
||||
typeA.getColumns() != typeB.getRows() ||
|
||||
typeB.getColumns() != typeR.getColumns())
|
||||
return op.emitOpError("matrix size must match");
|
||||
if (typeR.getScope() != typeA.getScope() ||
|
||||
typeR.getScope() != typeB.getScope() ||
|
||||
typeR.getScope() != typeC.getScope())
|
||||
return op.emitOpError("matrix scope must match");
|
||||
if (typeA.getElementType() != typeB.getElementType() ||
|
||||
typeR.getElementType() != typeC.getElementType())
|
||||
return op.emitOpError("matrix element type must match");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult spirv::JointMatrixMadINTELOp::verify() {
|
||||
return verifyJointMatrixMad(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.MatrixTimesScalar
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -4150,6 +4232,8 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
|
|||
|
||||
if (cType.isa<spirv::CooperativeMatrixNVType>())
|
||||
return emitError("unsupported composite type ") << cType;
|
||||
if (cType.isa<spirv::JointMatrixINTELType>())
|
||||
return emitError("unsupported composite type ") << cType;
|
||||
if (constituents.size() != cType.getNumElements())
|
||||
return emitError("has incorrect number of operands: expected ")
|
||||
<< cType.getNumElements() << ", but provided "
|
||||
|
|
|
@ -89,9 +89,9 @@ Optional<int64_t> ArrayType::getSizeInBytes() {
|
|||
bool CompositeType::classof(Type type) {
|
||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||
return isValid(vectorType);
|
||||
return type
|
||||
.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
|
||||
spirv::RuntimeArrayType, spirv::StructType>();
|
||||
return type.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType,
|
||||
spirv::JointMatrixINTELType, spirv::MatrixType,
|
||||
spirv::RuntimeArrayType, spirv::StructType>();
|
||||
}
|
||||
|
||||
bool CompositeType::isValid(VectorType type) {
|
||||
|
@ -110,7 +110,8 @@ bool CompositeType::isValid(VectorType type) {
|
|||
|
||||
Type CompositeType::getElementType(unsigned index) const {
|
||||
return TypeSwitch<Type, Type>(*this)
|
||||
.Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
|
||||
.Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
|
||||
RuntimeArrayType, VectorType>(
|
||||
[](auto type) { return type.getElementType(); })
|
||||
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
|
||||
.Case<StructType>(
|
||||
|
@ -132,6 +133,10 @@ unsigned CompositeType::getNumElements() const {
|
|||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::CooperativeMatrix type");
|
||||
}
|
||||
if (isa<JointMatrixINTELType>()) {
|
||||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::JointMatrix type");
|
||||
}
|
||||
if (isa<RuntimeArrayType>()) {
|
||||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::RuntimeArray type");
|
||||
|
@ -140,15 +145,16 @@ unsigned CompositeType::getNumElements() const {
|
|||
}
|
||||
|
||||
bool CompositeType::hasCompileTimeKnownNumElements() const {
|
||||
return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
|
||||
return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
|
||||
RuntimeArrayType>();
|
||||
}
|
||||
|
||||
void CompositeType::getExtensions(
|
||||
SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
Optional<StorageClass> storage) {
|
||||
TypeSwitch<Type>(*this)
|
||||
.Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
|
||||
StructType>(
|
||||
.Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
|
||||
MatrixType, RuntimeArrayType, StructType>(
|
||||
[&](auto type) { type.getExtensions(extensions, storage); })
|
||||
.Case<VectorType>([&](VectorType type) {
|
||||
return type.getElementType().cast<ScalarType>().getExtensions(
|
||||
|
@ -161,8 +167,8 @@ void CompositeType::getCapabilities(
|
|||
SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||
Optional<StorageClass> storage) {
|
||||
TypeSwitch<Type>(*this)
|
||||
.Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
|
||||
StructType>(
|
||||
.Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
|
||||
MatrixType, RuntimeArrayType, StructType>(
|
||||
[&](auto type) { type.getCapabilities(capabilities, storage); })
|
||||
.Case<VectorType>([&](VectorType type) {
|
||||
auto vecSize = getNumElements();
|
||||
|
@ -255,6 +261,74 @@ void CooperativeMatrixNVType::getCapabilities(
|
|||
capabilities.push_back(ref);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// JointMatrixType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
|
||||
using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
|
||||
|
||||
static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
return new (allocator.allocate<JointMatrixTypeStorage>())
|
||||
JointMatrixTypeStorage(key);
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
|
||||
}
|
||||
|
||||
JointMatrixTypeStorage(const KeyTy &key)
|
||||
: elementType(std::get<0>(key)), rows(std::get<1>(key)),
|
||||
columns(std::get<2>(key)), scope(std::get<4>(key)),
|
||||
matrixLayout(std::get<3>(key)) {}
|
||||
|
||||
Type elementType;
|
||||
unsigned rows;
|
||||
unsigned columns;
|
||||
Scope scope;
|
||||
MatrixLayout matrixLayout;
|
||||
};
|
||||
|
||||
JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope,
|
||||
unsigned rows, unsigned columns,
|
||||
MatrixLayout matrixLayout) {
|
||||
return Base::get(elementType.getContext(), elementType, rows, columns,
|
||||
matrixLayout, scope);
|
||||
}
|
||||
|
||||
Type JointMatrixINTELType::getElementType() const {
|
||||
return getImpl()->elementType;
|
||||
}
|
||||
|
||||
Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
|
||||
|
||||
unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
|
||||
|
||||
unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
|
||||
|
||||
MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
|
||||
return getImpl()->matrixLayout;
|
||||
}
|
||||
|
||||
void JointMatrixINTELType::getExtensions(
|
||||
SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
Optional<StorageClass> storage) {
|
||||
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
|
||||
static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
|
||||
ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
|
||||
extensions.push_back(ref);
|
||||
}
|
||||
|
||||
void JointMatrixINTELType::getCapabilities(
|
||||
SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||
Optional<StorageClass> storage) {
|
||||
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
|
||||
static const Capability caps[] = {Capability::JointMatrixINTEL};
|
||||
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
|
||||
capabilities.push_back(ref);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ImageType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1172,6 +1246,7 @@ void MatrixType::getCapabilities(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SPIRVDialect::registerTypes() {
|
||||
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
|
||||
PointerType, RuntimeArrayType, SampledImageType, StructType>();
|
||||
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, JointMatrixINTELType,
|
||||
MatrixType, PointerType, RuntimeArrayType, SampledImageType,
|
||||
StructType>();
|
||||
}
|
||||
|
|
|
@ -168,6 +168,8 @@ LogicalResult spirv::Deserializer::processInstruction(
|
|||
return processType(opcode, operands);
|
||||
case spirv::Opcode::OpTypeForwardPointer:
|
||||
return processTypeForwardPointer(operands);
|
||||
case spirv::Opcode::OpTypeJointMatrixINTEL:
|
||||
return processType(opcode, operands);
|
||||
case spirv::Opcode::OpConstant:
|
||||
return processConstant(operands, /*isSpec=*/false);
|
||||
case spirv::Opcode::OpSpecConstant:
|
||||
|
|
|
@ -730,6 +730,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
|
|||
return processCooperativeMatrixType(operands);
|
||||
case spirv::Opcode::OpTypeFunction:
|
||||
return processFunctionType(operands);
|
||||
case spirv::Opcode::OpTypeJointMatrixINTEL:
|
||||
return processJointMatrixType(operands);
|
||||
case spirv::Opcode::OpTypeImage:
|
||||
return processImageType(operands);
|
||||
case spirv::Opcode::OpTypeSampledImage:
|
||||
|
@ -888,6 +890,40 @@ spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() != 6) {
|
||||
return emitError(unknownLoc, "OpTypeJointMatrix must have element "
|
||||
"type and row x column parameters");
|
||||
}
|
||||
|
||||
Type elementTy = getType(operands[1]);
|
||||
if (!elementTy) {
|
||||
return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
|
||||
<< operands[1];
|
||||
}
|
||||
|
||||
auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
|
||||
if (!scope) {
|
||||
return emitError(unknownLoc,
|
||||
"OpTypeJointMatrix references undefined scope <id> ")
|
||||
<< operands[5];
|
||||
}
|
||||
auto matrixLayout =
|
||||
spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
|
||||
if (!matrixLayout) {
|
||||
return emitError(unknownLoc,
|
||||
"OpTypeJointMatrix references undefined scope <id> ")
|
||||
<< operands[4];
|
||||
}
|
||||
unsigned rows = getConstantInt(operands[2]).getInt();
|
||||
unsigned columns = getConstantInt(operands[3]).getInt();
|
||||
|
||||
typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
|
||||
elementTy, scope.value(), rows, columns, matrixLayout.value());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() != 2) {
|
||||
|
|
|
@ -257,6 +257,8 @@ private:
|
|||
|
||||
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
|
||||
|
||||
LogicalResult processJointMatrixType(ArrayRef<uint32_t> operands);
|
||||
|
||||
LogicalResult processImageType(ArrayRef<uint32_t> operands);
|
||||
|
||||
LogicalResult processSampledImageType(ArrayRef<uint32_t> operands);
|
||||
|
|
|
@ -598,6 +598,27 @@ LogicalResult Serializer::prepareBasicType(
|
|||
return success();
|
||||
}
|
||||
|
||||
if (auto jointMatrixType = type.dyn_cast<spirv::JointMatrixINTELType>()) {
|
||||
uint32_t elementTypeID = 0;
|
||||
if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
|
||||
elementTypeID, serializationCtx))) {
|
||||
return failure();
|
||||
}
|
||||
typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
|
||||
auto getConstantOp = [&](uint32_t id) {
|
||||
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
|
||||
return prepareConstantInt(loc, attr);
|
||||
};
|
||||
operands.push_back(elementTypeID);
|
||||
operands.push_back(getConstantOp(jointMatrixType.getRows()));
|
||||
operands.push_back(getConstantOp(jointMatrixType.getColumns()));
|
||||
operands.push_back(getConstantOp(
|
||||
static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
|
||||
operands.push_back(
|
||||
getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
|
||||
uint32_t elementTypeID = 0;
|
||||
if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_load
|
||||
spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
|
||||
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @joint_matrix_load_memaccess
|
||||
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, CrossWorkgroup>, %stride : i32) "None" {
|
||||
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_load_diff_ptr_type
|
||||
spv.func @joint_matrix_load_diff_ptr_type(%ptr : !spv.ptr<vector<4xi32>, Workgroup>, %stride : i32) "None" {
|
||||
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_store
|
||||
spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>) "None" {
|
||||
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
|
||||
spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_store_memaccess
|
||||
spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
|
||||
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
|
||||
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_length
|
||||
spv.func @joint_matrix_length() -> i32 "None" {
|
||||
// CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
|
||||
%0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
|
||||
spv.ReturnValue %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_muladd
|
||||
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, %b : !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_add
|
||||
spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_sub
|
||||
spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_sdiv
|
||||
spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_udiv
|
||||
spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_fadd
|
||||
spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
%r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_fsub
|
||||
spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
%r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_fdiv
|
||||
spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
%r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_access_chain
|
||||
spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
|
||||
%0 = spv.Constant 0: i32
|
||||
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
|
||||
%1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
|
||||
spv.ReturnValue %1 : !spv.ptr<f32, Function>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
|
||||
// expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
|
||||
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
|
||||
// expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
|
||||
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
|
||||
// expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix scope must match}}
|
||||
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
|
||||
// expected-error @+1 {{matrix element type must match}}
|
||||
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, %stride : i32) "None" {
|
||||
// expected-error @+1 {{Pointer must point to a scalar or vector type}}
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, i32)-> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32) "None" {
|
||||
// expected-error @+1 {{Pointer storage class must be Workgroup or CrossWorkgroup}}
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Function>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [JointMatrixINTEL], [SPV_INTEL_joint_matrix]> {
|
||||
// CHECK-LABEL: @joint_matrix_load
|
||||
spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
|
||||
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_load_memaccess
|
||||
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
|
||||
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_store
|
||||
spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>) "None" {
|
||||
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
|
||||
spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_store_memaccess
|
||||
spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
|
||||
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
|
||||
spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_length
|
||||
spv.func @joint_matrix_length() -> i32 "None" {
|
||||
// CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.ReturnValue %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_muladd
|
||||
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_add
|
||||
spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_sub
|
||||
spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_sdiv
|
||||
spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_udiv
|
||||
spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
%r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_fadd
|
||||
spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
%r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_fsub
|
||||
spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
%r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_fdiv
|
||||
spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
%r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @joint_matrix_access_chain
|
||||
spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
|
||||
%0 = spv.Constant 0: i32
|
||||
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
|
||||
%1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
|
||||
spv.ReturnValue %1 : !spv.ptr<f32, Function>
|
||||
}
|
||||
}
|
|
@ -518,7 +518,8 @@ static void emitAttributeSerialization(const Attribute &attr,
|
|||
os << tabs
|
||||
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
|
||||
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
|
||||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
|
||||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr" ||
|
||||
attr.getAttrDefName() == "SPV_MatrixLayoutAttr") {
|
||||
// These two enums are encoded as <id> to constant values in SPIR-V blob,
|
||||
// but we directly use the constant value as attribute in SPIR-V dialect. So
|
||||
// need to handle them separately from normal enum attributes.
|
||||
|
@ -810,7 +811,8 @@ static void emitAttributeDeserialization(const Attribute &attr,
|
|||
StringRef words, StringRef wordIndex,
|
||||
raw_ostream &os) {
|
||||
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
|
||||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
|
||||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr" ||
|
||||
attr.getAttrDefName() == "SPV_MatrixLayoutAttr") {
|
||||
// These two enums are encoded as <id> to constant values in SPIR-V blob,
|
||||
// but we directly use the constant value as attribute in SPIR-V dialect. So
|
||||
// need to handle them separately from normal enum attributes.
|
||||
|
|
Loading…
Reference in New Issue