[mlir][spirv] Refactor vendor op definitions

Use dedicated vendor op classes/categories. This is so that we can later
change the mnemonics of all vendor ops by changing the base class: `SPV_VendorOp`.

Issue: https://github.com/llvm/llvm-project/issues/56863
This commit is contained in:
Jakub Kuderski 2022-09-02 14:59:42 -04:00
parent 6a378b38ff
commit b8bea837f3
9 changed files with 88 additions and 76 deletions

View File

@ -262,7 +262,7 @@ def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
// ----- // -----
def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> { def SPV_EXTAtomicFAddOp : SPV_ExtVendorOp<"AtomicFAdd", []> {
let summary = "TBD"; let summary = "TBD";
let description = [{ let description = [{
@ -279,7 +279,7 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
3) store the New Value back through Pointer. 3) store the New Value back through Pointer.
The instructions result is the Original Value. The instruction's result is the Original Value.
Result Type must be a floating-point type scalar. Result Type must be a floating-point type scalar.

View File

@ -15,7 +15,7 @@
// ----- // -----
def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV", def SPV_NVCooperativeMatrixLengthOp : SPV_NvVendorOp<"CooperativeMatrixLength",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "See extension SPV_NV_cooperative_matrix"; let summary = "See extension SPV_NV_cooperative_matrix";
@ -60,7 +60,7 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
// ----- // -----
def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> { def SPV_NVCooperativeMatrixLoadOp : SPV_NvVendorOp<"CooperativeMatrixLoad", []> {
let summary = "See extension SPV_NV_cooperative_matrix"; let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{ let description = [{
@ -136,7 +136,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
// ----- // -----
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV", def SPV_NVCooperativeMatrixMulAddOp : SPV_NvVendorOp<"CooperativeMatrixMulAdd",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> { [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_NV_cooperative_matrix"; let summary = "See extension SPV_NV_cooperative_matrix";
@ -210,7 +210,7 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
// ----- // -----
def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> { def SPV_NVCooperativeMatrixStoreOp : SPV_NvVendorOp<"CooperativeMatrixStore", []> {
let summary = "See extension SPV_NV_cooperative_matrix"; let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{ let description = [{

View File

@ -92,7 +92,7 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
// ----- // -----
def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> { def SPV_KHRSubgroupBallotOp : SPV_KhrVendorOp<"SubgroupBallot", []> {
let summary = "See extension SPV_KHR_shader_ballot"; let summary = "See extension SPV_KHR_shader_ballot";
let description = [{ let description = [{
@ -146,7 +146,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
// ----- // -----
def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> { def SPV_INTELSubgroupBlockReadOp : SPV_IntelVendorOp<"SubgroupBlockRead", []> {
let summary = "See extension SPV_INTEL_subgroups"; let summary = "See extension SPV_INTEL_subgroups";
let description = [{ let description = [{
@ -197,7 +197,7 @@ def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
// ----- // -----
def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> { def SPV_INTELSubgroupBlockWriteOp : SPV_IntelVendorOp<"SubgroupBlockWrite", []> {
let summary = "See extension SPV_INTEL_subgroups"; let summary = "See extension SPV_INTEL_subgroups";
let description = [{ let description = [{

View File

@ -15,7 +15,7 @@
// ----- // -----
def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL", def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "See extension SPV_INTEL_joint_matrix"; let summary = "See extension SPV_INTEL_joint_matrix";
@ -60,7 +60,7 @@ def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTE
// ----- // -----
def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> { def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> {
let summary = "See extension SPV_INTEL_joint_matrix"; let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{ let description = [{
@ -119,7 +119,7 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
// ----- // -----
def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL", def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> { [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_INTEL_joint_matrix"; let summary = "See extension SPV_INTEL_joint_matrix";
@ -182,7 +182,7 @@ def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
// ----- // -----
def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> { def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> {
let summary = "See extension SPV_INTEL_joint_matrix"; let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{ let description = [{

View File

@ -18,7 +18,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
// ----- // -----
def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> { def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> {
let summary = "TBD"; let summary = "TBD";
let description = [{ let description = [{

View File

@ -1335,15 +1335,15 @@ void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
// spv.AtomicFAddEXTOp // spv.AtomicFAddEXTOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult spirv::AtomicFAddEXTOp::verify() { LogicalResult spirv::EXTAtomicFAddOp::verify() {
return ::verifyAtomicUpdateOp<FloatType>(getOperation()); return ::verifyAtomicUpdateOp<FloatType>(getOperation());
} }
ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser, ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
return ::parseAtomicUpdateOp(parser, result, true); return ::parseAtomicUpdateOp(parser, result, true);
} }
void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) { void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
::printAtomicUpdateOp(*this, p); ::printAtomicUpdateOp(*this, p);
} }
@ -2646,7 +2646,7 @@ LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
// spv.SubgroupBlockReadINTEL // spv.SubgroupBlockReadINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser, ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
// Parse the storage class specification // Parse the storage class specification
spirv::StorageClass storageClass; spirv::StorageClass storageClass;
@ -2669,11 +2669,11 @@ ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
return success(); return success();
} }
void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) { void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << " : " << getType(); printer << " " << ptr() << " : " << getType();
} }
LogicalResult spirv::SubgroupBlockReadINTELOp::verify() { LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure(); return failure();
@ -2684,7 +2684,7 @@ LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
// spv.SubgroupBlockWriteINTEL // spv.SubgroupBlockWriteINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser, ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
// Parse the storage class specification // Parse the storage class specification
spirv::StorageClass storageClass; spirv::StorageClass storageClass;
@ -2708,11 +2708,11 @@ ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
return success(); return success();
} }
void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) { void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << ", " << value() << " : " << value().getType(); printer << " " << ptr() << ", " << value() << " : " << value().getType();
} }
LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() { LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure(); return failure();
@ -3816,7 +3816,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
// spv.CooperativeMatrixLoadNV // spv.CooperativeMatrixLoadNV
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser, ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo; SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32); Type strideType = parser.getBuilder().getIntegerType(32);
@ -3838,7 +3838,7 @@ ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
return success(); return success();
} }
void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) { void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << stride() << ", " << columnmajor(); printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
// Print optional memory access attribute. // Print optional memory access attribute.
if (auto memAccess = memory_access()) if (auto memAccess = memory_access())
@ -3865,7 +3865,7 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
return success(); return success();
} }
LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() { LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(), return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
result().getType()); result().getType());
} }
@ -3874,7 +3874,7 @@ LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
// spv.CooperativeMatrixStoreNV // spv.CooperativeMatrixStoreNV
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser, ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
OperationState &result) { OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo; SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32); Type strideType = parser.getBuilder().getIntegerType(32);
@ -3896,7 +3896,7 @@ ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
return success(); return success();
} }
void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) { void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << object() << ", " << stride() << ", " printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
<< columnmajor(); << columnmajor();
// Print optional memory access attribute. // Print optional memory access attribute.
@ -3905,7 +3905,7 @@ void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
printer << " : " << pointer().getType() << ", " << getOperand(1).getType(); printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
} }
LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() { LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(), return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
object().getType()); object().getType());
} }
@ -3915,7 +3915,7 @@ LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) { verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
if (op.c().getType() != op.result().getType()) if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type"); return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>(); auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
@ -3936,7 +3936,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
return success(); return success();
} }
LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() { LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
return verifyCoopMatrixMulAdd(*this); return verifyCoopMatrixMulAdd(*this);
} }
@ -3963,7 +3963,7 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
// spv.JointMatrixLoadINTEL // spv.JointMatrixLoadINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult spirv::JointMatrixLoadINTELOp::verify() { LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(), return verifyPointerAndJointMatrixType(*this, pointer().getType(),
result().getType()); result().getType());
} }
@ -3972,7 +3972,7 @@ LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
// spv.JointMatrixStoreINTEL // spv.JointMatrixStoreINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult spirv::JointMatrixStoreINTELOp::verify() { LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(), return verifyPointerAndJointMatrixType(*this, pointer().getType(),
object().getType()); object().getType());
} }
@ -3981,7 +3981,7 @@ LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
// spv.JointMatrixMadINTEL // spv.JointMatrixMadINTEL
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) { static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
if (op.c().getType() != op.result().getType()) if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type"); return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>(); auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
@ -4002,7 +4002,7 @@ static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
return success(); return success();
} }
LogicalResult spirv::JointMatrixMadINTELOp::verify() { LogicalResult spirv::INTELJointMatrixMadOp::verify() {
return verifyJointMatrixMad(*this); return verifyJointMatrixMad(*this);
} }

View File

@ -240,7 +240,7 @@ ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const { PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0); Value predicate = op->getOperand(0);
rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>( rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
op, op->getResult(0).getType(), predicate); op, op->getResult(0).getType(), predicate);
return success(); return success();
} }

View File

@ -23,13 +23,17 @@ file_name=$1
baseclass=$2 baseclass=$2
case $baseclass in case $baseclass in
Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp) Op | ArithmeticBinaryOp | ArithmeticUnaryOp \
| LogicalBinaryOp | LogicalUnaryOp \
| CastOp | ControlFlowOp | StructureOp \
| AtomicUpdateOp | AtomicUpdateWithValueOp \
| KhrVendorOp | ExtVendorOp | IntelVendorOp | NvVendorOp )
;; ;;
*) *)
echo "Usage : " $0 "<filename> <baseclass> (<opname>)*" echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
echo "<filename> is the file name of MLIR SPIR-V op definitions spec" echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
echo "<baseclass> must be one of " \ echo "<baseclass> must be one of " \
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)" "(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp|KhrVendorOp|ExtVendorOp|IntelVendorOp|NvVendorOp)"
exit 1; exit 1;
;; ;;
esac esac

