[mlir][spirv] Refactor vendor op definitions
Use dedicated vendor op classes/categories. This is so that we can later change the mnemonics of all vendor ops by changing the base class: `SPV_VendorOp`. Issue: https://github.com/llvm/llvm-project/issues/56863
This commit is contained in:
parent
6a378b38ff
commit
b8bea837f3
|
@ -262,7 +262,7 @@ def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
|
||||
def SPV_EXTAtomicFAddOp : SPV_ExtVendorOp<"AtomicFAdd", []> {
|
||||
let summary = "TBD";
|
||||
|
||||
let description = [{
|
||||
|
@ -279,7 +279,7 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
|
|||
|
||||
3) store the New Value back through Pointer.
|
||||
|
||||
The instruction’s result is the Original Value.
|
||||
The instruction's result is the Original Value.
|
||||
|
||||
Result Type must be a floating-point type scalar.
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
|
||||
def SPV_NVCooperativeMatrixLengthOp : SPV_NvVendorOp<"CooperativeMatrixLength",
|
||||
[NoSideEffect]> {
|
||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||
|
||||
|
@ -60,7 +60,7 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
|
||||
def SPV_NVCooperativeMatrixLoadOp : SPV_NvVendorOp<"CooperativeMatrixLoad", []> {
|
||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||
|
||||
let description = [{
|
||||
|
@ -136,7 +136,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
|
||||
def SPV_NVCooperativeMatrixMulAddOp : SPV_NvVendorOp<"CooperativeMatrixMulAdd",
|
||||
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
|
||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||
|
||||
|
@ -210,7 +210,7 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
|
||||
def SPV_NVCooperativeMatrixStoreOp : SPV_NvVendorOp<"CooperativeMatrixStore", []> {
|
||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||
|
||||
let description = [{
|
||||
|
|
|
@ -92,7 +92,7 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
|
||||
def SPV_KHRSubgroupBallotOp : SPV_KhrVendorOp<"SubgroupBallot", []> {
|
||||
let summary = "See extension SPV_KHR_shader_ballot";
|
||||
|
||||
let description = [{
|
||||
|
@ -146,7 +146,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
|
||||
def SPV_INTELSubgroupBlockReadOp : SPV_IntelVendorOp<"SubgroupBlockRead", []> {
|
||||
let summary = "See extension SPV_INTEL_subgroups";
|
||||
|
||||
let description = [{
|
||||
|
@ -197,7 +197,7 @@ def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
|
||||
def SPV_INTELSubgroupBlockWriteOp : SPV_IntelVendorOp<"SubgroupBlockWrite", []> {
|
||||
let summary = "See extension SPV_INTEL_subgroups";
|
||||
|
||||
let description = [{
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
|
||||
def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength",
|
||||
[NoSideEffect]> {
|
||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||
|
||||
let description = [{
|
||||
Return number of components owned by the current work-item in
|
||||
Return number of components owned by the current work-item in
|
||||
a joint matrix.
|
||||
|
||||
Result Type must be an 32-bit unsigned integer type scalar.
|
||||
|
@ -60,7 +60,7 @@ def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTE
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
|
||||
def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> {
|
||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||
|
||||
let description = [{
|
||||
|
@ -68,26 +68,26 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
|
|||
|
||||
Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
|
||||
|
||||
Pointer is the pointer to load through. It specifies start of memory region where
|
||||
Pointer is the pointer to load through. It specifies start of memory region where
|
||||
elements of the matrix are stored and arranged according to Layout.
|
||||
|
||||
Stride is the number of elements in memory between beginnings of successive rows,
|
||||
Stride is the number of elements in memory between beginnings of successive rows,
|
||||
columns (or words) in the result. It must be a scalar integer type.
|
||||
|
||||
Layout indicates how the values loaded from memory are arranged. It must be the
|
||||
Layout indicates how the values loaded from memory are arranged. It must be the
|
||||
result of a constant instruction.
|
||||
|
||||
Scope is syncronization scope for operation on the matrix. It must be the result
|
||||
Scope is syncronization scope for operation on the matrix. It must be the result
|
||||
of a constant instruction with scalar integer type.
|
||||
|
||||
If present, any Memory Operands must begin with a memory operand literal. If not
|
||||
If present, any Memory Operands must begin with a memory operand literal. If not
|
||||
present, it is the same as specifying the memory operand None.
|
||||
|
||||
#### Example:
|
||||
```mlir
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
|
||||
{memory_access = #spv.memory_access<Volatile>} :
|
||||
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
|
||||
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
|
||||
{memory_access = #spv.memory_access<Volatile>} :
|
||||
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
|
||||
!spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
|
||||
```
|
||||
}];
|
||||
|
@ -119,39 +119,39 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
|
||||
def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad",
|
||||
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
|
||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||
|
||||
let description = [{
|
||||
Multiply matrix A by matrix B and add matrix C to the result
|
||||
of the multiplication: A*B+C. Here A is a M x K matrix, B is
|
||||
Multiply matrix A by matrix B and add matrix C to the result
|
||||
of the multiplication: A*B+C. Here A is a M x K matrix, B is
|
||||
a K x N matrix and C is a M x N matrix.
|
||||
|
||||
Behavior is undefined if sizes of operands do not meet the
|
||||
conditions above. All operands and the Result Type must be
|
||||
Behavior is undefined if sizes of operands do not meet the
|
||||
conditions above. All operands and the Result Type must be
|
||||
OpTypeJointMatrixINTEL.
|
||||
|
||||
A must be a OpTypeJointMatrixINTEL whose Component Type is a
|
||||
signed numerical type, Row Count equals to M and Column Count
|
||||
A must be a OpTypeJointMatrixINTEL whose Component Type is a
|
||||
signed numerical type, Row Count equals to M and Column Count
|
||||
equals to K
|
||||
|
||||
B must be a OpTypeJointMatrixINTEL whose Component Type is a
|
||||
signed numerical type, Row Count equals to K and Column Count
|
||||
B must be a OpTypeJointMatrixINTEL whose Component Type is a
|
||||
signed numerical type, Row Count equals to K and Column Count
|
||||
equals to N
|
||||
|
||||
C and Result Type must be a OpTypeJointMatrixINTEL with Row
|
||||
C and Result Type must be a OpTypeJointMatrixINTEL with Row
|
||||
Count equals to M and Column Count equals to N
|
||||
|
||||
Scope is syncronization scope for operation on the matrix.
|
||||
It must be the result of a constant instruction with scalar
|
||||
Scope is syncronization scope for operation on the matrix.
|
||||
It must be the result of a constant instruction with scalar
|
||||
integer type.
|
||||
|
||||
#### Example:
|
||||
```mlir
|
||||
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
|
||||
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
|
||||
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
|
||||
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
|
||||
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
|
||||
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
|
||||
-> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
|
||||
```
|
||||
|
||||
|
@ -182,38 +182,38 @@ def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
|
||||
def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> {
|
||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||
|
||||
let description = [{
|
||||
Store a matrix through a pointer.
|
||||
|
||||
Pointer is the pointer to store through. It specifies
|
||||
start of memory region where elements of the matrix must
|
||||
Pointer is the pointer to store through. It specifies
|
||||
start of memory region where elements of the matrix must
|
||||
be stored and arranged according to Layout.
|
||||
|
||||
Object is the matrix to store. It must be
|
||||
Object is the matrix to store. It must be
|
||||
OpTypeJointMatrixINTEL.
|
||||
|
||||
Stride is the number of elements in memory between beginnings
|
||||
of successive rows, columns (or words) of the Object. It must
|
||||
Stride is the number of elements in memory between beginnings
|
||||
of successive rows, columns (or words) of the Object. It must
|
||||
be a scalar integer type.
|
||||
|
||||
Layout indicates how the values stored to memory are arranged.
|
||||
Layout indicates how the values stored to memory are arranged.
|
||||
It must be the result of a constant instruction.
|
||||
|
||||
Scope is syncronization scope for operation on the matrix.
|
||||
It must be the result of a constant instruction with scalar
|
||||
Scope is syncronization scope for operation on the matrix.
|
||||
It must be the result of a constant instruction with scalar
|
||||
integer type.
|
||||
|
||||
If present, any Memory Operands must begin with a memory operand
|
||||
literal. If not present, it is the same as specifying the memory
|
||||
If present, any Memory Operands must begin with a memory operand
|
||||
literal. If not present, it is the same as specifying the memory
|
||||
operand None.
|
||||
|
||||
#### Example:
|
||||
```mlir
|
||||
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
|
||||
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
|
||||
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
|
||||
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
|
||||
!spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
|
||||
```
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> {
|
||||
def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> {
|
||||
let summary = "TBD";
|
||||
|
||||
let description = [{
|
||||
|
|
|
@ -1335,15 +1335,15 @@ void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
|
|||
// spv.AtomicFAddEXTOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::AtomicFAddEXTOp::verify() {
|
||||
LogicalResult spirv::EXTAtomicFAddOp::verify() {
|
||||
return ::verifyAtomicUpdateOp<FloatType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
|
||||
ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return ::parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
|
||||
void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
|
||||
::printAtomicUpdateOp(*this, p);
|
||||
}
|
||||
|
||||
|
@ -2646,7 +2646,7 @@ LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
|
|||
// spv.SubgroupBlockReadINTEL
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
|
||||
ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
// Parse the storage class specification
|
||||
spirv::StorageClass storageClass;
|
||||
|
@ -2669,11 +2669,11 @@ ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
|
||||
void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
|
||||
printer << " " << ptr() << " : " << getType();
|
||||
}
|
||||
|
||||
LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
|
||||
LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
|
||||
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
|
||||
return failure();
|
||||
|
||||
|
@ -2684,7 +2684,7 @@ LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
|
|||
// spv.SubgroupBlockWriteINTEL
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
|
||||
ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
// Parse the storage class specification
|
||||
spirv::StorageClass storageClass;
|
||||
|
@ -2708,11 +2708,11 @@ ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
|
||||
void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
|
||||
printer << " " << ptr() << ", " << value() << " : " << value().getType();
|
||||
}
|
||||
|
||||
LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
|
||||
LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
|
||||
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
|
||||
return failure();
|
||||
|
||||
|
@ -3816,7 +3816,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
|
|||
// spv.CooperativeMatrixLoadNV
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
|
||||
ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
|
||||
Type strideType = parser.getBuilder().getIntegerType(32);
|
||||
|
@ -3838,7 +3838,7 @@ ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
|
||||
void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
|
||||
printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
|
||||
// Print optional memory access attribute.
|
||||
if (auto memAccess = memory_access())
|
||||
|
@ -3865,7 +3865,7 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
|
||||
LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
|
||||
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
|
||||
result().getType());
|
||||
}
|
||||
|
@ -3874,7 +3874,7 @@ LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
|
|||
// spv.CooperativeMatrixStoreNV
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
|
||||
ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
|
||||
Type strideType = parser.getBuilder().getIntegerType(32);
|
||||
|
@ -3896,7 +3896,7 @@ ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
|
||||
void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
|
||||
printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
|
||||
<< columnmajor();
|
||||
// Print optional memory access attribute.
|
||||
|
@ -3905,7 +3905,7 @@ void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
|
|||
printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
|
||||
}
|
||||
|
||||
LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
|
||||
LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
|
||||
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
|
||||
object().getType());
|
||||
}
|
||||
|
@ -3915,7 +3915,7 @@ LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult
|
||||
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
||||
verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
|
||||
if (op.c().getType() != op.result().getType())
|
||||
return op.emitOpError("result and third operand must have the same type");
|
||||
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
|
||||
|
@ -3936,7 +3936,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
|
||||
LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
|
||||
return verifyCoopMatrixMulAdd(*this);
|
||||
}
|
||||
|
||||
|
@ -3963,7 +3963,7 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
|
|||
// spv.JointMatrixLoadINTEL
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
|
||||
LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
|
||||
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
|
||||
result().getType());
|
||||
}
|
||||
|
@ -3972,7 +3972,7 @@ LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
|
|||
// spv.JointMatrixStoreINTEL
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
|
||||
LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
|
||||
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
|
||||
object().getType());
|
||||
}
|
||||
|
@ -3981,7 +3981,7 @@ LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
|
|||
// spv.JointMatrixMadINTEL
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
|
||||
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
|
||||
if (op.c().getType() != op.result().getType())
|
||||
return op.emitOpError("result and third operand must have the same type");
|
||||
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
|
||||
|
@ -4002,7 +4002,7 @@ static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult spirv::JointMatrixMadINTELOp::verify() {
|
||||
LogicalResult spirv::INTELJointMatrixMadOp::verify() {
|
||||
return verifyJointMatrixMad(*this);
|
||||
}
|
||||
|
||||
|
|
|
@ -240,7 +240,7 @@ ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
|
|||
PatternRewriter &rewriter) const {
|
||||
Value predicate = op->getOperand(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
|
||||
rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
|
||||
op, op->getResult(0).getType(), predicate);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -23,13 +23,17 @@ file_name=$1
|
|||
baseclass=$2
|
||||
|
||||
case $baseclass in
|
||||
Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp)
|
||||
Op | ArithmeticBinaryOp | ArithmeticUnaryOp \
|
||||
| LogicalBinaryOp | LogicalUnaryOp \
|
||||
| CastOp | ControlFlowOp | StructureOp \
|
||||
| AtomicUpdateOp | AtomicUpdateWithValueOp \
|
||||
| KhrVendorOp | ExtVendorOp | IntelVendorOp | NvVendorOp )
|
||||
;;
|
||||
*)
|
||||
echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
|
||||
echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
|
||||
echo "<baseclass> must be one of " \
|
||||
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)"
|
||||
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp|KhrVendorOp|ExtVendorOp|IntelVendorOp|NvVendorOp)"
|
||||
exit 1;
|
||||
;;
|
||||
esac
|
||||
|
|
|
@ -730,15 +730,19 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
|
|||
'{{\n let summary = {summary};\n\n let description = '
|
||||
'[{{\n{description}}}];{availability}\n')
|
||||
else:
|
||||
fmt_str = ('def SPV_{opname_src}Op : '
|
||||
'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> '
|
||||
fmt_str = ('def SPV_{vendor_name}{opname_src}Op : '
|
||||
'SPV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
|
||||
'{{\n let summary = {summary};\n\n let description = '
|
||||
'[{{\n{description}}}];{availability}\n')
|
||||
|
||||
vendor_name = ''
|
||||
inst_category = existing_info.get('inst_category', 'Op')
|
||||
if inst_category == 'Op':
|
||||
fmt_str +='\n let arguments = (ins{args});\n\n'\
|
||||
' let results = (outs{results});\n'
|
||||
elif inst_category.endswith('VendorOp'):
|
||||
vendor_name = inst_category.split('VendorOp')[0].upper()
|
||||
assert len(vendor_name) != 0, 'Invalid instruction category'
|
||||
|
||||
fmt_str +='{extras}'\
|
||||
'}}\n'
|
||||
|
@ -746,6 +750,9 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
|
|||
opname_src = instruction['opname']
|
||||
if opname.startswith('Op'):
|
||||
opname_src = opname_src[2:]
|
||||
if len(vendor_name) > 0:
|
||||
assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
|
||||
opname_src = opname_src[:-len(vendor_name)]
|
||||
|
||||
category_args = existing_info.get('category_args', '')
|
||||
|
||||
|
@ -759,7 +766,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
|
|||
|
||||
# Format summary. If the summary can fit in the same line, we print it out
|
||||
# as a "-quoted string; otherwise, wrap the lines using "[{...}]".
|
||||
summary = summary.strip();
|
||||
summary = summary.strip()
|
||||
if len(summary) + len(' let summary = "";') <= 80:
|
||||
summary = '"{}"'.format(summary)
|
||||
else:
|
||||
|
@ -815,6 +822,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
|
|||
opcode=instruction['opcode'],
|
||||
category_args=category_args,
|
||||
inst_category=inst_category,
|
||||
vendor_name=vendor_name,
|
||||
traits=existing_info.get('traits', ''),
|
||||
summary=summary,
|
||||
description=description,
|
||||
|
|
Loading…
Reference in New Issue