[mlir][spirv] Refactor vendor op definitions
Use dedicated vendor op classes/categories. This is so that we can later change the mnemonics of all vendor ops by changing the base class: `SPV_VendorOp`. Issue: https://github.com/llvm/llvm-project/issues/56863
This commit is contained in:
parent
6a378b38ff
commit
b8bea837f3
|
@ -262,7 +262,7 @@ def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
|
def SPV_EXTAtomicFAddOp : SPV_ExtVendorOp<"AtomicFAdd", []> {
|
||||||
let summary = "TBD";
|
let summary = "TBD";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -279,7 +279,7 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
|
||||||
|
|
||||||
3) store the New Value back through Pointer.
|
3) store the New Value back through Pointer.
|
||||||
|
|
||||||
The instruction’s result is the Original Value.
|
The instruction's result is the Original Value.
|
||||||
|
|
||||||
Result Type must be a floating-point type scalar.
|
Result Type must be a floating-point type scalar.
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
|
def SPV_NVCooperativeMatrixLengthOp : SPV_NvVendorOp<"CooperativeMatrixLength",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
|
def SPV_NVCooperativeMatrixLoadOp : SPV_NvVendorOp<"CooperativeMatrixLoad", []> {
|
||||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -136,7 +136,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
|
def SPV_NVCooperativeMatrixMulAddOp : SPV_NvVendorOp<"CooperativeMatrixMulAdd",
|
||||||
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
|
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
|
||||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
|
def SPV_NVCooperativeMatrixStoreOp : SPV_NvVendorOp<"CooperativeMatrixStore", []> {
|
||||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -92,7 +92,7 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
|
def SPV_KHRSubgroupBallotOp : SPV_KhrVendorOp<"SubgroupBallot", []> {
|
||||||
let summary = "See extension SPV_KHR_shader_ballot";
|
let summary = "See extension SPV_KHR_shader_ballot";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -146,7 +146,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
|
def SPV_INTELSubgroupBlockReadOp : SPV_IntelVendorOp<"SubgroupBlockRead", []> {
|
||||||
let summary = "See extension SPV_INTEL_subgroups";
|
let summary = "See extension SPV_INTEL_subgroups";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -197,7 +197,7 @@ def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
|
def SPV_INTELSubgroupBlockWriteOp : SPV_IntelVendorOp<"SubgroupBlockWrite", []> {
|
||||||
let summary = "See extension SPV_INTEL_subgroups";
|
let summary = "See extension SPV_INTEL_subgroups";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
|
def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTE
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
|
def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> {
|
||||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -119,7 +119,7 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
|
def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad",
|
||||||
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
|
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
|
||||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
|
def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> {
|
||||||
let summary = "See extension SPV_INTEL_joint_matrix";
|
let summary = "See extension SPV_INTEL_joint_matrix";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -18,7 +18,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> {
|
def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> {
|
||||||
let summary = "TBD";
|
let summary = "TBD";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -1335,15 +1335,15 @@ void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
|
||||||
// spv.AtomicFAddEXTOp
|
// spv.AtomicFAddEXTOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult spirv::AtomicFAddEXTOp::verify() {
|
LogicalResult spirv::EXTAtomicFAddOp::verify() {
|
||||||
return ::verifyAtomicUpdateOp<FloatType>(getOperation());
|
return ::verifyAtomicUpdateOp<FloatType>(getOperation());
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
|
ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
|
||||||
OperationState &result) {
|
OperationState &result) {
|
||||||
return ::parseAtomicUpdateOp(parser, result, true);
|
return ::parseAtomicUpdateOp(parser, result, true);
|
||||||
}
|
}
|
||||||
void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
|
void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
|
||||||
::printAtomicUpdateOp(*this, p);
|
::printAtomicUpdateOp(*this, p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2646,7 +2646,7 @@ LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
|
||||||
// spv.SubgroupBlockReadINTEL
|
// spv.SubgroupBlockReadINTEL
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
|
ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
|
||||||
OperationState &result) {
|
OperationState &result) {
|
||||||
// Parse the storage class specification
|
// Parse the storage class specification
|
||||||
spirv::StorageClass storageClass;
|
spirv::StorageClass storageClass;
|
||||||
|
@ -2669,11 +2669,11 @@ ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
|
void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
|
||||||
printer << " " << ptr() << " : " << getType();
|
printer << " " << ptr() << " : " << getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
|
LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
|
||||||
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
|
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -2684,7 +2684,7 @@ LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
|
||||||
// spv.SubgroupBlockWriteINTEL
|
// spv.SubgroupBlockWriteINTEL
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
|
ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
|
||||||
OperationState &result) {
|
OperationState &result) {
|
||||||
// Parse the storage class specification
|
// Parse the storage class specification
|
||||||
spirv::StorageClass storageClass;
|
spirv::StorageClass storageClass;
|
||||||
|
@ -2708,11 +2708,11 @@ ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
|
void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
|
||||||
printer << " " << ptr() << ", " << value() << " : " << value().getType();
|
printer << " " << ptr() << ", " << value() << " : " << value().getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
|
LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
|
||||||
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
|
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -3816,7 +3816,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
|
||||||
// spv.CooperativeMatrixLoadNV
|
// spv.CooperativeMatrixLoadNV
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
|
ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
|
||||||
OperationState &result) {
|
OperationState &result) {
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
|
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
|
||||||
Type strideType = parser.getBuilder().getIntegerType(32);
|
Type strideType = parser.getBuilder().getIntegerType(32);
|
||||||
|
@ -3838,7 +3838,7 @@ ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
|
void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
|
||||||
printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
|
printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
|
||||||
// Print optional memory access attribute.
|
// Print optional memory access attribute.
|
||||||
if (auto memAccess = memory_access())
|
if (auto memAccess = memory_access())
|
||||||
|
@ -3865,7 +3865,7 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
|
LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
|
||||||
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
|
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
|
||||||
result().getType());
|
result().getType());
|
||||||
}
|
}
|
||||||
|
@ -3874,7 +3874,7 @@ LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
|
||||||
// spv.CooperativeMatrixStoreNV
|
// spv.CooperativeMatrixStoreNV
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
|
ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
|
||||||
OperationState &result) {
|
OperationState &result) {
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
|
||||||
Type strideType = parser.getBuilder().getIntegerType(32);
|
Type strideType = parser.getBuilder().getIntegerType(32);
|
||||||
|
@ -3896,7 +3896,7 @@ ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
|
void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
|
||||||
printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
|
printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
|
||||||
<< columnmajor();
|
<< columnmajor();
|
||||||
// Print optional memory access attribute.
|
// Print optional memory access attribute.
|
||||||
|
@ -3905,7 +3905,7 @@ void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
|
||||||
printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
|
printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
|
LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
|
||||||
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
|
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
|
||||||
object().getType());
|
object().getType());
|
||||||
}
|
}
|
||||||
|
@ -3915,7 +3915,7 @@ LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
|
||||||
if (op.c().getType() != op.result().getType())
|
if (op.c().getType() != op.result().getType())
|
||||||
return op.emitOpError("result and third operand must have the same type");
|
return op.emitOpError("result and third operand must have the same type");
|
||||||
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
|
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
|
||||||
|
@ -3936,7 +3936,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
|
LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
|
||||||
return verifyCoopMatrixMulAdd(*this);
|
return verifyCoopMatrixMulAdd(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3963,7 +3963,7 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
|
||||||
// spv.JointMatrixLoadINTEL
|
// spv.JointMatrixLoadINTEL
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
|
LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
|
||||||
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
|
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
|
||||||
result().getType());
|
result().getType());
|
||||||
}
|
}
|
||||||
|
@ -3972,7 +3972,7 @@ LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
|
||||||
// spv.JointMatrixStoreINTEL
|
// spv.JointMatrixStoreINTEL
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
|
LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
|
||||||
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
|
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
|
||||||
object().getType());
|
object().getType());
|
||||||
}
|
}
|
||||||
|
@ -3981,7 +3981,7 @@ LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
|
||||||
// spv.JointMatrixMadINTEL
|
// spv.JointMatrixMadINTEL
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
|
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
|
||||||
if (op.c().getType() != op.result().getType())
|
if (op.c().getType() != op.result().getType())
|
||||||
return op.emitOpError("result and third operand must have the same type");
|
return op.emitOpError("result and third operand must have the same type");
|
||||||
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
|
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
|
||||||
|
@ -4002,7 +4002,7 @@ static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult spirv::JointMatrixMadINTELOp::verify() {
|
LogicalResult spirv::INTELJointMatrixMadOp::verify() {
|
||||||
return verifyJointMatrixMad(*this);
|
return verifyJointMatrixMad(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -240,7 +240,7 @@ ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) const {
|
PatternRewriter &rewriter) const {
|
||||||
Value predicate = op->getOperand(0);
|
Value predicate = op->getOperand(0);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
|
rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
|
||||||
op, op->getResult(0).getType(), predicate);
|
op, op->getResult(0).getType(), predicate);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,13 +23,17 @@ file_name=$1
|
||||||
baseclass=$2
|
baseclass=$2
|
||||||
|
|
||||||
case $baseclass in
|
case $baseclass in
|
||||||
Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp)
|
Op | ArithmeticBinaryOp | ArithmeticUnaryOp \
|
||||||
|
| LogicalBinaryOp | LogicalUnaryOp \
|
||||||
|
| CastOp | ControlFlowOp | StructureOp \
|
||||||
|
| AtomicUpdateOp | AtomicUpdateWithValueOp \
|
||||||
|
| KhrVendorOp | ExtVendorOp | IntelVendorOp | NvVendorOp )
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
|
echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
|
||||||
echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
|
echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
|
||||||
echo "<baseclass> must be one of " \
|
echo "<baseclass> must be one of " \
|
||||||
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)"
|
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp|KhrVendorOp|ExtVendorOp|IntelVendorOp|NvVendorOp)"
|
||||||
exit 1;
|
exit 1;
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|
|
@ -730,15 +730,19 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
|
||||||
'{{\n let summary = {summary};\n\n let description = '
|
'{{\n let summary = {summary};\n\n let description = '
|
||||||
'[{{\n{description}}}];{availability}\n')
|
'[{{\n{description}}}];{availability}\n')
|
||||||
else:
|
else:
|
||||||
fmt_str = ('def SPV_{opname_src}Op : '
|
fmt_str = ('def SPV_{vendor_name}{opname_src}Op : '
|
||||||
'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> '
|
'SPV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
|
||||||
'{{\n let summary = {summary};\n\n let description = '
|
'{{\n let summary = {summary};\n\n let description = '
|
||||||
'[{{\n{description}}}];{availability}\n')
|
'[{{\n{description}}}];{availability}\n')
|
||||||
|
|
||||||
|
vendor_name = ''
|
||||||
inst_category = existing_info.get('inst_category', 'Op')
|
inst_category = existing_info.get('inst_category', 'Op')
|
||||||
if inst_category == 'Op':
|
if inst_category == 'Op':
|
||||||
fmt_str +='\n let arguments = (ins{args});\n\n'\
|
fmt_str +='\n let arguments = (ins{args});\n\n'\
|
||||||
' let results = (outs{results});\n'
|
' let results = (outs{results});\n'
|
||||||
|
elif inst_category.endswith('VendorOp'):
|
||||||
|
vendor_name = inst_category.split('VendorOp')[0].upper()
|
||||||
|
assert len(vendor_name) != 0, 'Invalid instruction category'
|
||||||
|
|
||||||
fmt_str +='{extras}'\
|
fmt_str +='{extras}'\
|
||||||
'}}\n'
|
'}}\n'
|
||||||
|
@ -746,6 +750,9 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
|
||||||
opname_src = instruction['opname']
|
opname_src = instruction['opname']
|
||||||
if opname.startswith('Op'):
|
if opname.startswith('Op'):
|
||||||
opname_src = opname_src[2:]
|
opname_src = opname_src[2:]
|
||||||
|
if len(vendor_name) > 0:
|
||||||
|
assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
|
||||||
|
opname_src = opname_src[:-len(vendor_name)]
|
||||||
|
|
||||||
category_args = existing_info.get('category_args', '')
|
category_args = existing_info.get('category_args', '')
|
||||||
|
|
||||||
|
@ -759,7 +766,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
|
||||||
|
|
||||||
# Format summary. If the summary can fit in the same line, we print it out
|
# Format summary. If the summary can fit in the same line, we print it out
|
||||||
# as a "-quoted string; otherwise, wrap the lines using "[{...}]".
|
# as a "-quoted string; otherwise, wrap the lines using "[{...}]".
|
||||||
summary = summary.strip();
|
summary = summary.strip()
|
||||||
if len(summary) + len(' let summary = "";') <= 80:
|
if len(summary) + len(' let summary = "";') <= 80:
|
||||||
summary = '"{}"'.format(summary)
|
summary = '"{}"'.format(summary)
|
||||||
else:
|
else:
|
||||||
|
@ -815,6 +822,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
|
||||||
opcode=instruction['opcode'],
|
opcode=instruction['opcode'],
|
||||||
category_args=category_args,
|
category_args=category_args,
|
||||||
inst_category=inst_category,
|
inst_category=inst_category,
|
||||||
|
vendor_name=vendor_name,
|
||||||
traits=existing_info.get('traits', ''),
|
traits=existing_info.get('traits', ''),
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
|
|
Loading…
Reference in New Issue