View File

@ -730,15 +730,19 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
'{{\n let summary = {summary};\n\n let description = ' '{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n') '[{{\n{description}}}];{availability}\n')
else: else:
fmt_str = ('def SPV_{opname_src}Op : ' fmt_str = ('def SPV_{vendor_name}{opname_src}Op : '
'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> ' 'SPV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
'{{\n let summary = {summary};\n\n let description = ' '{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n') '[{{\n{description}}}];{availability}\n')
vendor_name = ''
inst_category = existing_info.get('inst_category', 'Op') inst_category = existing_info.get('inst_category', 'Op')
if inst_category == 'Op': if inst_category == 'Op':
fmt_str +='\n let arguments = (ins{args});\n\n'\ fmt_str +='\n let arguments = (ins{args});\n\n'\
' let results = (outs{results});\n' ' let results = (outs{results});\n'
elif inst_category.endswith('VendorOp'):
vendor_name = inst_category.split('VendorOp')[0].upper()
assert len(vendor_name) != 0, 'Invalid instruction category'
fmt_str +='{extras}'\ fmt_str +='{extras}'\
'}}\n' '}}\n'
@ -746,6 +750,9 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
opname_src = instruction['opname'] opname_src = instruction['opname']
if opname.startswith('Op'): if opname.startswith('Op'):
opname_src = opname_src[2:] opname_src = opname_src[2:]
if len(vendor_name) > 0:
assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
opname_src = opname_src[:-len(vendor_name)]
category_args = existing_info.get('category_args', '') category_args = existing_info.get('category_args', '')
@ -759,7 +766,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
# Format summary. If the summary can fit in the same line, we print it out # Format summary. If the summary can fit in the same line, we print it out
# as a "-quoted string; otherwise, wrap the lines using "[{...}]". # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
summary = summary.strip(); summary = summary.strip()
if len(summary) + len(' let summary = "";') <= 80: if len(summary) + len(' let summary = "";') <= 80:
summary = '"{}"'.format(summary) summary = '"{}"'.format(summary)
else: else:
@ -815,6 +822,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
opcode=instruction['opcode'], opcode=instruction['opcode'],
category_args=category_args, category_args=category_args,
inst_category=inst_category, inst_category=inst_category,
vendor_name=vendor_name,
traits=existing_info.get('traits', ''), traits=existing_info.get('traits', ''),
summary=summary, summary=summary,
description=description, description=description,