polish argmax

This commit is contained in:
Dun Liang 2020-06-07 14:43:07 +08:00
parent a2736164fb
commit 40ca973107
8 changed files with 26 additions and 15 deletions

View File

@ -20,7 +20,7 @@
namespace jittor {
#ifndef JIT
CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, string op, bool keepdims)
CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdims)
: x(x), offsets(offsets), op(op), keepdims(keepdims) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
@ -56,7 +56,7 @@ void CubArgReduceOp::infer_shape() {
void CubArgReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Toffsets", offsets->dtype());
add_jit_define("FUNC", op=="min" ? "ArgMin" : "ArgMax");
add_jit_define("FUNC", op==ns_minimum ? "ArgMin" : "ArgMax");
}
#else // JIT

View File

@ -14,10 +14,10 @@ namespace jittor {
struct CubArgReduceOp : Op {
Var* x, * offsets, * y, * y_key;
string op;
NanoString op;
bool keepdims;
// @attrs(multiple_outputs)
CubArgReduceOp(Var* x, Var* offsets, string op, bool keepdims);
CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdims);
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;

View File

@ -295,6 +295,14 @@ Var.masked_fill = masked_fill
def sqr(x): return x*x
Var.sqr = sqr
def argmax(x, dim:int, keepdims:bool=False):
return x.arg_reduce("max", dim, keepdims)
Var.argmax = argmax
def argmin(x, dim:int, keepdims:bool=False):
return x.arg_reduce("min", dim, keepdims)
Var.argmin = argmin
def attrs(var):
return {
"is_stop_fuse": var.is_stop_fuse(),

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.4.3',
version='1.1.4.4',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",

View File

@ -149,6 +149,8 @@ static void init_ns() {
FOR_ALL_NS(INIT_NS);
ASSERT(NanoString::__ns_to_string.size()<=(1<<NanoString::_index_nbits));
NanoString::__string_to_ns["sum"] = ns_add;
NanoString::__string_to_ns["min"] = ns_minimum;
NanoString::__string_to_ns["max"] = ns_maximum;
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
}

View File

@ -34,18 +34,16 @@ static auto make_reshape = get_op_info("reshape")
static auto make_reindex_reduce = get_op_info("reindex_reduce")
.get_constructor<VarPtr, Var*, NanoString, NanoVector, vector<string>&&, vector<string>&&, vector<Var*>&&>();
ArgReduceOp::ArgReduceOp(Var* x, string op, int dim, bool keepdims)
ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims)
: x(x), op(op), dim(dim), keepdims(keepdims) {
if (this->dim == -1)
this->dim = x->shape.size() - 1;
dim = this->dim;
#ifdef HAS_CUDA
if (use_cuda) {
static std::vector<VarPtr>(*cub_arg_reduce)(Var*, Var*, string, bool) = nullptr;
if (!cub_arg_reduce && has_op("cub_arg_reduce")) {
cub_arg_reduce = get_op_info("cub_arg_reduce")
.get_constructor<std::vector<VarPtr>, Var*, Var*, string, bool>();
}
static auto cub_arg_reduce = has_op("cub_arg_reduce") ?
get_op_info("cub_arg_reduce").get_constructor<std::vector<VarPtr>, Var*, Var*, NanoString, bool>()
: nullptr;
if (cub_arg_reduce) {
if (x->num<0) exe.run_sync(vector<Var*>({x}), true);
int dims = x->shape.size();
@ -162,7 +160,7 @@ void ArgReduceOp::jit_prepare() {
add_jit_define("YDIM", JK::hex1(y->shape.size()));
add_jit_define("KEEPDIMS", keepdims ? 1 : 0);
add_jit_define("DIM", JK::hex1(dim));
add_jit_define("CMP", op=="min" ? "<" : ">");
add_jit_define("CMP", op==ns_minimum ? "<" : ">");
}
#else // JIT

View File

@ -14,11 +14,11 @@ namespace jittor {
struct ArgReduceOp : Op {
Var* x, * y, * y_key;
string op;
NanoString op;
int dim;
bool keepdims;
// @attrs(multiple_outputs)
ArgReduceOp(Var* x, string op, int dim, bool keepdims);
ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims);
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
static VarPtr get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y);
void infer_shape() override;

View File

@ -61,7 +61,10 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
reduce_mask |= 1<<dim;
}
}
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
if (x->dtype() == ns_bool && ns == ns_add)
y = create_output(nullptr, ns_int32);
else
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
}
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, bool keepdims)