pipeline-data-transfer: remove dead tag alloc's and improve test coverage for replaceMemRefUsesWith / pipeline-data-transfer

- address remaining comments from PR tensorflow/mlir#87 for better test coverage for
  pipeline-data-transfer/replaceAllMemRefUsesWith
- remove dead tag allocs the same way they are removed for the replaced buffers

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow/mlir#106

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/106 from bondhugula:followup 9e868666d047e8d43e5f82f43e4093b838c710fa
PiperOrigin-RevId: 267144774
This commit is contained in:
Uday Bondhugula 2019-09-04 06:58:39 -07:00 committed by A. Unique TensorFlower
parent 6395229509
commit 8c9dc690eb
4 changed files with 98 additions and 35 deletions

View File

@ -574,11 +574,12 @@ LogicalResult AffineDataCopyGeneration::generateCopy(
prevOfBegin = std::prev(begin);
// *Only* those uses within the range [begin, end) of 'block' are replaced.
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/regionSymbols,
/*domInstFilter=*/&*begin,
/*postDomInstFilter=*/&*postDomFilter);
if (failed(replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/regionSymbols,
/*domInstFilter=*/&*begin,
/*postDomInstFilter=*/&*postDomFilter)))
llvm_unreachable("memref replacement guaranteed to succeed here");
*nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);

View File

@ -293,7 +293,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
} else if (oldMemRef->hasOneUse()) {
if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef->user_begin())) {
dealloc.erase();
oldMemRef->getDefiningOp()->erase();
allocInst->erase();
}
}
}
@ -308,11 +308,18 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return;
}
// If the old tag has no more uses, remove its 'dead' alloc if it was
// alloc'ed.
if (oldTagMemRef->use_empty())
if (auto *allocInst = oldTagMemRef->getDefiningOp())
allocInst->erase();
// If the old tag has no uses or a single dealloc use, remove it.
// (canonicalization handles more complex cases).
if (auto *tagAllocInst = oldTagMemRef->getDefiningOp()) {
if (oldTagMemRef->use_empty()) {
tagAllocInst->erase();
} else if (oldTagMemRef->hasOneUse()) {
if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef->user_begin())) {
dealloc.erase();
tagAllocInst->erase();
}
}
}
}
// Double buffering would have invalidated all the old DMA start/wait insts.

View File

@ -288,7 +288,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
for (auto *op : opsToReplace) {
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
indexRemap, extraOperands)))
assert(false && "memref replacement guaranteed to succeed here");
llvm_unreachable("memref replacement guaranteed to succeed here");
}
return success();

View File

