forked from maxjhandsome/jittor
Function customized backward and no grad
This commit is contained in:
parent
7cbed2b1ab
commit
f2bf93ae56
|
@ -59,6 +59,11 @@ class flag_scope(_call_no_record_scope):
|
|||
for k,v in self.flags_bk.items():
|
||||
setattr(flags, k, v)
|
||||
|
||||
class no_grad(flag_scope):
|
||||
def __init__(self, **jt_flags):
|
||||
self.jt_flags = jt_flags
|
||||
jt_flags["no_grad"] = 1
|
||||
|
||||
single_log_capture = None
|
||||
|
||||
class log_capture_scope(_call_no_record_scope):
|
||||
|
@ -559,6 +564,92 @@ class Module:
|
|||
for p in self.parameters():
|
||||
p.update(p.mpi_broadcast(root))
|
||||
|
||||
class Function(Module):
|
||||
''' Function Module for customized backward operations
|
||||
|
||||
Example 1 (Function can have multiple input and multiple output, and user
|
||||
can store value for backward computation)::
|
||||
|
||||
import jittor as jt
|
||||
from jittor import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
return grads[0] * self.y, grads[1] * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
c,d = func(a, b)
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4
|
||||
assert db.data == 9
|
||||
|
||||
Example 2(Function can return None for no gradiant, and gradiant
|
||||
can also be None)::
|
||||
|
||||
import jittor as jt
|
||||
from jittor import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
assert grads[1] is None
|
||||
return grads[0] * self.y, None
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
c,d = func(a, b)
|
||||
d.stop_grad()
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4
|
||||
assert db.data == 0
|
||||
|
||||
'''
|
||||
def __call__(self, *args, **kw):
|
||||
args2 = list(args)
|
||||
kw = dict(kw)
|
||||
taped_inputs = []
|
||||
taped_outputs = []
|
||||
for i,v in enumerate(args2):
|
||||
if isinstance(v, Var):
|
||||
v = v.tape()
|
||||
args2[i] = v
|
||||
taped_inputs.append(v)
|
||||
for k,v in kw.items():
|
||||
if isinstance(v, Var):
|
||||
v = v.tape()
|
||||
kw[k] = v
|
||||
taped_inputs.append(v)
|
||||
res = self.execute(*args2, **kw)
|
||||
if isinstance(res, Var):
|
||||
res = res.tape()
|
||||
taped_outputs.append(res)
|
||||
else:
|
||||
assert isinstance(res, Sequence)
|
||||
res = list(res)
|
||||
for i,v in enumerate(res):
|
||||
if isinstance(v, Var):
|
||||
v = v.tape()
|
||||
res[i] = v
|
||||
taped_outputs.append(v)
|
||||
# tape output and input together so
|
||||
# backward treat them as one operator
|
||||
tape_together(taped_inputs, taped_outputs, lambda args: self.grad(args))
|
||||
return res
|
||||
|
||||
def dfs(self, parents, k, callback, callback_leave=None):
|
||||
pass
|
||||
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
def __init__(self, *args, **kw):
|
||||
|
|
|
@ -12,7 +12,7 @@ import numpy as np
|
|||
class TestClone(unittest.TestCase):
|
||||
def test(self):
|
||||
jt.clean()
|
||||
b = a = jt.array(1)
|
||||
b = a = jt.array(1.0)
|
||||
for i in range(10):
|
||||
b = b.clone()
|
||||
if i==5: c=b
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from collections.abc import Sequence, Mapping
|
||||
from .test_core import expect_error
|
||||
from jittor import Function
|
||||
|
||||
class TestFunction(unittest.TestCase):
|
||||
def test1(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x):
|
||||
return x+1
|
||||
|
||||
def grad(self, grads):
|
||||
return grads[0]-2
|
||||
a = jt.ones(1)
|
||||
func = MyFunc()
|
||||
b = func(a)
|
||||
da = jt.grad(b, a)
|
||||
assert da.data == -1
|
||||
|
||||
def test2(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x):
|
||||
self.x = x
|
||||
return x+1
|
||||
|
||||
def grad(self, grads):
|
||||
return (grads[0]-2) * self.x
|
||||
a = jt.ones(1) * 10
|
||||
func = MyFunc()
|
||||
b = func(a)
|
||||
da = jt.grad(b, a)
|
||||
assert da.data == -10
|
||||
|
||||
def test_grad_not_match_error(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y
|
||||
|
||||
def grad(self, grads):
|
||||
return (grads[0]-2) * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
c = func(a, b)
|
||||
expect_error(lambda: jt.grad(c, [a, b]))
|
||||
|
||||
def test_multi_grads(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y
|
||||
|
||||
def grad(self, grads):
|
||||
return (grads[0]-2) * self.y, (grads[0]-2) * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
c = func(a, b)
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert da.data == -4
|
||||
assert db.data == -3
|
||||
|
||||
def test_multi_grads_none(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y
|
||||
|
||||
def grad(self, grads):
|
||||
return (grads[0]-2) * self.y, None
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
c = func(a, b)
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert da.data == -4
|
||||
assert db.data == 0
|
||||
|
||||
def test_multi_grads_multi_out(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
return grads[0] * self.y, grads[1] * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
c,d = func(a, b)
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4
|
||||
assert db.data == 9
|
||||
|
||||
def test_multi_grads_multi_out_stop_grad_0(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
return grads[0] * self.y, grads[1] * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
b.stop_grad()
|
||||
func = MyFunc()
|
||||
c,d = func(a, b)
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4
|
||||
assert db.data == 0
|
||||
|
||||
def test_multi_grads_multi_out_stop_grad_1(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
assert grads[1] is None
|
||||
return grads[0] * self.y, None
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
c,d = func(a, b)
|
||||
d.stop_grad()
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4
|
||||
assert db.data == 0
|
||||
|
||||
def test_multi_grads_multi_out2(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
res = (grads[0] * self.y, grads[1] * self.x)
|
||||
print(res)
|
||||
return res
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
c,d = func(a, b)
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4, da.data
|
||||
assert db.data == 9
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -151,7 +151,16 @@ class TestGrad(unittest.TestCase):
|
|||
self.assertEqual(dx.data, 4*2**3)
|
||||
self.assertEqual(ddx.data, 4*3*2**2)
|
||||
self.assertEqual(dddx.data, 4*3*2*2**1)
|
||||
|
||||
|
||||
def test_no_grad(self):
|
||||
a = jt.array(1.0)
|
||||
with jt.no_grad():
|
||||
b = a
|
||||
for i in range(10):
|
||||
b = b.clone() + 1
|
||||
assert b.data == 11
|
||||
jt.clean()
|
||||
assert jt.liveness_info()["lived_vars"] == 2
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -91,15 +91,15 @@ for cc_type in ["g++", "clang"]:
|
|||
# compress source
|
||||
# tar -cvzf build/jittor.tgz . --exclude build --exclude .git --exclude .ipynb_checkpoints --exclude __pycache__
|
||||
# mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor
|
||||
assert os.system(f"cd {polish_path} && tar -cvzf build/jittor.tgz . --exclude build --exclude .git --exclude .ipynb_checkpoints --exclude __pycache__")==0
|
||||
assert os.system(f"cd {polish_path} && tar --exclude=build --exclude=.git --exclude=.ipynb_checkpoints --exclude=__pycache__ -cvzf build/jittor.tgz . ")==0
|
||||
|
||||
# rsync to build-server
|
||||
jittor_web_base_dir = "Documents/jittor-blog/assets/"
|
||||
jittor_web_build_dir = jittor_web_base_dir + "build/"
|
||||
assert os.system(f"rsync -avPu {polish_path}/build/ jittor-web:{jittor_web_build_dir}")==0
|
||||
assert os.system(f"ssh jittor@166.111.68.30 Documents/jittor-blog.git/hooks/post-update")==0
|
||||
assert os.system(f"ssh jittor-web Documents/jittor-blog.git/hooks/post-update")==0
|
||||
|
||||
# push to github
|
||||
assert os.system(f"cd {polish_path} && git push -f origin master")==0
|
||||
# assert os.system(f"cd {polish_path} && git push -f origin master")==0
|
||||
|
||||
# push to pip
|
|
@ -1 +1 @@
|
|||
08f4ca8b2c0a2978cd3fbc9a3a6e76bd1463ca12
|
||||
b27082f9444a4e627f7dfc574d0114302ba27b5e
|
||||
|
|
2
setup.py
2
setup.py
|
@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
|
|||
|
||||
setuptools.setup(
|
||||
name='jittor',
|
||||
version='1.1.5.4',
|
||||
version='1.1.5.5',
|
||||
# scripts=[],
|
||||
author="Jittor Group",
|
||||
author_email="ran.donglang@gmail.com",
|
||||
|
|
289
src/grad.cc
289
src/grad.cc
|
@ -35,6 +35,125 @@ inline static void assign_attrs(Var* a, Var* b) {
|
|||
a->flags.set(NodeFlags::_stop_fuse);
|
||||
}
|
||||
|
||||
void tape_together(
|
||||
const vector<VarHolder*>& taped_inputs,
|
||||
const vector<VarHolder*>& taped_outputs,
|
||||
GradCallback&& grad_callback
|
||||
) {
|
||||
auto tapes = new Tapes();
|
||||
tapes->total = tapes->ref = taped_inputs.size() + taped_outputs.size();
|
||||
tapes->callback = move(grad_callback);
|
||||
tapes->flags.set(NodeFlags::_grads);
|
||||
for (int i=0; i<taped_inputs.size(); i++) {
|
||||
auto v = taped_inputs[i]->var;
|
||||
auto op = (TapeOp*)v->input();
|
||||
ASSERT(op);
|
||||
op->flags.set(NodeFlags::_tape);
|
||||
tapes->_inputs.emplace_back(op->inputs().front());
|
||||
op->tapes = tapes;
|
||||
}
|
||||
for (int i=0; i<taped_outputs.size(); i++) {
|
||||
auto v = taped_outputs[i]->var;
|
||||
auto op = (TapeOp*)v->input();
|
||||
ASSERT(op);
|
||||
op->flags.set(NodeFlags::_tape);
|
||||
tapes->_outputs.emplace_back(v,0);
|
||||
op->tapes = tapes;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Func>
|
||||
void bfs_backward_with_tape(vector<Node*>& queue, Func&& func) {
|
||||
auto t = ++Node::tflag_count;
|
||||
size_t i=0;
|
||||
for (Node* node : queue) node->tflag = t;
|
||||
while (i < queue.size()) {
|
||||
Node* node = queue[i++];
|
||||
for (auto i : node->_inputs) {
|
||||
auto inode = i.node;
|
||||
if (inode->flags.get(NodeFlags::_tape)) {
|
||||
Tapes* t = ((TapeOp*)inode)->tapes;
|
||||
inode = t;
|
||||
ASSERT(t->ref == t->total);
|
||||
}
|
||||
if (inode->tflag != t && func(inode)) {
|
||||
inode->tflag = t;
|
||||
queue.push_back(inode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void bfs_backward_with_tape(vector<Node*>& seed, vector<Node*>& queue, Func&& func) {
|
||||
for (Node* node : seed)
|
||||
if (func(node)) queue.push_back(node);
|
||||
bfs_backward_with_tape(queue, func);
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void bfs_forward_with_tape(vector<Node*>& queue, Func&& func) {
|
||||
auto t = ++Node::tflag_count;
|
||||
size_t i=0;
|
||||
for (Node* node : queue) node->tflag = t;
|
||||
while (i < queue.size()) {
|
||||
Node* node = queue[i++];
|
||||
for (auto o : node->_outputs) {
|
||||
auto onode = o.node;
|
||||
if (onode->flags.get(NodeFlags::_tape)) {
|
||||
Tapes* t = ((TapeOp*)onode)->tapes;
|
||||
ASSERT(t->ref == t->total) << t->ref << t->total;
|
||||
onode = t;
|
||||
}
|
||||
if (onode->tflag != t && func(onode)) {
|
||||
onode->tflag = t;
|
||||
queue.push_back(onode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Func>
|
||||
void toplogical_sort_backward_with_tape(vector<Node*>& nodes, vector<Node*>& sorted, Func&& func) {
|
||||
auto t = ++Node::tflag_count;
|
||||
sorted.reserve(nodes.size());
|
||||
for (auto node : nodes) node->tflag = t;
|
||||
for (auto node : nodes) {
|
||||
auto& deps = node->custom_data;
|
||||
deps = 0;
|
||||
for (auto o : node->_outputs) {
|
||||
auto onode = o.node;
|
||||
if (onode->flags.get(NodeFlags::_tape)) {
|
||||
Tapes* t = ((TapeOp*)onode)->tapes;
|
||||
onode = t;
|
||||
}
|
||||
if (onode->tflag == t)
|
||||
deps++;
|
||||
}
|
||||
if (deps == 0) sorted.push_back(node);
|
||||
}
|
||||
size_t i=0;
|
||||
while (i < sorted.size()) {
|
||||
Node* node = sorted[i++];
|
||||
for (auto i : node->_inputs) {
|
||||
auto inode = i.node;
|
||||
if (inode->flags.get(NodeFlags::_tape)) {
|
||||
Tapes* t = ((TapeOp*)inode)->tapes;
|
||||
inode = t;
|
||||
}
|
||||
if (inode->tflag == t) {
|
||||
inode->custom_data--;
|
||||
if (inode->custom_data == 0)
|
||||
sorted.push_back(inode);
|
||||
}
|
||||
}
|
||||
func(node);
|
||||
}
|
||||
ASSERTop(nodes.size(),==,sorted.size());
|
||||
}
|
||||
|
||||
vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
||||
LOGvv << "loss:" >> loss << "targets:" >> targets;
|
||||
CHECK(loss->is_float()) << "Loss should be float";
|
||||
|
@ -44,13 +163,13 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
vector<Node*> ts(targets.begin(), targets.end());
|
||||
// bfs visit find all successors of targets
|
||||
LOGvv << "Size of successors:" << ts.size();
|
||||
bfs_forward(ts, [](Node*){ return true; });
|
||||
bfs_forward_with_tape(ts, [](Node*){ return true; });
|
||||
vector<Node*> gnodes;
|
||||
gnodes.reserve(ts.size());
|
||||
auto nt = Node::tflag_count;
|
||||
if (loss->tflag == nt)
|
||||
gnodes.push_back(loss);
|
||||
bfs_backward(gnodes, [&](Node* node) {
|
||||
bfs_backward_with_tape(gnodes, [&](Node* node) {
|
||||
if (node->tflag != nt)
|
||||
return false;
|
||||
if (node->is_stop_grad())
|
||||
|
@ -63,58 +182,145 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
LOGvv << "Size of grad nodes:" << gnodes.size();
|
||||
|
||||
vector<Node*> sorted;
|
||||
toplogical_sort_backward(gnodes, sorted, [](Node*){});
|
||||
toplogical_sort_backward_with_tape(gnodes, sorted, [](Node*){});
|
||||
nt = Node::tflag_count;
|
||||
vector<Var*> gvars;
|
||||
gvars.reserve(sorted.size());
|
||||
for (Node* node : sorted)
|
||||
if (node->is_var())
|
||||
gvars.push_back(node->var());
|
||||
if (node->is_var()) {
|
||||
Var* v = node->var();
|
||||
v->custom_data = gvars.size();
|
||||
gvars.push_back(v);
|
||||
}
|
||||
LOGvv << "Size of grad vars:" << gvars.size();
|
||||
|
||||
vector<VarPtr> grads(gvars.size());
|
||||
vector<VarPtr> results(targets.size());
|
||||
for (size_t i=0; i<gvars.size(); i++)
|
||||
gvars[i]->custom_data = i;
|
||||
vector<int> target_id(targets.size());
|
||||
for (int i=0; i<targets.size(); i++) {
|
||||
Var* var = targets[i];
|
||||
target_id[i] = (var->tflag == nt) ?
|
||||
var->custom_data : -1;
|
||||
}
|
||||
|
||||
if (grads.size()) {
|
||||
grads[0] = make_number(1.f, loss);
|
||||
assign_attrs(grads[0].ptr, loss);
|
||||
registe_node_trace_grad(grads[0].ptr, loss, 0);
|
||||
}
|
||||
|
||||
vector<pair<Node*, int64>> id_buffer;
|
||||
id_buffer.reserve(sorted.size()+10);
|
||||
|
||||
// backup id in custum data
|
||||
for (int i=1; i<gvars.size(); i++) {
|
||||
Var* var = gvars[i];
|
||||
for (auto it : var->outputs_with_index()) {
|
||||
Op* op = it.op;
|
||||
if (op->flags.get(NodeFlags::_tape)) {
|
||||
op = ((TapeOp*)op)->tapes;
|
||||
}
|
||||
auto index = it.index;
|
||||
if (op->tflag != nt) continue;
|
||||
id_buffer.emplace_back(op, index);
|
||||
|
||||
for (size_t i=0; i<gvars.size(); i++) {
|
||||
// backward together
|
||||
if (op->flags.get(NodeFlags::_grads)) {
|
||||
// dont backward next time
|
||||
op->tflag = 0;
|
||||
for (Var* out : op->outputs()) {
|
||||
id_buffer.emplace_back(
|
||||
out,
|
||||
out->tflag == nt ? out->custom_data : -1);
|
||||
}
|
||||
for (Var* in : op->inputs()) {
|
||||
id_buffer.emplace_back(
|
||||
in,
|
||||
in->tflag == nt ? in->custom_data : -1);
|
||||
}
|
||||
} else {
|
||||
// single var backward
|
||||
for (Var* out : op->outputs()) {
|
||||
id_buffer.emplace_back(
|
||||
out,
|
||||
out->tflag == nt ? out->custom_data : -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
// end of var output
|
||||
id_buffer.emplace_back(nullptr, 0);
|
||||
}
|
||||
|
||||
// real backward construction from prev backuped ids
|
||||
int j=0;
|
||||
for (int i=1; i<gvars.size(); i++,j++) {
|
||||
Var* var = gvars[i];
|
||||
auto& grad = grads[i];
|
||||
#ifdef PREVENT_LARGE_FUSED_OP
|
||||
int gsum = 0;
|
||||
#endif
|
||||
if (i==0) {
|
||||
grad = make_number(1.f, loss);
|
||||
assign_attrs(grad.ptr, loss);
|
||||
registe_node_trace_grad(grad.ptr, loss, 0);
|
||||
} else
|
||||
for (auto it : var->outputs_with_index()) {
|
||||
Op* op = it.op;
|
||||
auto index = it.index;
|
||||
if (op->tflag != nt) continue;
|
||||
for (Var* out : op->outputs()) {
|
||||
if (out->tflag != nt) continue;
|
||||
Var* dout = grads[out->custom_data];
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
registe_node_trace_grad(dvar.ptr, op, index);
|
||||
if (dvar)
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
if (!grad)
|
||||
grad = move(dvar);
|
||||
else if (dvar) {
|
||||
grad = make_binary(grad, dvar, ns_add);
|
||||
#ifdef PREVENT_LARGE_FUSED_OP
|
||||
gsum ++;
|
||||
if (gsum>=PREVENT_LARGE_FUSED_OP) {
|
||||
// TODO: this is a dirty fix for
|
||||
// stopping fuse lots of op together,
|
||||
// try to find a better solution
|
||||
grad->flags.set(NodeFlags::_stop_fuse);
|
||||
// dump "for (auto it : var->outputs_with_index())"
|
||||
while (id_buffer[j].first) {
|
||||
Op* op = id_buffer[j].first->op();
|
||||
auto index = id_buffer[j].second;
|
||||
j++;
|
||||
auto n_o = op->outputs().size();
|
||||
|
||||
if (op->flags.get(NodeFlags::_grads)) {
|
||||
// backward together
|
||||
auto n_i = op->inputs().size();
|
||||
Var* douts[n_o];
|
||||
VarPtr dins[n_i];
|
||||
// dump "for (Var* out : op->outputs())"
|
||||
for (int i=0; i<n_o; i++,j++) {
|
||||
auto id = id_buffer[j].second;
|
||||
if (id>=0) {
|
||||
douts[i] = grads[id];
|
||||
} else
|
||||
douts[i] = nullptr;
|
||||
}
|
||||
op->grads(douts, dins);
|
||||
// dump "for (Var* in : op->inputs())"
|
||||
for (int i=0; i<n_i; i++,j++) {
|
||||
auto id = id_buffer[j].second;
|
||||
if (id>=0) {
|
||||
auto& din = dins[i];
|
||||
auto& grad = grads[id];
|
||||
if (din && grad) {
|
||||
grad = make_binary(grad, din, ns_add);
|
||||
} else
|
||||
grad = move(din);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// single var backward
|
||||
// dump "for (Var* out : op->outputs())"
|
||||
for (int i=0; i<n_o; i++,j++) {
|
||||
auto id = id_buffer[j].second;
|
||||
auto out = id_buffer[j].first->var();
|
||||
if (id<0) continue;
|
||||
Var* dout = grads[id];
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
registe_node_trace_grad(dvar.ptr, op, index);
|
||||
if (dvar)
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
if (!grad)
|
||||
grad = move(dvar);
|
||||
else if (dvar) {
|
||||
grad = make_binary(grad, dvar, ns_add);
|
||||
#ifdef PREVENT_LARGE_FUSED_OP
|
||||
gsum ++;
|
||||
if (gsum>=PREVENT_LARGE_FUSED_OP) {
|
||||
// TODO: this is a dirty fix for
|
||||
// stopping fuse lots of op together,
|
||||
// try to find a better solution
|
||||
grad->flags.set(NodeFlags::_stop_fuse);
|
||||
}
|
||||
#endif
|
||||
assign_attrs(grad.ptr, var);
|
||||
registe_node_trace_grad(grad.ptr, var, index);
|
||||
}
|
||||
#endif
|
||||
assign_attrs(grad.ptr, var);
|
||||
registe_node_trace_grad(grad.ptr, var, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -123,8 +329,9 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
for (size_t i=0; i<results.size(); i++) {
|
||||
Var* var = targets[i];
|
||||
VarPtr& grad = results[i];
|
||||
if (var->tflag == nt)
|
||||
grad = move(grads[var->custom_data]);
|
||||
auto id = target_id[i];
|
||||
if (id>=0)
|
||||
grad = move(grads[id]);
|
||||
if (!grad) {
|
||||
LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
|
||||
grad = make_number(0.f, var);
|
||||
|
|
|
@ -3,10 +3,18 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "ops/tape_op.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
vector<VarPtr> grad(Var* loss, vector<Var*> targets);
|
||||
|
||||
// @pyjt(tape_together)
|
||||
void tape_together(
|
||||
const vector<VarHolder*>& taped_inputs,
|
||||
const vector<VarHolder*>& taped_outputs,
|
||||
GradCallback&& grad_callback
|
||||
);
|
||||
|
||||
} // jittor
|
|
@ -44,6 +44,10 @@ struct NodeFlags {
|
|||
_vary_shape=_n+3,
|
||||
// bit4~5: op type
|
||||
_op_type=_n+4, _op_type_nbits=2,
|
||||
// bit6: is tape op
|
||||
_tape=_n+6,
|
||||
// bit7: backprop grad at ones
|
||||
_grads=_n+7,
|
||||
};
|
||||
|
||||
inline void set(Flags f, int a=1, int nbits=1) {
|
||||
|
|
|
@ -46,6 +46,10 @@ VarPtr Op::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void Op::grads(Var** douts, VarPtr* dins) {
|
||||
LOGw << "Grads of" << name() << "return zeros";
|
||||
}
|
||||
|
||||
Var* Op::create_output(NanoVector shape, NanoString dtype) {
|
||||
VarPtr vp(shape, dtype);
|
||||
Var* output = vp.ptr;
|
||||
|
|
1
src/op.h
1
src/op.h
|
@ -37,6 +37,7 @@ struct Op : Node {
|
|||
~Op();
|
||||
|
||||
virtual VarPtr grad(Var* out, Var* dout, Var* v, int v_index);
|
||||
virtual void grads(Var** douts, VarPtr* dins);
|
||||
virtual void infer_shape() {}
|
||||
virtual void run() {};
|
||||
virtual void jit_prepare() {};
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "ops/tape_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static auto make_tape = get_op_info("tape")
|
||||
.get_constructor<VarPtr, Var*>();
|
||||
|
||||
TapeOp::TapeOp(Var* x) : tapes(nullptr) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
auto y = create_output(nullptr, x->dtype());
|
||||
if (x->name.ptr)
|
||||
y->name = x->name;
|
||||
}
|
||||
|
||||
TapeOp::~TapeOp() {
|
||||
if (tapes) {
|
||||
if (! --tapes->ref) {
|
||||
tapes->_inputs.clear();
|
||||
tapes->_outputs.clear();
|
||||
delete tapes;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VarPtr TapeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return dout;
|
||||
}
|
||||
|
||||
void TapeOp::infer_shape() {
|
||||
auto x = inputs().front();
|
||||
auto y = outputs().front();
|
||||
y->set_shape(x->shape);
|
||||
y->share_with(x);
|
||||
}
|
||||
|
||||
void Tapes::grads(Var** douts, VarPtr* dins) {
|
||||
callback.func(_outputs.size(), douts, _inputs.size(), dins);
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,54 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include "op.h"
|
||||
#include "var_holder.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct Tapes;
|
||||
|
||||
struct GradCallback {
|
||||
typedef jittor::VarHolder VarHolder;
|
||||
typedef VarHolder* VarHolderPtr;
|
||||
std::function<void(int,Var**,int,VarPtr*)> func;
|
||||
std::function<void()> deleter;
|
||||
inline ~GradCallback() { if (deleter) deleter(); }
|
||||
GradCallback(const GradCallback&) = delete;
|
||||
GradCallback() = default;
|
||||
GradCallback(GradCallback&& other) : func(other.func), deleter(other.deleter) {
|
||||
other.func = nullptr;
|
||||
other.deleter = nullptr;
|
||||
};
|
||||
GradCallback(std::function<void(int,Var**,int,VarPtr*)> && func, std::function<void()>&& deleter)
|
||||
: func(move(func)), deleter(move(deleter)) {};
|
||||
|
||||
void operator =(GradCallback&& other) { this->~GradCallback(); new (this) GradCallback(move(other)); }
|
||||
};
|
||||
|
||||
struct TapeOp final : Op {
|
||||
Tapes* tapes;
|
||||
TapeOp(Var* x);
|
||||
~TapeOp();
|
||||
|
||||
const char* name() const override { return "tape"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
};
|
||||
|
||||
|
||||
struct Tapes final : Op {
|
||||
int ref, total;
|
||||
GradCallback callback;
|
||||
const char* name() const override { return "tapes"; }
|
||||
void grads(Var** douts, VarPtr* dins) override;
|
||||
};
|
||||
|
||||
|
||||
} // jittor
|
|
@ -590,4 +590,62 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
|
|||
return func;
|
||||
}
|
||||
|
||||
|
||||
struct GradCallback;
|
||||
|
||||
DEF_IS(GradCallback, bool) is_type(PyObject* obj) {
|
||||
return PyCallable_Check(obj);
|
||||
}
|
||||
|
||||
DEF_IS(GradCallback, T) from_py_object(PyObject* obj) {
|
||||
// PyObject_Call
|
||||
Py_INCREF(obj);
|
||||
T func(
|
||||
// callback
|
||||
[obj](int n_o, Var** douts, int n_i, VarPtr* dins) {
|
||||
PyObjHolder list(PyTuple_New(n_o));
|
||||
for (int i=0; i<n_o; i++) {
|
||||
if (douts[i]) {
|
||||
PyTuple_SET_ITEM(list.obj, i,
|
||||
to_py_object(new typename T::VarHolder(douts[i])));
|
||||
} else {
|
||||
Py_INCREF(Py_None);
|
||||
PyTuple_SET_ITEM(list.obj, i, Py_None);
|
||||
}
|
||||
}
|
||||
PyObjHolder args(PyTuple_New(1));
|
||||
PyTuple_SET_ITEM(args.obj, 0, list.release());
|
||||
|
||||
PyObjHolder ret(PyObject_Call(obj, args.obj, nullptr));
|
||||
auto is_seq = PyList_CheckExact(ret.obj) || PyTuple_CheckExact(ret.obj);
|
||||
auto check = [&](int i, PyObject* obj) {
|
||||
if (obj == Py_None) {
|
||||
dins[i] = nullptr;
|
||||
} else {
|
||||
CHECK(Py_TYPE(obj) == &PyjtVarHolder.ht_type) << "returned grad("<<Py_TYPE(obj)->tp_name<<") is not jittor variable";
|
||||
auto vh = from_py_object<typename T::VarHolderPtr>(obj);
|
||||
dins[i] = vh->var;
|
||||
}
|
||||
};
|
||||
if (!is_seq) {
|
||||
CHECKop(n_i,==,1) << "returned grad size not match";
|
||||
check(0, ret.obj);
|
||||
} else {
|
||||
auto size = Py_SIZE(ret.obj);
|
||||
CHECKop(n_i,==,size) << "returned grad size not match";
|
||||
auto arr = PySequence_Fast_ITEMS(ret.obj);
|
||||
for (int i=0; i<size; i++) {
|
||||
auto oi = arr[i];
|
||||
check(i, oi);
|
||||
}
|
||||
}
|
||||
},
|
||||
// deleter
|
||||
[obj]() {
|
||||
Py_DECREF(obj);
|
||||
}
|
||||
);
|
||||
return func;
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -46,7 +46,7 @@ JIT_TEST(sfrl_allocator_time) {
|
|||
std::chrono::steady_clock::now().time_since_epoch()).count();
|
||||
|
||||
LOGvv << "Use time " << float(end - begin) / 1000 << "ms\n";
|
||||
ASSERT(float(end - begin) / 1000 < tasks[i].time_limit);
|
||||
ASSERTop(float(end - begin) / 1000, <, tasks[i].time_limit);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,7 +84,7 @@ JIT_TEST(sfrl_allocator_share) {
|
|||
std::chrono::steady_clock::now().time_since_epoch()).count();
|
||||
|
||||
LOGvvv << "Use time " << float(end - begin) / 1000 << "ms\n";
|
||||
ASSERT(float(end - begin) / 1000 < tasks[i].time_limit);
|
||||
ASSERTop(float(end - begin) / 1000, <, tasks[i].time_limit);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -220,6 +220,10 @@ bool check_vlog(const char* fileline, int verbose);
|
|||
#define LOGrrrr LOGvvvv >> jittor::red
|
||||
#define LOGyyyy LOGvvvv >> jittor::yellow
|
||||
|
||||
#define LOGir LOGi >> jittor::red
|
||||
#define LOGig LOGi >> jittor::green
|
||||
#define LOGiy LOGi >> jittor::yellow
|
||||
|
||||
void system_with_check(const char* cmd);
|
||||
|
||||
} // jittor
|
|
@ -17,11 +17,14 @@ int64_t Var::number_of_lived_vars = 0;
|
|||
|
||||
DEFINE_FLAG(fast_shared_ptr<loop_options_t>, compile_options, {},
|
||||
"Override the default loop transfrom options");
|
||||
DEFINE_FLAG(bool, no_grad, 0,
|
||||
"No grad for all jittor Var creation");
|
||||
|
||||
Var::Var(NanoVector shape, NanoString dtype)
|
||||
: shape(shape),
|
||||
loop_options(compile_options) {
|
||||
flags.set(NodeFlags::_var, 1);
|
||||
flags.set(NodeFlags::_stop_grad, !dtype.is_float() || no_grad);
|
||||
ns = dtype;
|
||||
ASSERT(ns.is_dtype());
|
||||
number_of_lived_vars++;
|
||||
|
|
Loading…
Reference in New Issue