[mlir][spirv] Refactor vendor op definitions

Use dedicated vendor op classes/categories. This is so that we can later
change the mnemonics of all vendor ops by changing the base class: `SPV_VendorOp`.

Issue: https://github.com/llvm/llvm-project/issues/56863
This commit is contained in:
Jakub Kuderski 2022-09-02 14:59:42 -04:00
parent 6a378b38ff
commit b8bea837f3
9 changed files with 88 additions and 76 deletions

View File

@ -262,7 +262,7 @@ def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
// -----
def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
def SPV_EXTAtomicFAddOp : SPV_ExtVendorOp<"AtomicFAdd", []> {
let summary = "TBD";
let description = [{
@ -279,7 +279,7 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
3) store the New Value back through Pointer.
The instructions result is the Original Value.
The instruction's result is the Original Value.
Result Type must be a floating-point type scalar.

View File

@ -15,7 +15,7 @@
// -----
def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
def SPV_NVCooperativeMatrixLengthOp : SPV_NvVendorOp<"CooperativeMatrixLength",
[NoSideEffect]> {
let summary = "See extension SPV_NV_cooperative_matrix";
@ -60,7 +60,7 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
// -----
def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
def SPV_NVCooperativeMatrixLoadOp : SPV_NvVendorOp<"CooperativeMatrixLoad", []> {
let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{
@ -136,7 +136,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
// -----
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
def SPV_NVCooperativeMatrixMulAddOp : SPV_NvVendorOp<"CooperativeMatrixMulAdd",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_NV_cooperative_matrix";
@ -210,7 +210,7 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
// -----
def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
def SPV_NVCooperativeMatrixStoreOp : SPV_NvVendorOp<"CooperativeMatrixStore", []> {
let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{

View File

@ -92,7 +92,7 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
// -----
def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
def SPV_KHRSubgroupBallotOp : SPV_KhrVendorOp<"SubgroupBallot", []> {
let summary = "See extension SPV_KHR_shader_ballot";
let description = [{
@ -146,7 +146,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
// -----
def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
def SPV_INTELSubgroupBlockReadOp : SPV_IntelVendorOp<"SubgroupBlockRead", []> {
let summary = "See extension SPV_INTEL_subgroups";
let description = [{
@ -197,7 +197,7 @@ def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
// -----
def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
def SPV_INTELSubgroupBlockWriteOp : SPV_IntelVendorOp<"SubgroupBlockWrite", []> {
let summary = "See extension SPV_INTEL_subgroups";
let description = [{

View File

@ -15,12 +15,12 @@
// -----
def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength",
[NoSideEffect]> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
Return number of components owned by the current work-item in
Return number of components owned by the current work-item in
a joint matrix.
Result Type must be an 32-bit unsigned integer type scalar.
@ -60,7 +60,7 @@ def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTE
// -----
def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
@ -68,26 +68,26 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
Pointer is the pointer to load through. It specifies start of memory region where
Pointer is the pointer to load through. It specifies start of memory region where
elements of the matrix are stored and arranged according to Layout.
Stride is the number of elements in memory between beginnings of successive rows,
Stride is the number of elements in memory between beginnings of successive rows,
columns (or words) in the result. It must be a scalar integer type.
Layout indicates how the values loaded from memory are arranged. It must be the
Layout indicates how the values loaded from memory are arranged. It must be the
result of a constant instruction.
Scope is syncronization scope for operation on the matrix. It must be the result
Scope is syncronization scope for operation on the matrix. It must be the result
of a constant instruction with scalar integer type.
If present, any Memory Operands must begin with a memory operand literal. If not
If present, any Memory Operands must begin with a memory operand literal. If not
present, it is the same as specifying the memory operand None.
#### Example:
```mlir
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
{memory_access = #spv.memory_access<Volatile>} :
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
{memory_access = #spv.memory_access<Volatile>} :
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
!spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
```
}];
@ -119,39 +119,39 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
// -----
def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
Multiply matrix A by matrix B and add matrix C to the result
of the multiplication: A*B+C. Here A is a M x K matrix, B is
Multiply matrix A by matrix B and add matrix C to the result
of the multiplication: A*B+C. Here A is a M x K matrix, B is
a K x N matrix and C is a M x N matrix.
Behavior is undefined if sizes of operands do not meet the
conditions above. All operands and the Result Type must be
Behavior is undefined if sizes of operands do not meet the
conditions above. All operands and the Result Type must be
OpTypeJointMatrixINTEL.
A must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to M and Column Count
A must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to M and Column Count
equals to K
B must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to K and Column Count
B must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to K and Column Count
equals to N
C and Result Type must be a OpTypeJointMatrixINTEL with Row
C and Result Type must be a OpTypeJointMatrixINTEL with Row
Count equals to M and Column Count equals to N
Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
integer type.
#### Example:
```mlir
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
-> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
```
@ -182,38 +182,38 @@ def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
// -----
def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
Store a matrix through a pointer.
Pointer is the pointer to store through. It specifies
start of memory region where elements of the matrix must
Pointer is the pointer to store through. It specifies
start of memory region where elements of the matrix must
be stored and arranged according to Layout.
Object is the matrix to store. It must be
Object is the matrix to store. It must be
OpTypeJointMatrixINTEL.
Stride is the number of elements in memory between beginnings
of successive rows, columns (or words) of the Object. It must
Stride is the number of elements in memory between beginnings
of successive rows, columns (or words) of the Object. It must
be a scalar integer type.
Layout indicates how the values stored to memory are arranged.
Layout indicates how the values stored to memory are arranged.
It must be the result of a constant instruction.
Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
integer type.
If present, any Memory Operands must begin with a memory operand
literal. If not present, it is the same as specifying the memory
If present, any Memory Operands must begin with a memory operand
literal. If not present, it is the same as specifying the memory
operand None.
#### Example:
```mlir
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
!spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
```

View File

@ -18,7 +18,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
// -----
def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> {
def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> {
let summary = "TBD";
let description = [{

View File

@ -1335,15 +1335,15 @@ void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
// spv.AtomicFAddEXTOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::AtomicFAddEXTOp::verify() {
LogicalResult spirv::EXTAtomicFAddOp::verify() {
return ::verifyAtomicUpdateOp<FloatType>(getOperation());
}
ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
OperationState &result) {
return ::parseAtomicUpdateOp(parser, result, true);
}
void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
::printAtomicUpdateOp(*this, p);
}
@ -2646,7 +2646,7 @@ LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
// spv.SubgroupBlockReadINTEL
//===----------------------------------------------------------------------===//
ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
@ -2669,11 +2669,11 @@ ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
return success();
}
void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << " : " << getType();
}
LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure();
@ -2684,7 +2684,7 @@ LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
// spv.SubgroupBlockWriteINTEL
//===----------------------------------------------------------------------===//
ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
@ -2708,11 +2708,11 @@ ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
return success();
}
void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << ", " << value() << " : " << value().getType();
}
LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure();
@ -3816,7 +3816,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
// spv.CooperativeMatrixLoadNV
//===----------------------------------------------------------------------===//
ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
@ -3838,7 +3838,7 @@ ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
return success();
}
void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
// Print optional memory access attribute.
if (auto memAccess = memory_access())
@ -3865,7 +3865,7 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
return success();
}
LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
result().getType());
}
@ -3874,7 +3874,7 @@ LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
// spv.CooperativeMatrixStoreNV
//===----------------------------------------------------------------------===//
ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
@ -3896,7 +3896,7 @@ ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
return success();
}
void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
<< columnmajor();
// Print optional memory access attribute.
@ -3905,7 +3905,7 @@ void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
}
LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
object().getType());
}
@ -3915,7 +3915,7 @@ LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
//===----------------------------------------------------------------------===//
static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
@ -3936,7 +3936,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
return success();
}
LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
return verifyCoopMatrixMulAdd(*this);
}
@ -3963,7 +3963,7 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
// spv.JointMatrixLoadINTEL
//===----------------------------------------------------------------------===//
LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
result().getType());
}
@ -3972,7 +3972,7 @@ LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
// spv.JointMatrixStoreINTEL
//===----------------------------------------------------------------------===//
LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
object().getType());
}
@ -3981,7 +3981,7 @@ LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
// spv.JointMatrixMadINTEL
//===----------------------------------------------------------------------===//
static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
@ -4002,7 +4002,7 @@ static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
return success();
}
LogicalResult spirv::JointMatrixMadINTELOp::verify() {
LogicalResult spirv::INTELJointMatrixMadOp::verify() {
return verifyJointMatrixMad(*this);
}

View File

@ -240,7 +240,7 @@ ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0);
rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
op, op->getResult(0).getType(), predicate);
return success();
}

