forked from jittor/jittor
jtorch compatible
This commit is contained in:
parent
437a720500
commit
6fbffb7c62
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.3.6.9'
|
__version__ = '1.3.6.10'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -291,7 +291,7 @@ def median(x,dim=None,keepdim=False, keepdims=False):
|
||||||
if dim is None:
|
if dim is None:
|
||||||
x = x.reshape(-1)
|
x = x.reshape(-1)
|
||||||
dim=0
|
dim=0
|
||||||
_,x = x.argsort(dim)
|
_,x = jt.argsort(x, dim)
|
||||||
slices = [slice(None) for i in range(dim-1)]
|
slices = [slice(None) for i in range(dim-1)]
|
||||||
k = (x.shape[dim]-1)//2
|
k = (x.shape[dim]-1)//2
|
||||||
if keepdim:
|
if keepdim:
|
||||||
|
|
|
@ -498,9 +498,9 @@ def softmax(x, dim=None, log=False):
|
||||||
return code_softmax.softmax_v1(x, log)
|
return code_softmax.softmax_v1(x, log)
|
||||||
if dim is None: dim = ()
|
if dim is None: dim = ()
|
||||||
if log:
|
if log:
|
||||||
a = x-x.max(dim, keepdims=True)
|
a = x - jt.max(x, dim, keepdims=True)
|
||||||
return a - a.exp().sum(dim, keepdims=True).log()
|
return a - a.exp().sum(dim, keepdims=True).log()
|
||||||
x = (x-x.max(dim, keepdims=True)).exp()
|
x = (x - jt.max(x, dim, keepdims=True)).exp()
|
||||||
return x / x.sum(dim, keepdims=True)
|
return x / x.sum(dim, keepdims=True)
|
||||||
jt.Var.softmax = softmax
|
jt.Var.softmax = softmax
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue