[mlir][spirv] Refactor vendor op definitions

Use dedicated vendor op classes/categories. This is so that we can later
change the mnemonics of all vendor ops by changing the base class: `SPV_VendorOp`.

Issue: https://github.com/llvm/llvm-project/issues/56863
This commit is contained in:
Jakub Kuderski 2022-09-02 14:59:42 -04:00
parent 6a378b38ff
commit b8bea837f3
9 changed files with 88 additions and 76 deletions

View File

@ -262,7 +262,7 @@ def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
// ----- // -----
def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> { def SPV_EXTAtomicFAddOp : SPV_ExtVendorOp<"AtomicFAdd", []> {
let summary = "TBD"; let summary = "TBD";
let description = [{ let description = [{
@ -279,7 +279,7 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
3) store the New Value back through Pointer. 3) store the New Value back through Pointer.
The instructions result is the Original Value. The instruction's result is the Original Value.
Result Type must be a floating-point type scalar. Result Type must be a floating-point type scalar.

View File

@ -15,7 +15,7 @@
// ----- // -----
def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV", def SPV_NVCooperativeMatrixLengthOp : SPV_NvVendorOp<"CooperativeMatrixLength",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "See extension SPV_NV_cooperative_matrix"; let summary = "See extension SPV_NV_cooperative_matrix";
@ -60,7 +60,7 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
// ----- // -----
def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> { def SPV_NVCooperativeMatrixLoadOp : SPV_NvVendorOp<"CooperativeMatrixLoad", []> {
let summary = "See extension SPV_NV_cooperative_matrix"; let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{ let description = [{
@ -136,7 +136,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
// ----- // -----
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV", def SPV_NVCooperativeMatrixMulAddOp : SPV_NvVendorOp<"CooperativeMatrixMulAdd",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> { [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_NV_cooperative_matrix"; let summary = "See extension SPV_NV_cooperative_matrix";
@ -210,7 +210,7 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
// ----- // -----
def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> { def SPV_NVCooperativeMatrixStoreOp : SPV_NvVendorOp<"CooperativeMatrixStore", []> {
let summary = "See extension SPV_NV_cooperative_matrix"; let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{ let description = [{

View File

@ -92,7 +92,7 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
// ----- // -----
def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> { def SPV_KHRSubgroupBallotOp : SPV_KhrVendorOp<"SubgroupBallot", []> {
let summary = "See extension SPV_KHR_shader_ballot"; let summary = "See extension SPV_KHR_shader_ballot";
let description = [{ let description = [{
@ -146,7 +146,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
// ----- // -----
def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> { def SPV_INTELSubgroupBlockReadOp : SPV_IntelVendorOp<"SubgroupBlockRead", []> {
let summary = "See extension SPV_INTEL_subgroups"; let summary = "See extension SPV_INTEL_subgroups";
let description = [{ let description = [{
@ -197,7 +197,7 @@ def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
// ----- // -----
def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> { def SPV_INTELSubgroupBlockWriteOp : SPV_IntelVendorOp<"SubgroupBlockWrite", []> {
let summary = "See extension SPV_INTEL_subgroups"; let summary = "See extension SPV_INTEL_subgroups";
let description = [{ let description = [{

View File

@ -15,12 +15,12 @@
// ----- // -----
def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL", def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "See extension SPV_INTEL_joint_matrix"; let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{ let description = [{
Return number of components owned by the current work-item in Return number of components owned by the current work-item in
a joint matrix. a joint matrix.
Result Type must be an 32-bit unsigned integer type scalar. Result Type must be an 32-bit unsigned integer type scalar.
@ -60,7 +60,7 @@ def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTE
// ----- // -----
def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> { def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> {
let summary = "See extension SPV_INTEL_joint_matrix"; let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{ let description = [{
@ -68,26 +68,26 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL. Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
Pointer is the pointer to load through. It specifies start of memory region where Pointer is the pointer to load through. It specifies start of memory region where
elements of the matrix are stored and arranged according to Layout. elements of the matrix are stored and arranged according to Layout.
Stride is the number of elements in memory between beginnings of successive rows, Stride is the number of elements in memory between beginnings of successive rows,
columns (or words) in the result. It must be a scalar integer type. columns (or words) in the result. It must be a scalar integer type.
Layout indicates how the values loaded from memory are arranged. It must be the Layout indicates how the values loaded from memory are arranged. It must be the
result of a constant instruction. result of a constant instruction.
Scope is syncronization scope for operation on the matrix. It must be the result Scope is syncronization scope for operation on the matrix. It must be the result
of a constant instruction with scalar integer type. of a constant instruction with scalar integer type.
If present, any Memory Operands must begin with a memory operand literal. If not If present, any Memory Operands must begin with a memory operand literal. If not
present, it is the same as specifying the memory operand None. present, it is the same as specifying the memory operand None.
#### Example: #### Example:
```mlir ```mlir
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
{memory_access = #spv.memory_access<Volatile>} : {memory_access = #spv.memory_access<Volatile>} :
(!spv.ptr<i32, CrossWorkgroup>, i32) -> (!spv.ptr<i32, CrossWorkgroup>, i32) ->
!spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
``` ```
}]; }];
@ -119,39 +119,39 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
// ----- // -----
def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL", def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> { [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_INTEL_joint_matrix"; let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{ let description = [{
Multiply matrix A by matrix B and add matrix C to the result Multiply matrix A by matrix B and add matrix C to the result
of the multiplication: A*B+C. Here A is a M x K matrix, B is of the multiplication: A*B+C. Here A is a M x K matrix, B is
a K x N matrix and C is a M x N matrix. a K x N matrix and C is a M x N matrix.
Behavior is undefined if sizes of operands do not meet the Behavior is undefined if sizes of operands do not meet the
conditions above. All operands and the Result Type must be conditions above. All operands and the Result Type must be
OpTypeJointMatrixINTEL. OpTypeJointMatrixINTEL.
A must be a OpTypeJointMatrixINTEL whose Component Type is a A must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to M and Column Count signed numerical type, Row Count equals to M and Column Count
equals to K equals to K
B must be a OpTypeJointMatrixINTEL whose Component Type is a B must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to K and Column Count signed numerical type, Row Count equals to K and Column Count
equals to N equals to N
C and Result Type must be a OpTypeJointMatrixINTEL with Row C and Result Type must be a OpTypeJointMatrixINTEL with Row
Count equals to M and Column Count equals to N Count equals to M and Column Count equals to N
Scope is syncronization scope for operation on the matrix. Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar It must be the result of a constant instruction with scalar
integer type. integer type.
#### Example: #### Example:
```mlir ```mlir
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
-> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
``` ```
@ -182,38 +182,38 @@ def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
// ----- // -----
def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> { def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> {
let summary = "See extension SPV_INTEL_joint_matrix"; let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{ let description = [{
Store a matrix through a pointer. Store a matrix through a pointer.
Pointer is the pointer to store through. It specifies Pointer is the pointer to store through. It specifies
start of memory region where elements of the matrix must start of memory region where elements of the matrix must
be stored and arranged according to Layout. be stored and arranged according to Layout.
Object is the matrix to store. It must be Object is the matrix to store. It must be
OpTypeJointMatrixINTEL. OpTypeJointMatrixINTEL.
Stride is the number of elements in memory between beginnings Stride is the number of elements in memory between beginnings
of successive rows, columns (or words) of the Object. It must of successive rows, columns (or words) of the Object. It must
be a scalar integer type. be a scalar integer type.
Layout indicates how the values stored to memory are arranged. Layout indicates how the values stored to memory are arranged.
It must be the result of a constant instruction. It must be the result of a constant instruction.
Scope is syncronization scope for operation on the matrix. Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar It must be the result of a constant instruction with scalar
integer type. integer type.
If present, any Memory Operands must begin with a memory operand If present, any Memory Operands must begin with a memory operand
literal. If not present, it is the same as specifying the memory literal. If not present, it is the same as specifying the memory
operand None. operand None.
#### Example: #### Example:
```mlir ```mlir
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
!spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32) !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
``` ```

View File

@ -18,7 +18,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
// ----- // -----
def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> { def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> {
let summary = "TBD"; let summary = "TBD";
let description = [{ let description = [{

View File

@ -1335,15 +1335,15 @@ void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
// spv.AtomicFAddEXTOp // spv.AtomicFAddEXTOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult spirv::AtomicFAddEXTOp::verify() { LogicalResult spirv::EXTAtomicFAddOp::verify() {
return ::verifyAtomicUpdateOp<FloatType>(getOperation()); return ::verifyAtomicUpdateOp<FloatType>(getOperation());
} }
ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser, ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
return ::parseAtomicUpdateOp(parser, result, true); return ::parseAtomicUpdateOp(parser, result, true);
} }
void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) { void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
::printAtomicUpdateOp(*this, p); ::printAtomicUpdateOp(*this, p);
} }
@ -2646,7 +2646,7 @@ LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
// spv.SubgroupBlockReadINTEL // spv.SubgroupBlockReadINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser, ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
// Parse the storage class specification // Parse the storage class specification
spirv::StorageClass storageClass; spirv::StorageClass storageClass;
@ -2669,11 +2669,11 @@ ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
return success(); return success();
} }
void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) { void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << " : " << getType(); printer << " " << ptr() << " : " << getType();
} }
LogicalResult spirv::SubgroupBlockReadINTELOp::verify() { LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure(); return failure();
@ -2684,7 +2684,7 @@ LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
// spv.SubgroupBlockWriteINTEL // spv.SubgroupBlockWriteINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser, ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
// Parse the storage class specification // Parse the storage class specification
spirv::StorageClass storageClass; spirv::StorageClass storageClass;
@ -2708,11 +2708,11 @@ ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
return success(); return success();
} }
void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) { void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << ", " << value() << " : " << value().getType(); printer << " " << ptr() << ", " << value() << " : " << value().getType();
} }
LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() { LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure(); return failure();
@ -3816,7 +3816,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
// spv.CooperativeMatrixLoadNV // spv.CooperativeMatrixLoadNV
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser, ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo; SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32); Type strideType = parser.getBuilder().getIntegerType(32);
@ -3838,7 +3838,7 @@ ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
return success(); return success();
} }
void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) { void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << stride() << ", " << columnmajor(); printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
// Print optional memory access attribute. // Print optional memory access attribute.
if (auto memAccess = memory_access()) if (auto memAccess = memory_access())
@ -3865,7 +3865,7 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
return success(); return success();
} }
LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() { LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(), return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
result().getType()); result().getType());
} }
@ -3874,7 +3874,7 @@ LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
// spv.CooperativeMatrixStoreNV // spv.CooperativeMatrixStoreNV
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser, ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo; SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32); Type strideType = parser.getBuilder().getIntegerType(32);
@ -3896,7 +3896,7 @@ ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
return success(); return success();
} }
void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) { void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << object() << ", " << stride() << ", " printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
<< columnmajor(); << columnmajor();
// Print optional memory access attribute. // Print optional memory access attribute.
@ -3905,7 +3905,7 @@ void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
printer << " : " << pointer().getType() << ", " << getOperand(1).getType(); printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
} }
LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() { LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(), return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
object().getType()); object().getType());
} }
@ -3915,7 +3915,7 @@ LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) { verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
if (op.c().getType() != op.result().getType()) if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type"); return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>(); auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
@ -3936,7 +3936,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
return success(); return success();
} }
LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() { LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
return verifyCoopMatrixMulAdd(*this); return verifyCoopMatrixMulAdd(*this);
} }
@ -3963,7 +3963,7 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
// spv.JointMatrixLoadINTEL // spv.JointMatrixLoadINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult spirv::JointMatrixLoadINTELOp::verify() { LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(), return verifyPointerAndJointMatrixType(*this, pointer().getType(),
result().getType()); result().getType());
} }
@ -3972,7 +3972,7 @@ LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
// spv.JointMatrixStoreINTEL // spv.JointMatrixStoreINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult spirv::JointMatrixStoreINTELOp::verify() { LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(), return verifyPointerAndJointMatrixType(*this, pointer().getType(),
object().getType()); object().getType());
} }
@ -3981,7 +3981,7 @@ LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
// spv.JointMatrixMadINTEL // spv.JointMatrixMadINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) { static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
if (op.c().getType() != op.result().getType()) if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type"); return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>(); auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
@ -4002,7 +4002,7 @@ static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
return success(); return success();
} }
LogicalResult spirv::JointMatrixMadINTELOp::verify() { LogicalResult spirv::INTELJointMatrixMadOp::verify() {
return verifyJointMatrixMad(*this); return verifyJointMatrixMad(*this);
} }

View File

@ -240,7 +240,7 @@ ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const { PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0); Value predicate = op->getOperand(0);
rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>( rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
op, op->getResult(0).getType(), predicate); op, op->getResult(0).getType(), predicate);
return success(); return success();
} }

View File

@ -23,13 +23,17 @@ file_name=$1
baseclass=$2 baseclass=$2
case $baseclass in case $baseclass in
Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp) Op | ArithmeticBinaryOp | ArithmeticUnaryOp \
| LogicalBinaryOp | LogicalUnaryOp \
| CastOp | ControlFlowOp | StructureOp \
| AtomicUpdateOp | AtomicUpdateWithValueOp \
| KhrVendorOp | ExtVendorOp | IntelVendorOp | NvVendorOp )
;; ;;
*) *)
echo "Usage : " $0 "<filename> <baseclass> (<opname>)*" echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
echo "<filename> is the file name of MLIR SPIR-V op definitions spec" echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
echo "<baseclass> must be one of " \ echo "<baseclass> must be one of " \
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)" "(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp|KhrVendorOp|ExtVendorOp|IntelVendorOp|NvVendorOp)"
exit 1; exit 1;
;; ;;
esac esac

