Function customized backward and no grad

This commit is contained in:
Dun Liang 2020-07-13 22:50:29 +08:00
parent 7cbed2b1ab
commit f2bf93ae56
18 changed files with 710 additions and 50 deletions

View File

@ -59,6 +59,11 @@ class flag_scope(_call_no_record_scope):
for k,v in self.flags_bk.items():
setattr(flags, k, v)
class no_grad(flag_scope):
def __init__(self, **jt_flags):
self.jt_flags = jt_flags
jt_flags["no_grad"] = 1
single_log_capture = None
class log_capture_scope(_call_no_record_scope):
@ -559,6 +564,92 @@ class Module:
for p in self.parameters():
p.update(p.mpi_broadcast(root))
class Function(Module):
''' Function Module for customized backward operations
Example 1 (Function can have multiple input and multiple output, and user
can store value for backward computation)::
import jittor as jt
from jittor import Function
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y, x/y
def grad(self, grads):
return grads[0] * self.y, grads[1] * self.x
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
c,d = func(a, b)
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4
assert db.data == 9
Example 2(Function can return None for no gradiant, and gradiant
can also be None)::
import jittor as jt
from jittor import Function
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y, x/y
def grad(self, grads):
assert grads[1] is None
return grads[0] * self.y, None
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
c,d = func(a, b)
d.stop_grad()
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4
assert db.data == 0
'''
def __call__(self, *args, **kw):
args2 = list(args)
kw = dict(kw)
taped_inputs = []
taped_outputs = []
for i,v in enumerate(args2):
if isinstance(v, Var):
v = v.tape()
args2[i] = v
taped_inputs.append(v)
for k,v in kw.items():
if isinstance(v, Var):
v = v.tape()
kw[k] = v
taped_inputs.append(v)
res = self.execute(*args2, **kw)
if isinstance(res, Var):
res = res.tape()
taped_outputs.append(res)
else:
assert isinstance(res, Sequence)
res = list(res)
for i,v in enumerate(res):
if isinstance(v, Var):
v = v.tape()
res[i] = v
taped_outputs.append(v)
# tape output and input together so
# backward treat them as one operator
tape_together(taped_inputs, taped_outputs, lambda args: self.grad(args))
return res
def dfs(self, parents, k, callback, callback_leave=None):
pass
def make_module(func, exec_n_args=1):
class MakeModule(Module):
def __init__(self, *args, **kw):

View File

@ -12,7 +12,7 @@ import numpy as np
class TestClone(unittest.TestCase):
def test(self):
jt.clean()
b = a = jt.array(1)
b = a = jt.array(1.0)
for i in range(10):
b = b.clone()
if i==5: c=b

View File

@ -0,0 +1,166 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
from collections.abc import Sequence, Mapping
from .test_core import expect_error
from jittor import Function
class TestFunction(unittest.TestCase):
def test1(self):
class MyFunc(Function):
def execute(self, x):
return x+1
def grad(self, grads):
return grads[0]-2
a = jt.ones(1)
func = MyFunc()
b = func(a)
da = jt.grad(b, a)
assert da.data == -1
def test2(self):
class MyFunc(Function):
def execute(self, x):
self.x = x
return x+1
def grad(self, grads):
return (grads[0]-2) * self.x
a = jt.ones(1) * 10
func = MyFunc()
b = func(a)
da = jt.grad(b, a)
assert da.data == -10
def test_grad_not_match_error(self):
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y
def grad(self, grads):
return (grads[0]-2) * self.x
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
c = func(a, b)
expect_error(lambda: jt.grad(c, [a, b]))
def test_multi_grads(self):
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y
def grad(self, grads):
return (grads[0]-2) * self.y, (grads[0]-2) * self.x
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
c = func(a, b)
da, db = jt.grad(c, [a, b])
assert da.data == -4
assert db.data == -3
def test_multi_grads_none(self):
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y
def grad(self, grads):
return (grads[0]-2) * self.y, None
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
c = func(a, b)
da, db = jt.grad(c, [a, b])
assert da.data == -4
assert db.data == 0
def test_multi_grads_multi_out(self):
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y, x/y
def grad(self, grads):
return grads[0] * self.y, grads[1] * self.x
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
c,d = func(a, b)
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4
assert db.data == 9
def test_multi_grads_multi_out_stop_grad_0(self):
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y, x/y
def grad(self, grads):
return grads[0] * self.y, grads[1] * self.x
a = jt.array(3.0)
b = jt.array(4.0)
b.stop_grad()
func = MyFunc()
c,d = func(a, b)
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4
assert db.data == 0
def test_multi_grads_multi_out_stop_grad_1(self):
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y, x/y
def grad(self, grads):
assert grads[1] is None
return grads[0] * self.y, None
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
c,d = func(a, b)
d.stop_grad()
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4
assert db.data == 0
def test_multi_grads_multi_out2(self):
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y, x/y
def grad(self, grads):
res = (grads[0] * self.y, grads[1] * self.x)
print(res)
return res
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
c,d = func(a, b)
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4, da.data
assert db.data == 9
if __name__ == "__main__":
unittest.main()

