[mlir] create gpu memset op
Create a gpu memset op and corresponding CUDA and ROCm wrappers. Reviewed By: herhut, lorenrose1013 Differential Revision: https://reviews.llvm.org/D107548
This commit is contained in:
parent
bb51f76fb1
commit
361458b1ce
|
@ -901,6 +901,42 @@ def GPU_MemcpyOp : GPU_Op<"memcpy", [GPU_AsyncOpInterface]> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def GPU_MemsetOp : GPU_Op<"memset",
|
||||
[GPU_AsyncOpInterface, AllElementTypesMatch<["dst", "value"]>]> {
|
||||
|
||||
let summary = "GPU memset operation";
|
||||
|
||||
let description = [{
|
||||
The `gpu.memset` operation sets the content of memref to a scalar value.
|
||||
|
||||
The op does not execute before all async dependencies have finished
|
||||
executing.
|
||||
|
||||
If the `async` keyword is present, the op is executed asynchronously (i.e.
|
||||
it does not block until the execution has finished on the device). In
|
||||
that case, it returns a !gpu.async.token.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%token = gpu.memset async [%dep] %dst, %value : memref<?xf32, 1>, f32
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
|
||||
Arg<AnyMemRef, "", [MemWrite]>:$dst,
|
||||
Arg<AnyType, "">:$value);
|
||||
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
|
||||
|
||||
let assemblyFormat = [{
|
||||
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
|
||||
$dst`,` $value `:` type($dst)`,` type($value) attr-dict
|
||||
}];
|
||||
// MemsetOp is fully verified by traits.
|
||||
let verifier = [{ return success(); }];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
|
||||
[MemoryEffects<[MemRead]>]>{
|
||||
|
||||
|
|
|
@ -79,6 +79,18 @@ public:
|
|||
: ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
|
||||
|
||||
protected:
|
||||
Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
|
||||
MemRefType type, MemRefDescriptor desc) const {
|
||||
return type.hasStaticShape()
|
||||
? ConvertToLLVMPattern::createIndexConstant(
|
||||
rewriter, loc, type.getNumElements())
|
||||
// For identity maps (verified by caller), the number of
|
||||
// elements is stride[0] * size[0].
|
||||
: rewriter.create<LLVM::MulOp>(loc,
|
||||
desc.stride(rewriter, loc, 0),
|
||||
desc.size(rewriter, loc, 0));
|
||||
}
|
||||
|
||||
MLIRContext *context = &this->getTypeConverter()->getContext();
|
||||
|
||||
Type llvmVoidType = LLVM::LLVMVoidType::get(context);
|
||||
|
@ -165,6 +177,12 @@ protected:
|
|||
{llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
|
||||
llvmIntPtrType /* intptr_t sizeBytes */,
|
||||
llvmPointerType /* void *stream */}};
|
||||
FunctionCallBuilder memsetCallBuilder = {
|
||||
"mgpuMemset32",
|
||||
llvmVoidType,
|
||||
{llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
|
||||
llvmIntPtrType /* intptr_t sizeBytes */,
|
||||
llvmPointerType /* void *stream */}};
|
||||
};
|
||||
|
||||
/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
|
||||
|
@ -308,6 +326,20 @@ private:
|
|||
matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
|
||||
/// call. Currently it supports CUDA and ROCm (HIP).
|
||||
class ConvertMemsetOpToGpuRuntimeCallPattern
|
||||
: public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
|
||||
public:
|
||||
ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
|
||||
: ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
|
||||
|
||||
private:
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void GpuToLLVMConversionPass::runOnOperation() {
|
||||
|
@ -757,14 +789,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
|
|||
auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());
|
||||
|
||||
MemRefDescriptor srcDesc(adaptor.src());
|
||||
|
||||
Value numElements =
|
||||
memRefType.hasStaticShape()
|
||||
? createIndexConstant(rewriter, loc, memRefType.getNumElements())
|
||||
// For identity layouts (verified above), the number of elements is
|
||||
// stride[0] * size[0].
|
||||
: rewriter.create<LLVM::MulOp>(loc, srcDesc.stride(rewriter, loc, 0),
|
||||
srcDesc.size(rewriter, loc, 0));
|
||||
Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
|
||||
|
||||
Type elementPtrType = getElementPtrType(memRefType);
|
||||
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
|
||||
|
@ -787,6 +812,40 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
|
||||
|
||||
if (failed(areAllLLVMTypes(memsetOp, operands, rewriter)) ||
|
||||
!isConvertibleAndHasIdentityMaps(memRefType) ||
|
||||
failed(isAsyncWithOneDependency(rewriter, memsetOp)))
|
||||
return failure();
|
||||
|
||||
auto loc = memsetOp.getLoc();
|
||||
auto adaptor = gpu::MemsetOpAdaptor(operands, memsetOp->getAttrDictionary());
|
||||
|
||||
Type valueType = adaptor.value().getType();
|
||||
if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
|
||||
return rewriter.notifyMatchFailure(memsetOp,
|
||||
"value must be a 32 bit scalar");
|
||||
}
|
||||
|
||||
MemRefDescriptor dstDesc(adaptor.dst());
|
||||
Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
|
||||
|
||||
auto value =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
|
||||
auto dst = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
|
||||
|
||||
auto stream = adaptor.asyncDependencies().front();
|
||||
memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
|
||||
|
||||
rewriter.replaceOp(memsetOp, {stream});
|
||||
return success();
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||
mlir::createGpuToLLVMConversionPass() {
|
||||
return std::make_unique<GpuToLLVMConversionPass>();
|
||||
|
@ -803,6 +862,7 @@ void mlir::populateGpuToLLVMConversionPatterns(
|
|||
ConvertDeallocOpToGpuRuntimeCallPattern,
|
||||
ConvertHostRegisterOpToGpuRuntimeCallPattern,
|
||||
ConvertMemcpyOpToGpuRuntimeCallPattern,
|
||||
ConvertMemsetOpToGpuRuntimeCallPattern,
|
||||
ConvertWaitAsyncOpToGpuRuntimeCallPattern,
|
||||
ConvertWaitOpToGpuRuntimeCallPattern,
|
||||
ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
|
||||
|
|
|
@ -1079,6 +1079,11 @@ LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
|
|||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<::mlir::OpFoldResult> &results) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GPU_AllocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -141,13 +141,19 @@ extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) {
|
|||
CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
|
||||
extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
|
||||
CUstream stream) {
|
||||
CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst),
|
||||
reinterpret_cast<CUdeviceptr>(src),
|
||||
sizeBytes, stream));
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count,
|
||||
CUstream stream) {
|
||||
CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst),
|
||||
value, count, stream));
|
||||
}
|
||||
|
||||
/// Helper functions for writing mlir example code
|
||||
|
||||
// Allows to register byte array with the CUDA runtime. Helpful until we have
|
||||
|
|
|
@ -133,12 +133,17 @@ extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
|
|||
HIP_REPORT_IF_ERROR(hipFree(ptr));
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
|
||||
extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
|
||||
hipStream_t stream) {
|
||||
HIP_REPORT_IF_ERROR(
|
||||
hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemset32(void *dst, int value, size_t count,
|
||||
hipStream_t stream) {
|
||||
HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst),
|
||||
value, count, stream));
|
||||
}
|
||||
/// Helper functions for writing mlir example code
|
||||
|
||||
// Allows to register byte array with the ROCM runtime. Helpful until we have
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
|
||||
// CHECK: func @foo
|
||||
func @foo(%dst : memref<7xf32, 1>, %value : f32) {
|
||||
// CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
|
||||
%t0 = gpu.wait async
|
||||
// CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
|
||||
// CHECK: %[[value:.*]] = llvm.bitcast
|
||||
// CHECK: %[[dst:.*]] = llvm.bitcast
|
||||
// CHECK: llvm.call @mgpuMemset32(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]])
|
||||
%t1 = gpu.memset async [%t0] %dst, %value : memref<7xf32, 1>, f32
|
||||
// CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])
|
||||
// CHECK: llvm.call @mgpuStreamDestroy(%[[t0]])
|
||||
gpu.wait [%t1]
|
||||
return
|
||||
}
|
||||
}
|
|
@ -6,7 +6,16 @@ func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
|
|||
// CHECK: gpu.memcpy
|
||||
%0 = memref.cast %arg0 : memref<10xf32> to memref<?xf32>
|
||||
%1 = memref.cast %arg1 : memref<10xf32> to memref<?xf32>
|
||||
gpu.memcpy %0,%1 : memref<?xf32>, memref<?xf32>
|
||||
gpu.memcpy %0, %1 : memref<?xf32>, memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @memset_after_cast
|
||||
func @memset_after_cast(%arg0: memref<10xf32>, %arg1: f32) {
|
||||
// CHECK-NOT: memref.cast
|
||||
// CHECK: gpu.memset
|
||||
%0 = memref.cast %arg0 : memref<10xf32> to memref<?xf32>
|
||||
gpu.memset %0, %arg1 : memref<?xf32>, f32
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -467,6 +467,13 @@ func @memcpy_incompatible_shape(%dst : memref<7xf32>, %src : memref<9xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @memset_incompatible_shape(%dst : memref<?xf32>, %value : i32) {
|
||||
// expected-error @+1 {{'gpu.memset' op failed to verify that all of {dst, value} have same element type}}
|
||||
gpu.memset %dst, %value : memref<?xf32>, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mmamatrix_invalid_shape(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = constant 16 : index
|
||||
|
|
|
@ -195,6 +195,17 @@ module attributes {gpu.container_module} {
|
|||
return
|
||||
}
|
||||
|
||||
func @memset(%dst : memref<3x7xf32>, %value : f32) {
|
||||
// CHECK-LABEL: func @memset
|
||||
// CHECK: gpu.memset {{.*}}, {{.*}} : memref<3x7xf32>, f32
|
||||
gpu.memset %dst, %value : memref<3x7xf32>, f32
|
||||
// CHECK: %[[t0:.*]] = gpu.wait async
|
||||
%0 = gpu.wait async
|
||||
// CHECK: {{.*}} = gpu.memset async [%[[t0]]] {{.*}}, {{.*}} : memref<3x7xf32>, f32
|
||||
%1 = gpu.memset async [%0] %dst, %value : memref<3x7xf32>, f32
|
||||
return
|
||||
}
|
||||
|
||||
func @mmamatrix_valid_element_type(){
|
||||
// CHECK-LABEL: func @mmamatrix_valid_element_type
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
|
|
Loading…
Reference in New Issue