86 lines
3.2 KiB
Python
86 lines
3.2 KiB
Python
import math
|
|
from typing import List, Tuple, Union
|
|
|
|
import numpy as np
|
|
import jittor as jt
|
|
from jittor import nn
|
|
from jittor import init
|
|
from jittor.misc import _pair, _triple
|
|
|
|
from python import SparseTensor
|
|
from python.nn import functional as F
|
|
# from utils import make_ntuple
|
|
|
|
__all__ = ['Conv3d']
|
|
|
|
class Conv3d(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Union[int, Tuple[int, ...]] = 3,
|
|
stride: Union[int, Tuple[int, ...]] = 1,
|
|
dilation: int = 1,
|
|
groups: int = 1,
|
|
bias: bool = False,
|
|
transposed: bool = False) -> None:
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
# self.kernel_size = make_ntuple(kernel_size, ndim=3)
|
|
# self.stride = make_ntuple(stride, ndim=3)
|
|
# self.dilation = dilation
|
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
|
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
|
|
self.groups = groups
|
|
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
|
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
|
self.transposed = transposed
|
|
|
|
self.kernel_volume = int(np.prod(self.kernel_size))
|
|
# if self.kernel_volume > 1:
|
|
# self.kernel = nn.Parameter(
|
|
# jt.zeros(self.kernel_volume, in_channels, out_channels))
|
|
# else:
|
|
# self.kernel = nn.Parameter(jt.zeros(in_channels, out_channels))
|
|
# if bias:
|
|
# self.bias = nn.Parameter(jt.zeros(out_channels))
|
|
# else:
|
|
# self.register_parameter('bias', None)
|
|
# self.reset_parameters()
|
|
fan = (self.out_channels if self.transposed else self.in_channels) * self.kernel_volume
|
|
std = 1 / math.sqrt(fan)
|
|
|
|
if self.kernel_volume > 1:
|
|
self.weight = init.uniform([self.kernel_volume, in_channels, out_channels], 'float32', -std, std)
|
|
else:
|
|
self.weight = init.uniform([in_channels, out_channels], 'float32')
|
|
if bias:
|
|
self.bias = init.uniform([out_channels], "float32", -std, std)
|
|
else:
|
|
self.bias = None
|
|
# self.reset_parameters()
|
|
|
|
def execute(self, input: SparseTensor) -> SparseTensor:
|
|
return F.conv3d(input,
|
|
weight=self.weight,
|
|
kernel_size=self.kernel_size,
|
|
bias=self.bias,
|
|
stride=self.stride,
|
|
dilation=self.dilation,
|
|
groups=self.groups,
|
|
transposed=self.transposed)
|
|
|
|
# def set_parameters(self) -> None:
|
|
# std = 1 / math.sqrt(
|
|
# (self.out_channels if self.transposed else self.in_channels)
|
|
# * self.kernel_volume)
|
|
# self.weight *= std
|
|
# if self.bias is not None:
|
|
# self.bias *= std
|
|
|
|
|
|
|
|
|