View File

@ -151,7 +151,16 @@ class TestGrad(unittest.TestCase):
self.assertEqual(dx.data, 4*2**3)
self.assertEqual(ddx.data, 4*3*2**2)
self.assertEqual(dddx.data, 4*3*2*2**1)
def test_no_grad(self):
a = jt.array(1.0)
with jt.no_grad():
b = a
for i in range(10):
b = b.clone() + 1
assert b.data == 11
jt.clean()
assert jt.liveness_info()["lived_vars"] == 2
if __name__ == "__main__":
unittest.main()

View File

@ -91,15 +91,15 @@ for cc_type in ["g++", "clang"]:
# compress source
# tar -cvzf build/jittor.tgz . --exclude build --exclude .git --exclude .ipynb_checkpoints --exclude __pycache__
# mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor
assert os.system(f"cd {polish_path} && tar -cvzf build/jittor.tgz . --exclude build --exclude .git --exclude .ipynb_checkpoints --exclude __pycache__")==0
assert os.system(f"cd {polish_path} && tar --exclude=build --exclude=.git --exclude=.ipynb_checkpoints --exclude=__pycache__ -cvzf build/jittor.tgz . ")==0
# rsync to build-server
jittor_web_base_dir = "Documents/jittor-blog/assets/"
jittor_web_build_dir = jittor_web_base_dir + "build/"
assert os.system(f"rsync -avPu {polish_path}/build/ jittor-web:{jittor_web_build_dir}")==0
assert os.system(f"ssh jittor@166.111.68.30 Documents/jittor-blog.git/hooks/post-update")==0
assert os.system(f"ssh jittor-web Documents/jittor-blog.git/hooks/post-update")==0
# push to github
assert os.system(f"cd {polish_path} && git push -f origin master")==0
# assert os.system(f"cd {polish_path} && git push -f origin master")==0
# push to pip

View File

@ -1 +1 @@
08f4ca8b2c0a2978cd3fbc9a3a6e76bd1463ca12
b27082f9444a4e627f7dfc574d0114302ba27b5e

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.5.4',
version='1.1.5.5',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",

View File

