forked from maxjhandsome/jittor
fix async fetch with update queue
This commit is contained in:
parent
225a8f4944
commit
1e9d29c2ae
|
@ -101,6 +101,17 @@ const char *_cudaGetErrorEnum(NppStatus error);
|
|||
#endif
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void peek(T result, char const *const func, const char *const file,
|
||||
int const line) {
|
||||
if (result) {
|
||||
// DEVICE_RESET
|
||||
LOGe << "Peek CUDA error at" << file >> ":" >> line << " code="
|
||||
>> static_cast<unsigned int>(result) >> "(" << _cudaGetErrorEnum(result) << ")"
|
||||
<< func;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void check(T result, char const *const func, const char *const file,
|
||||
int const line) {
|
||||
|
@ -116,6 +127,7 @@ void check(T result, char const *const func, const char *const file,
|
|||
// This will output the proper CUDA error strings in the event
|
||||
// that a CUDA host call returns an error
|
||||
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
|
||||
#define peekCudaErrors(val) peek((val), #val, __FILE__, __LINE__)
|
||||
|
||||
// This will output the proper error string when calling cudaGetLastError
|
||||
#define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__)
|
||||
|
|
|
@ -83,7 +83,7 @@ mpi_initer() {
|
|||
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
||||
|
||||
|
||||
//calculating localRank based on hostname which is used in selecting a GPU
|
||||
uint64_t hostHashs[mpi_world_rank];
|
||||
char hostname[1024];
|
||||
|
|
|
@ -321,13 +321,54 @@ def attrs(var):
|
|||
}
|
||||
Var.attrs = attrs
|
||||
|
||||
def fetch(vars, func, *args, **kw):
|
||||
core.fetch(vars, lambda *results: func(*results, *args, **kw))
|
||||
def fetch(*args):
|
||||
''' Async fetch vars with function closure.
|
||||
|
||||
Example 1::
|
||||
|
||||
def fetch_var(var, func, *args, **kw):
|
||||
core.fetch([var], lambda a: func(a, *args, **kw))
|
||||
Var.fetch = fetch_var
|
||||
del fetch_var
|
||||
for img,label in enumerate(your_dataset):
|
||||
pred = your_model(img)
|
||||
loss = critic(pred, label)
|
||||
acc = accuracy(pred, label)
|
||||
jt.fetch(acc, loss,
|
||||
lambda acc, loss:
|
||||
print(f"loss:{loss} acc:{acc}"
|
||||
)
|
||||
|
||||
Example 2::
|
||||
|
||||
for i,(img,label) in enumerate(your_dataset):
|
||||
pred = your_model(img)
|
||||
loss = critic(pred, label)
|
||||
acc = accuracy(pred, label)
|
||||
# variable i will be bind into function closure
|
||||
jt.fetch(i, acc, loss,
|
||||
lambda i, acc, loss:
|
||||
print(f"#{i}, loss:{loss} acc:{acc}"
|
||||
)
|
||||
'''
|
||||
assert len(args)>=1
|
||||
func = args[-1]
|
||||
assert callable(func)
|
||||
args = list(args[:-1])
|
||||
if len(args)>0 and isinstance(args[0], Sequence) \
|
||||
and len(args[0])>=1 and isinstance(args[0][0], Var):
|
||||
raise TypeError("jt.Var should not inside a list or tuple.")
|
||||
|
||||
var_map = []
|
||||
variables = []
|
||||
for i, v in enumerate(args):
|
||||
if isinstance(v, Var):
|
||||
variables.append(v)
|
||||
var_map.append(i)
|
||||
args[i] = None
|
||||
def callback(*results):
|
||||
for i,v in enumerate(results):
|
||||
args[var_map[i]] = v
|
||||
func(*args)
|
||||
core.ops.fetch(variables, callback)
|
||||
|
||||
Var.fetch = fetch
|
||||
|
||||
def display_memory_info():
|
||||
import inspect, os
|
||||
|
@ -574,6 +615,7 @@ def jittor_exit():
|
|||
pass
|
||||
else:
|
||||
core.sync_all(True)
|
||||
core.cleanup()
|
||||
atexit.register(jittor_exit)
|
||||
|
||||
Var.__str__ = lambda x: str(x.data)
|
||||
|
|
|
@ -345,7 +345,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
with open(os.path.join(jittor_path, header), encoding='utf8') as f:
|
||||
src = f.read()
|
||||
# XxxXxxOp(args)
|
||||
res = re.findall(pybind_attrs_reg + '('+name2+"\\([^\\n]*\\))", src, re.S)
|
||||
res = re.findall(pybind_attrs_reg + '[^~]('+name2+"\\([^\\n]*\\))", src, re.S)
|
||||
assert len(res) >= 1, "Wrong op args in " + header
|
||||
# registe op
|
||||
cc_name = os.path.join(jittor_path, header[:-2] + ".cc")
|
||||
|
@ -908,14 +908,14 @@ with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
|||
f.write(jit_src)
|
||||
cc_flags += f' -I{cache_path} '
|
||||
# gen pyjt
|
||||
pyjt_compiler.compile(cache_path, jittor_path)
|
||||
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
||||
|
||||
# initialize order:
|
||||
# 1. registers
|
||||
# 2. generate source
|
||||
# 3. op_utils
|
||||
# 4. other
|
||||
files2 = run_cmd(f'find "{os.path.join(cache_path, "gen")}" | grep "cc$"').splitlines()
|
||||
files2 = pyjt_gen_src
|
||||
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
|
||||
at_beginning = [
|
||||
"src/ops/op_utils.cc",
|
||||
|
|
|
@ -849,6 +849,7 @@ def compile(cache_path, jittor_path):
|
|||
headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
|
||||
[ os.path.join(cache_path, h) for h in headers2 ]
|
||||
basenames = []
|
||||
pyjt_names = []
|
||||
for h in headers:
|
||||
with open(h, 'r') as f:
|
||||
src = f.read()
|
||||
|
@ -866,6 +867,7 @@ def compile(cache_path, jittor_path):
|
|||
if not check: continue
|
||||
|
||||
basenames.append(basename)
|
||||
pyjt_names.append(fname)
|
||||
|
||||
code = f"""
|
||||
#include "pyjt/numpy.h"
|
||||
|
@ -888,3 +890,5 @@ def compile(cache_path, jittor_path):
|
|||
LOG.vvvv(code)
|
||||
with open(fname, "w") as f:
|
||||
f.write(code)
|
||||
pyjt_names.append(fname)
|
||||
return pyjt_names
|
||||
|
|
|
@ -60,6 +60,7 @@ class TestArray(unittest.TestCase):
|
|||
for i in range(3):
|
||||
x = jt.array(im)
|
||||
b = net(x)
|
||||
b.fetch(lambda b: None)
|
||||
b.sync()
|
||||
jt.sync(device_sync=True)
|
||||
|
||||
|
@ -70,6 +71,7 @@ class TestArray(unittest.TestCase):
|
|||
x = jt.array(im)
|
||||
b = net(x)
|
||||
b.fetch(lambda b: results.append(b))
|
||||
b.sync()
|
||||
# del c
|
||||
jt.sync(device_sync=True)
|
||||
t2 = time.time() - time_start
|
||||
|
|
|
@ -13,7 +13,10 @@ class TestFetcher(unittest.TestCase):
|
|||
a = jt.array([1,2,3])
|
||||
a = a*2
|
||||
v = []
|
||||
jt.fetch([a], lambda a: v.append(a))
|
||||
jt.fetch(a, lambda a: v.append(a))
|
||||
jt.fetch(1, 2, 3, a,
|
||||
lambda x, y, z, a: self.assertTrue(x==1 and y==2 and z==3 and isinstance(a, np.ndarray))
|
||||
)
|
||||
jt.sync_all(True)
|
||||
assert len(v)==1 and (v[0]==[2,4,6]).all()
|
||||
|
||||
|
|
|
@ -99,6 +99,7 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
test_case(20, [1024, 1024], [1.2, 1.8][mid])
|
||||
test_case(20, [1024, 666], [0.8,1.0][mid])
|
||||
|
||||
@unittest.skipIf(torch is None, "no torch found")
|
||||
def test_resize(self):
|
||||
import torch.nn.functional as F
|
||||
x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32")
|
||||
|
@ -108,11 +109,13 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
jnn.Resize((r_size, r_size), 'bilinear', align_corners),
|
||||
lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners))
|
||||
|
||||
@unittest.skipIf(torch is None, "no torch found")
|
||||
def test_upsample(self):
|
||||
arr = np.random.randn(2,3,224,224)
|
||||
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
||||
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
||||
|
||||
@unittest.skipIf(torch is None, "no torch found")
|
||||
def test_pixelshuffle(self):
|
||||
arr = np.random.randn(2,4,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
||||
|
|
|
@ -64,16 +64,16 @@ class TestResnet(unittest.TestCase):
|
|||
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
||||
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
|
||||
# train step
|
||||
with jt.log_capture_scope(
|
||||
log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=10",
|
||||
) as logs:
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
SGD.step(loss)
|
||||
def callback(loss, output, target, batch_idx):
|
||||
def callback(batch_idx, loss, output, target):
|
||||
# print train info
|
||||
global prev
|
||||
pred = np.argmax(output, axis=1)
|
||||
|
@ -83,13 +83,13 @@ class TestResnet(unittest.TestCase):
|
|||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
.format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
|
||||
# prev = time.time()
|
||||
jt.fetch([loss, output, target], callback, batch_idx)
|
||||
|
||||
jt.fetch(batch_idx, loss, output, target, callback)
|
||||
|
||||
log_conv = find_log_with_re(logs,
|
||||
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
||||
log_matmul = find_log_with_re(logs,
|
||||
"Jit op key (not )?found: ((mkl)|(cublas))_matmul.*")
|
||||
if batch_idx:
|
||||
if batch_idx > 2:
|
||||
assert len(log_conv)==59 and len(log_matmul)==6, (len(log_conv), len(log_matmul))
|
||||
|
||||
mem_used = jt.flags.stat_allocator_total_alloc_byte \
|
||||
|
@ -114,15 +114,13 @@ class TestResnet(unittest.TestCase):
|
|||
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
|
||||
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
||||
|
||||
# print(jt.core.number_of_lived_vars(), mem_used)
|
||||
jt.display_memory_info()
|
||||
# if jt.in_mpi:
|
||||
# assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
|
||||
# else:
|
||||
# assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
|
||||
if jt.in_mpi:
|
||||
assert jt.core.number_of_lived_vars() < 7500, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 6500, jt.core.number_of_lived_vars()
|
||||
|
||||
jt.sync_all(True)
|
||||
assert np.mean(loss_list[-50:])<0.3
|
||||
assert np.mean(loss_list[-50:])<0.5
|
||||
assert np.mean(acc_list[-50:])>0.8
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -77,7 +77,7 @@ class TestVGGClass(unittest.TestCase):
|
|||
acc_list.append(acc)
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}'
|
||||
.format(0, batch_idx, 100,1. * batch_idx, loss[0], acc))
|
||||
jt.fetch([loss, output, target], callback, batch_idx)
|
||||
jt.fetch(batch_idx, loss, output, target, callback)
|
||||
|
||||
log_conv = find_log_with_re(logs,
|
||||
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2013 Benedikt Morbach <moben@exherbo.org>
|
||||
# Distributed under the terms of the GNU General Public License v2
|
||||
|
||||
# runs multiple MPI processes as a grid in a new tmux window and multiplexes keyboard input to all of them
|
||||
|
||||
additional_vars=( LD_LIBRARY_PATH LD_PRELOAD )
|
||||
export "${additional_vars[@]}"
|
||||
|
||||
usage() {
|
||||
echo 'tmpi: Run multiple MPI processes as a grid in a new tmux window and multiplex keyboard input to all of them.'
|
||||
echo ''
|
||||
echo 'Usage:'
|
||||
echo ' tmpi [number] [command]'
|
||||
echo ''
|
||||
echo 'You need to pass at least two arguments.'
|
||||
echo 'The first argument is the number of processes to use, every argument after that is the commandline to run.'
|
||||
echo 'If you call this script from outside tmux and your command contains important whitespace then you need to appy two levels of quoting to preserve it.'
|
||||
echo ''
|
||||
echo 'LD_LIBRARY_PATH and LD_PRELOAD are passed through, so you can run it like this:'
|
||||
echo 'LD_LIBRARY_PATH="${PWD}/.libs:${LD_LIBRARY_PATH}" tmpi 16 gdb -q bin/.libs/example'
|
||||
echo ''
|
||||
echo 'The new window is set to remain on exit and has to be closed manually. ("C-b + k" by default)'
|
||||
}
|
||||
|
||||
check_tools() {
|
||||
tools=( tmux mpirun )
|
||||
|
||||
for tool in "${tools[@]}"; do
|
||||
if ! which ${tool}; then
|
||||
echo "You need to install ${tool} to run this script."
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
if [[ ${#} -lt 2 ]]; then
|
||||
usage
|
||||
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z ${TMUX} ]]; then
|
||||
# it seems we aren't in a tmux session.
|
||||
# start a new one so that our window doesn't end up in some other session and we have to search it.
|
||||
# actually start a new server with '-L' to ensure that our environment carries over.
|
||||
socket=$(mktemp --dry-run tmpi.XXXX)
|
||||
exec tmux -L ${socket} new-session "${0} ${*}"
|
||||
fi
|
||||
|
||||
if [[ ${1} == runmpi ]] ; then
|
||||
# we are being started as one of many processes by mpirun.
|
||||
shift
|
||||
|
||||
# start the processes in the order of their rank.
|
||||
# this avoids races, as we have to push the variables in tmux' environment.
|
||||
# it has the nice side-effect that the panes are also ordered by rank.
|
||||
while [[ $(cat /tmp/tmpi.lock) -ne ${OMPI_COMM_WORLD_RANK} ]] ; do
|
||||
sleep 0.02
|
||||
done
|
||||
|
||||
# get all the variables that mpirun starts us with so that we can pass them through.
|
||||
mpi_vars=( $( env | grep -e MPI -e OPAL -e PMIX -e PYTHON -e debug | cut -d '=' -f1 ) )
|
||||
mpi_vars+=( "${additional_vars[@]}" )
|
||||
|
||||
# add the variables to tmux' session environment.
|
||||
# we can't just export them because the process will be started as a child of tmux, not us.
|
||||
for var in "${mpi_vars[@]}"; do
|
||||
tmux set-environment -t ${session} "${var}" "${!var}"
|
||||
done
|
||||
|
||||
x=( $(tmux split-window -P -F '#{pane_pid} #{pane_id}' -t ${window} "${*}") )
|
||||
pid=${x[0]}
|
||||
pane=${x[1]}
|
||||
|
||||
for var in "${mpi_vars[@]}"; do
|
||||
tmux set-environment -t ${session} -u "${var}"
|
||||
done
|
||||
|
||||
# kill the dummy pane that opened the new window
|
||||
[[ ${OMPI_COMM_WORLD_RANK} -eq 0 ]] && tmux kill-pane -t ${dummy} &> /dev/null
|
||||
|
||||
# set the window to tiled mode.
|
||||
# have to do this after every new pane is spawned because otherwise the splits get
|
||||
# smaller and smaller until tmux refuses to open new panes, despite plenty of space being left.
|
||||
tmux select-layout -t ${pane} tiled &> /dev/null
|
||||
|
||||
# let the next process start
|
||||
echo $((${OMPI_COMM_WORLD_RANK}+1)) > /tmp/tmpi.lock
|
||||
|
||||
# don't exit here as mpirun needs to be kept alive and it would also exit.
|
||||
while [[ -d /proc/${pid} ]]; do
|
||||
sleep 1
|
||||
done
|
||||
else
|
||||
# we are the parent and set everything up before we start ourselves a bunch of times via mpirun.
|
||||
processes=${1}
|
||||
self=${0}
|
||||
shift
|
||||
|
||||
# create an empty new dummy window which we sill later split up for the mpi processes.
|
||||
x=( $(tmux new-window ${session} -P -F '#{pane_id} #{window_id} #{session_id}') )
|
||||
export dummy=${x[0]}
|
||||
export window=${x[1]}
|
||||
export session=${x[2]}
|
||||
|
||||
# syncronize input to all panes.
|
||||
tmux set-window-option -t ${window} synchronize-panes on &> /dev/null
|
||||
tmux set-window-option -t ${window} remain-on-exit on &> /dev/null
|
||||
|
||||
# always start with rank 0.
|
||||
echo 0 > /tmp/tmpi.lock
|
||||
|
||||
# re-execute ourself to spawn of the processes.
|
||||
echo mpirun -np ${processes} ${self} runmpi "${@}"
|
||||
mpirun -np ${processes} ${self} runmpi "${@}"
|
||||
fi
|
2
setup.py
2
setup.py
|
@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
|
|||
|
||||
setuptools.setup(
|
||||
name='jittor',
|
||||
version='1.1.4.9',
|
||||
version='1.1.5.0',
|
||||
# scripts=[],
|
||||
author="Jittor Group",
|
||||
author_email="ran.donglang@gmail.com",
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "fetcher.h"
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
#include "misc/cuda_flags.h"
|
||||
|
@ -26,6 +25,9 @@ namespace jittor {
|
|||
|
||||
Executor exe;
|
||||
|
||||
// from fetch_op.cc
|
||||
extern list<VarPtr> fetcher_to_free;
|
||||
|
||||
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||
auto allocator = get_allocator();
|
||||
this->allocator = allocator;
|
||||
|
@ -33,22 +35,43 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
int op_num = 0;
|
||||
vector<Node*> bfs_q;
|
||||
bfs_q.reserve(vars.size());
|
||||
auto nodes = (vector<Node*>*)&vars;
|
||||
int start_var_num = 0;
|
||||
for (Var* v : vars)
|
||||
if (!v->is_finished())
|
||||
start_var_num++;
|
||||
bfs_backward(*nodes, bfs_q, [&](Node *node) -> bool {
|
||||
node->custom_data = 0;
|
||||
if (node->is_finished())
|
||||
return false;
|
||||
op_num += !node->is_var();
|
||||
return true;
|
||||
});
|
||||
{
|
||||
// get all nodes need to be executed
|
||||
auto t = ++Node::tflag_count;
|
||||
for (Var* v : vars)
|
||||
if (!v->is_finished() && v->tflag != t) {
|
||||
v->tflag = t;
|
||||
start_var_num++;
|
||||
bfs_q.push_back(v);
|
||||
}
|
||||
for (int i=0; i<bfs_q.size(); i++) {
|
||||
auto node = bfs_q[i];
|
||||
op_num += !node->is_var();
|
||||
for (auto i : node->_inputs)
|
||||
if (i.node->tflag != t && !i.node->is_finished()) {
|
||||
i.node->tflag = t;
|
||||
bfs_q.push_back(i.node);
|
||||
}
|
||||
// this var has been fetched
|
||||
if (node->flags.get(NodeFlags::_fetch)) {
|
||||
for (auto& n : node->_outputs) {
|
||||
// if not in queue and is fetch op
|
||||
if (n.node->tflag != t &&
|
||||
!n.node->is_finished() &&
|
||||
n.node->flags.get(NodeFlags::_fetch)) {
|
||||
n.node->tflag = t;
|
||||
bfs_q.push_back(n.node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto tt = Node::tflag_count;
|
||||
vector<Op*> ops;
|
||||
vector<Var*> all_vars;
|
||||
ops.reserve(op_num);
|
||||
all_vars.reserve(bfs_q.size() - op_num);
|
||||
for (Node* node : bfs_q)
|
||||
if (!node->is_var()) {
|
||||
node->custom_data = ops.size();
|
||||
|
@ -391,7 +414,6 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
outputs_bk.push_back(var);
|
||||
op->finish_pending_liveness();
|
||||
for (Var* var : outputs_bk)
|
||||
// var->finish_pending_liveness();
|
||||
var->finish_pending_liveness();
|
||||
} catch (const std::exception& e) {
|
||||
// log memory info
|
||||
|
@ -410,6 +432,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
}
|
||||
LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
|
||||
for (Var* v : vars) ASSERT(v->mem_ptr);
|
||||
// clean fetcher free buffer
|
||||
fetcher_to_free.clear();
|
||||
#ifdef HAS_CUDA
|
||||
if (device_sync && use_cuda) {
|
||||
last_is_cuda = false;
|
||||
|
|
|
@ -27,7 +27,7 @@ VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
|
|||
auto dx = op->grad(out, dout, x, x_index);
|
||||
if (x->loop_options)
|
||||
dx->loop_options = x->loop_options;
|
||||
return move(dx);
|
||||
return dx;
|
||||
}
|
||||
|
||||
inline static void assign_attrs(Var* a, Var* b) {
|
||||
|
|
10
src/init.cc
10
src/init.cc
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "init.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "var.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -21,6 +22,15 @@ unique_ptr<std::default_random_engine> eng;
|
|||
vector<set_seed_callback> callbacks;
|
||||
int current_seed;
|
||||
|
||||
// fron fetch_op.cc
|
||||
extern list<VarPtr> fetcher;
|
||||
extern list<VarPtr> fetcher_to_free;
|
||||
|
||||
void cleanup() {
|
||||
fetcher_to_free.clear();
|
||||
fetcher.clear();
|
||||
}
|
||||
|
||||
static void init_cuda_devices() {
|
||||
#ifdef HAS_CUDA
|
||||
int count=0;
|
||||
|
|
|
@ -20,4 +20,8 @@ void add_set_seed_callback(set_seed_callback callback);
|
|||
extern "C"
|
||||
std::default_random_engine* get_random_engine();
|
||||
|
||||
// things need to be clean before python exit
|
||||
// @pyjt(cleanup)
|
||||
void cleanup();
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -95,7 +95,7 @@ struct DelayFree final : Allocator {
|
|||
void free(void* mem_ptr, size_t size, const size_t& allocation) override {
|
||||
using namespace cuda_dual_local;
|
||||
allocations.emplace_back(mem_ptr, allocation, size, &cuda_dual_allocator);
|
||||
checkCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0));
|
||||
peekCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0));
|
||||
}
|
||||
|
||||
void migrate_to_cpu(void*& mem_ptr, size_t& allocation, size_t size, Allocator* allocator) {
|
||||
|
|
|
@ -24,7 +24,9 @@ struct NodeFlags {
|
|||
_finished=1,
|
||||
// bit2: stop grad
|
||||
_stop_grad=2,
|
||||
_n=3,
|
||||
// bit3: is fetch
|
||||
_fetch=3,
|
||||
_n=4,
|
||||
|
||||
// var related flags
|
||||
_force_fuse=_n+0,
|
||||
|
|
|
@ -32,9 +32,9 @@ Init() {
|
|||
}
|
||||
~Init() {
|
||||
if (!get_device_count()) return;
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaStreamDestroy(stream));
|
||||
checkCudaErrors(cudaEventDestroy(event));
|
||||
peekCudaErrors(cudaDeviceSynchronize());
|
||||
peekCudaErrors(cudaStreamDestroy(stream));
|
||||
peekCudaErrors(cudaEventDestroy(event));
|
||||
}
|
||||
} init;
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -12,8 +14,9 @@
|
|||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
#include "fetcher.h"
|
||||
#include "ops/fetch_op.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "executor.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -49,31 +52,68 @@ Init() {
|
|||
// do not call deleter on exit
|
||||
for (auto& f : fetch_tasks)
|
||||
f.func.deleter = nullptr;
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaStreamDestroy(stream));
|
||||
checkCudaErrors(cudaEventDestroy(event));
|
||||
peekCudaErrors(cudaDeviceSynchronize());
|
||||
peekCudaErrors(cudaStreamDestroy(stream));
|
||||
peekCudaErrors(cudaEventDestroy(event));
|
||||
}
|
||||
};
|
||||
} ;
|
||||
|
||||
}
|
||||
using namespace fetcher_local;
|
||||
|
||||
#endif
|
||||
|
||||
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
||||
list<VarPtr> fetcher;
|
||||
// this list will be free at each execution
|
||||
list<VarPtr> fetcher_to_free;
|
||||
|
||||
FetchOp::FetchOp(vector<Var*>&& inputs, FetchFunc&& func)
|
||||
: fetch_vars(inputs), func(move(func)) {
|
||||
#ifdef HAS_CUDA
|
||||
static Init init;
|
||||
// stream needs to be created after nccl plugin
|
||||
static Init init_fetch;
|
||||
#endif
|
||||
sync(vh);
|
||||
vector<Allocation> allocations(vh.size());
|
||||
vector<ArrayArgs> arrays(vh.size());
|
||||
VarPtr vp(0, ns_int);
|
||||
outputs_holder.emplace_back(vp);
|
||||
fetcher.emplace_front(move(vp));
|
||||
fetcher_iter = fetcher.begin();
|
||||
bool all_finished = true;
|
||||
for (auto v : fetch_vars)
|
||||
if (!v->is_finished()) {
|
||||
all_finished = false;
|
||||
v->flags.set(NodeFlags::_stop_fuse);
|
||||
v->flags.set(NodeFlags::_fetch);
|
||||
}
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_fetch);
|
||||
flags.set(NodeFlags::_stop_grad);
|
||||
fetcher_iter->ptr->flags.set(NodeFlags::_fetch);
|
||||
// fetcher_to_free.clear();
|
||||
if (all_finished) {
|
||||
// if all finished, run immediately
|
||||
run();
|
||||
}
|
||||
// if too many fetchers are bufferd, force flush
|
||||
while (fetcher.size() > 20) {
|
||||
LOGvvvv << "too many fetchers(">>fetcher.size() >>
|
||||
") are bufferd, force flush";
|
||||
exe.run_sync({fetcher.back().ptr}, false);
|
||||
}
|
||||
}
|
||||
|
||||
void FetchOp::run() {
|
||||
vector<Allocation> allocations(fetch_vars.size());
|
||||
vector<ArrayArgs> arrays(fetch_vars.size());
|
||||
#ifdef HAS_CUDA
|
||||
bool has_cuda_memcpy = false;
|
||||
event_queue.flush();
|
||||
#endif
|
||||
for (int i=0; i<vh.size(); i++) {
|
||||
auto v = vh[i]->var;
|
||||
LOGvvvv << "fetch" << fetch_vars.size() << "vars" << fetch_vars;
|
||||
int i = 0;
|
||||
for (auto v : fetch_vars) {
|
||||
auto& allocation = allocations[i];
|
||||
|
||||
#ifdef HAS_CUDA
|
||||
if (v->allocator->is_cuda()) {
|
||||
checkCudaErrors(cudaEventRecord(event, 0));
|
||||
|
@ -98,6 +138,7 @@ void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
|||
arrays[i].ptr = allocation.ptr;
|
||||
arrays[i].shape = v->shape;
|
||||
arrays[i].dtype = v->dtype();
|
||||
i++;
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
if (has_cuda_memcpy) {
|
||||
|
@ -109,6 +150,8 @@ void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
|||
FetchResult fr{move(func), move(allocations), move(arrays)};
|
||||
fr.call();
|
||||
}
|
||||
fetcher_to_free.emplace_front(move(*fetcher_iter));
|
||||
fetcher.erase(fetcher_iter);
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -5,8 +5,9 @@
|
|||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include "common.h"
|
||||
#include "var_holder.h"
|
||||
#include "op.h"
|
||||
#include "var.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "ops/array_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
@ -42,7 +43,15 @@ struct FetchResult {
|
|||
inline void call() { func.callback(this); }
|
||||
};
|
||||
|
||||
// @pyjt(fetch)
|
||||
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func);
|
||||
struct FetchOp final : Op {
|
||||
vector<Var*> fetch_vars;
|
||||
FetchFunc func;
|
||||
list<VarPtr>::iterator fetcher_iter;
|
||||
|
||||
} // jittor
|
||||
FetchOp(vector<Var*>&& inputs, FetchFunc&& func);
|
||||
|
||||
const char* name() const override { return "fetch"; }
|
||||
void run() override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -97,6 +97,9 @@ ArrayArgs VarHolder::fetch_sync() {
|
|||
return {var->mem_ptr, var->shape, var->dtype()};
|
||||
}
|
||||
|
||||
// from fetch_op.cc
|
||||
extern list<VarPtr> fetcher;
|
||||
|
||||
void sync_all(bool device_sync) {
|
||||
vector<Var*> vars;
|
||||
vars.reserve(VarHolder::hold_vars.size());
|
||||
|
@ -104,6 +107,8 @@ void sync_all(bool device_sync) {
|
|||
if (!v->var->_outputs.size())
|
||||
vars.push_back(v->var);
|
||||
}
|
||||
for (auto& v :fetcher)
|
||||
vars.push_back(v.ptr);
|
||||
graph_check();
|
||||
exe.run_sync(vars, device_sync); //need sync at last
|
||||
graph_check();
|
||||
|
|
Loading…
Reference in New Issue