update classification
This commit is contained in:
parent
6dc7bb6220
commit
b65f1c2036
|
@ -3,7 +3,6 @@ __pycache__
|
||||||
build
|
build
|
||||||
dist
|
dist
|
||||||
*.egg-info
|
*.egg-info
|
||||||
*.txt
|
|
||||||
*.npy
|
*.npy
|
||||||
*.log
|
*.log
|
||||||
*.zip
|
*.zip
|
||||||
|
|
|
@ -12,7 +12,8 @@ The latest JSparse can be installed by
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd python
|
cd python
|
||||||
python setup.py install
|
python setup.py install # or
|
||||||
|
python setup.py develop
|
||||||
```
|
```
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from ast import arg
|
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
from jittor import nn
|
from jittor import nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -49,6 +48,9 @@ class ModelNet40H5(Dataset):
|
||||||
):
|
):
|
||||||
super().__init__(batch_size=batch_size, num_workers=num_workers,shuffle=shuffle)
|
super().__init__(batch_size=batch_size, num_workers=num_workers,shuffle=shuffle)
|
||||||
assert mode in ['test','train']
|
assert mode in ['test','train']
|
||||||
|
if not os.path.exists(data_root):
|
||||||
|
import wget
|
||||||
|
wget.download()
|
||||||
# download_modelnet40_dataset()
|
# download_modelnet40_dataset()
|
||||||
self.data, self.label = self.load_data(data_root, mode)
|
self.data, self.label = self.load_data(data_root, mode)
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
@ -105,17 +107,6 @@ class ModelNet40H5(Dataset):
|
||||||
tensor = SparseTensor(indices=jt.array(indices), values=jt.array(feats))
|
tensor = SparseTensor(indices=jt.array(indices), values=jt.array(feats))
|
||||||
return tensor,labels
|
return tensor,labels
|
||||||
|
|
||||||
class ConvBNReLU(nn.Module):
|
|
||||||
def __init__(self,in_channels,out_channels,kernel_size=3,stride=1):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = nn.Sequential(
|
|
||||||
spnn.Conv3d(in_channels,out_channels,kernel_size=kernel_size,stride=stride),
|
|
||||||
spnn.BatchNorm(out_channels),
|
|
||||||
spnn.ReLU()
|
|
||||||
)
|
|
||||||
def execute(self,x):
|
|
||||||
return self.conv1(x)
|
|
||||||
|
|
||||||
class VoxelCNN(nn.Module):
|
class VoxelCNN(nn.Module):
|
||||||
def __init__(self,in_channels,out_channels):
|
def __init__(self,in_channels,out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -125,15 +116,15 @@ class VoxelCNN(nn.Module):
|
||||||
spnn.ReLU()
|
spnn.ReLU()
|
||||||
)
|
)
|
||||||
self.convs = nn.Sequential(
|
self.convs = nn.Sequential(
|
||||||
ConvBNReLU(64,128),
|
spnn.SparseConvBlock(64,128,kernel_size=3),
|
||||||
ConvBNReLU(128,128,kernel_size=2,stride=2),
|
spnn.SparseConvBlock(128,128,kernel_size=2,stride=2),
|
||||||
ConvBNReLU(128,256),
|
spnn.SparseConvBlock(128,256,kernel_size=3),
|
||||||
ConvBNReLU(256,256,kernel_size=2,stride=2),
|
spnn.SparseConvBlock(256,256,kernel_size=2,stride=2),
|
||||||
ConvBNReLU(256,512),
|
spnn.SparseConvBlock(256,512, kernel_size=3),
|
||||||
ConvBNReLU(512,512,kernel_size=2,stride=2),
|
spnn.SparseConvBlock(512,512,kernel_size=2,stride=2),
|
||||||
ConvBNReLU(512,1024),
|
spnn.SparseConvBlock(512,1024,kernel_size=3),
|
||||||
ConvBNReLU(1024,1024,kernel_size=2,stride=2),
|
spnn.SparseConvBlock(1024,1024,kernel_size=2,stride=2),
|
||||||
ConvBNReLU(1024,1024)
|
spnn.SparseConvBlock(1024,1024,kernel_size=3),
|
||||||
)
|
)
|
||||||
self.global_pool = spnn.GlobalPool(op="max")
|
self.global_pool = spnn.GlobalPool(op="max")
|
||||||
|
|
||||||
|
@ -191,7 +182,7 @@ def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
train_loader = ModelNet40H5(
|
train_loader = ModelNet40H5(
|
||||||
data_root="data/modelnet40_ply_hdf5_2048",
|
data_root=args.data_root,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
mode='train',
|
mode='train',
|
||||||
transform=CoordinateTransformation(),
|
transform=CoordinateTransformation(),
|
||||||
|
@ -199,7 +190,7 @@ def main():
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
voxel_size=args.voxel_size)
|
voxel_size=args.voxel_size)
|
||||||
test_loader = ModelNet40H5(
|
test_loader = ModelNet40H5(
|
||||||
data_root="data/modelnet40_ply_hdf5_2048",
|
data_root=args.data_root,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
mode='test',
|
mode='test',
|
||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
|
@ -212,7 +203,7 @@ def main():
|
||||||
train(model,train_loader,optimizer)
|
train(model,train_loader,optimizer)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
acc = test(model,test_loader)
|
acc = test(model,test_loader)
|
||||||
print(f"epoch:{epoch},acc:{acc}")
|
print(f"epoch:{epoch},acc:{acc},lr:{optimizer.lr}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
|
@ -0,0 +1,68 @@
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
from jittor import nn
|
||||||
|
|
||||||
|
from jsparse import SparseTensor
|
||||||
|
from jsparse.nn import SparseConvBlock, SparseResBlock
|
||||||
|
|
||||||
|
__all__ = ['SparseResNet21D']
|
||||||
|
|
||||||
|
|
||||||
|
class SparseResNet(nn.ModuleList):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
blocks: List[Tuple[int, int, Union[int, Tuple[int, ...]],
|
||||||
|
Union[int, Tuple[int, ...]]]],
|
||||||
|
*,
|
||||||
|
in_channels: int = 4,
|
||||||
|
width_multiplier: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = blocks
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.width_multiplier = width_multiplier
|
||||||
|
|
||||||
|
for num_blocks, out_channels, kernel_size, stride in blocks:
|
||||||
|
out_channels = int(out_channels * width_multiplier)
|
||||||
|
blocks = []
|
||||||
|
for index in range(num_blocks):
|
||||||
|
if index == 0:
|
||||||
|
blocks.append(
|
||||||
|
SparseConvBlock(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
blocks.append(
|
||||||
|
SparseResBlock(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
))
|
||||||
|
in_channels = out_channels
|
||||||
|
self.append(nn.Sequential(*blocks))
|
||||||
|
|
||||||
|
def forward(self, x: SparseTensor) -> List[SparseTensor]:
|
||||||
|
outputs = []
|
||||||
|
for module in self:
|
||||||
|
x = module(x)
|
||||||
|
outputs.append(x)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class SparseResNet21D(SparseResNet):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
|
blocks=[
|
||||||
|
(3, 16, 3, 1),
|
||||||
|
(3, 32, 3, 2),
|
||||||
|
(3, 64, 3, 2),
|
||||||
|
(3, 128, 3, 2),
|
||||||
|
(1, 128, (1, 3, 1), (1, 2, 1)),
|
||||||
|
],
|
||||||
|
**kwargs,
|
||||||
|
)
|
|
@ -0,0 +1,122 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import jsparse
|
||||||
|
from jsparse import SparseTensor
|
||||||
|
from jsparse import nn as spnn
|
||||||
|
|
||||||
|
from jsparse.nn import SparseConvBlock, SparseConvTransposeBlock, SparseResBlock
|
||||||
|
|
||||||
|
__all__ = ['SparseResUNet42']
|
||||||
|
|
||||||
|
|
||||||
|
class SparseResUNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stem_channels: int,
|
||||||
|
encoder_channels: List[int],
|
||||||
|
decoder_channels: List[int],
|
||||||
|
*,
|
||||||
|
in_channels: int = 4,
|
||||||
|
width_multiplier: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.stem_channels = stem_channels
|
||||||
|
self.encoder_channels = encoder_channels
|
||||||
|
self.decoder_channels = decoder_channels
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.width_multiplier = width_multiplier
|
||||||
|
|
||||||
|
num_channels = [stem_channels] + encoder_channels + decoder_channels
|
||||||
|
num_channels = [int(width_multiplier * nc) for nc in num_channels]
|
||||||
|
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
spnn.Conv3d(in_channels, num_channels[0], 3),
|
||||||
|
spnn.BatchNorm(num_channels[0]),
|
||||||
|
spnn.ReLU(),
|
||||||
|
spnn.Conv3d(num_channels[0], num_channels[0], 3),
|
||||||
|
spnn.BatchNorm(num_channels[0]),
|
||||||
|
spnn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(Zhijian): the current implementation of encoder and decoder
|
||||||
|
# is hard-coded for 4 encoder stages and 4 decoder stages. We should
|
||||||
|
# work on a more generic implementation in the future.
|
||||||
|
|
||||||
|
self.encoders = nn.ModuleList()
|
||||||
|
for k in range(4):
|
||||||
|
self.encoders.append(
|
||||||
|
nn.Sequential(
|
||||||
|
SparseConvBlock(
|
||||||
|
num_channels[k],
|
||||||
|
num_channels[k],
|
||||||
|
2,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
SparseResBlock(num_channels[k], num_channels[k + 1], 3),
|
||||||
|
SparseResBlock(num_channels[k + 1], num_channels[k + 1], 3),
|
||||||
|
))
|
||||||
|
|
||||||
|
self.decoders = nn.ModuleList()
|
||||||
|
for k in range(4):
|
||||||
|
self.decoders.append(
|
||||||
|
nn.ModuleDict({
|
||||||
|
'upsample':
|
||||||
|
SparseConvTransposeBlock(
|
||||||
|
num_channels[k + 4],
|
||||||
|
num_channels[k + 5],
|
||||||
|
2,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
'fuse':
|
||||||
|
nn.Sequential(
|
||||||
|
SparseResBlock(
|
||||||
|
num_channels[k + 5] + num_channels[3 - k],
|
||||||
|
num_channels[k + 5],
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
SparseResBlock(
|
||||||
|
num_channels[k + 5],
|
||||||
|
num_channels[k + 5],
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
|
||||||
|
def _unet_forward(
|
||||||
|
self,
|
||||||
|
x: SparseTensor,
|
||||||
|
encoders: nn.ModuleList,
|
||||||
|
decoders: nn.ModuleList,
|
||||||
|
) -> List[SparseTensor]:
|
||||||
|
if not encoders and not decoders:
|
||||||
|
return [x]
|
||||||
|
|
||||||
|
# downsample
|
||||||
|
xd = encoders[0](x)
|
||||||
|
|
||||||
|
# inner recursion
|
||||||
|
outputs = self._unet_forward(xd, encoders[1:], decoders[:-1])
|
||||||
|
yd = outputs[-1]
|
||||||
|
|
||||||
|
# upsample and fuse
|
||||||
|
u = decoders[-1]['upsample'](yd)
|
||||||
|
y = decoders[-1]['fuse'](jsparse.cat([u, x]))
|
||||||
|
|
||||||
|
return [x] + outputs + [y]
|
||||||
|
|
||||||
|
def forward(self, x: SparseTensor) -> List[SparseTensor]:
|
||||||
|
return self._unet_forward(self.stem(x), self.encoders, self.decoders)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseResUNet42(SparseResUNet):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
|
stem_channels=32,
|
||||||
|
encoder_channels=[32, 64, 128, 256],
|
||||||
|
decoder_channels=[256, 128, 96, 96],
|
||||||
|
**kwargs,
|
||||||
|
)
|
|
@ -5,7 +5,6 @@ from .count import *
|
||||||
from .devoxelize import *
|
from .devoxelize import *
|
||||||
from .downsample import *
|
from .downsample import *
|
||||||
from .hash import *
|
from .hash import *
|
||||||
from .pooling import *
|
|
||||||
from .query import *
|
from .query import *
|
||||||
from .voxelize import *
|
from .voxelize import *
|
||||||
from .unique import *
|
from .unique import *
|
||||||
|
|
|
@ -206,7 +206,8 @@ class Convolution(Function):
|
||||||
|
|
||||||
assert input.size(1) == weight.size(1)
|
assert input.size(1) == weight.size(1)
|
||||||
|
|
||||||
self.save_vars = input, weight, nbmaps.numpy(), nbsizes.numpy(), transposed
|
# self.save_vars = input, weight, nbmaps.numpy(), nbsizes.numpy(), transposed
|
||||||
|
self.save_vars = (input, weight, nbmaps, nbsizes, transposed)
|
||||||
t = 1 if transposed else 0
|
t = 1 if transposed else 0
|
||||||
|
|
||||||
return jt.code(output_size, input.dtype, [input, weight, nbmaps, nbsizes],
|
return jt.code(output_size, input.dtype, [input, weight, nbmaps, nbsizes],
|
||||||
|
@ -314,8 +315,8 @@ class Convolution(Function):
|
||||||
grad_output: jt.Var
|
grad_output: jt.Var
|
||||||
):
|
):
|
||||||
input, weight, nbmaps, nbsizes, transposed = self.save_vars
|
input, weight, nbmaps, nbsizes, transposed = self.save_vars
|
||||||
nbmaps = jt.array(nbmaps)
|
# nbmaps = jt.array(nbmaps)
|
||||||
nbsizes = jt.array(nbsizes)
|
# nbsizes = jt.array(nbsizes)
|
||||||
|
|
||||||
grad_input, grad_weight = jt.code([input.shape, weight.shape], [input.dtype, weight.dtype], [input, weight, nbmaps, nbsizes, grad_output],
|
grad_input, grad_weight = jt.code([input.shape, weight.shape], [input.dtype, weight.dtype], [input, weight, nbmaps, nbsizes, grad_output],
|
||||||
cuda_header=self.cuda_header,
|
cuda_header=self.cuda_header,
|
||||||
|
|
|
@ -1,114 +0,0 @@
|
||||||
import jittor as jt
|
|
||||||
from typing import Union, Optional, Tuple
|
|
||||||
from jittor.misc import _pair, _triple
|
|
||||||
from jittor import nn
|
|
||||||
from jsparse import SparseTensor
|
|
||||||
from jsparse.nn import functional as F
|
|
||||||
from jsparse.nn.utils import get_kernel_offsets
|
|
||||||
|
|
||||||
__all__ = ['max_pool']
|
|
||||||
|
|
||||||
def apply_pool(
|
|
||||||
input: jt.Var,
|
|
||||||
nbmaps: jt.Var,
|
|
||||||
nbsizes: jt.Var,
|
|
||||||
sizes: Tuple[int, int],
|
|
||||||
transposed: bool = False,
|
|
||||||
method: str = 'max'
|
|
||||||
) -> jt.Var:
|
|
||||||
if not transposed:
|
|
||||||
output = jt.zeros((sizes[1], input.size(-1)), dtype=input.dtype)
|
|
||||||
else:
|
|
||||||
output = jt.zeros((sizes[0], input.size(-1)), dtype=input.dtype)
|
|
||||||
|
|
||||||
kernel_volume = nbsizes.size(0)
|
|
||||||
in_channels = input.size(1)
|
|
||||||
out_size = output.size(0)
|
|
||||||
cur_offset = 0
|
|
||||||
for i in range(kernel_volume):
|
|
||||||
n_active_feats = int(nbsizes[i])
|
|
||||||
t = 1 if transposed else 0
|
|
||||||
|
|
||||||
in_buffer_activated = input.reindex([n_active_feats, in_channels], ['@e0(i0)', 'i1'],
|
|
||||||
extras=[nbmaps[cur_offset:cur_offset + n_active_feats, t]])
|
|
||||||
|
|
||||||
output = jt.maximum(output, in_buffer_activated.reindex_reduce(method, [out_size, in_channels], ['@e0(i0)', 'i1'],
|
|
||||||
extras=[nbmaps[cur_offset:cur_offset + n_active_feats, 1-t]]))
|
|
||||||
|
|
||||||
#output = output.scatter_(0, nbmaps[cur_offset:cur_offset + n_active_feats, 1 - t],
|
|
||||||
# in_buffer_activated, reduce=method)
|
|
||||||
|
|
||||||
cur_offset += n_active_feats
|
|
||||||
return output
|
|
||||||
|
|
||||||
def max_pool(
|
|
||||||
input: SparseTensor,
|
|
||||||
kernel_size: Union[int, Tuple[int, ...]] = 1,
|
|
||||||
stride: Union[int, Tuple[int, ...]] = 1,
|
|
||||||
dilation: Union[int, Tuple[int, ...]] = 1,
|
|
||||||
transposed: bool = False
|
|
||||||
) -> SparseTensor:
|
|
||||||
kernel_size = _triple(kernel_size)
|
|
||||||
stride = _triple(stride)
|
|
||||||
dilation = _triple(dilation)
|
|
||||||
|
|
||||||
if (kernel_size == _triple(1) and stride == _triple(1) and dilation == _triple(1)):
|
|
||||||
return input
|
|
||||||
elif not transposed:
|
|
||||||
output_stride = tuple(input.stride[k] * stride[k] for k in range(3))
|
|
||||||
|
|
||||||
if output_stride in input.cmaps:
|
|
||||||
output_indices = input.cmaps[output_stride]
|
|
||||||
elif all(stride[k] == 1 for k in range(3)):
|
|
||||||
output_indices = input.indices
|
|
||||||
else:
|
|
||||||
output_indices = F.spdownsample(
|
|
||||||
input.indices, stride, kernel_size, input.stride,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input.stride, kernel_size, stride, dilation) not in input.kmaps:
|
|
||||||
offsets = get_kernel_offsets(
|
|
||||||
kernel_size,
|
|
||||||
stride=input.stride,
|
|
||||||
dilation=dilation
|
|
||||||
)
|
|
||||||
references = F.sphash(input.indices) # (N,)
|
|
||||||
queries = F.sphash(output_indices, offsets) # (|K|, N)
|
|
||||||
results = F.spquery(queries, references) # (|K|, N)
|
|
||||||
|
|
||||||
nbsizes = jt.sum(results != -1, dim=1)
|
|
||||||
nbmaps = jt.nonzero(results != -1)
|
|
||||||
|
|
||||||
indices = nbmaps[:, 0] * results.size(1) + nbmaps[:, 1]
|
|
||||||
nbmaps[:, 0] = results.view(-1)[indices]
|
|
||||||
|
|
||||||
input.kmaps[(input.stride, kernel_size, stride, dilation)] = [
|
|
||||||
nbmaps, nbsizes, (input.indices.shape[0], output_indices.shape[0])
|
|
||||||
]
|
|
||||||
|
|
||||||
output_values = apply_pool(
|
|
||||||
input.values,
|
|
||||||
*input.kmaps[(input.stride, kernel_size, stride, dilation)],
|
|
||||||
transposed,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output_stride = tuple(input.stride[k] // stride[k] for k in range(3))
|
|
||||||
output_indices = input.cmaps[output_stride]
|
|
||||||
output_values = apply_pool(
|
|
||||||
input.values,
|
|
||||||
*input.kmaps[(output_stride, kernel_size, stride, dilation)],
|
|
||||||
transposed,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = SparseTensor(
|
|
||||||
indices=output_indices,
|
|
||||||
values=output_values,
|
|
||||||
stride=output_stride,
|
|
||||||
size=input.size
|
|
||||||
)
|
|
||||||
output.cmaps = input.cmaps
|
|
||||||
output.cmaps.setdefault(output_stride, output_indices)
|
|
||||||
output.kmaps = input.kmaps
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
|
@ -4,3 +4,4 @@ from .conv import *
|
||||||
from .norm import *
|
from .norm import *
|
||||||
from .pooling import *
|
from .pooling import *
|
||||||
from .modules import *
|
from .modules import *
|
||||||
|
from .blocks import *
|
|
@ -0,0 +1,87 @@
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from jittor import nn
|
||||||
|
|
||||||
|
from jsparse import SparseTensor
|
||||||
|
from jsparse import nn as spnn
|
||||||
|
|
||||||
|
__all__ = ['SparseConvBlock', 'SparseConvTransposeBlock', 'SparseResBlock']
|
||||||
|
|
||||||
|
|
||||||
|
class SparseConvBlock(nn.Sequential):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, List[int], Tuple[int, ...]],
|
||||||
|
stride: Union[int, List[int], Tuple[int, ...]] = 1,
|
||||||
|
dilation: int = 1) -> None:
|
||||||
|
super().__init__(
|
||||||
|
spnn.Conv3d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation),
|
||||||
|
spnn.BatchNorm(out_channels),
|
||||||
|
spnn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseConvTransposeBlock(nn.Sequential):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, List[int], Tuple[int, ...]],
|
||||||
|
stride: Union[int, List[int], Tuple[int, ...]] = 1,
|
||||||
|
dilation: int = 1) -> None:
|
||||||
|
super().__init__(
|
||||||
|
spnn.Conv3d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
transposed=True),
|
||||||
|
spnn.BatchNorm(out_channels),
|
||||||
|
spnn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseResBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, List[int], Tuple[int, ...]],
|
||||||
|
stride: Union[int, List[int], Tuple[int, ...]] = 1,
|
||||||
|
dilation: int = 1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.main = nn.Sequential(
|
||||||
|
spnn.Conv3d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation=dilation,
|
||||||
|
stride=stride),
|
||||||
|
spnn.BatchNorm(out_channels),
|
||||||
|
spnn.ReLU(),
|
||||||
|
spnn.Conv3d(out_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation=dilation),
|
||||||
|
spnn.BatchNorm(out_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
if in_channels != out_channels or np.prod(stride) != 1:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
spnn.Conv3d(in_channels, out_channels, 1, stride=stride),
|
||||||
|
spnn.BatchNorm(out_channels),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shortcut = nn.Identity()
|
||||||
|
|
||||||
|
self.relu = spnn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||||
|
x = self.relu(self.main(x) + self.shortcut(x))
|
||||||
|
return x
|
|
@ -1,12 +1,5 @@
|
||||||
from ast import Global
|
|
||||||
import jittor as jt
|
|
||||||
from jittor import nn
|
from jittor import nn
|
||||||
|
|
||||||
from jsparse import SparseTensor
|
|
||||||
from jsparse.nn.functional import max_pool
|
|
||||||
|
|
||||||
MaxPool = jt.make_module(max_pool)
|
|
||||||
|
|
||||||
class GlobalPool(nn.Module):
|
class GlobalPool(nn.Module):
|
||||||
def __init__(self,op="max"):
|
def __init__(self,op="max"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
jittor
|
||||||
|
wget
|
||||||
|
h5py
|
Loading…
Reference in New Issue