View File

@ -730,15 +730,19 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
'{{\n let summary = {summary};\n\n let description = ' '{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n') '[{{\n{description}}}];{availability}\n')
else: else:
fmt_str = ('def SPV_{opname_src}Op : ' fmt_str = ('def SPV_{vendor_name}{opname_src}Op : '
'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> ' 'SPV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
'{{\n let summary = {summary};\n\n let description = ' '{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n') '[{{\n{description}}}];{availability}\n')
vendor_name = ''
inst_category = existing_info.get('inst_category', 'Op') inst_category = existing_info.get('inst_category', 'Op')
if inst_category == 'Op': if inst_category == 'Op':
fmt_str +='\n let arguments = (ins{args});\n\n'\ fmt_str +='\n let arguments = (ins{args});\n\n'\
' let results = (outs{results});\n' ' let results = (outs{results});\n'
elif inst_category.endswith('VendorOp'):
vendor_name = inst_category.split('VendorOp')[0].upper()
assert len(vendor_name) != 0, 'Invalid instruction category'
fmt_str +='{extras}'\ fmt_str +='{extras}'\
'}}\n' '}}\n'
@ -746,6 +750,9 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
opname_src = instruction['opname'] opname_src = instruction['opname']
if opname.startswith('Op'): if opname.startswith('Op'):
opname_src = opname_src[2:] opname_src = opname_src[2:]
if len(vendor_name) > 0:
assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
opname_src = opname_src[:-len(vendor_name)]
category_args = existing_info.get('category_args', '') category_args = existing_info.get('category_args', '')
@ -759,7 +766,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
# Format summary. If the summary can fit in the same line, we print it out # Format summary. If the summary can fit in the same line, we print it out
# as a "-quoted string; otherwise, wrap the lines using "[{...}]". # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
summary = summary.strip(); summary = summary.strip()
if len(summary) + len(' let summary = "";') <= 80: if len(summary) + len(' let summary = "";') <= 80:
summary = '"{}"'.format(summary) summary = '"{}"'.format(summary)
else: else:
@ -815,6 +822,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
opcode=instruction['opcode'], opcode=instruction['opcode'],
category_args=category_args, category_args=category_args,
inst_category=inst_category, inst_category=inst_category,
vendor_name=vendor_name,
traits=existing_info.get('traits', ''), traits=existing_info.get('traits', ''),
summary=summary, summary=summary,
description=description, description=description,