fix async fetch with update queue

This commit is contained in:
Dun Liang 2020-06-30 12:01:53 +08:00
parent 225a8f4944
commit 1e9d29c2ae
22 changed files with 341 additions and 63 deletions

View File

@ -101,6 +101,17 @@ const char *_cudaGetErrorEnum(NppStatus error);
#endif
#endif
template <typename T>
void peek(T result, char const *const func, const char *const file,
int const line) {
if (result) {
// DEVICE_RESET
LOGe << "Peek CUDA error at" << file >> ":" >> line << " code="
>> static_cast<unsigned int>(result) >> "(" << _cudaGetErrorEnum(result) << ")"
<< func;
}
}
template <typename T>
void check(T result, char const *const func, const char *const file,
int const line) {
@ -116,6 +127,7 @@ void check(T result, char const *const func, const char *const file,
// This will output the proper CUDA error strings in the event
// that a CUDA host call returns an error
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
#define peekCudaErrors(val) peek((val), #val, __FILE__, __LINE__)
// This will output the proper error string when calling cudaGetLastError
#define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__)

View File

@ -83,7 +83,7 @@ mpi_initer() {
MPI_CHECK(MPI_Init(NULL, NULL));
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
//calculating localRank based on hostname which is used in selecting a GPU
uint64_t hostHashs[mpi_world_rank];
char hostname[1024];

View File

@ -321,13 +321,54 @@ def attrs(var):
}
Var.attrs = attrs
def fetch(vars, func, *args, **kw):
core.fetch(vars, lambda *results: func(*results, *args, **kw))
def fetch(*args):
''' Async fetch vars with function closure.
Example 1::
def fetch_var(var, func, *args, **kw):
core.fetch([var], lambda a: func(a, *args, **kw))
Var.fetch = fetch_var
del fetch_var
for img,label in enumerate(your_dataset):
pred = your_model(img)
loss = critic(pred, label)
acc = accuracy(pred, label)
jt.fetch(acc, loss,
lambda acc, loss:
print(f"loss:{loss} acc:{acc}"
)
Example 2::
for i,(img,label) in enumerate(your_dataset):
pred = your_model(img)
loss = critic(pred, label)
acc = accuracy(pred, label)
# variable i will be bind into function closure
jt.fetch(i, acc, loss,
lambda i, acc, loss:
print(f"#{i}, loss:{loss} acc:{acc}"
)
'''
assert len(args)>=1
func = args[-1]
assert callable(func)
args = list(args[:-1])
if len(args)>0 and isinstance(args[0], Sequence) \
and len(args[0])>=1 and isinstance(args[0][0], Var):
raise TypeError("jt.Var should not inside a list or tuple.")
var_map = []
variables = []
for i, v in enumerate(args):
if isinstance(v, Var):
variables.append(v)
var_map.append(i)
args[i] = None
def callback(*results):
for i,v in enumerate(results):
args[var_map[i]] = v
func(*args)
core.ops.fetch(variables, callback)
Var.fetch = fetch
def display_memory_info():
import inspect, os
@ -574,6 +615,7 @@ def jittor_exit():
pass
else:
core.sync_all(True)
core.cleanup()
atexit.register(jittor_exit)
Var.__str__ = lambda x: str(x.data)

View File

@ -345,7 +345,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
with open(os.path.join(jittor_path, header), encoding='utf8') as f:
src = f.read()
# XxxXxxOp(args)
res = re.findall(pybind_attrs_reg + '('+name2+"\\([^\\n]*\\))", src, re.S)
res = re.findall(pybind_attrs_reg + '[^~]('+name2+"\\([^\\n]*\\))", src, re.S)
assert len(res) >= 1, "Wrong op args in " + header
# registe op
cc_name = os.path.join(jittor_path, header[:-2] + ".cc")
@ -908,14 +908,14 @@ with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
f.write(jit_src)
cc_flags += f' -I{cache_path} '
# gen pyjt
pyjt_compiler.compile(cache_path, jittor_path)
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
# initialize order:
# 1. registers
# 2. generate source
# 3. op_utils
# 4. other
files2 = run_cmd(f'find "{os.path.join(cache_path, "gen")}" | grep "cc$"').splitlines()
files2 = pyjt_gen_src
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
at_beginning = [
"src/ops/op_utils.cc",

View File

@ -849,6 +849,7 @@ def compile(cache_path, jittor_path):
headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
[ os.path.join(cache_path, h) for h in headers2 ]
basenames = []
pyjt_names = []
for h in headers:
with open(h, 'r') as f:
src = f.read()
@ -866,6 +867,7 @@ def compile(cache_path, jittor_path):
if not check: continue
basenames.append(basename)
pyjt_names.append(fname)
code = f"""
#include "pyjt/numpy.h"
@ -888,3 +890,5 @@ def compile(cache_path, jittor_path):
LOG.vvvv(code)
with open(fname, "w") as f:
f.write(code)
pyjt_names.append(fname)
return pyjt_names

View File

@ -60,6 +60,7 @@ class TestArray(unittest.TestCase):
for i in range(3):
x = jt.array(im)
b = net(x)
b.fetch(lambda b: None)
b.sync()
jt.sync(device_sync=True)
@ -70,6 +71,7 @@ class TestArray(unittest.TestCase):
x = jt.array(im)
b = net(x)
b.fetch(lambda b: results.append(b))
b.sync()
# del c
jt.sync(device_sync=True)
t2 = time.time() - time_start

View File

@ -13,7 +13,10 @@ class TestFetcher(unittest.TestCase):
a = jt.array([1,2,3])
a = a*2
v = []
jt.fetch([a], lambda a: v.append(a))
jt.fetch(a, lambda a: v.append(a))
jt.fetch(1, 2, 3, a,
lambda x, y, z, a: self.assertTrue(x==1 and y==2 and z==3 and isinstance(a, np.ndarray))
)
jt.sync_all(True)
assert len(v)==1 and (v[0]==[2,4,6]).all()

View File

@ -99,6 +99,7 @@ class TestResizeAndCrop(unittest.TestCase):
test_case(20, [1024, 1024], [1.2, 1.8][mid])
test_case(20, [1024, 666], [0.8,1.0][mid])
@unittest.skipIf(torch is None, "no torch found")
def test_resize(self):
import torch.nn.functional as F
x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32")
@ -108,11 +109,13 @@ class TestResizeAndCrop(unittest.TestCase):
jnn.Resize((r_size, r_size), 'bilinear', align_corners),
lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners))
@unittest.skipIf(torch is None, "no torch found")
def test_upsample(self):
arr = np.random.randn(2,3,224,224)
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
@unittest.skipIf(torch is None, "no torch found")
def test_pixelshuffle(self):
arr = np.random.randn(2,4,224,224)
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))

View File

@ -64,16 +64,16 @@ class TestResnet(unittest.TestCase):
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
for batch_idx, (data, target) in enumerate(self.train_loader):
output = mnist_net(data)
loss = nn.cross_entropy_loss(output, target)
# train step
with jt.log_capture_scope(
log_silent=1,
log_v=1, log_vprefix="op.cc=100,exe=10",
) as logs:
output = mnist_net(data)
loss = nn.cross_entropy_loss(output, target)
SGD.step(loss)
def callback(loss, output, target, batch_idx):
def callback(batch_idx, loss, output, target):
# print train info
global prev
pred = np.argmax(output, axis=1)
@ -83,13 +83,13 @@ class TestResnet(unittest.TestCase):
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
.format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
# prev = time.time()
jt.fetch([loss, output, target], callback, batch_idx)
jt.fetch(batch_idx, loss, output, target, callback)
log_conv = find_log_with_re(logs,
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
log_matmul = find_log_with_re(logs,
"Jit op key (not )?found: ((mkl)|(cublas))_matmul.*")
if batch_idx:
if batch_idx > 2:
assert len(log_conv)==59 and len(log_matmul)==6, (len(log_conv), len(log_matmul))
mem_used = jt.flags.stat_allocator_total_alloc_byte \
@ -114,15 +114,13 @@ class TestResnet(unittest.TestCase):
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
# print(jt.core.number_of_lived_vars(), mem_used)
jt.display_memory_info()
# if jt.in_mpi:
# assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
# else:
# assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
if jt.in_mpi:
assert jt.core.number_of_lived_vars() < 7500, jt.core.number_of_lived_vars()
else:
assert jt.core.number_of_lived_vars() < 6500, jt.core.number_of_lived_vars()
jt.sync_all(True)
assert np.mean(loss_list[-50:])<0.3
assert np.mean(loss_list[-50:])<0.5
assert np.mean(acc_list[-50:])>0.8
if __name__ == "__main__":

View File

@ -77,7 +77,7 @@ class TestVGGClass(unittest.TestCase):
acc_list.append(acc)
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}'
.format(0, batch_idx, 100,1. * batch_idx, loss[0], acc))
jt.fetch([loss, output, target], callback, batch_idx)
jt.fetch(batch_idx, loss, output, target, callback)
log_conv = find_log_with_re(logs,
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")

117
script/tmpi Executable file
View File

@ -0,0 +1,117 @@
#!/bin/bash
# Copyright 2013 Benedikt Morbach <moben@exherbo.org>
# Distributed under the terms of the GNU General Public License v2
# runs multiple MPI processes as a grid in a new tmux window and multiplexes keyboard input to all of them
additional_vars=( LD_LIBRARY_PATH LD_PRELOAD )
export "${additional_vars[@]}"
usage() {
echo 'tmpi: Run multiple MPI processes as a grid in a new tmux window and multiplex keyboard input to all of them.'
echo ''
echo 'Usage:'
echo ' tmpi [number] [command]'
echo ''
echo 'You need to pass at least two arguments.'
echo 'The first argument is the number of processes to use, every argument after that is the commandline to run.'
echo 'If you call this script from outside tmux and your command contains important whitespace then you need to appy two levels of quoting to preserve it.'
echo ''
echo 'LD_LIBRARY_PATH and LD_PRELOAD are passed through, so you can run it like this:'
echo 'LD_LIBRARY_PATH="${PWD}/.libs:${LD_LIBRARY_PATH}" tmpi 16 gdb -q bin/.libs/example'
echo ''
echo 'The new window is set to remain on exit and has to be closed manually. ("C-b + k" by default)'
}
check_tools() {
tools=( tmux mpirun )
for tool in "${tools[@]}"; do
if ! which ${tool}; then
echo "You need to install ${tool} to run this script."
fi
done
}
if [[ ${#} -lt 2 ]]; then
usage
exit 1
fi
if [[ -z ${TMUX} ]]; then
# it seems we aren't in a tmux session.
# start a new one so that our window doesn't end up in some other session and we have to search it.
# actually start a new server with '-L' to ensure that our environment carries over.
socket=$(mktemp --dry-run tmpi.XXXX)
exec tmux -L ${socket} new-session "${0} ${*}"
fi
if [[ ${1} == runmpi ]] ; then
# we are being started as one of many processes by mpirun.
shift
# start the processes in the order of their rank.
# this avoids races, as we have to push the variables in tmux' environment.
# it has the nice side-effect that the panes are also ordered by rank.
while [[ $(cat /tmp/tmpi.lock) -ne ${OMPI_COMM_WORLD_RANK} ]] ; do
sleep 0.02
done
# get all the variables that mpirun starts us with so that we can pass them through.
mpi_vars=( $( env | grep -e MPI -e OPAL -e PMIX -e PYTHON -e debug | cut -d '=' -f1 ) )
mpi_vars+=( "${additional_vars[@]}" )
# add the variables to tmux' session environment.
# we can't just export them because the process will be started as a child of tmux, not us.
for var in "${mpi_vars[@]}"; do
tmux set-environment -t ${session} "${var}" "${!var}"
done
x=( $(tmux split-window -P -F '#{pane_pid} #{pane_id}' -t ${window} "${*}") )
pid=${x[0]}
pane=${x[1]}
for var in "${mpi_vars[@]}"; do
tmux set-environment -t ${session} -u "${var}"
done
# kill the dummy pane that opened the new window
[[ ${OMPI_COMM_WORLD_RANK} -eq 0 ]] && tmux kill-pane -t ${dummy} &> /dev/null
# set the window to tiled mode.
# have to do this after every new pane is spawned because otherwise the splits get
# smaller and smaller until tmux refuses to open new panes, despite plenty of space being left.
tmux select-layout -t ${pane} tiled &> /dev/null
# let the next process start
echo $((${OMPI_COMM_WORLD_RANK}+1)) > /tmp/tmpi.lock
# don't exit here as mpirun needs to be kept alive and it would also exit.
while [[ -d /proc/${pid} ]]; do
sleep 1
done
else
# we are the parent and set everything up before we start ourselves a bunch of times via mpirun.
processes=${1}
self=${0}
shift
# create an empty new dummy window which we sill later split up for the mpi processes.
x=( $(tmux new-window ${session} -P -F '#{pane_id} #{window_id} #{session_id}') )
export dummy=${x[0]}
export window=${x[1]}
export session=${x[2]}
# syncronize input to all panes.
tmux set-window-option -t ${window} synchronize-panes on &> /dev/null
tmux set-window-option -t ${window} remain-on-exit on &> /dev/null
# always start with rank 0.
echo 0 > /tmp/tmpi.lock
# re-execute ourself to spawn of the processes.
echo mpirun -np ${processes} ${self} runmpi "${@}"
mpirun -np ${processes} ${self} runmpi "${@}"
fi

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.4.9',
version='1.1.5.0',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",

View File

@ -9,7 +9,6 @@
#include <cuda_runtime.h>
#include <helper_cuda.h>
#include "mem/allocator/cuda_dual_allocator.h"
#include "fetcher.h"
#include "event_queue.h"
#endif
#include "misc/cuda_flags.h"
@ -26,6 +25,9 @@ namespace jittor {
Executor exe;
// from fetch_op.cc
extern list<VarPtr> fetcher_to_free;
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
auto allocator = get_allocator();
this->allocator = allocator;
@ -33,22 +35,43 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
int op_num = 0;
vector<Node*> bfs_q;
bfs_q.reserve(vars.size());
auto nodes = (vector<Node*>*)&vars;
int start_var_num = 0;
for (Var* v : vars)
if (!v->is_finished())
start_var_num++;
bfs_backward(*nodes, bfs_q, [&](Node *node) -> bool {
node->custom_data = 0;
if (node->is_finished())
return false;
op_num += !node->is_var();
return true;
});
{
// get all nodes need to be executed
auto t = ++Node::tflag_count;
for (Var* v : vars)
if (!v->is_finished() && v->tflag != t) {
v->tflag = t;
start_var_num++;
bfs_q.push_back(v);
}
for (int i=0; i<bfs_q.size(); i++) {
auto node = bfs_q[i];
op_num += !node->is_var();
for (auto i : node->_inputs)
if (i.node->tflag != t && !i.node->is_finished()) {
i.node->tflag = t;
bfs_q.push_back(i.node);
}
// this var has been fetched
if (node->flags.get(NodeFlags::_fetch)) {
for (auto& n : node->_outputs) {
// if not in queue and is fetch op
if (n.node->tflag != t &&
!n.node->is_finished() &&
n.node->flags.get(NodeFlags::_fetch)) {
n.node->tflag = t;
bfs_q.push_back(n.node);
}
}
}
}
}
auto tt = Node::tflag_count;
vector<Op*> ops;
vector<Var*> all_vars;
ops.reserve(op_num);
all_vars.reserve(bfs_q.size() - op_num);
for (Node* node : bfs_q)
if (!node->is_var()) {
node->custom_data = ops.size();
@ -391,7 +414,6 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
outputs_bk.push_back(var);
op->finish_pending_liveness();
for (Var* var : outputs_bk)
// var->finish_pending_liveness();
var->finish_pending_liveness();
} catch (const std::exception& e) {
// log memory info
@ -410,6 +432,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
}
LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
for (Var* v : vars) ASSERT(v->mem_ptr);
// clean fetcher free buffer
fetcher_to_free.clear();
#ifdef HAS_CUDA
if (device_sync && use_cuda) {
last_is_cuda = false;

View File

@ -27,7 +27,7 @@ VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
auto dx = op->grad(out, dout, x, x_index);
if (x->loop_options)
dx->loop_options = x->loop_options;
return move(dx);
return dx;
}
inline static void assign_attrs(Var* a, Var* b) {

View File

@ -11,6 +11,7 @@
#include "init.h"
#include "ops/op_register.h"
#include "var.h"
namespace jittor {
@ -21,6 +22,15 @@ unique_ptr<std::default_random_engine> eng;
vector<set_seed_callback> callbacks;
int current_seed;
// fron fetch_op.cc
extern list<VarPtr> fetcher;
extern list<VarPtr> fetcher_to_free;
void cleanup() {
fetcher_to_free.clear();
fetcher.clear();
}
static void init_cuda_devices() {
#ifdef HAS_CUDA
int count=0;

View File

@ -20,4 +20,8 @@ void add_set_seed_callback(set_seed_callback callback);
extern "C"
std::default_random_engine* get_random_engine();
// things need to be clean before python exit
// @pyjt(cleanup)
void cleanup();
} // jittor

View File

@ -95,7 +95,7 @@ struct DelayFree final : Allocator {
void free(void* mem_ptr, size_t size, const size_t& allocation) override {
using namespace cuda_dual_local;
allocations.emplace_back(mem_ptr, allocation, size, &cuda_dual_allocator);
checkCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0));
peekCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0));
}
void migrate_to_cpu(void*& mem_ptr, size_t& allocation, size_t size, Allocator* allocator) {

View File

@ -24,7 +24,9 @@ struct NodeFlags {
_finished=1,
// bit2: stop grad
_stop_grad=2,
_n=3,
// bit3: is fetch
_fetch=3,
_n=4,
// var related flags
_force_fuse=_n+0,

View File

@ -32,9 +32,9 @@ Init() {
}
~Init() {
if (!get_device_count()) return;
checkCudaErrors(cudaDeviceSynchronize());
checkCudaErrors(cudaStreamDestroy(stream));
checkCudaErrors(cudaEventDestroy(event));
peekCudaErrors(cudaDeviceSynchronize());
peekCudaErrors(cudaStreamDestroy(stream));
peekCudaErrors(cudaEventDestroy(event));
}
} init;

View File

@ -1,5 +1,7 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor.
// Authors: Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
@ -12,8 +14,9 @@
#include "mem/allocator/cuda_dual_allocator.h"
#include "event_queue.h"
#endif
#include "fetcher.h"
#include "ops/fetch_op.h"
#include "mem/allocator.h"
#include "executor.h"
namespace jittor {
@ -49,31 +52,68 @@ Init() {
// do not call deleter on exit
for (auto& f : fetch_tasks)
f.func.deleter = nullptr;
checkCudaErrors(cudaDeviceSynchronize());
checkCudaErrors(cudaStreamDestroy(stream));
checkCudaErrors(cudaEventDestroy(event));
peekCudaErrors(cudaDeviceSynchronize());
peekCudaErrors(cudaStreamDestroy(stream));
peekCudaErrors(cudaEventDestroy(event));
}
};
} ;
}
using namespace fetcher_local;
#endif
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
list<VarPtr> fetcher;
// this list will be free at each execution
list<VarPtr> fetcher_to_free;
FetchOp::FetchOp(vector<Var*>&& inputs, FetchFunc&& func)
: fetch_vars(inputs), func(move(func)) {
#ifdef HAS_CUDA
static Init init;
// stream needs to be created after nccl plugin
static Init init_fetch;
#endif
sync(vh);
vector<Allocation> allocations(vh.size());
vector<ArrayArgs> arrays(vh.size());
VarPtr vp(0, ns_int);
outputs_holder.emplace_back(vp);
fetcher.emplace_front(move(vp));
fetcher_iter = fetcher.begin();
bool all_finished = true;
for (auto v : fetch_vars)
if (!v->is_finished()) {
all_finished = false;
v->flags.set(NodeFlags::_stop_fuse);
v->flags.set(NodeFlags::_fetch);
}
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_fetch);
flags.set(NodeFlags::_stop_grad);
fetcher_iter->ptr->flags.set(NodeFlags::_fetch);
// fetcher_to_free.clear();
if (all_finished) {
// if all finished, run immediately
run();
}
// if too many fetchers are bufferd, force flush
while (fetcher.size() > 20) {
LOGvvvv << "too many fetchers(">>fetcher.size() >>
") are bufferd, force flush";
exe.run_sync({fetcher.back().ptr}, false);
}
}
void FetchOp::run() {
vector<Allocation> allocations(fetch_vars.size());
vector<ArrayArgs> arrays(fetch_vars.size());
#ifdef HAS_CUDA
bool has_cuda_memcpy = false;
event_queue.flush();
#endif
for (int i=0; i<vh.size(); i++) {
auto v = vh[i]->var;
LOGvvvv << "fetch" << fetch_vars.size() << "vars" << fetch_vars;
int i = 0;
for (auto v : fetch_vars) {
auto& allocation = allocations[i];
#ifdef HAS_CUDA
if (v->allocator->is_cuda()) {
checkCudaErrors(cudaEventRecord(event, 0));
@ -98,6 +138,7 @@ void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
arrays[i].ptr = allocation.ptr;
arrays[i].shape = v->shape;
arrays[i].dtype = v->dtype();
i++;
}
#ifdef HAS_CUDA
if (has_cuda_memcpy) {
@ -109,6 +150,8 @@ void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
FetchResult fr{move(func), move(allocations), move(arrays)};
fr.call();
}
fetcher_to_free.emplace_front(move(*fetcher_iter));
fetcher.erase(fetcher_iter);
}
} // jittor

View File

@ -5,8 +5,9 @@
// ***************************************************************
#pragma once
#include <functional>
#include "common.h"
#include "var_holder.h"
#include "op.h"
#include "var.h"
#include "mem/allocator.h"
#include "ops/array_op.h"
namespace jittor {
@ -42,7 +43,15 @@ struct FetchResult {
inline void call() { func.callback(this); }
};
// @pyjt(fetch)
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func);
struct FetchOp final : Op {
vector<Var*> fetch_vars;
FetchFunc func;
list<VarPtr>::iterator fetcher_iter;
} // jittor
FetchOp(vector<Var*>&& inputs, FetchFunc&& func);
const char* name() const override { return "fetch"; }
void run() override;
};
} // jittor

View File

@ -97,6 +97,9 @@ ArrayArgs VarHolder::fetch_sync() {
return {var->mem_ptr, var->shape, var->dtype()};
}
// from fetch_op.cc
extern list<VarPtr> fetcher;
void sync_all(bool device_sync) {
vector<Var*> vars;
vars.reserve(VarHolder::hold_vars.size());
@ -104,6 +107,8 @@ void sync_all(bool device_sync) {
if (!v->var->_outputs.size())
vars.push_back(v->var);
}
for (auto& v :fetcher)
vars.push_back(v.ptr);
graph_check();
exe.run_sync(vars, device_sync); //need sync at last
graph_check();