@ -26,6 +26,8 @@ func @loop_nest_dma() {
"do_more_compute"(%i, %j) : (index, index) -> ()
}
}
dealloc %tag : memref<1 x f32>
dealloc %Ah : memref<32 x f32, (d0) -> (d0), 1>
return
}
// CHECK: %{{.*}} = alloc() : memref<256xf32>
@ -77,9 +79,12 @@ func @loop_step(%arg0: memref<512xf32>,
: memref<512xf32>, memref<4xf32, 1>, memref<1xi32>
affine.dma_wait %2[%c0], %c4 : memref<1xi32>
"compute"(%i0) : (index) -> ()
dealloc %2 : memref<1xi32>
dealloc %1 : memref<4xf32, 1>
}
return
}
// CHECK: [[BUF:%[0-9]+]] = alloc() : memref<2x4xf32, 1>
// CHECK: [[TAG:%[0-9]+]] = alloc() : memref<2x1xi32>
// CHECK-NEXT: affine.dma_start %{{.*}}[%{{.*}}], %{{.*}}[(%{{.*}} floordiv 4) mod 2, 0], [[TAG]][(%{{.*}} floordiv 4) mod 2, 0], %{{.*}} : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32>
// CHECK-NEXT: affine.for %{{.*}} = 4 to 512 step 4 {
@ -93,21 +98,22 @@ func @loop_step(%arg0: memref<512xf32>,
// CHECK-NEXT: %{{.*}} = affine.apply [[FLOOR_MOD_2]]([[SHIFTED]])
// CHECK: affine.dma_wait [[TAG]][(%{{.*}} floordiv 4) mod 2, 0], %{{.*}} : memref<2x1xi32>
// CHECK-NEXT: "compute"(%{{.*}}) : (index) -> ()
// CHECK: return
// CHECK-NEXT: dealloc [[TAG]] : memref<2x1xi32>
// CHECK-NEXT: dealloc [[BUF]] : memref<2x4xf32, 1>
// CHECK-NEXT: return
// CHECK-NEXT: }
// -----
#map0 = (d0, d1) -> (d0, d1)
#map1 = (d0, d1) -> ((d0 * 2048 + d1 * 256) floordiv 32)
#map2 = (d0) -> ((d0 * 2048) floordiv 32)
// CHECK-LABEL: func @loop_dma_nested(%{{.*}}: memref<512x32xvector<8xf32>
func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref<512x32xvector<8xf32>, #map0>, %arg2: memref<512x32xvector<8xf32>, #map0>) {
func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>>, %arg1: memref<512x32xvector<8xf32>>, %arg2: memref<512x32xvector<8xf32>>) {
%num_elts = constant 256 : index
%c0 = constant 0 : index
%0 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
%1 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
%2 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
%0 = alloc() : memref<64x4xvector<8xf32>, 2>
%1 = alloc() : memref<64x4xvector<8xf32>, 2>
%2 = alloc() : memref<64x4xvector<8xf32>, 2>
%3 = alloc() : memref<2xi32>
%4 = alloc() : memref<2xi32>
%5 = alloc() : memref<2xi32>
@ -118,7 +124,7 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref<
// CHECK: affine.for %{{.*}} = 1 to 8 {
affine.for %i0 = 0 to 8 {
%6 = affine.apply #map2(%i0)
affine.dma_start %arg2[%6, %c0], %2[%c0, %c0], %5[%c0], %num_elts : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
affine.dma_start %arg2[%6, %c0], %2[%c0, %c0], %5[%c0], %num_elts : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32>
affine.dma_wait %5[%c0], %num_elts : memref<2xi32>
// Steady state for DMA overlap on arg2
// CHECK: affine.dma_start %{{.*}}[
@ -134,8 +140,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref<
affine.for %i1 = 0 to 8 {
%7 = affine.apply #map1(%i0, %i1)
%8 = affine.apply #map2(%i1)
affine.dma_start %arg0[%7, %c0], %0[%c0, %c0], %3[%c0], %num_elts : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
affine.dma_start %arg1[%8, %c0], %1[%c0, %c0], %4[%c0], %num_elts : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
affine.dma_start %arg0[%7, %c0], %0[%c0, %c0], %3[%c0], %num_elts : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32>
affine.dma_start %arg1[%8, %c0], %1[%c0, %c0], %4[%c0], %num_elts : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32>
affine.dma_wait %3[%c0], %num_elts : memref<2xi32>
affine.dma_wait %4[%c0], %num_elts : memref<2xi32>
// Steady state for DMA overlap on arg0, arg1
@ -175,15 +181,21 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref<
// CHECK: affine.dma_wait [[TAG_ARG1_NESTED]]
// CHECK: affine.for %{{.*}} = 0 to 4 {
}
dealloc %5 : memref<2xi32>
dealloc %4 : memref<2xi32>
dealloc %3 : memref<2xi32>
dealloc %2 : memref<64x4xvector<8xf32>, 2>
dealloc %1 : memref<64x4xvector<8xf32>, 2>
dealloc %0 : memref<64x4xvector<8xf32>, 2>
return
// CHECK: }
// CHECK-DAG: dealloc [[TAG_ARG1_NESTED]] : memref<2x2xi32>
// CHECK-DAG: dealloc [[TAG_ARG0_NESTED]] : memref<2x2xi32>
// CHECK-DAG: dealloc [[BUF_ARG1_NESTED]] : memref<2x64x4xvector<8xf32>, 2>
// CHECK-DAG: dealloc [[BUF_ARG0_NESTED]] : memref<2x64x4xvector<8xf32>, 2>
// CHECK-DAG: dealloc [[TAG_ARG2]] : memref<2x2xi32>
// CHECK-DAG: dealloc [[BUF_ARG2]] : memref<2x64x4xvector<8xf32>, 2>
// CHECK: return
// CHECK-DAG: dealloc [[TAG_ARG1_NESTED]] : memref<2x2xi32>
// CHECK-DAG: dealloc [[TAG_ARG0_NESTED]] : memref<2x2xi32>
// CHECK-DAG: dealloc [[BUF_ARG1_NESTED]] : memref<2x64x4xvector<8xf32>, 2>
// CHECK-DAG: dealloc [[BUF_ARG0_NESTED]] : memref<2x64x4xvector<8xf32>, 2>
// CHECK-DAG: dealloc [[TAG_ARG2]] : memref<2x2xi32>
// CHECK-DAG: dealloc [[BUF_ARG2]] : memref<2x64x4xvector<8xf32>, 2>
// CHECK-NEXT: return
}
// -----
@ -211,8 +223,14 @@ func @loop_dma_dependent(%arg2: memref<512x32xvector<8xf32>>) {
affine.dma_start %2[%c0, %c0], %arg2[%6, %c0], %5[%c0], %num_elts : memref<64x4xvector<8xf32>, 2>, memref<512x32xvector<8xf32>>, memref<2xi32>
affine.dma_wait %5[%c0], %num_elts : memref<2xi32>
} // CHECK: }
return // CHECK: return
}
dealloc %5 : memref<2xi32>
dealloc %4 : memref<2xi32>
dealloc %3 : memref<2xi32>
dealloc %2 : memref<64x4xvector<8xf32>, 2>
dealloc %1 : memref<64x4xvector<8xf32>, 2>
dealloc %0 : memref<64x4xvector<8xf32>, 2>
return
}
// -----
@ -235,12 +253,43 @@ func @escaping_use(%arg0: memref<512 x 32 x f32>) {
// escaping use; no DMA pipelining / double buffering will be done.
"foo"(%Av) : (memref<32 x 32 x f32, 2>) -> ()
}
dealloc %tag : memref<1 x i32>
dealloc %Av : memref<32 x 32 x f32, 2>
return
// CHECK: "foo"(%{{[0-9]+}}) : (memref<32x32xf32, 2>) -> ()
// CHECK: }
// CHECK: return
}
// -----
// CHECK-LABEL: func @escaping_tag
func @escaping_tag(%arg0: memref<512 x 32 x f32>) {
%c32 = constant 32 : index
%num_elt = constant 512 : index
%zero = constant 0 : index
%Av = alloc() : memref<32 x 32 x f32, 2>
%tag = alloc() : memref<1 x i32>
// CHECK-NOT: affine.dma_start
// CHECK: affine.for %{{.*}} = 0 to 16 {
affine.for %kTT = 0 to 16 {
affine.dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %tag[%zero], %num_elt :
memref<512 x 32 x f32>,
memref<32 x 32 x f32, 2>, memref<1 x i32>
affine.dma_wait %tag[%zero], %num_elt : memref<1 x i32>
// escaping use; no DMA pipelining / double buffering will be done.
"foo"(%tag) : (memref<1 x i32>) -> ()
}
dealloc %tag : memref<1 x i32>
dealloc %Av : memref<32 x 32 x f32, 2>
return
// CHECK: "foo"(%{{[0-9]+}}) : (memref<1xi32>) -> ()
// CHECK: }
// CHECK: return
}
// -----
// CHECK-LABEL: func @live_out_use
@ -261,6 +310,8 @@ func @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 {
}
// Use live out of 'affine.for' op; no DMA pipelining will be done.
%v = affine.load %Av[%zero, %zero] : memref<32 x 32 x f32, 2>
dealloc %tag : memref<1 x i32>
dealloc %Av : memref<32 x 32 x f32, 2>
return %v : f32
// CHECK: %{{[0-9]+}} = affine.load %{{[0-9]+}}[%{{.*}}, %{{.*}}] : memref<32x32xf32, 2>
// CHECK: return
@ -289,6 +340,7 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) {
memref<? x ? x f32, 2>, memref<1 x i32>
affine.dma_wait %tag[%zero], %num_elt : memref<1 x i32>
}
dealloc %Av : memref<? x ? x f32, 2>
return
// CHECK-NEXT: affine.for %{{.*}} = 1 to 16 {
// CHECK: affine.dma_start %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}} mod 2, 0, 0], %{{.*}}[%{{.*}} mod 2, 0], %{{.*}}
@ -299,10 +351,11 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) {
}
// Memref replacement will fail here due to a non-dereferencing use. However,
// no incorrect transformation is performed since replaceAllMemRefUsesWith
// checks for escaping uses before performing any replacement.
// CHECK-LABEL: func @escaping_use
func @escaping_use() {
// no incorrect transformation is performed in spite of one of the uses being a
// dereferencing one since replaceAllMemRefUsesWith checks for escaping uses
// before performing any replacement.
// CHECK-LABEL: func @escaping_and_indexed_use_mix
func @escaping_and_indexed_use_mix() {
%A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
%Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1>
%tag = alloc() : memref<1 x f32>
@ -317,9 +370,11 @@ func @escaping_use() {
%v = affine.load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
"foo"(%v) : (f32) -> ()
}
dealloc %A : memref<256 x f32, (d0) -> (d0), 0>
dealloc %Ah : memref<32 x f32, (d0) -> (d0), 1>
return
}
// No replacement
// No replacement.
// CHECK: affine.for %{{.*}} = 0 to 8 {
// CHECK-NEXT: affine.dma_start %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}}
// CHECK-NEXT: affine.dma_wait %{{.*}}[%{{.*}}], %{{.*}} : memref<1xf32>