@ -35,6 +35,125 @@ inline static void assign_attrs(Var* a, Var* b) {
a->flags.set(NodeFlags::_stop_fuse);
}
void tape_together(
const vector<VarHolder*>& taped_inputs,
const vector<VarHolder*>& taped_outputs,
GradCallback&& grad_callback
) {
auto tapes = new Tapes();
tapes->total = tapes->ref = taped_inputs.size() + taped_outputs.size();
tapes->callback = move(grad_callback);
tapes->flags.set(NodeFlags::_grads);
for (int i=0; i<taped_inputs.size(); i++) {
auto v = taped_inputs[i]->var;
auto op = (TapeOp*)v->input();
ASSERT(op);
op->flags.set(NodeFlags::_tape);
tapes->_inputs.emplace_back(op->inputs().front());
op->tapes = tapes;
}
for (int i=0; i<taped_outputs.size(); i++) {
auto v = taped_outputs[i]->var;
auto op = (TapeOp*)v->input();
ASSERT(op);
op->flags.set(NodeFlags::_tape);
tapes->_outputs.emplace_back(v,0);
op->tapes = tapes;
}
}
template <typename Func>
void bfs_backward_with_tape(vector<Node*>& queue, Func&& func) {
auto t = ++Node::tflag_count;
size_t i=0;
for (Node* node : queue) node->tflag = t;
while (i < queue.size()) {
Node* node = queue[i++];
for (auto i : node->_inputs) {
auto inode = i.node;
if (inode->flags.get(NodeFlags::_tape)) {
Tapes* t = ((TapeOp*)inode)->tapes;
inode = t;
ASSERT(t->ref == t->total);
}
if (inode->tflag != t && func(inode)) {
inode->tflag = t;
queue.push_back(inode);
}
}
}
}
template <typename Func>
void bfs_backward_with_tape(vector<Node*>& seed, vector<Node*>& queue, Func&& func) {
for (Node* node : seed)
if (func(node)) queue.push_back(node);
bfs_backward_with_tape(queue, func);
}
template <typename Func>
void bfs_forward_with_tape(vector<Node*>& queue, Func&& func) {
auto t = ++Node::tflag_count;
size_t i=0;
for (Node* node : queue) node->tflag = t;
while (i < queue.size()) {
Node* node = queue[i++];
for (auto o : node->_outputs) {
auto onode = o.node;
if (onode->flags.get(NodeFlags::_tape)) {
Tapes* t = ((TapeOp*)onode)->tapes;
ASSERT(t->ref == t->total) << t->ref << t->total;
onode = t;
}
if (onode->tflag != t && func(onode)) {
onode->tflag = t;
queue.push_back(onode);
}
}
}
}
template <typename Func>
void toplogical_sort_backward_with_tape(vector<Node*>& nodes, vector<Node*>& sorted, Func&& func) {
auto t = ++Node::tflag_count;
sorted.reserve(nodes.size());
for (auto node : nodes) node->tflag = t;
for (auto node : nodes) {
auto& deps = node->custom_data;
deps = 0;
for (auto o : node->_outputs) {
auto onode = o.node;
if (onode->flags.get(NodeFlags::_tape)) {
Tapes* t = ((TapeOp*)onode)->tapes;
onode = t;
}
if (onode->tflag == t)
deps++;
}
if (deps == 0) sorted.push_back(node);
}
size_t i=0;
while (i < sorted.size()) {
Node* node = sorted[i++];
for (auto i : node->_inputs) {
auto inode = i.node;
if (inode->flags.get(NodeFlags::_tape)) {
Tapes* t = ((TapeOp*)inode)->tapes;
inode = t;
}
if (inode->tflag == t) {
inode->custom_data--;
if (inode->custom_data == 0)
sorted.push_back(inode);
}
}
func(node);
}
ASSERTop(nodes.size(),==,sorted.size());
}
vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
LOGvv << "loss:" >> loss << "targets:" >> targets;
CHECK(loss->is_float()) << "Loss should be float";
@ -44,13 +163,13 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
vector<Node*> ts(targets.begin(), targets.end());
// bfs visit find all successors of targets
LOGvv << "Size of successors:" << ts.size();
bfs_forward(ts, [](Node*){ return true; });
bfs_forward_with_tape(ts, [](Node*){ return true; });
vector<Node*> gnodes;
gnodes.reserve(ts.size());
auto nt = Node::tflag_count;
if (loss->tflag == nt)
gnodes.push_back(loss);
bfs_backward(gnodes, [&](Node* node) {
bfs_backward_with_tape(gnodes, [&](Node* node) {
if (node->tflag != nt)
return false;
if (node->is_stop_grad())
@ -63,58 +182,145 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
LOGvv << "Size of grad nodes:" << gnodes.size();
vector<Node*> sorted;
toplogical_sort_backward(gnodes, sorted, [](Node*){});
toplogical_sort_backward_with_tape(gnodes, sorted, [](Node*){});
nt = Node::tflag_count;
vector<Var*> gvars;
gvars.reserve(sorted.size());
for (Node* node : sorted)
if (node->is_var())
gvars.push_back(node->var());
if (node->is_var()) {
Var* v = node->var();
v->custom_data = gvars.size();
gvars.push_back(v);
}
LOGvv << "Size of grad vars:" << gvars.size();
vector<VarPtr> grads(gvars.size());
vector<VarPtr> results(targets.size());
for (size_t i=0; i<gvars.size(); i++)
gvars[i]->custom_data = i;
vector<int> target_id(targets.size());
for (int i=0; i<targets.size(); i++) {
Var* var = targets[i];
target_id[i] = (var->tflag == nt) ?
var->custom_data : -1;
}
if (grads.size()) {
grads[0] = make_number(1.f, loss);
assign_attrs(grads[0].ptr, loss);
registe_node_trace_grad(grads[0].ptr, loss, 0);
}
vector<pair<Node*, int64>> id_buffer;
id_buffer.reserve(sorted.size()+10);
// backup id in custum data
for (int i=1; i<gvars.size(); i++) {
Var* var = gvars[i];
for (auto it : var->outputs_with_index()) {
Op* op = it.op;
if (op->flags.get(NodeFlags::_tape)) {
op = ((TapeOp*)op)->tapes;
}
auto index = it.index;
if (op->tflag != nt) continue;
id_buffer.emplace_back(op, index);
for (size_t i=0; i<gvars.size(); i++) {
// backward together
if (op->flags.get(NodeFlags::_grads)) {
// dont backward next time
op->tflag = 0;
for (Var* out : op->outputs()) {
id_buffer.emplace_back(
out,
out->tflag == nt ? out->custom_data : -1);
}
for (Var* in : op->inputs()) {
id_buffer.emplace_back(
in,
in->tflag == nt ? in->custom_data : -1);
}
} else {
// single var backward
for (Var* out : op->outputs()) {
id_buffer.emplace_back(
out,
out->tflag == nt ? out->custom_data : -1);
}
}
}
// end of var output
id_buffer.emplace_back(nullptr, 0);
}
// real backward construction from prev backuped ids
int j=0;
for (int i=1; i<gvars.size(); i++,j++) {
Var* var = gvars[i];
auto& grad = grads[i];
#ifdef PREVENT_LARGE_FUSED_OP
int gsum = 0;
#endif
if (i==0) {
grad = make_number(1.f, loss);
assign_attrs(grad.ptr, loss);
registe_node_trace_grad(grad.ptr, loss, 0);
} else
for (auto it : var->outputs_with_index()) {
Op* op = it.op;
auto index = it.index;
if (op->tflag != nt) continue;
for (Var* out : op->outputs()) {
if (out->tflag != nt) continue;
Var* dout = grads[out->custom_data];
VarPtr dvar = make_grad(op, out, dout, var, index);
registe_node_trace_grad(dvar.ptr, op, index);
if (dvar)
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
<< "dvar" << dvar << "var" << var;
if (!grad)
grad = move(dvar);
else if (dvar) {
grad = make_binary(grad, dvar, ns_add);
#ifdef PREVENT_LARGE_FUSED_OP
gsum ++;
if (gsum>=PREVENT_LARGE_FUSED_OP) {
// TODO: this is a dirty fix for
// stopping fuse lots of op together,
// try to find a better solution
grad->flags.set(NodeFlags::_stop_fuse);
// dump "for (auto it : var->outputs_with_index())"
while (id_buffer[j].first) {
Op* op = id_buffer[j].first->op();
auto index = id_buffer[j].second;
j++;
auto n_o = op->outputs().size();
if (op->flags.get(NodeFlags::_grads)) {
// backward together
auto n_i = op->inputs().size();
Var* douts[n_o];
VarPtr dins[n_i];
// dump "for (Var* out : op->outputs())"
for (int i=0; i<n_o; i++,j++) {
auto id = id_buffer[j].second;
if (id>=0) {
douts[i] = grads[id];
} else
douts[i] = nullptr;
}
op->grads(douts, dins);
// dump "for (Var* in : op->inputs())"
for (int i=0; i<n_i; i++,j++) {
auto id = id_buffer[j].second;
if (id>=0) {
auto& din = dins[i];
auto& grad = grads[id];
if (din && grad) {
grad = make_binary(grad, din, ns_add);
} else
grad = move(din);
}
}
} else {
// single var backward
// dump "for (Var* out : op->outputs())"
for (int i=0; i<n_o; i++,j++) {
auto id = id_buffer[j].second;
auto out = id_buffer[j].first->var();
if (id<0) continue;
Var* dout = grads[id];
VarPtr dvar = make_grad(op, out, dout, var, index);
registe_node_trace_grad(dvar.ptr, op, index);
if (dvar)
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
<< "dvar" << dvar << "var" << var;
if (!grad)
grad = move(dvar);
else if (dvar) {
grad = make_binary(grad, dvar, ns_add);
#ifdef PREVENT_LARGE_FUSED_OP
gsum ++;
if (gsum>=PREVENT_LARGE_FUSED_OP) {
// TODO: this is a dirty fix for
// stopping fuse lots of op together,
// try to find a better solution
grad->flags.set(NodeFlags::_stop_fuse);
}
#endif
assign_attrs(grad.ptr, var);
registe_node_trace_grad(grad.ptr, var, index);
}
#endif
assign_attrs(grad.ptr, var);
registe_node_trace_grad(grad.ptr, var, index);
}
}
}
@ -123,8 +329,9 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
for (size_t i=0; i<results.size(); i++) {
Var* var = targets[i];
VarPtr& grad = results[i];
if (var->tflag == nt)
grad = move(grads[var->custom_data]);
auto id = target_id[i];
if (id>=0)
grad = move(grads[id]);
if (!grad) {
LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
grad = make_number(0.f, var);

View File

@ -3,10 +3,18 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "ops/tape_op.h"
#include "common.h"
namespace jittor {
vector<VarPtr> grad(Var* loss, vector<Var*> targets);
// @pyjt(tape_together)
void tape_together(
const vector<VarHolder*>& taped_inputs,
const vector<VarHolder*>& taped_outputs,
GradCallback&& grad_callback
);
} // jittor

View File

@ -44,6 +44,10 @@ struct NodeFlags {
_vary_shape=_n+3,
// bit4~5: op type
_op_type=_n+4, _op_type_nbits=2,
// bit6: is tape op
_tape=_n+6,
// bit7: backprop grad at ones
_grads=_n+7,
};
inline void set(Flags f, int a=1, int nbits=1) {

View File

@ -46,6 +46,10 @@ VarPtr Op::grad(Var* out, Var* dout, Var* v, int v_index) {
return nullptr;
}
void Op::grads(Var** douts, VarPtr* dins) {
LOGw << "Grads of" << name() << "return zeros";
}
Var* Op::create_output(NanoVector shape, NanoString dtype) {
VarPtr vp(shape, dtype);
Var* output = vp.ptr;

View File

@ -37,6 +37,7 @@ struct Op : Node {
~Op();
virtual VarPtr grad(Var* out, Var* dout, Var* v, int v_index);
virtual void grads(Var** douts, VarPtr* dins);
virtual void infer_shape() {}
virtual void run() {};
virtual void jit_prepare() {};

51
src/ops/tape_op.cc Normal file
View File

@ -0,0 +1,51 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "ops/array_op.h"
#include "ops/op_register.h"
#include "ops/tape_op.h"
namespace jittor {
static auto make_tape = get_op_info("tape")
.get_constructor<VarPtr, Var*>();
TapeOp::TapeOp(Var* x) : tapes(nullptr) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
auto y = create_output(nullptr, x->dtype());
if (x->name.ptr)
y->name = x->name;
}
TapeOp::~TapeOp() {
if (tapes) {
if (! --tapes->ref) {
tapes->_inputs.clear();
tapes->_outputs.clear();
delete tapes;
}
}
}
VarPtr TapeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return dout;
}
void TapeOp::infer_shape() {
auto x = inputs().front();
auto y = outputs().front();
y->set_shape(x->shape);
y->share_with(x);
}
void Tapes::grads(Var** douts, VarPtr* dins) {
callback.func(_outputs.size(), douts, _inputs.size(), dins);
}
} // jittor

54
src/ops/tape_op.h Normal file
View File

@ -0,0 +1,54 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include <functional>
#include "op.h"
#include "var_holder.h"
namespace jittor {
struct Tapes;
struct GradCallback {
typedef jittor::VarHolder VarHolder;
typedef VarHolder* VarHolderPtr;
std::function<void(int,Var**,int,VarPtr*)> func;
std::function<void()> deleter;
inline ~GradCallback() { if (deleter) deleter(); }
GradCallback(const GradCallback&) = delete;
GradCallback() = default;
GradCallback(GradCallback&& other) : func(other.func), deleter(other.deleter) {
other.func = nullptr;
other.deleter = nullptr;
};
GradCallback(std::function<void(int,Var**,int,VarPtr*)> && func, std::function<void()>&& deleter)
: func(move(func)), deleter(move(deleter)) {};
void operator =(GradCallback&& other) { this->~GradCallback(); new (this) GradCallback(move(other)); }
};
struct TapeOp final : Op {
Tapes* tapes;
TapeOp(Var* x);
~TapeOp();
const char* name() const override { return "tape"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
};
struct Tapes final : Op {
int ref, total;
GradCallback callback;
const char* name() const override { return "tapes"; }
void grads(Var** douts, VarPtr* dins) override;
};
} // jittor

View File

@ -590,4 +590,62 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
return func;
}
struct GradCallback;
DEF_IS(GradCallback, bool) is_type(PyObject* obj) {
return PyCallable_Check(obj);
}
DEF_IS(GradCallback, T) from_py_object(PyObject* obj) {
// PyObject_Call
Py_INCREF(obj);
T func(
// callback
[obj](int n_o, Var** douts, int n_i, VarPtr* dins) {
PyObjHolder list(PyTuple_New(n_o));
for (int i=0; i<n_o; i++) {
if (douts[i]) {
PyTuple_SET_ITEM(list.obj, i,
to_py_object(new typename T::VarHolder(douts[i])));
} else {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(list.obj, i, Py_None);
}
}
PyObjHolder args(PyTuple_New(1));
PyTuple_SET_ITEM(args.obj, 0, list.release());
PyObjHolder ret(PyObject_Call(obj, args.obj, nullptr));
auto is_seq = PyList_CheckExact(ret.obj) || PyTuple_CheckExact(ret.obj);
auto check = [&](int i, PyObject* obj) {
if (obj == Py_None) {
dins[i] = nullptr;
} else {
CHECK(Py_TYPE(obj) == &PyjtVarHolder.ht_type) << "returned grad("<<Py_TYPE(obj)->tp_name<<") is not jittor variable";
auto vh = from_py_object<typename T::VarHolderPtr>(obj);
dins[i] = vh->var;
}
};
if (!is_seq) {
CHECKop(n_i,==,1) << "returned grad size not match";
check(0, ret.obj);
} else {
auto size = Py_SIZE(ret.obj);
CHECKop(n_i,==,size) << "returned grad size not match";
auto arr = PySequence_Fast_ITEMS(ret.obj);
for (int i=0; i<size; i++) {
auto oi = arr[i];
check(i, oi);
}
}
},
// deleter
[obj]() {
Py_DECREF(obj);
}
);
return func;
}
} // jittor

View File

@ -46,7 +46,7 @@ JIT_TEST(sfrl_allocator_time) {
std::chrono::steady_clock::now().time_since_epoch()).count();
LOGvv << "Use time " << float(end - begin) / 1000 << "ms\n";
ASSERT(float(end - begin) / 1000 < tasks[i].time_limit);
ASSERTop(float(end - begin) / 1000, <, tasks[i].time_limit);
}
}
@ -84,7 +84,7 @@ JIT_TEST(sfrl_allocator_share) {
std::chrono::steady_clock::now().time_since_epoch()).count();
LOGvvv << "Use time " << float(end - begin) / 1000 << "ms\n";
ASSERT(float(end - begin) / 1000 < tasks[i].time_limit);
ASSERTop(float(end - begin) / 1000, <, tasks[i].time_limit);
}
}

View File

@ -220,6 +220,10 @@ bool check_vlog(const char* fileline, int verbose);
#define LOGrrrr LOGvvvv >> jittor::red
#define LOGyyyy LOGvvvv >> jittor::yellow
#define LOGir LOGi >> jittor::red
#define LOGig LOGi >> jittor::green
#define LOGiy LOGi >> jittor::yellow
void system_with_check(const char* cmd);
} // jittor

View File

@ -17,11 +17,14 @@ int64_t Var::number_of_lived_vars = 0;
DEFINE_FLAG(fast_shared_ptr<loop_options_t>, compile_options, {},
"Override the default loop transfrom options");
DEFINE_FLAG(bool, no_grad, 0,
"No grad for all jittor Var creation");
Var::Var(NanoVector shape, NanoString dtype)
: shape(shape),
loop_options(compile_options) {
flags.set(NodeFlags::_var, 1);
flags.set(NodeFlags::_stop_grad, !dtype.is_float() || no_grad);
ns = dtype;
ASSERT(ns.is_dtype());
number_of_lived_vars++;