[mlir] create gpu memset op

Create a gpu memset op and corresponding CUDA and ROCm wrappers.

Reviewed By: herhut, lorenrose1013

Differential Revision: https://reviews.llvm.org/D107548
This commit is contained in:
Loren Maggiore 2021-09-04 08:03:33 +02:00 committed by Christian Sigg
parent bb51f76fb1
commit 361458b1ce
9 changed files with 169 additions and 11 deletions

View File

@ -901,6 +901,42 @@ def GPU_MemcpyOp : GPU_Op<"memcpy", [GPU_AsyncOpInterface]> {
let hasFolder = 1;
}
def GPU_MemsetOp : GPU_Op<"memset",
[GPU_AsyncOpInterface, AllElementTypesMatch<["dst", "value"]>]> {
let summary = "GPU memset operation";
let description = [{
The `gpu.memset` operation sets the content of memref to a scalar value.
The op does not execute before all async dependencies have finished
executing.
If the `async` keyword is present, the op is executed asynchronously (i.e.
it does not block until the execution has finished on the device). In
that case, it returns a !gpu.async.token.
Example:
```mlir
%token = gpu.memset async [%dep] %dst, %value : memref<?xf32, 1>, f32
```
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
Arg<AnyMemRef, "", [MemWrite]>:$dst,
Arg<AnyType, "">:$value);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
$dst`,` $value `:` type($dst)`,` type($value) attr-dict
}];
// MemsetOp is fully verified by traits.
let verifier = [{ return success(); }];
let hasFolder = 1;
}
def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
[MemoryEffects<[MemRead]>]>{

View File

@ -79,6 +79,18 @@ public:
: ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
protected:
Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
MemRefType type, MemRefDescriptor desc) const {
return type.hasStaticShape()
? ConvertToLLVMPattern::createIndexConstant(
rewriter, loc, type.getNumElements())
// For identity maps (verified by caller), the number of
// elements is stride[0] * size[0].
: rewriter.create<LLVM::MulOp>(loc,
desc.stride(rewriter, loc, 0),
desc.size(rewriter, loc, 0));
}
MLIRContext *context = &this->getTypeConverter()->getContext();
Type llvmVoidType = LLVM::LLVMVoidType::get(context);
@ -165,6 +177,12 @@ protected:
{llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
FunctionCallBuilder memsetCallBuilder = {
"mgpuMemset32",
llvmVoidType,
{llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
};
/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@ -308,6 +326,20 @@ private:
matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
/// call. Currently it supports CUDA and ROCm (HIP).
class ConvertMemsetOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
public:
ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
void GpuToLLVMConversionPass::runOnOperation() {
@ -757,14 +789,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());
MemRefDescriptor srcDesc(adaptor.src());
Value numElements =
memRefType.hasStaticShape()
? createIndexConstant(rewriter, loc, memRefType.getNumElements())
// For identity layouts (verified above), the number of elements is
// stride[0] * size[0].
: rewriter.create<LLVM::MulOp>(loc, srcDesc.stride(rewriter, loc, 0),
srcDesc.size(rewriter, loc, 0));
Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
Type elementPtrType = getElementPtrType(memRefType);
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
@ -787,6 +812,40 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
return success();
}
LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
if (failed(areAllLLVMTypes(memsetOp, operands, rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, memsetOp)))
return failure();
auto loc = memsetOp.getLoc();
auto adaptor = gpu::MemsetOpAdaptor(operands, memsetOp->getAttrDictionary());
Type valueType = adaptor.value().getType();
if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
return rewriter.notifyMatchFailure(memsetOp,
"value must be a 32 bit scalar");
}
MemRefDescriptor dstDesc(adaptor.dst());
Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
auto value =
rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
auto dst = rewriter.create<LLVM::BitcastOp>(
loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
auto stream = adaptor.asyncDependencies().front();
memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
rewriter.replaceOp(memsetOp, {stream});
return success();
}
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
mlir::createGpuToLLVMConversionPass() {
return std::make_unique<GpuToLLVMConversionPass>();
@ -803,6 +862,7 @@ void mlir::populateGpuToLLVMConversionPatterns(
ConvertDeallocOpToGpuRuntimeCallPattern,
ConvertHostRegisterOpToGpuRuntimeCallPattern,
ConvertMemcpyOpToGpuRuntimeCallPattern,
ConvertMemsetOpToGpuRuntimeCallPattern,
ConvertWaitAsyncOpToGpuRuntimeCallPattern,
ConvertWaitOpToGpuRuntimeCallPattern,
ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);

View File

@ -1079,6 +1079,11 @@ LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
return foldMemRefCast(*this);
}
LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<::mlir::OpFoldResult> &results) {
return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
// GPU_AllocOp
//===----------------------------------------------------------------------===//

View File

@ -141,13 +141,19 @@ extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) {
CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
}
extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
CUstream stream) {
CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst),
reinterpret_cast<CUdeviceptr>(src),
sizeBytes, stream));
}
extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count,
CUstream stream) {
CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst),
value, count, stream));
}
/// Helper functions for writing mlir example code
// Allows to register byte array with the CUDA runtime. Helpful until we have