View File

@ -23,13 +23,17 @@ file_name=$1
baseclass=$2
case $baseclass in
Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp)
Op | ArithmeticBinaryOp | ArithmeticUnaryOp \
| LogicalBinaryOp | LogicalUnaryOp \
| CastOp | ControlFlowOp | StructureOp \
| AtomicUpdateOp | AtomicUpdateWithValueOp \
| KhrVendorOp | ExtVendorOp | IntelVendorOp | NvVendorOp )
;;
*)
echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
echo "<baseclass> must be one of " \
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)"
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp|KhrVendorOp|ExtVendorOp|IntelVendorOp|NvVendorOp)"
exit 1;
;;
esac

View File

@ -730,15 +730,19 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
'{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n')
else:
fmt_str = ('def SPV_{opname_src}Op : '
'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> '
fmt_str = ('def SPV_{vendor_name}{opname_src}Op : '
'SPV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
'{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n')
vendor_name = ''
inst_category = existing_info.get('inst_category', 'Op')
if inst_category == 'Op':
fmt_str +='\n let arguments = (ins{args});\n\n'\
' let results = (outs{results});\n'
elif inst_category.endswith('VendorOp'):
vendor_name = inst_category.split('VendorOp')[0].upper()
assert len(vendor_name) != 0, 'Invalid instruction category'
fmt_str +='{extras}'\
'}}\n'
@ -746,6 +750,9 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
opname_src = instruction['opname']
if opname.startswith('Op'):
opname_src = opname_src[2:]
if len(vendor_name) > 0:
assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
opname_src = opname_src[:-len(vendor_name)]
category_args = existing_info.get('category_args', '')
@ -759,7 +766,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
# Format summary. If the summary can fit in the same line, we print it out
# as a "-quoted string; otherwise, wrap the lines using "[{...}]".
summary = summary.strip();
summary = summary.strip()
if len(summary) + len(' let summary = "";') <= 80:
summary = '"{}"'.format(summary)
else:
@ -815,6 +822,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
opcode=instruction['opcode'],
category_args=category_args,
inst_category=inst_category,
vendor_name=vendor_name,
traits=existing_info.get('traits', ''),
summary=summary,
description=description,