pipeline-data-transfer: remove dead tag alloc's and improve test coverage for replaceMemRefUsesWith / pipeline-data-transfer
- address remaining comments from PR tensorflow/mlir#87 for better test coverage for pipeline-data-transfer/replaceAllMemRefUsesWith - remove dead tag allocs the same way they are removed for the replaced buffers Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#106 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/106 from bondhugula:followup 9e868666d047e8d43e5f82f43e4093b838c710fa PiperOrigin-RevId: 267144774
This commit is contained in:
parent
6395229509
commit
8c9dc690eb
|
@ -574,11 +574,12 @@ LogicalResult AffineDataCopyGeneration::generateCopy(
|
|||
prevOfBegin = std::prev(begin);
|
||||
|
||||
// *Only* those uses within the range [begin, end) of 'block' are replaced.
|
||||
replaceAllMemRefUsesWith(memref, fastMemRef,
|
||||
/*extraIndices=*/{}, indexRemap,
|
||||
/*extraOperands=*/regionSymbols,
|
||||
/*domInstFilter=*/&*begin,
|
||||
/*postDomInstFilter=*/&*postDomFilter);
|
||||
if (failed(replaceAllMemRefUsesWith(memref, fastMemRef,
|
||||
/*extraIndices=*/{}, indexRemap,
|
||||
/*extraOperands=*/regionSymbols,
|
||||
/*domInstFilter=*/&*begin,
|
||||
/*postDomInstFilter=*/&*postDomFilter)))
|
||||
llvm_unreachable("memref replacement guaranteed to succeed here");
|
||||
|
||||
*nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
|
||||
|
||||
|
|
|
@ -293,7 +293,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
|
|||
} else if (oldMemRef->hasOneUse()) {
|
||||
if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef->user_begin())) {
|
||||
dealloc.erase();
|
||||
oldMemRef->getDefiningOp()->erase();
|
||||
allocInst->erase();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -308,11 +308,18 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
|
|||
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
|
||||
return;
|
||||
}
|
||||
// If the old tag has no more uses, remove its 'dead' alloc if it was
|
||||
// alloc'ed.
|
||||
if (oldTagMemRef->use_empty())
|
||||
if (auto *allocInst = oldTagMemRef->getDefiningOp())
|
||||
allocInst->erase();
|
||||
// If the old tag has no uses or a single dealloc use, remove it.
|
||||
// (canonicalization handles more complex cases).
|
||||
if (auto *tagAllocInst = oldTagMemRef->getDefiningOp()) {
|
||||
if (oldTagMemRef->use_empty()) {
|
||||
tagAllocInst->erase();
|
||||
} else if (oldTagMemRef->hasOneUse()) {
|
||||
if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef->user_begin())) {
|
||||
dealloc.erase();
|
||||
tagAllocInst->erase();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Double buffering would have invalidated all the old DMA start/wait insts.
|
||||
|
|
|
@ -288,7 +288,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
|
|||
for (auto *op : opsToReplace) {
|
||||
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
|
||||
indexRemap, extraOperands)))
|
||||
assert(false && "memref replacement guaranteed to succeed here");
|
||||
llvm_unreachable("memref replacement guaranteed to succeed here");
|
||||
}
|
||||
|
||||
return success();
|
||||
|
|
|
@ -26,6 +26,8 @@ func @loop_nest_dma() {
|
|||
"do_more_compute"(%i, %j) : (index, index) -> ()
|
||||
}
|
||||
}
|
||||
dealloc %tag : memref<1 x f32>
|
||||
dealloc %Ah : memref<32 x f32, (d0) -> (d0), 1>
|
||||
return
|
||||
}
|
||||
// CHECK: %{{.*}} = alloc() : memref<256xf32>
|
||||
|
@ -77,9 +79,12 @@ func @loop_step(%arg0: memref<512xf32>,
|
|||
: memref<512xf32>, memref<4xf32, 1>, memref<1xi32>
|
||||
affine.dma_wait %2[%c0], %c4 : memref<1xi32>
|
||||
"compute"(%i0) : (index) -> ()
|
||||
dealloc %2 : memref<1xi32>
|
||||
dealloc %1 : memref<4xf32, 1>
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: [[BUF:%[0-9]+]] = alloc() : memref<2x4xf32, 1>
|
||||
// CHECK: [[TAG:%[0-9]+]] = alloc() : memref<2x1xi32>
|
||||
// CHECK-NEXT: affine.dma_start %{{.*}}[%{{.*}}], %{{.*}}[(%{{.*}} floordiv 4) mod 2, 0], [[TAG]][(%{{.*}} floordiv 4) mod 2, 0], %{{.*}} : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32>
|
||||
// CHECK-NEXT: affine.for %{{.*}} = 4 to 512 step 4 {
|
||||
|
@ -93,21 +98,22 @@ func @loop_step(%arg0: memref<512xf32>,
|
|||
// CHECK-NEXT: %{{.*}} = affine.apply [[FLOOR_MOD_2]]([[SHIFTED]])
|
||||
// CHECK: affine.dma_wait [[TAG]][(%{{.*}} floordiv 4) mod 2, 0], %{{.*}} : memref<2x1xi32>
|
||||
// CHECK-NEXT: "compute"(%{{.*}}) : (index) -> ()
|
||||
// CHECK: return
|
||||
// CHECK-NEXT: dealloc [[TAG]] : memref<2x1xi32>
|
||||
// CHECK-NEXT: dealloc [[BUF]] : memref<2x4xf32, 1>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = (d0, d1) -> (d0, d1)
|
||||
#map1 = (d0, d1) -> ((d0 * 2048 + d1 * 256) floordiv 32)
|
||||
#map2 = (d0) -> ((d0 * 2048) floordiv 32)
|
||||
// CHECK-LABEL: func @loop_dma_nested(%{{.*}}: memref<512x32xvector<8xf32>
|
||||
func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref<512x32xvector<8xf32>, #map0>, %arg2: memref<512x32xvector<8xf32>, #map0>) {
|
||||
func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>>, %arg1: memref<512x32xvector<8xf32>>, %arg2: memref<512x32xvector<8xf32>>) {
|
||||
%num_elts = constant 256 : index
|
||||
%c0 = constant 0 : index
|
||||
%0 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
|
||||
%1 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
|
||||
%2 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
|
||||
%0 = alloc() : memref<64x4xvector<8xf32>, 2>
|
||||
%1 = alloc() : memref<64x4xvector<8xf32>, 2>
|
||||
%2 = alloc() : memref<64x4xvector<8xf32>, 2>
|
||||
%3 = alloc() : memref<2xi32>
|
||||
%4 = alloc() : memref<2xi32>
|
||||
%5 = alloc() : memref<2xi32>
|
||||
|
@ -118,7 +124,7 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref<
|
|||
// CHECK: affine.for %{{.*}} = 1 to 8 {
|
||||
affine.for %i0 = 0 to 8 {
|
||||
%6 = affine.apply #map2(%i0)
|
||||
affine.dma_start %arg2[%6, %c0], %2[%c0, %c0], %5[%c0], %num_elts : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
|
||||
affine.dma_start %arg2[%6, %c0], %2[%c0, %c0], %5[%c0], %num_elts : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32>
|
||||
affine.dma_wait %5[%c0], %num_elts : memref<2xi32>
|
||||
// Steady state for DMA overlap on arg2
|
||||
// CHECK: affine.dma_start %{{.*}}[
|
||||
|
@ -134,8 +140,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref<
|
|||
affine.for %i1 = 0 to 8 {
|
||||
%7 = affine.apply #map1(%i0, %i1)
|
||||
%8 = affine.apply #map2(%i1)
|
||||
affine.dma_start %arg0[%7, %c0], %0[%c0, %c0], %3[%c0], %num_elts : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
|
||||
affine.dma_start %arg1[%8, %c0], %1[%c0, %c0], %4[%c0], %num_elts : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
|
||||
affine.dma_start %arg0[%7, %c0], %0[%c0, %c0], %3[%c0], %num_elts : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32>
|
||||
affine.dma_start %arg1[%8, %c0], %1[%c0, %c0], %4[%c0], %num_elts : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32>
|
||||
affine.dma_wait %3[%c0], %num_elts : memref<2xi32>
|
||||
affine.dma_wait %4[%c0], %num_elts : memref<2xi32>
|
||||
// Steady state for DMA overlap on arg0, arg1
|
||||
|
@ -175,15 +181,21 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref<
|
|||
// CHECK: affine.dma_wait [[TAG_ARG1_NESTED]]
|
||||
// CHECK: affine.for %{{.*}} = 0 to 4 {
|
||||
}
|
||||
dealloc %5 : memref<2xi32>
|
||||
dealloc %4 : memref<2xi32>
|
||||
dealloc %3 : memref<2xi32>
|
||||
dealloc %2 : memref<64x4xvector<8xf32>, 2>
|
||||
dealloc %1 : memref<64x4xvector<8xf32>, 2>
|
||||
dealloc %0 : memref<64x4xvector<8xf32>, 2>
|
||||
return
|
||||
// CHECK: }
|
||||
// CHECK-DAG: dealloc [[TAG_ARG1_NESTED]] : memref<2x2xi32>
|
||||
// CHECK-DAG: dealloc [[TAG_ARG0_NESTED]] : memref<2x2xi32>
|
||||
// CHECK-DAG: dealloc [[BUF_ARG1_NESTED]] : memref<2x64x4xvector<8xf32>, 2>
|
||||
// CHECK-DAG: dealloc [[BUF_ARG0_NESTED]] : memref<2x64x4xvector<8xf32>, 2>
|
||||
// CHECK-DAG: dealloc [[TAG_ARG2]] : memref<2x2xi32>
|
||||
// CHECK-DAG: dealloc [[BUF_ARG2]] : memref<2x64x4xvector<8xf32>, 2>
|
||||
// CHECK: return
|
||||
// CHECK-DAG: dealloc [[TAG_ARG1_NESTED]] : memref<2x2xi32>
|
||||
// CHECK-DAG: dealloc [[TAG_ARG0_NESTED]] : memref<2x2xi32>
|
||||
// CHECK-DAG: dealloc [[BUF_ARG1_NESTED]] : memref<2x64x4xvector<8xf32>, 2>
|
||||
// CHECK-DAG: dealloc [[BUF_ARG0_NESTED]] : memref<2x64x4xvector<8xf32>, 2>
|
||||
// CHECK-DAG: dealloc [[TAG_ARG2]] : memref<2x2xi32>
|
||||
// CHECK-DAG: dealloc [[BUF_ARG2]] : memref<2x64x4xvector<8xf32>, 2>
|
||||
// CHECK-NEXT: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -211,8 +223,14 @@ func @loop_dma_dependent(%arg2: memref<512x32xvector<8xf32>>) {
|
|||
|
||||
affine.dma_start %2[%c0, %c0], %arg2[%6, %c0], %5[%c0], %num_elts : memref<64x4xvector<8xf32>, 2>, memref<512x32xvector<8xf32>>, memref<2xi32>
|
||||
affine.dma_wait %5[%c0], %num_elts : memref<2xi32>
|
||||
} // CHECK: }
|
||||
return // CHECK: return
|
||||
}
|
||||
dealloc %5 : memref<2xi32>
|
||||
dealloc %4 : memref<2xi32>
|
||||
dealloc %3 : memref<2xi32>
|
||||
dealloc %2 : memref<64x4xvector<8xf32>, 2>
|
||||
dealloc %1 : memref<64x4xvector<8xf32>, 2>
|
||||
dealloc %0 : memref<64x4xvector<8xf32>, 2>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -235,12 +253,43 @@ func @escaping_use(%arg0: memref<512 x 32 x f32>) {
|
|||
// escaping use; no DMA pipelining / double buffering will be done.
|
||||
"foo"(%Av) : (memref<32 x 32 x f32, 2>) -> ()
|
||||
}
|
||||
dealloc %tag : memref<1 x i32>
|
||||
dealloc %Av : memref<32 x 32 x f32, 2>
|
||||
return
|
||||
// CHECK: "foo"(%{{[0-9]+}}) : (memref<32x32xf32, 2>) -> ()
|
||||
// CHECK: }
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @escaping_tag
|
||||
func @escaping_tag(%arg0: memref<512 x 32 x f32>) {
|
||||
%c32 = constant 32 : index
|
||||
%num_elt = constant 512 : index
|
||||
%zero = constant 0 : index
|
||||
%Av = alloc() : memref<32 x 32 x f32, 2>
|
||||
%tag = alloc() : memref<1 x i32>
|
||||
|
||||
// CHECK-NOT: affine.dma_start
|
||||
// CHECK: affine.for %{{.*}} = 0 to 16 {
|
||||
affine.for %kTT = 0 to 16 {
|
||||
affine.dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %tag[%zero], %num_elt :
|
||||
memref<512 x 32 x f32>,
|
||||
memref<32 x 32 x f32, 2>, memref<1 x i32>
|
||||
affine.dma_wait %tag[%zero], %num_elt : memref<1 x i32>
|
||||
// escaping use; no DMA pipelining / double buffering will be done.
|
||||
"foo"(%tag) : (memref<1 x i32>) -> ()
|
||||
}
|
||||
dealloc %tag : memref<1 x i32>
|
||||
dealloc %Av : memref<32 x 32 x f32, 2>
|
||||
return
|
||||
// CHECK: "foo"(%{{[0-9]+}}) : (memref<1xi32>) -> ()
|
||||
// CHECK: }
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @live_out_use
|
||||
|
@ -261,6 +310,8 @@ func @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 {
|
|||
}
|
||||
// Use live out of 'affine.for' op; no DMA pipelining will be done.
|
||||
%v = affine.load %Av[%zero, %zero] : memref<32 x 32 x f32, 2>
|
||||
dealloc %tag : memref<1 x i32>
|
||||
dealloc %Av : memref<32 x 32 x f32, 2>
|
||||
return %v : f32
|
||||
// CHECK: %{{[0-9]+}} = affine.load %{{[0-9]+}}[%{{.*}}, %{{.*}}] : memref<32x32xf32, 2>
|
||||
// CHECK: return
|
||||
|
@ -289,6 +340,7 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) {
|
|||
memref<? x ? x f32, 2>, memref<1 x i32>
|
||||
affine.dma_wait %tag[%zero], %num_elt : memref<1 x i32>
|
||||
}
|
||||
dealloc %Av : memref<? x ? x f32, 2>
|
||||
return
|
||||
// CHECK-NEXT: affine.for %{{.*}} = 1 to 16 {
|
||||
// CHECK: affine.dma_start %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}} mod 2, 0, 0], %{{.*}}[%{{.*}} mod 2, 0], %{{.*}}
|
||||
|
@ -299,10 +351,11 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) {
|
|||
}
|
||||
|
||||
// Memref replacement will fail here due to a non-dereferencing use. However,
|
||||
// no incorrect transformation is performed since replaceAllMemRefUsesWith
|
||||
// checks for escaping uses before performing any replacement.
|
||||
// CHECK-LABEL: func @escaping_use
|
||||
func @escaping_use() {
|
||||
// no incorrect transformation is performed in spite of one of the uses being a
|
||||
// dereferencing one since replaceAllMemRefUsesWith checks for escaping uses
|
||||
// before performing any replacement.
|
||||
// CHECK-LABEL: func @escaping_and_indexed_use_mix
|
||||
func @escaping_and_indexed_use_mix() {
|
||||
%A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
|
||||
%Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1>
|
||||
%tag = alloc() : memref<1 x f32>
|
||||
|
@ -317,9 +370,11 @@ func @escaping_use() {
|
|||
%v = affine.load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
|
||||
"foo"(%v) : (f32) -> ()
|
||||
}
|
||||
dealloc %A : memref<256 x f32, (d0) -> (d0), 0>
|
||||
dealloc %Ah : memref<32 x f32, (d0) -> (d0), 1>
|
||||
return
|
||||
}
|
||||
// No replacement
|
||||
// No replacement.
|
||||
// CHECK: affine.for %{{.*}} = 0 to 8 {
|
||||
// CHECK-NEXT: affine.dma_start %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}}
|
||||
// CHECK-NEXT: affine.dma_wait %{{.*}}[%{{.*}}], %{{.*}} : memref<1xf32>
|
||||
|
|
Loading…
Reference in New Issue