forked from OSchip/llvm-project
[mlir][spirv] Drop experimental LinalgToSPIRV pass
This experimental pass is unused and obsolete. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D139056
This commit is contained in:
parent
b948a9f40f
commit
9ad215bb3d
|
@ -1,28 +0,0 @@
|
|||
//===- LinalgToSPIRV.h - Linalg to SPIR-V Patterns --------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Provides patterns to convert Linalg dialect to SPIR-V dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H
|
||||
#define MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class SPIRVTypeConverter;
|
||||
class RewritePatternSet;
|
||||
|
||||
/// Appends to a pattern list additional patterns for translating Linalg ops to
|
||||
/// SPIR-V ops.
|
||||
void populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H
|
|
@ -1,29 +0,0 @@
|
|||
//===- LinalgToSPIRVPass.h - Linalg to SPIR-V Passes -----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Provides passes to convert Linalg dialect to SPIR-V dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRVPASS_H
|
||||
#define MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRVPASS_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
|
||||
#define GEN_PASS_DECL_CONVERTLINALGTOSPIRV
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
/// Creates and returns a pass to convert Linalg ops to SPIR-V ops.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLinalgToSPIRVPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRVPASS_H
|
|
@ -31,7 +31,6 @@
|
|||
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
|
||||
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
|
||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
|
||||
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
|
||||
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
|
||||
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
|
|
|
@ -485,20 +485,6 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
|
|||
let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LinalgToSPIRV
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> {
|
||||
let summary = "Convert Linalg dialect to SPIR-V dialect";
|
||||
let description = [{
|
||||
This pass converts supported Linalg ops to SPIR-V ops. It's quite
|
||||
experimental and are expected to migrate to other proper conversions.
|
||||
}];
|
||||
let constructor = "mlir::createLinalgToSPIRVPass()";
|
||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MathToLibm
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -20,7 +20,6 @@ add_subdirectory(GPUToSPIRV)
|
|||
add_subdirectory(GPUToVulkan)
|
||||
add_subdirectory(IndexToLLVM)
|
||||
add_subdirectory(LinalgToLLVM)
|
||||
add_subdirectory(LinalgToSPIRV)
|
||||
add_subdirectory(LinalgToStandard)
|
||||
add_subdirectory(LLVMCommon)
|
||||
add_subdirectory(MathToFuncs)
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
add_mlir_conversion_library(MLIRLinalgToSPIRV
|
||||
LinalgToSPIRV.cpp
|
||||
LinalgToSPIRVPass.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLinalgDialect
|
||||
MLIRLinalgUtils
|
||||
MLIRPass
|
||||
MLIRSPIRVDialect
|
||||
MLIRSPIRVConversion
|
||||
MLIRSupport
|
||||
)
|
|
@ -1,209 +0,0 @@
|
|||
//===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
|
||||
/// location invocation ID. This function will create necessary operations with
|
||||
/// `builder` at the proper region containing `op`.
|
||||
static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType,
|
||||
Location loc, OpBuilder *builder) {
|
||||
assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
|
||||
Value invocation = spirv::getBuiltinVariableValue(
|
||||
op, spirv::BuiltIn::LocalInvocationId, integerType, *builder);
|
||||
Type xType = invocation.getType().cast<ShapedType>().getElementType();
|
||||
return builder->create<spirv::CompositeExtractOp>(
|
||||
loc, xType, invocation, builder->getI32ArrayAttr({dim}));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reduction (single workgroup)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
/// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
|
||||
/// that the linalg.generic op is performing reduction with a workload size that
|
||||
/// can fit in one workgroup.
|
||||
struct SingleWorkgroupReduction final
|
||||
: public OpConversionPattern<linalg::GenericOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
/// Matches the given linalg.generic op as performing reduction and returns
|
||||
/// the binary op kind if successful.
|
||||
static Optional<linalg::RegionMatcher::BinaryOpKind>
|
||||
matchAsPerformingReduction(linalg::GenericOp genericOp);
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Optional<linalg::RegionMatcher::BinaryOpKind>
|
||||
SingleWorkgroupReduction::matchAsPerformingReduction(
|
||||
linalg::GenericOp genericOp) {
|
||||
Operation *op = genericOp.getOperation();
|
||||
|
||||
// Make sure the linalg.generic is working on memrefs.
|
||||
if (!genericOp.hasBufferSemantics())
|
||||
return llvm::None;
|
||||
|
||||
// Make sure this is reduction with one input and one output.
|
||||
if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
|
||||
return llvm::None;
|
||||
|
||||
auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
|
||||
auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
|
||||
|
||||
// Make sure the original input has one dimension.
|
||||
if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
|
||||
return llvm::None;
|
||||
// Make sure the original output has one element.
|
||||
if (!originalOutputType.hasStaticShape() ||
|
||||
originalOutputType.getNumElements() != 1)
|
||||
return llvm::None;
|
||||
|
||||
if (!genericOp.hasSingleReductionLoop())
|
||||
return llvm::None;
|
||||
|
||||
auto indexingMaps = genericOp.getIndexingMapsArray();
|
||||
if (indexingMaps.size() != 2)
|
||||
return llvm::None;
|
||||
|
||||
// TODO: create utility functions for these checks in Linalg
|
||||
// and use them.
|
||||
auto inputMap = indexingMaps[0];
|
||||
auto outputMap = indexingMaps[1];
|
||||
// The indexing map for the input should be `(i) -> (i)`.
|
||||
if (inputMap != AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext())))
|
||||
return llvm::None;
|
||||
// The indexing map for the input should be `(i) -> (0)`.
|
||||
if (outputMap !=
|
||||
AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext())))
|
||||
return llvm::None;
|
||||
|
||||
return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
|
||||
}
|
||||
|
||||
LogicalResult SingleWorkgroupReduction::matchAndRewrite(
|
||||
linalg::GenericOp genericOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Operation *op = genericOp.getOperation();
|
||||
auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
|
||||
auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
|
||||
|
||||
auto binaryOpKind = matchAsPerformingReduction(genericOp);
|
||||
if (!binaryOpKind)
|
||||
return failure();
|
||||
|
||||
// Query the shader interface for local workgroup size to make sure the
|
||||
// invocation configuration fits with the input memref's shape.
|
||||
DenseI32ArrayAttr workgroupSize = spirv::lookupLocalWorkGroupSize(genericOp);
|
||||
if (!workgroupSize)
|
||||
return failure();
|
||||
|
||||
if (workgroupSize.asArrayRef()[0] != originalInputType.getDimSize(0))
|
||||
return failure();
|
||||
if (llvm::any_of(workgroupSize.asArrayRef().drop_front(),
|
||||
[](int size) { return size != 1; }))
|
||||
return failure();
|
||||
|
||||
// TODO: Query the target environment to make sure the current
|
||||
// workload fits in a local workgroup.
|
||||
|
||||
Value convertedInput = adaptor.getOperands()[0];
|
||||
Value convertedOutput = adaptor.getOperands()[1];
|
||||
Location loc = genericOp.getLoc();
|
||||
|
||||
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
|
||||
auto indexType = typeConverter->getIndexType();
|
||||
|
||||
// Get the invocation ID.
|
||||
Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc,
|
||||
&rewriter);
|
||||
|
||||
// TODO: Load to Workgroup storage class first.
|
||||
|
||||
// Get the input element accessed by this invocation.
|
||||
Value inputElementPtr = spirv::getElementPtr(
|
||||
*typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
|
||||
Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
|
||||
|
||||
// Perform the group reduction operation.
|
||||
Value groupOperation;
|
||||
#define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \
|
||||
case linalg::RegionMatcher::BinaryOpKind::opKind: { \
|
||||
groupOperation = rewriter.create<spirv::spvOp>( \
|
||||
loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \
|
||||
spirv::GroupOperation::Reduce, inputElement, \
|
||||
/*cluster_size=*/nullptr); \
|
||||
} break
|
||||
switch (*binaryOpKind) {
|
||||
CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
|
||||
}
|
||||
#undef CREATE_GROUP_NON_UNIFORM_BIN_OP
|
||||
|
||||
// Get the output element accessed by this reduction.
|
||||
Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter);
|
||||
SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
|
||||
Value outputElementPtr =
|
||||
spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,
|
||||
zeroIndices, loc, rewriter);
|
||||
|
||||
// Write out the final reduction result. This should be only conducted by one
|
||||
// invocation. We use spirv.GroupNonUniformElect to find the invocation with
|
||||
// the lowest ID.
|
||||
//
|
||||
// ```
|
||||
// if (spirv.GroupNonUniformElect) { output = ... }
|
||||
// ```
|
||||
|
||||
Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
|
||||
loc, spirv::Scope::Subgroup);
|
||||
|
||||
auto createAtomicOp = [&](OpBuilder &builder) {
|
||||
#define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \
|
||||
case linalg::RegionMatcher::BinaryOpKind::opKind: { \
|
||||
builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \
|
||||
spirv::MemorySemantics::AcquireRelease, \
|
||||
groupOperation); \
|
||||
} break
|
||||
switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
|
||||
#undef CREATE_ATOMIC_BIN_OP
|
||||
};
|
||||
|
||||
spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter);
|
||||
|
||||
rewriter.eraseOp(genericOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext());
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
//===- LinalgToSPIRVPass.cpp - Linalg to SPIR-V Passes --------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
|
||||
|
||||
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_CONVERTLINALGTOSPIRV
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// A pass converting MLIR Linalg ops into SPIR-V ops.
|
||||
class LinalgToSPIRVPass
|
||||
: public impl::ConvertLinalgToSPIRVBase<LinalgToSPIRVPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void LinalgToSPIRVPass::runOnOperation() {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
|
||||
std::unique_ptr<ConversionTarget> target =
|
||||
SPIRVConversionTarget::get(targetAttr);
|
||||
|
||||
SPIRVTypeConverter typeConverter(targetAttr);
|
||||
RewritePatternSet patterns(context);
|
||||
populateLinalgToSPIRVPatterns(typeConverter, patterns);
|
||||
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
|
||||
|
||||
// Allow builtin ops.
|
||||
target->addLegalOp<ModuleOp>();
|
||||
target->addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
|
||||
typeConverter.isLegal(&op.getBody());
|
||||
});
|
||||
|
||||
if (failed(applyFullConversion(module, *target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgToSPIRVPass() {
|
||||
return std::make_unique<LinalgToSPIRVPass>();
|
||||
}
|
|
@ -1,150 +0,0 @@
|
|||
// RUN: mlir-opt -split-input-file -convert-linalg-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Single workgroup reduction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#single_workgroup_reduction_trait = {
|
||||
iterator_types = ["reduction"],
|
||||
indexing_maps = [
|
||||
affine_map<(i) -> (i)>,
|
||||
affine_map<(i) -> (0)>
|
||||
]
|
||||
}
|
||||
|
||||
module attributes {
|
||||
spirv.target_env = #spirv.target_env<
|
||||
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
|
||||
} {
|
||||
|
||||
// CHECK: spirv.GlobalVariable
|
||||
// CHECK-SAME: built_in("LocalInvocationId")
|
||||
|
||||
// CHECK: @single_workgroup_reduction
|
||||
// CHECK-SAME: (%[[INPUT:.+]]: !spirv.ptr{{.+}}, %[[OUTPUT:.+]]: !spirv.ptr{{.+}})
|
||||
|
||||
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
|
||||
// CHECK: %[[ID:.+]] = spirv.Load "Input" %{{.+}} : vector<3xi32>
|
||||
// CHECK: %[[X:.+]] = spirv.CompositeExtract %[[ID]][0 : i32]
|
||||
|
||||
// CHECK: %[[INPTR:.+]] = spirv.AccessChain %[[INPUT]][%[[ZERO]], %[[X]]]
|
||||
// CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[INPTR]] : i32
|
||||
// CHECK: %[[ADD:.+]] = spirv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[VAL]] : i32
|
||||
|
||||
// CHECK: %[[OUTPTR:.+]] = spirv.AccessChain %[[OUTPUT]][%[[ZERO]], %[[ZERO]]]
|
||||
// CHECK: %[[ELECT:.+]] = spirv.GroupNonUniformElect <Subgroup> : i1
|
||||
|
||||
// CHECK: spirv.mlir.selection {
|
||||
// CHECK: spirv.BranchConditional %[[ELECT]], ^bb1, ^bb2
|
||||
// CHECK: ^bb1:
|
||||
// CHECK: spirv.AtomicIAdd "Device" "AcquireRelease" %[[OUTPTR]], %[[ADD]]
|
||||
// CHECK: spirv.Branch ^bb2
|
||||
// CHECK: ^bb2:
|
||||
// CHECK: spirv.mlir.merge
|
||||
// CHECK: }
|
||||
// CHECK: spirv.Return
|
||||
|
||||
func.func @single_workgroup_reduction(%input: memref<16xi32, #spirv.storage_class<StorageBuffer>>, %output: memref<1xi32, #spirv.storage_class<StorageBuffer>>) attributes {
|
||||
spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>
|
||||
} {
|
||||
linalg.generic #single_workgroup_reduction_trait
|
||||
ins(%input : memref<16xi32, #spirv.storage_class<StorageBuffer>>)
|
||||
outs(%output : memref<1xi32, #spirv.storage_class<StorageBuffer>>) {
|
||||
^bb(%in: i32, %out: i32):
|
||||
%sum = arith.addi %in, %out : i32
|
||||
linalg.yield %sum : i32
|
||||
}
|
||||
spirv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Missing shader entry point ABI
|
||||
|
||||
#single_workgroup_reduction_trait = {
|
||||
iterator_types = ["reduction"],
|
||||
indexing_maps = [
|
||||
affine_map<(i) -> (i)>,
|
||||
affine_map<(i) -> (0)>
|
||||
]
|
||||
}
|
||||
|
||||
module attributes {
|
||||
spirv.target_env = #spirv.target_env<
|
||||
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
|
||||
} {
|
||||
func.func @single_workgroup_reduction(%input: memref<16xi32, #spirv.storage_class<StorageBuffer>>, %output: memref<1xi32, #spirv.storage_class<StorageBuffer>>) {
|
||||
// expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
|
||||
linalg.generic #single_workgroup_reduction_trait
|
||||
ins(%input : memref<16xi32, #spirv.storage_class<StorageBuffer>>)
|
||||
outs(%output : memref<1xi32, #spirv.storage_class<StorageBuffer>>) {
|
||||
^bb(%in: i32, %out: i32):
|
||||
%sum = arith.addi %in, %out : i32
|
||||
linalg.yield %sum : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Mismatch between shader entry point ABI and input memref shape
|
||||
|
||||
#single_workgroup_reduction_trait = {
|
||||
iterator_types = ["reduction"],
|
||||
indexing_maps = [
|
||||
affine_map<(i) -> (i)>,
|
||||
affine_map<(i) -> (0)>
|
||||
]
|
||||
}
|
||||
|
||||
module attributes {
|
||||
spirv.target_env = #spirv.target_env<
|
||||
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
|
||||
} {
|
||||
func.func @single_workgroup_reduction(%input: memref<16xi32, #spirv.storage_class<StorageBuffer>>, %output: memref<1xi32, #spirv.storage_class<StorageBuffer>>) attributes {
|
||||
spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>
|
||||
} {
|
||||
// expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
|
||||
linalg.generic #single_workgroup_reduction_trait
|
||||
ins(%input : memref<16xi32, #spirv.storage_class<StorageBuffer>>)
|
||||
outs(%output : memref<1xi32, #spirv.storage_class<StorageBuffer>>) {
|
||||
^bb(%in: i32, %out: i32):
|
||||
%sum = arith.addi %in, %out : i32
|
||||
linalg.yield %sum : i32
|
||||
}
|
||||
spirv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Unsupported multi-dimension input memref
|
||||
|
||||
#single_workgroup_reduction_trait = {
|
||||
iterator_types = ["parallel", "reduction"],
|
||||
indexing_maps = [
|
||||
affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (i)>
|
||||
]
|
||||
}
|
||||
|
||||
module attributes {
|
||||
spirv.target_env = #spirv.target_env<
|
||||
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
|
||||
} {
|
||||
func.func @single_workgroup_reduction(%input: memref<16x8xi32, #spirv.storage_class<StorageBuffer>>, %output: memref<16xi32, #spirv.storage_class<StorageBuffer>>) attributes {
|
||||
spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 8, 1]>
|
||||
} {
|
||||
// expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
|
||||
linalg.generic #single_workgroup_reduction_trait
|
||||
ins(%input : memref<16x8xi32, #spirv.storage_class<StorageBuffer>>)
|
||||
outs(%output : memref<16xi32, #spirv.storage_class<StorageBuffer>>) {
|
||||
^bb(%in: i32, %out: i32):
|
||||
%sum = arith.addi %in, %out : i32
|
||||
linalg.yield %sum : i32
|
||||
}
|
||||
spirv.Return
|
||||
}
|
||||
}
|
|
@ -2749,7 +2749,6 @@ cc_library(
|
|||
":GPUToVulkanTransforms",
|
||||
":IndexToLLVM",
|
||||
":LinalgToLLVM",
|
||||
":LinalgToSPIRV",
|
||||
":LinalgToStandard",
|
||||
":MathToFuncs",
|
||||
":MathToLLVM",
|
||||
|
@ -6765,7 +6764,6 @@ cc_library(
|
|||
":LinalgDialect",
|
||||
":LinalgPassIncGen",
|
||||
":LinalgToLLVM",
|
||||
":LinalgToSPIRV",
|
||||
":LinalgToStandard",
|
||||
":LinalgTransformOps",
|
||||
":LinalgTransforms",
|
||||
|
@ -8123,31 +8121,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "LinalgToSPIRV",
|
||||
srcs = glob([
|
||||
"lib/Conversion/LinalgToSPIRV/*.cpp",
|
||||
"lib/Conversion/LinalgToSPIRV/*.h",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"include/mlir/Conversion/LinalgToSPIRV/*.h",
|
||||
]),
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":ConversionPassIncGen",
|
||||
":DialectUtils",
|
||||
":FuncDialect",
|
||||
":IR",
|
||||
":LinalgDialect",
|
||||
":LinalgTransforms",
|
||||
":LinalgUtils",
|
||||
":Pass",
|
||||
":SPIRVConversion",
|
||||
":SPIRVDialect",
|
||||
":TransformUtils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "LinalgDialect",
|
||||
srcs = glob(["lib/Dialect/Linalg/IR/*.cpp"]),
|
||||
|
|
Loading…
Reference in New Issue