add voxelize and devoxelize, modify the dist structure
This commit is contained in:
parent
6197e72261
commit
64f5c7549d
|
@ -1 +0,0 @@
|
|||
from .sparse import *
|
|
@ -0,0 +1 @@
|
|||
from .sparse import *
|
|
@ -1,10 +1,10 @@
|
|||
from .activation import *
|
||||
from .conv import *
|
||||
# from .count import *
|
||||
from .count import *
|
||||
# from .crop import *
|
||||
# from .devoxelize import *
|
||||
from .devoxelize import *
|
||||
from .downsample import *
|
||||
from .hash import *
|
||||
from .pooling import *
|
||||
from .query import *
|
||||
# from .voxelize import *
|
||||
from .voxelize import *
|
|
@ -1,8 +1,8 @@
|
|||
import jittor as jt
|
||||
import jittor.nn as nn
|
||||
|
||||
from python import SparseTensor
|
||||
from python.nn.utils import fapply
|
||||
from python.jsparse import SparseTensor
|
||||
from python.jsparse.nn.utils import fapply
|
||||
|
||||
__all__ = ['relu', 'leaky_relu']
|
||||
# __all__ = ['relu', 'leaky_relu', 'ReLU', 'LeakyReLU']
|
|
@ -4,10 +4,10 @@ import jittor as jt
|
|||
from jittor import Function
|
||||
from jittor.misc import _pair, _triple
|
||||
|
||||
from python import SparseTensor
|
||||
from python.nn import functional as F
|
||||
from python.nn.utils import get_kernel_offsets
|
||||
from python import make_ntuple
|
||||
from python.jsparse import SparseTensor
|
||||
from python.jsparse.nn import functional as F
|
||||
from python.jsparse.nn.utils import get_kernel_offsets
|
||||
from python.jsparse import make_ntuple
|
||||
|
||||
__all__ = ['conv3d', 'Convolution']
|
||||
|
||||
|
@ -319,7 +319,7 @@ def conv3d(
|
|||
nbmaps, nbsizes, (input.indices.shape[0], output_indices.shape[0])
|
||||
]
|
||||
|
||||
output_values = Convolution(
|
||||
output_values = Convolution.apply(
|
||||
input.values,
|
||||
weight,
|
||||
*input.kmaps[(input.stride, kernel_size, stride, dilation)],
|
||||
|
@ -328,7 +328,7 @@ def conv3d(
|
|||
else:
|
||||
output_stride = tuple(input.stride[k] // stride[k] for k in range(3))
|
||||
output_indices = input.cmaps[output_stride]
|
||||
output_values = Convolution(
|
||||
output_values = Convolution.apply(
|
||||
input.values,
|
||||
weight,
|
||||
*input.kmaps[(output_stride, kernel_size, stride, dilation)],
|
|
@ -0,0 +1,30 @@
|
|||
import jittor as jt
|
||||
|
||||
def spcount(idx_query: jt.Var, num: int) -> jt.Var:
|
||||
return jt.code((num,), jt.int32, [idx_query],
|
||||
cuda_src="""
|
||||
__global__ void count_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@alias(idx_query, in0)
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int cur_idx = @idx_query(i);
|
||||
if (i < idx_query_shape0 && cur_idx >= 0) {
|
||||
atomicAdd(out_p + cur_idx, 1);
|
||||
}
|
||||
}
|
||||
@alias(idx_query, in0)
|
||||
count_kernel<<<(idx_query_shape0 + 511) / 512, 512>>>(@ARGS);
|
||||
""",
|
||||
cpu_src="""
|
||||
@alias(idx_query, in0)
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < idx_query_shape0; ++ i ) {
|
||||
int cur_idx = @idx_query(i);
|
||||
if (cur_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
#pragma omp atomic
|
||||
@out(cur_idx) ++;
|
||||
}
|
||||
"""
|
||||
)
|
|
@ -0,0 +1,166 @@
|
|||
import jittor as jt
|
||||
from jittor import Function
|
||||
|
||||
from python.jsparse import SparseTensor
|
||||
|
||||
__all__ = ['calc_ti_weights', 'spdevoxelize']
|
||||
|
||||
def calc_ti_weights(
|
||||
indices: jt.Var,
|
||||
idx_query: jt.Var,
|
||||
scale: float = 1
|
||||
) -> jt.Var:
|
||||
with jt.no_grad():
|
||||
p = indices
|
||||
if scale != 1:
|
||||
pf = jt.floor(indices / scale) * scale
|
||||
else:
|
||||
pf = jt.floor(indices)
|
||||
pc = pf + scale
|
||||
|
||||
x = p[:, 1].view(-1, 1)
|
||||
y = p[:, 2].view(-1, 1)
|
||||
z = p[:, 3].view(-1, 1)
|
||||
|
||||
xf = pf[:, 1].view(-1, 1).float()
|
||||
yf = pf[:, 2].view(-1, 1).float()
|
||||
zf = pf[:, 3].view(-1, 1).float()
|
||||
|
||||
xc = pc[:, 1].view(-1, 1).float()
|
||||
yc = pc[:, 2].view(-1, 1).float()
|
||||
zc = pc[:, 3].view(-1, 1).float()
|
||||
|
||||
w0 = (xc - x) * (yc - y) * (zc - z)
|
||||
w1 = (xc - x) * (yc - y) * (z - zf)
|
||||
w2 = (xc - x) * (y - yf) * (zc - z)
|
||||
w3 = (xc - x) * (y - yf) * (z - zf)
|
||||
w4 = (x - xf) * (yc - y) * (zc - z)
|
||||
w5 = (x - xf) * (yc - y) * (z - zf)
|
||||
w6 = (x - xf) * (y - yf) * (zc - z)
|
||||
w7 = (x - xf) * (y - yf) * (z - zf)
|
||||
|
||||
w = jt.concat([w0, w1, w2, w3, w4, w5, w6, w7], dim=1).t()
|
||||
if scale != 1:
|
||||
w /= scale ** 3
|
||||
w[idx_query == -1] = 0
|
||||
w /= jt.sum(w, dim=0) + 1e-8
|
||||
return w
|
||||
|
||||
|
||||
class Devoxelize(Function):
|
||||
def execute(
|
||||
self,
|
||||
values: jt.Var,
|
||||
idx_query: jt.Var,
|
||||
weights: jt.Var
|
||||
) -> jt.Var:
|
||||
# c = values_shape1
|
||||
# N = idx_query_shape0
|
||||
output = jt.code((idx_query.shape[0], values.shape[1]), jt.float32, [values, idx_query, weights],
|
||||
cuda_src="""
|
||||
__global__ void devoxelize_forward_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@alias(values, in0)
|
||||
@alias(idx_query, in1)
|
||||
@alias(weights, in2)
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int i = index / values_shape1;
|
||||
int j = index % values_shape1;
|
||||
|
||||
if (i < idx_query_shape0) {
|
||||
float cur_values = 0;
|
||||
for (int k = 0; k < 8; ++ k ) {
|
||||
int idx = @idx_query(i, k);
|
||||
cur_values = (idx >= 0) ? @values(idx, j) : 0;
|
||||
@out(i, j) += @weights(i, k) * cur_values;
|
||||
}
|
||||
}
|
||||
}
|
||||
devoxelize_forward_kernel<<<out_shape0, out_shape1>>>(@ARGS);
|
||||
""",
|
||||
cpu_src="""
|
||||
@alias(values, in0)
|
||||
@alias(idx_query, in1)
|
||||
@alias(weights, in2)
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < idx_query_shape0; ++ i ) {
|
||||
for (int j = 0; j < values_shape1; ++ j ) {
|
||||
float cur_values = 0;
|
||||
for (int k = 0; k < 8; ++ k ) {
|
||||
int idx = @idx_query(i, k);
|
||||
cur_values = (idx >= 0) ? @values(idx, j) : 0;
|
||||
#pragma omp atomic
|
||||
@out(i ,j) += @weights(i, k) * cur_values;
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
self.save_vars = (idx_query, weights, values.shape[0])
|
||||
return output
|
||||
|
||||
def grad(self, grad_output: jt.Var):
|
||||
idx_query, weights, input_size = self.save_vars
|
||||
|
||||
grad_values = jt.code((input_size, grad_output.shape[0]), jt.float, [idx_query, weights, grad_output],
|
||||
cuda_header="""
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_runtime.h>
|
||||
""",
|
||||
cuda_src="""
|
||||
__global__ void devoxelize_backward_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@alias(idx_query, in0)
|
||||
@alias(weights, in1)
|
||||
@alias(grad_output, in2)
|
||||
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int c = grad_output_shape1;
|
||||
int i = index / c;
|
||||
int j = index % c;
|
||||
|
||||
if (i < grad_output_shape0) {
|
||||
float cur_grad_output = @grad_output(i, j);
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 8; ++ k ) {
|
||||
int idx = @idx_query(i, k);
|
||||
if (idx >= 0) {
|
||||
atomicAdd(&@out(idx, j), @weights(i, k) * cur_grad_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@alias(grad_output, in2)
|
||||
devoxelize_backward_kernel<<<grad_output_shape0, grad_output_shape1>>>(@ARGS);
|
||||
""",
|
||||
cpu_src="""
|
||||
@alias(idx_query, in0)
|
||||
@alias(weights, in1)
|
||||
@alias(grad_output, in2)
|
||||
|
||||
for (int i = 0; i < grad_output_shape0; ++ i ) {
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < grad_output_shape1; ++ j ) {
|
||||
float cur_grad_output = 0;
|
||||
for (int k = 0; k < 8; ++ k ) {
|
||||
int idx = @idx_query(i, k);
|
||||
cur_grad_output = (idx >= 0) ? @grad_output(i, j) : 0;
|
||||
#pragma omp atomic
|
||||
@out(idx, j) += @weights(i, k) * cur_grad_output;
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
return grad_values, None, None
|
||||
|
||||
def spdevoxelize(
|
||||
values: jt.Var,
|
||||
idx_query: jt.Var,
|
||||
weights: jt.Var
|
||||
) -> jt.Var:
|
||||
return Devoxelize.apply(values, idx_query, weights)
|
||||
|
|
@ -3,8 +3,8 @@ from typing import Tuple, Union
|
|||
import jittor as jt
|
||||
from jittor.misc import _pair, _triple
|
||||
|
||||
from python.nn.utils import get_kernel_offsets
|
||||
from python.utils import make_ntuple, trunc
|
||||
from python.jsparse.nn.utils import get_kernel_offsets
|
||||
from python.jsparse.utils import make_ntuple, trunc
|
||||
|
||||
__all__ = ['spdownsample']
|
||||
|
|
@ -62,7 +62,7 @@ def sphash(indices: jt.Var,
|
|||
#include <iostream>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@alias(indices, in0)
|
||||
@alias(offsets, in1)
|
||||
""",
|
|
@ -1,6 +1,6 @@
|
|||
import jittor as jt
|
||||
|
||||
from python import SparseTensor
|
||||
from python.jsparse import SparseTensor
|
||||
|
||||
__all__ = ['global_avg_pool', 'global_max_pool']
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
import jittor as jt
|
||||
from jittor import Function
|
||||
|
||||
from python.jsparse import SparseTensor
|
||||
|
||||
__all__ = ['spvoxelize']
|
||||
|
||||
class Voxelize(Function):
|
||||
def execute(
|
||||
self,
|
||||
values: jt.Var,
|
||||
idx_query: jt.Var,
|
||||
counts: jt.Var
|
||||
) -> jt.Var:
|
||||
# N = values_shape0
|
||||
# c = values_shape1
|
||||
# N1 = counts_shape0
|
||||
# out: N1 x c
|
||||
output = jt.code((counts.shape[0], values.shape[1]), "float32", [values, idx_query, counts],
|
||||
cuda_header="""
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_runtime.h>
|
||||
""",
|
||||
cuda_src="""
|
||||
__global__ void voxelize_forward_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@alias(values, in0)
|
||||
@alias(idx_query, in1)
|
||||
@alias(counts, in2)
|
||||
|
||||
int index = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int c = values_shape1;
|
||||
int i = index / c;
|
||||
int j = index % c;
|
||||
|
||||
if (i < values_shape0) {
|
||||
int pos = @idx_query(i);
|
||||
if (pos < 0 || pos >= counts_shape0 || @counts(pos) == 0) return;
|
||||
atomicAdd(&@out(pos, j), @values(i, j) / (float)(@counts(pos)));
|
||||
}
|
||||
}
|
||||
@alias(values, in0)
|
||||
voxelize_forward_kernel<<< values_shape0, values_shape1 >>>(@ARGS);
|
||||
""",
|
||||
cpu_src="""
|
||||
@alias(values, in0)
|
||||
@alias(idx_query, in1)
|
||||
@alias(counts, in2)
|
||||
for (int i = 0; i < values_shape0; ++ i ) {
|
||||
int pos = @idx_query(i);
|
||||
if (@counts(pos) == 0)
|
||||
continue;
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < values_shape1; ++ j ) {
|
||||
#pragma omp atomic
|
||||
@out(pos, j) += @values(i, j) / (float)@counts(pos);
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
self.save_vars = idx_query, counts, values.shape[0]
|
||||
return output
|
||||
|
||||
def grad(self, grad_output: jt.Var):
|
||||
idx_query, counts, input_size = self.save_vars
|
||||
|
||||
grad_values = jt.code((input_size, grad_output.shape[1]), jt.float32, [idx_query, counts, grad_output],
|
||||
cuda_header="""
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_runtime.h>
|
||||
""",
|
||||
cuda_src="""
|
||||
__global__ void voxelize_backward_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@alias(idx_query, in0)
|
||||
@alias(counts, in1)
|
||||
@alias(grad_output, in2)
|
||||
int index = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = index / grad_output_shape1;
|
||||
int j = index % grad_output_shape1;
|
||||
if (i < out_shape0) {
|
||||
int pos = @idx_query(i);
|
||||
if (pos < 0 || pos >= counts_shape0 || @counts(pos) == 0) return;
|
||||
atomicAdd(&@out(pos, j), @grad_output(pos, j) / @counts(pos));
|
||||
}
|
||||
}
|
||||
|
||||
voxelize_backward_kernel<<<out_shape0, out_shape1>>>(@ARGS);
|
||||
""",
|
||||
cpu_src="""
|
||||
@alias(idx_query, in0)
|
||||
@alias(counts, in1)
|
||||
@alias(grad_output, in2)
|
||||
|
||||
for (int i = 0; i < out_shape0; ++ i ) {
|
||||
int pos = @idx_query(i);
|
||||
if (@counts(pos) == 0) continue;
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < grad_output_shape1; ++ j ) {
|
||||
@out(i, j) = @grad_output(pos, j) / (float)@counts(pos);
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
return grad_values, None, None
|
||||
|
||||
def spvoxelize(
|
||||
values: jt.Var,
|
||||
idx_query: jt.Var,
|
||||
counts: jt.Var
|
||||
) -> jt.Var:
|
||||
return Voxelize.apply(values, idx_query, counts)
|
|
@ -1,8 +1,8 @@
|
|||
import jittor as jt
|
||||
from jittor import nn
|
||||
|
||||
from python import SparseTensor
|
||||
from python.nn.functional import relu, leaky_relu
|
||||
from python.jsparse import SparseTensor
|
||||
from python.jsparse.nn.functional import relu, leaky_relu
|
||||
# from python.nn.utils import fapply
|
||||
|
||||
__all__ = ['ReLU', 'LeakyReLU']
|
|
@ -7,8 +7,8 @@ from jittor import nn
|
|||
from jittor import init
|
||||
from jittor.misc import _pair, _triple
|
||||
|
||||
from python import SparseTensor
|
||||
from python.nn import functional as F
|
||||
from python.jsparse import SparseTensor
|
||||
from python.jsparse.nn import functional as F
|
||||
# from utils import make_ntuple
|
||||
|
||||
__all__ = ['Conv3d']
|
|
@ -2,8 +2,8 @@ import jittor as jt
|
|||
from jittor import nn
|
||||
from numpy import kaiser
|
||||
|
||||
from python import SparseTensor
|
||||
from python.nn.utils import fapply
|
||||
from python.jsparse import SparseTensor
|
||||
from python.jsparse.nn.utils import fapply
|
||||
|
||||
__all__ = ['BatchNorm', 'GroupNorm']
|
||||
|
|
@ -2,8 +2,8 @@ from ast import Global
|
|||
import jittor as jt
|
||||
from jittor import nn
|
||||
|
||||
from python import SparseTensor
|
||||
from python.nn.functional import global_avg_pool, global_max_pool
|
||||
from python.jsparse import SparseTensor
|
||||
from python.jsparse.nn.functional import global_avg_pool, global_max_pool
|
||||
|
||||
__all__ = ['GlobalAvgPool', 'GlobalMaxPool']
|
||||
|
|
@ -2,7 +2,7 @@ from typing import Callable
|
|||
|
||||
import jittor as jt
|
||||
|
||||
from python import SparseTensor
|
||||
from python.jsparse import SparseTensor
|
||||
|
||||
__all__ = ['fapply']
|
||||
|
|
@ -3,21 +3,21 @@ from typing import Tuple, Union
|
|||
import numpy as np
|
||||
import jittor as jt
|
||||
|
||||
from python.utils import make_ntuple, trunc
|
||||
from python.jsparse.utils import make_ntuple, trunc
|
||||
|
||||
__all__ = ['get_kernel_offsets']
|
||||
|
||||
def get_kernel_offsets(size: Union[int, Tuple[int, ...]],
|
||||
def get_kernel_offsets(kernel_size: Union[int, Tuple[int, ...]],
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
dilation: Union[int, Tuple[int, ...]] = 1) -> jt.Var:
|
||||
size = make_ntuple(size, ndim=3)
|
||||
kernel_size = make_ntuple(kernel_size, ndim=3)
|
||||
stride = make_ntuple(stride, ndim=3)
|
||||
dilation = make_ntuple(dilation, ndim=3)
|
||||
|
||||
offsets = [(np.arange(-size[k] // 2 + 1, size[k] // 2 + 1) * stride[k]
|
||||
offsets = [(np.arange(-kernel_size[k] // 2 + 1, kernel_size[k] // 2 + 1) * stride[k]
|
||||
* dilation[k]) for k in range(3)]
|
||||
|
||||
if np.prod(size) % 2 == 1:
|
||||
if np.prod(kernel_size) % 2 == 1:
|
||||
offsets = [[x, y, z] for z in offsets[2] for y in offsets[1]
|
||||
for x in offsets[0]]
|
||||
else:
|
|
@ -0,0 +1,124 @@
|
|||
from itertools import count
|
||||
import numpy as np
|
||||
|
||||
import jittor as jt
|
||||
from jittor.misc import _pair, _triple
|
||||
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from python.jsparse.utils import make_ntuple, sparse_quantize, set_hash
|
||||
# from .utils.quantize import sparse_quantize
|
||||
# from indice_manager import IndiceManager
|
||||
|
||||
class SparseTensor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indices: jt.Var,
|
||||
values: jt.Var,
|
||||
stride: Union[int, Tuple[int, ...]],
|
||||
size,
|
||||
quantize=True,
|
||||
voxel_size=1,
|
||||
coalesce_mode='sum',
|
||||
indice_manager=None,
|
||||
device=None,
|
||||
):
|
||||
assert isinstance(indices, jt.Var) and isinstance(values, jt.Var)
|
||||
assert (values.ndim == 2)
|
||||
# self.indices = indices
|
||||
# self.values = values
|
||||
self.size = size
|
||||
self.ndim = indices.shape[1] - 1
|
||||
self.stride =_triple(stride)
|
||||
self.voxel_size = voxel_size
|
||||
self.coalesce_mode = coalesce_mode
|
||||
self.cmaps = {}
|
||||
self.kmaps = {}
|
||||
|
||||
##########################
|
||||
# Setup CoordsManager
|
||||
##########################
|
||||
# if indice_manager is None:
|
||||
# self.indice_manager = IndiceManager(
|
||||
# ndim=self.ndim,
|
||||
|
||||
# )
|
||||
|
||||
##########################
|
||||
# Initialize coords
|
||||
##########################
|
||||
if quantize:
|
||||
self.seed = 1
|
||||
for i in range(len(self.stride)):
|
||||
self.seed += i
|
||||
self.seed *= self.stride[i]
|
||||
self.hash_multiplier = set_hash(self.ndim, self.seed)
|
||||
|
||||
self.hash_num, self.indices, mapping, inverse_mapping, count = sparse_quantize(indices, self.hash_multiplier, self.voxel_size, return_index=True, return_inverse=True, return_count=True)
|
||||
self.inverse_mapping = inverse_mapping
|
||||
|
||||
if len(values.shape) == 1:
|
||||
out_size = (self.indices.shape[0], )
|
||||
elif len(values.shape) == 2:
|
||||
out_size = (self.indices.shape[0], values.shape[-1])
|
||||
|
||||
if self.coalesce_mode == 'sum':
|
||||
out_size = (self.indices.shape[0], values.shape[-1])
|
||||
self.values = jt.zeros(out_size, dtype=values.dtype).scatter_(0, inverse_mapping, values, reduce='add')
|
||||
elif self.coalesce_mode == 'average':
|
||||
out_size = (self.indices.shape[0], values.shape[-1])
|
||||
self.values = jt.zeros(out_size, dtype=values.dtype).scatter_(0, inverse_mapping, values, reduce='add')
|
||||
self.values /= count
|
||||
elif self.coalesce_mode == 'sample':
|
||||
self.values = values[self.indices]
|
||||
else:
|
||||
self.indices = indices
|
||||
self.values = values
|
||||
|
||||
# if indice_manager is None:
|
||||
# # TODO If set to share the indices man, use the global indices man
|
||||
|
||||
# # init the indices
|
||||
# indice_manager = Indice
|
||||
|
||||
|
||||
def _indices(self):
|
||||
return self.indices
|
||||
|
||||
def _values(self):
|
||||
return self.values
|
||||
|
||||
def _binary_operation(self, other, _binary_op):
|
||||
assert isinstance(other, self.__class__)
|
||||
return
|
||||
# TODO set up the indices dict
|
||||
# so that wedo not need to merge the indice group
|
||||
# which has already been merged
|
||||
|
||||
# if the indices of self and other should be merged
|
||||
|
||||
|
||||
class PointTensor:
|
||||
|
||||
def __init__(self, values, indices, idx_query=None, weights=None):
|
||||
self.values = values
|
||||
self.indices = indices
|
||||
self.idx_query = idx_query if idx_query is not None else {}
|
||||
self.weights = weights if weights is not None else {}
|
||||
self.additional_values = {}
|
||||
self.additional_values['idx_query'] = {}
|
||||
self.additional_values['counts'] = {}
|
||||
|
||||
def detach(self):
|
||||
self.values = self.values.detach()
|
||||
self.indices = self.indices.detach()
|
||||
return self
|
||||
|
||||
def __add__(self, other):
|
||||
pt = PointTensor(self.values + other.values, self.indices, self.idx_query,
|
||||
self.weights)
|
||||
pt.additional_values = self.additional_values
|
||||
return pt
|
||||
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
from itertools import count
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from python.utils import make_ntuple, sparse_quantize, set_hash
|
||||
# from .utils.quantize import sparse_quantize
|
||||
# from indice_manager import IndiceManager
|
||||
|
||||
class SparseTensor:
|
||||
def __init__(
|
||||
self,
|
||||
indices: jt.Var,
|
||||
values: jt.Var,
|
||||
stride: Union[int, Tuple[int, ...]],
|
||||
size,
|
||||
voxel_size=1,
|
||||
coalesce_mode='sum',
|
||||
indice_manager=None,
|
||||
device=None,
|
||||
):
|
||||
assert isinstance(indices, jt.Var) and isinstance(values, jt.Var)
|
||||
assert (values.ndim == 2)
|
||||
# self.indices = indices
|
||||
# self.values = values
|
||||
self.size = size
|
||||
self.ndim = indices.shape[1] - 1
|
||||
self.stride = make_ntuple(stride, ndim=self.ndim)
|
||||
self.voxel_size = voxel_size
|
||||
self.coalesce_mode = coalesce_mode
|
||||
self.cmaps = {}
|
||||
self.kmaps = {}
|
||||
|
||||
##########################
|
||||
# Setup CoordsManager
|
||||
##########################
|
||||
# if indice_manager is None:
|
||||
# self.indice_manager = IndiceManager(
|
||||
# ndim=self.ndim,
|
||||
|
||||
# )
|
||||
|
||||
##########################
|
||||
# Initialize coords
|
||||
##########################
|
||||
self.seed = 1
|
||||
for i in range(len(self.stride)):
|
||||
self.seed += i
|
||||
self.seed *= self.stride[i]
|
||||
self.hash_multiplier = set_hash(self.ndim, self.seed)
|
||||
|
||||
self.hash_num, self.indices, mapping, inverse_mapping, count = sparse_quantize(indices, self.hash_multiplier, self.voxel_size, return_index=True, return_inverse=True, return_count=True)
|
||||
self.inverse_mapping = inverse_mapping
|
||||
|
||||
if len(values.shape) == 1:
|
||||
out_size = (self.indices.shape[0], )
|
||||
elif len(values.shape) == 2:
|
||||
out_size = (self.indices.shape[0], values.shape[-1])
|
||||
|
||||
if self.coalesce_mode == 'sum':
|
||||
out_size = (self.indices.shape[0], values.shape[-1])
|
||||
self.values = jt.zeros(out_size, dtype=values.dtype).scatter_(0, inverse_mapping, values, reduce='add')
|
||||
elif self.coalesce_mode == 'average':
|
||||
out_size = (self.indices.shape[0], values.shape[-1])
|
||||
self.values = jt.zeros(out_size, dtype=values.dtype).scatter_(0, inverse_mapping, values, reduce='add')
|
||||
self.values /= count
|
||||
elif self.coalesce_mode == 'sample':
|
||||
self.values = values[self.indices]
|
||||
|
||||
# if indice_manager is None:
|
||||
# # TODO If set to share the indices man, use the global indices man
|
||||
|
||||
# # init the indices
|
||||
# indice_manager = Indice
|
||||
|
||||
|
||||
def _indices(self):
|
||||
return self.indices
|
||||
|
||||
def _values(self):
|
||||
return self.values
|
||||
|
||||
def _binary_operation(self, other, _binary_op):
|
||||
assert isinstance(other, self.__class__)
|
||||
return
|
||||
# TODO set up the indices dict
|
||||
# so that wedo not need to merge the indice group
|
||||
# which has already been merged
|
||||
|
||||
# if the indices of self and other should be merged
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import jittor as jt
|
||||
import jittor.nn as nn
|
||||
from jittor import init
|
||||
from jittor.misc import _pair, _triple
|
||||
|
||||
from itertools import repeat
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from python.jsparse import SparseTensor
|
||||
from python.jsparse import PointTensor
|
||||
from python.jsparse.utils import make_ntuple
|
||||
from python.jsparse.nn import functional as F
|
||||
from python.jsparse.nn.utils import get_kernel_offsets
|
||||
from python.jsparse.nn.functional import Convolution
|
||||
|
||||
import torchsparse
|
||||
from torchsparse import nn as spnn
|
||||
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
in_channels = 3
|
||||
out_channels = 64
|
||||
kernel_size = 3
|
||||
stride = 1
|
||||
dilation = 1
|
||||
groups = 1
|
||||
bias = False
|
||||
transposed = False
|
||||
|
||||
kernel_size = _triple(kernel_size)
|
||||
stride = _triple(stride)
|
||||
dilation = _triple(dilation)
|
||||
kernel_volume = int(np.prod(kernel_size))
|
||||
|
||||
N = 10
|
||||
coords = np.random.uniform(0, 10, size=(N, 4))
|
||||
feats = np.random.randn(N, 3)
|
||||
labels = np.random.choice(5, N)
|
||||
print(coords.shape)
|
||||
print(feats.shape)
|
||||
|
||||
coo = jt.Var(coords)
|
||||
val = jt.Var(feats)
|
||||
size = (10, 10, 10)
|
||||
|
||||
fan = (out_channels if transposed else in_channels) * kernel_volume
|
||||
std = 1 / math.sqrt(fan)
|
||||
|
||||
if kernel_volume > 1:
|
||||
weight = init.uniform([kernel_volume, in_channels, out_channels], 'float32', -std, std)
|
||||
else:
|
||||
weight = init.uniform([in_channels, out_channels], 'float32')
|
||||
if bias:
|
||||
bias = init.uniform([out_channels], "float32", -std, std)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
|
||||
x = SparseTensor(coo, val, 1, size)
|
||||
z = PointTensor(x.values, x.indices.float())
|
||||
|
||||
pc_hash = F.sphash(
|
||||
jt.concat([
|
||||
z.indices[:, 0].int().view(-1, 1),
|
||||
jt.floor(z.indices[:, 1:] / x.stride[0]).int() * x.stride[0]
|
||||
], 1))
|
||||
sparse_hash = F.sphash(x.indices)
|
||||
idx_query = F.spquery(pc_hash, sparse_hash).int()
|
||||
counts = F.spcount(idx_query, x.indices.shape[0])
|
||||
z.additional_values['idx_query'][x.stride] = idx_query
|
||||
z.additional_values['counts'][x.stride] = counts
|
||||
inserted_values = F.spvoxelize(z.values, idx_query, counts)
|
||||
new_tensor = SparseTensor(inserted_values, x.indices, x.stride, x.size, False)
|
||||
new_tensor.cmaps = x.cmaps
|
||||
new_tensor.kmaps = x.kmaps
|
||||
print(inserted_values)
|
||||
|
||||
offsets = get_kernel_offsets(kernel_size=2, stride=x.stride, dilation=1)
|
||||
old_hash = F.sphash(
|
||||
jt.concat([
|
||||
z.indices[:, 0].int().view(-1, 1),
|
||||
jt.floor(z.indices[:, 1:] / x.stride[0]).int() * x.stride[0]
|
||||
], 1), offsets)
|
||||
pc_hash = F.sphash(x.indices)
|
||||
idx_query = F.spquery(old_hash, pc_hash).int()
|
||||
weights = F.calc_ti_weights(z.indices, idx_query,
|
||||
scale=x.stride[0]).t()
|
||||
idx_query = idx_query.t()
|
||||
new_values = F.spdevoxelize(x.values, idx_query, weights)
|
||||
|
||||
print(jt.grad(new_values, x.values))
|
Loading…
Reference in New Issue