View File

@ -133,12 +133,17 @@ extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
HIP_REPORT_IF_ERROR(hipFree(ptr));
}
extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
hipStream_t stream) {
HIP_REPORT_IF_ERROR(
hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
}
extern "C" void mgpuMemset32(void *dst, int value, size_t count,
hipStream_t stream) {
HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst),
value, count, stream));
}
/// Helper functions for writing mlir example code
// Allows to register byte array with the ROCM runtime. Helpful until we have

View File

@ -0,0 +1,19 @@
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
module attributes {gpu.container_module} {
// CHECK: func @foo
func @foo(%dst : memref<7xf32, 1>, %value : f32) {
// CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
%t0 = gpu.wait async
// CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
// CHECK: %[[value:.*]] = llvm.bitcast
// CHECK: %[[dst:.*]] = llvm.bitcast
// CHECK: llvm.call @mgpuMemset32(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]])
%t1 = gpu.memset async [%t0] %dst, %value : memref<7xf32, 1>, f32
// CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])
// CHECK: llvm.call @mgpuStreamDestroy(%[[t0]])
gpu.wait [%t1]
return
}
}

View File

@ -6,7 +6,16 @@ func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
// CHECK: gpu.memcpy
%0 = memref.cast %arg0 : memref<10xf32> to memref<?xf32>
%1 = memref.cast %arg1 : memref<10xf32> to memref<?xf32>
gpu.memcpy %0,%1 : memref<?xf32>, memref<?xf32>
gpu.memcpy %0, %1 : memref<?xf32>, memref<?xf32>
return
}
// CHECK-LABEL: @memset_after_cast
func @memset_after_cast(%arg0: memref<10xf32>, %arg1: f32) {
// CHECK-NOT: memref.cast
// CHECK: gpu.memset
%0 = memref.cast %arg0 : memref<10xf32> to memref<?xf32>
gpu.memset %0, %arg1 : memref<?xf32>, f32
return
}

View File

@ -467,6 +467,13 @@ func @memcpy_incompatible_shape(%dst : memref<7xf32>, %src : memref<9xf32>) {
// -----
func @memset_incompatible_shape(%dst : memref<?xf32>, %value : i32) {
// expected-error @+1 {{'gpu.memset' op failed to verify that all of {dst, value} have same element type}}
gpu.memset %dst, %value : memref<?xf32>, i32
}
// -----
func @mmamatrix_invalid_shape(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index

View File

@ -195,6 +195,17 @@ module attributes {gpu.container_module} {
return
}
func @memset(%dst : memref<3x7xf32>, %value : f32) {
// CHECK-LABEL: func @memset
// CHECK: gpu.memset {{.*}}, {{.*}} : memref<3x7xf32>, f32
gpu.memset %dst, %value : memref<3x7xf32>, f32
// CHECK: %[[t0:.*]] = gpu.wait async
%0 = gpu.wait async
// CHECK: {{.*}} = gpu.memset async [%[[t0]]] {{.*}}, {{.*}} : memref<3x7xf32>, f32
%1 = gpu.memset async [%0] %dst, %value : memref<3x7xf32>, f32
return
}
func @mmamatrix_valid_element_type(){
// CHECK-LABEL: func @mmamatrix_valid_element_type
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>