forked from maxjhandsome/jittor
polish argmax
This commit is contained in:
parent
a2736164fb
commit
40ca973107
|
@ -20,7 +20,7 @@
|
|||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, string op, bool keepdims)
|
||||
CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdims)
|
||||
: x(x), offsets(offsets), op(op), keepdims(keepdims) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
|
@ -56,7 +56,7 @@ void CubArgReduceOp::infer_shape() {
|
|||
void CubArgReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Toffsets", offsets->dtype());
|
||||
add_jit_define("FUNC", op=="min" ? "ArgMin" : "ArgMax");
|
||||
add_jit_define("FUNC", op==ns_minimum ? "ArgMin" : "ArgMax");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -14,10 +14,10 @@ namespace jittor {
|
|||
|
||||
struct CubArgReduceOp : Op {
|
||||
Var* x, * offsets, * y, * y_key;
|
||||
string op;
|
||||
NanoString op;
|
||||
bool keepdims;
|
||||
// @attrs(multiple_outputs)
|
||||
CubArgReduceOp(Var* x, Var* offsets, string op, bool keepdims);
|
||||
CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdims);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
|
||||
|
|
|
@ -295,6 +295,14 @@ Var.masked_fill = masked_fill
|
|||
def sqr(x): return x*x
|
||||
Var.sqr = sqr
|
||||
|
||||
def argmax(x, dim:int, keepdims:bool=False):
|
||||
return x.arg_reduce("max", dim, keepdims)
|
||||
Var.argmax = argmax
|
||||
|
||||
def argmin(x, dim:int, keepdims:bool=False):
|
||||
return x.arg_reduce("min", dim, keepdims)
|
||||
Var.argmin = argmin
|
||||
|
||||
def attrs(var):
|
||||
return {
|
||||
"is_stop_fuse": var.is_stop_fuse(),
|
||||
|
|
2
setup.py
2
setup.py
|
@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
|
|||
|
||||
setuptools.setup(
|
||||
name='jittor',
|
||||
version='1.1.4.3',
|
||||
version='1.1.4.4',
|
||||
# scripts=[],
|
||||
author="Jittor Group",
|
||||
author_email="ran.donglang@gmail.com",
|
||||
|
|
|
@ -149,6 +149,8 @@ static void init_ns() {
|
|||
FOR_ALL_NS(INIT_NS);
|
||||
ASSERT(NanoString::__ns_to_string.size()<=(1<<NanoString::_index_nbits));
|
||||
NanoString::__string_to_ns["sum"] = ns_add;
|
||||
NanoString::__string_to_ns["min"] = ns_minimum;
|
||||
NanoString::__string_to_ns["max"] = ns_maximum;
|
||||
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
|
||||
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
|
||||
}
|
||||
|
|
|
@ -34,18 +34,16 @@ static auto make_reshape = get_op_info("reshape")
|
|||
static auto make_reindex_reduce = get_op_info("reindex_reduce")
|
||||
.get_constructor<VarPtr, Var*, NanoString, NanoVector, vector<string>&&, vector<string>&&, vector<Var*>&&>();
|
||||
|
||||
ArgReduceOp::ArgReduceOp(Var* x, string op, int dim, bool keepdims)
|
||||
ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims)
|
||||
: x(x), op(op), dim(dim), keepdims(keepdims) {
|
||||
if (this->dim == -1)
|
||||
this->dim = x->shape.size() - 1;
|
||||
dim = this->dim;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static std::vector<VarPtr>(*cub_arg_reduce)(Var*, Var*, string, bool) = nullptr;
|
||||
if (!cub_arg_reduce && has_op("cub_arg_reduce")) {
|
||||
cub_arg_reduce = get_op_info("cub_arg_reduce")
|
||||
.get_constructor<std::vector<VarPtr>, Var*, Var*, string, bool>();
|
||||
}
|
||||
static auto cub_arg_reduce = has_op("cub_arg_reduce") ?
|
||||
get_op_info("cub_arg_reduce").get_constructor<std::vector<VarPtr>, Var*, Var*, NanoString, bool>()
|
||||
: nullptr;
|
||||
if (cub_arg_reduce) {
|
||||
if (x->num<0) exe.run_sync(vector<Var*>({x}), true);
|
||||
int dims = x->shape.size();
|
||||
|
@ -162,7 +160,7 @@ void ArgReduceOp::jit_prepare() {
|
|||
add_jit_define("YDIM", JK::hex1(y->shape.size()));
|
||||
add_jit_define("KEEPDIMS", keepdims ? 1 : 0);
|
||||
add_jit_define("DIM", JK::hex1(dim));
|
||||
add_jit_define("CMP", op=="min" ? "<" : ">");
|
||||
add_jit_define("CMP", op==ns_minimum ? "<" : ">");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -14,11 +14,11 @@ namespace jittor {
|
|||
|
||||
struct ArgReduceOp : Op {
|
||||
Var* x, * y, * y_key;
|
||||
string op;
|
||||
NanoString op;
|
||||
int dim;
|
||||
bool keepdims;
|
||||
// @attrs(multiple_outputs)
|
||||
ArgReduceOp(Var* x, string op, int dim, bool keepdims);
|
||||
ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
static VarPtr get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y);
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -61,7 +61,10 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
|||
reduce_mask |= 1<<dim;
|
||||
}
|
||||
}
|
||||
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
|
||||
if (x->dtype() == ns_bool && ns == ns_add)
|
||||
y = create_output(nullptr, ns_int32);
|
||||
else
|
||||
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
|
||||
}
|
||||
|
||||
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, bool keepdims)
|
||||
|
|
Loading…
Reference in New Issue