25 lines
579 B
Python
25 lines
579 B
Python
import jittor as jt
|
|
import jittor.nn as nn
|
|
|
|
from python import SparseTensor
|
|
from python.nn.utils import fapply
|
|
|
|
__all__ = ['relu', 'leaky_relu']
|
|
# __all__ = ['relu', 'leaky_relu', 'ReLU', 'LeakyReLU']
|
|
|
|
def relu(input: SparseTensor) -> SparseTensor:
|
|
return fapply(input, nn.relu)
|
|
|
|
|
|
def leaky_relu(input: SparseTensor,
|
|
scale: float = 0.01) -> SparseTensor:
|
|
return fapply(input,
|
|
nn.leaky_relu,
|
|
scale=scale)
|
|
|
|
# Relu = jt.make_module(relu)
|
|
# ReLU = Relu
|
|
# Leaky_relu = jt.make_module(leaky_relu, 2)
|
|
# LeakyReLU = Leaky_relu
|
|
|