update classification

This commit is contained in:
li-xl 2022-10-17 22:25:04 +08:00
parent 6dc7bb6220
commit b65f1c2036
13 changed files with 303 additions and 152 deletions

1
.gitignore vendored
View File

@ -3,7 +3,6 @@ __pycache__
build
dist
*.egg-info
*.txt
*.npy
*.log
*.zip

View File

@ -12,7 +12,8 @@ The latest JSparse can be installed by
```bash
cd python
python setup.py install
python setup.py install # or
python setup.py develop
```
## Getting Started

View File

@ -1,4 +1,3 @@
from ast import arg
import jittor as jt
from jittor import nn
import numpy as np
@ -49,6 +48,9 @@ class ModelNet40H5(Dataset):
):
super().__init__(batch_size=batch_size, num_workers=num_workers,shuffle=shuffle)
assert mode in ['test','train']
if not os.path.exists(data_root):
import wget
wget.download()
# download_modelnet40_dataset()
self.data, self.label = self.load_data(data_root, mode)
self.transform = transform
@ -104,17 +106,6 @@ class ModelNet40H5(Dataset):
labels = np.concatenate(labels,axis=0)
tensor = SparseTensor(indices=jt.array(indices), values=jt.array(feats))
return tensor,labels
class ConvBNReLU(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size=3,stride=1):
super().__init__()
self.conv1 = nn.Sequential(
spnn.Conv3d(in_channels,out_channels,kernel_size=kernel_size,stride=stride),
spnn.BatchNorm(out_channels),
spnn.ReLU()
)
def execute(self,x):
return self.conv1(x)
class VoxelCNN(nn.Module):
def __init__(self,in_channels,out_channels):
@ -125,15 +116,15 @@ class VoxelCNN(nn.Module):
spnn.ReLU()
)
self.convs = nn.Sequential(
ConvBNReLU(64,128),
ConvBNReLU(128,128,kernel_size=2,stride=2),
ConvBNReLU(128,256),
ConvBNReLU(256,256,kernel_size=2,stride=2),
ConvBNReLU(256,512),
ConvBNReLU(512,512,kernel_size=2,stride=2),
ConvBNReLU(512,1024),
ConvBNReLU(1024,1024,kernel_size=2,stride=2),
ConvBNReLU(1024,1024)
spnn.SparseConvBlock(64,128,kernel_size=3),
spnn.SparseConvBlock(128,128,kernel_size=2,stride=2),
spnn.SparseConvBlock(128,256,kernel_size=3),
spnn.SparseConvBlock(256,256,kernel_size=2,stride=2),
spnn.SparseConvBlock(256,512, kernel_size=3),
spnn.SparseConvBlock(512,512,kernel_size=2,stride=2),
spnn.SparseConvBlock(512,1024,kernel_size=3),
spnn.SparseConvBlock(1024,1024,kernel_size=2,stride=2),
spnn.SparseConvBlock(1024,1024,kernel_size=3),
)
self.global_pool = spnn.GlobalPool(op="max")
@ -191,7 +182,7 @@ def main():
args = parse_args()
train_loader = ModelNet40H5(
data_root="data/modelnet40_ply_hdf5_2048",
data_root=args.data_root,
batch_size=args.batch_size,
mode='train',
transform=CoordinateTransformation(),
@ -199,7 +190,7 @@ def main():
shuffle=True,
voxel_size=args.voxel_size)
test_loader = ModelNet40H5(
data_root="data/modelnet40_ply_hdf5_2048",
data_root=args.data_root,
batch_size=args.batch_size,
mode='test',
num_workers=args.num_workers,
@ -212,7 +203,7 @@ def main():
train(model,train_loader,optimizer)
scheduler.step()
acc = test(model,test_loader)
print(f"epoch:{epoch},acc:{acc}")
print(f"epoch:{epoch},acc:{acc},lr:{optimizer.lr}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,68 @@
from typing import List, Tuple, Union
from jittor import nn
from jsparse import SparseTensor
from jsparse.nn import SparseConvBlock, SparseResBlock
__all__ = ['SparseResNet21D']
class SparseResNet(nn.ModuleList):
def __init__(
self,
blocks: List[Tuple[int, int, Union[int, Tuple[int, ...]],
Union[int, Tuple[int, ...]]]],
*,
in_channels: int = 4,
width_multiplier: float = 1.0,
) -> None:
super().__init__()
self.blocks = blocks
self.in_channels = in_channels
self.width_multiplier = width_multiplier
for num_blocks, out_channels, kernel_size, stride in blocks:
out_channels = int(out_channels * width_multiplier)
blocks = []
for index in range(num_blocks):
if index == 0:
blocks.append(
SparseConvBlock(
in_channels,
out_channels,
kernel_size,
stride=stride,
))
else:
blocks.append(
SparseResBlock(
in_channels,
out_channels,
kernel_size,
))
in_channels = out_channels
self.append(nn.Sequential(*blocks))
def forward(self, x: SparseTensor) -> List[SparseTensor]:
outputs = []
for module in self:
x = module(x)
outputs.append(x)
return outputs
class SparseResNet21D(SparseResNet):
def __init__(self, **kwargs) -> None:
super().__init__(
blocks=[
(3, 16, 3, 1),
(3, 32, 3, 2),
(3, 64, 3, 2),
(3, 128, 3, 2),
(1, 128, (1, 3, 1), (1, 2, 1)),
],
**kwargs,
)

122
jsparse/backbones/unet.py Normal file
View File

@ -0,0 +1,122 @@
from typing import List
from torch import nn
import jsparse
from jsparse import SparseTensor
from jsparse import nn as spnn
from jsparse.nn import SparseConvBlock, SparseConvTransposeBlock, SparseResBlock
__all__ = ['SparseResUNet42']
class SparseResUNet(nn.Module):
def __init__(
self,
stem_channels: int,
encoder_channels: List[int],
decoder_channels: List[int],
*,
in_channels: int = 4,
width_multiplier: float = 1.0,
) -> None:
super().__init__()
self.stem_channels = stem_channels
self.encoder_channels = encoder_channels
self.decoder_channels = decoder_channels
self.in_channels = in_channels
self.width_multiplier = width_multiplier
num_channels = [stem_channels] + encoder_channels + decoder_channels
num_channels = [int(width_multiplier * nc) for nc in num_channels]
self.stem = nn.Sequential(
spnn.Conv3d(in_channels, num_channels[0], 3),
spnn.BatchNorm(num_channels[0]),
spnn.ReLU(),
spnn.Conv3d(num_channels[0], num_channels[0], 3),
spnn.BatchNorm(num_channels[0]),
spnn.ReLU(),
)
# TODO(Zhijian): the current implementation of encoder and decoder
# is hard-coded for 4 encoder stages and 4 decoder stages. We should
# work on a more generic implementation in the future.
self.encoders = nn.ModuleList()
for k in range(4):
self.encoders.append(
nn.Sequential(
SparseConvBlock(
num_channels[k],
num_channels[k],
2,
stride=2,
),
SparseResBlock(num_channels[k], num_channels[k + 1], 3),
SparseResBlock(num_channels[k + 1], num_channels[k + 1], 3),
))
self.decoders = nn.ModuleList()
for k in range(4):
self.decoders.append(
nn.ModuleDict({
'upsample':
SparseConvTransposeBlock(
num_channels[k + 4],
num_channels[k + 5],
2,
stride=2,
),
'fuse':
nn.Sequential(
SparseResBlock(
num_channels[k + 5] + num_channels[3 - k],
num_channels[k + 5],
3,
),
SparseResBlock(
num_channels[k + 5],
num_channels[k + 5],
3,
),
)
}))
def _unet_forward(
self,
x: SparseTensor,
encoders: nn.ModuleList,
decoders: nn.ModuleList,
) -> List[SparseTensor]:
if not encoders and not decoders:
return [x]
# downsample
xd = encoders[0](x)
# inner recursion
outputs = self._unet_forward(xd, encoders[1:], decoders[:-1])
yd = outputs[-1]
# upsample and fuse
u = decoders[-1]['upsample'](yd)
y = decoders[-1]['fuse'](jsparse.cat([u, x]))
return [x] + outputs + [y]
def forward(self, x: SparseTensor) -> List[SparseTensor]:
return self._unet_forward(self.stem(x), self.encoders, self.decoders)
class SparseResUNet42(SparseResUNet):
def __init__(self, **kwargs) -> None:
super().__init__(
stem_channels=32,
encoder_channels=[32, 64, 128, 256],
decoder_channels=[256, 128, 96, 96],
**kwargs,
)

View File

@ -5,7 +5,6 @@ from .count import *
from .devoxelize import *
from .downsample import *
from .hash import *
from .pooling import *
from .query import *
from .voxelize import *
from .unique import *

View File

@ -206,7 +206,8 @@ class Convolution(Function):
assert input.size(1) == weight.size(1)
self.save_vars = input, weight, nbmaps.numpy(), nbsizes.numpy(), transposed
# self.save_vars = input, weight, nbmaps.numpy(), nbsizes.numpy(), transposed
self.save_vars = (input, weight, nbmaps, nbsizes, transposed)
t = 1 if transposed else 0
return jt.code(output_size, input.dtype, [input, weight, nbmaps, nbsizes],
@ -314,8 +315,8 @@ class Convolution(Function):
grad_output: jt.Var
):
input, weight, nbmaps, nbsizes, transposed = self.save_vars
nbmaps = jt.array(nbmaps)
nbsizes = jt.array(nbsizes)
# nbmaps = jt.array(nbmaps)
# nbsizes = jt.array(nbsizes)
grad_input, grad_weight = jt.code([input.shape, weight.shape], [input.dtype, weight.dtype], [input, weight, nbmaps, nbsizes, grad_output],
cuda_header=self.cuda_header,

View File

@ -1,114 +0,0 @@
import jittor as jt
from typing import Union, Optional, Tuple
from jittor.misc import _pair, _triple
from jittor import nn
from jsparse import SparseTensor
from jsparse.nn import functional as F
from jsparse.nn.utils import get_kernel_offsets
__all__ = ['max_pool']
def apply_pool(
input: jt.Var,
nbmaps: jt.Var,
nbsizes: jt.Var,
sizes: Tuple[int, int],
transposed: bool = False,
method: str = 'max'
) -> jt.Var:
if not transposed:
output = jt.zeros((sizes[1], input.size(-1)), dtype=input.dtype)
else:
output = jt.zeros((sizes[0], input.size(-1)), dtype=input.dtype)
kernel_volume = nbsizes.size(0)
in_channels = input.size(1)
out_size = output.size(0)
cur_offset = 0
for i in range(kernel_volume):
n_active_feats = int(nbsizes[i])
t = 1 if transposed else 0
in_buffer_activated = input.reindex([n_active_feats, in_channels], ['@e0(i0)', 'i1'],
extras=[nbmaps[cur_offset:cur_offset + n_active_feats, t]])
output = jt.maximum(output, in_buffer_activated.reindex_reduce(method, [out_size, in_channels], ['@e0(i0)', 'i1'],
extras=[nbmaps[cur_offset:cur_offset + n_active_feats, 1-t]]))
#output = output.scatter_(0, nbmaps[cur_offset:cur_offset + n_active_feats, 1 - t],
# in_buffer_activated, reduce=method)
cur_offset += n_active_feats
return output
def max_pool(
input: SparseTensor,
kernel_size: Union[int, Tuple[int, ...]] = 1,
stride: Union[int, Tuple[int, ...]] = 1,
dilation: Union[int, Tuple[int, ...]] = 1,
transposed: bool = False
) -> SparseTensor:
kernel_size = _triple(kernel_size)
stride = _triple(stride)
dilation = _triple(dilation)
if (kernel_size == _triple(1) and stride == _triple(1) and dilation == _triple(1)):
return input
elif not transposed:
output_stride = tuple(input.stride[k] * stride[k] for k in range(3))
if output_stride in input.cmaps:
output_indices = input.cmaps[output_stride]
elif all(stride[k] == 1 for k in range(3)):
output_indices = input.indices
else:
output_indices = F.spdownsample(
input.indices, stride, kernel_size, input.stride,
)
if (input.stride, kernel_size, stride, dilation) not in input.kmaps:
offsets = get_kernel_offsets(
kernel_size,
stride=input.stride,
dilation=dilation
)
references = F.sphash(input.indices) # (N,)
queries = F.sphash(output_indices, offsets) # (|K|, N)
results = F.spquery(queries, references) # (|K|, N)
nbsizes = jt.sum(results != -1, dim=1)
nbmaps = jt.nonzero(results != -1)
indices = nbmaps[:, 0] * results.size(1) + nbmaps[:, 1]
nbmaps[:, 0] = results.view(-1)[indices]
input.kmaps[(input.stride, kernel_size, stride, dilation)] = [
nbmaps, nbsizes, (input.indices.shape[0], output_indices.shape[0])
]
output_values = apply_pool(
input.values,
*input.kmaps[(input.stride, kernel_size, stride, dilation)],
transposed,
)
else:
output_stride = tuple(input.stride[k] // stride[k] for k in range(3))
output_indices = input.cmaps[output_stride]
output_values = apply_pool(
input.values,
*input.kmaps[(output_stride, kernel_size, stride, dilation)],
transposed,
)
output = SparseTensor(
indices=output_indices,
values=output_values,
stride=output_stride,
size=input.size
)
output.cmaps = input.cmaps
output.cmaps.setdefault(output_stride, output_indices)
output.kmaps = input.kmaps
return output

View File

@ -3,4 +3,5 @@ from .conv import *
# from .crop import *
from .norm import *
from .pooling import *
from .modules import *
from .modules import *
from .blocks import *

View File

@ -0,0 +1,87 @@
from typing import List, Tuple, Union
import numpy as np
from jittor import nn
from jsparse import SparseTensor
from jsparse import nn as spnn
__all__ = ['SparseConvBlock', 'SparseConvTransposeBlock', 'SparseResBlock']
class SparseConvBlock(nn.Sequential):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]],
stride: Union[int, List[int], Tuple[int, ...]] = 1,
dilation: int = 1) -> None:
super().__init__(
spnn.Conv3d(in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation),
spnn.BatchNorm(out_channels),
spnn.ReLU(),
)
class SparseConvTransposeBlock(nn.Sequential):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]],
stride: Union[int, List[int], Tuple[int, ...]] = 1,
dilation: int = 1) -> None:
super().__init__(
spnn.Conv3d(in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
transposed=True),
spnn.BatchNorm(out_channels),
spnn.ReLU(),
)
class SparseResBlock(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]],
stride: Union[int, List[int], Tuple[int, ...]] = 1,
dilation: int = 1) -> None:
super().__init__()
self.main = nn.Sequential(
spnn.Conv3d(in_channels,
out_channels,
kernel_size,
dilation=dilation,
stride=stride),
spnn.BatchNorm(out_channels),
spnn.ReLU(),
spnn.Conv3d(out_channels,
out_channels,
kernel_size,
dilation=dilation),
spnn.BatchNorm(out_channels),
)
if in_channels != out_channels or np.prod(stride) != 1:
self.shortcut = nn.Sequential(
spnn.Conv3d(in_channels, out_channels, 1, stride=stride),
spnn.BatchNorm(out_channels),
)
else:
self.shortcut = nn.Identity()
self.relu = spnn.ReLU()
def forward(self, x: SparseTensor) -> SparseTensor:
x = self.relu(self.main(x) + self.shortcut(x))
return x

View File

@ -1,12 +1,5 @@
from ast import Global
import jittor as jt
from jittor import nn
from jsparse import SparseTensor
from jsparse.nn.functional import max_pool
MaxPool = jt.make_module(max_pool)
class GlobalPool(nn.Module):
def __init__(self,op="max"):
super().__init__()

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
jittor
wget
h5py