JSparse/voxelize_test.py

96 lines
2.6 KiB
Python

import time
import math
import numpy as np
import torch
import jittor as jt
import jittor.nn as nn
from jittor import init
from jittor.misc import _pair, _triple
from itertools import repeat
from typing import List, Tuple, Union
from JSparse import SparseTensor
from JSparse import PointTensor
from JSparse.utils import make_ntuple
from JSparse.nn import functional as F
from JSparse.nn.utils import get_kernel_offsets
from JSparse.nn.functional import Convolution
import torchsparse
from torchsparse import nn as spnn
jt.flags.use_cuda = 1
in_channels = 3
out_channels = 64
kernel_size = 3
stride = 1
dilation = 1
groups = 1
bias = False
transposed = False
kernel_size = _triple(kernel_size)
stride = _triple(stride)
dilation = _triple(dilation)
kernel_volume = int(np.prod(kernel_size))
N = 10
coords = np.random.uniform(0, 10, size=(N, 4))
feats = np.random.randn(N, 3)
labels = np.random.choice(5, N)
print(coords.shape)
print(feats.shape)
coo = jt.Var(coords)
val = jt.Var(feats)
size = (10, 10, 10)
fan = (out_channels if transposed else in_channels) * kernel_volume
std = 1 / math.sqrt(fan)
if kernel_volume > 1:
weight = init.uniform([kernel_volume, in_channels, out_channels], 'float32', -std, std)
else:
weight = init.uniform([in_channels, out_channels], 'float32')
if bias:
bias = init.uniform([out_channels], "float32", -std, std)
else:
bias = None
x = SparseTensor(coo, val, 1, size)
z = PointTensor(x.values, x.indices.float())
pc_hash = F.sphash(
jt.concat([
z.indices[:, 0].int().view(-1, 1),
jt.floor(z.indices[:, 1:] / x.stride[0]).int() * x.stride[0]
], 1))
sparse_hash = F.sphash(x.indices)
idx_query = F.spquery(pc_hash, sparse_hash).int()
counts = F.spcount(idx_query, x.indices.shape[0])
z.additional_values['idx_query'][x.stride] = idx_query
z.additional_values['counts'][x.stride] = counts
inserted_values = F.spvoxelize(z.values, idx_query, counts)
new_tensor = SparseTensor(inserted_values, x.indices, x.stride, x.size, False)
new_tensor.cmaps = x.cmaps
new_tensor.kmaps = x.kmaps
print(inserted_values)
offsets = get_kernel_offsets(kernel_size=2, stride=x.stride, dilation=1)
old_hash = F.sphash(
jt.concat([
z.indices[:, 0].int().view(-1, 1),
jt.floor(z.indices[:, 1:] / x.stride[0]).int() * x.stride[0]
], 1), offsets)
pc_hash = F.sphash(x.indices)
idx_query = F.spquery(old_hash, pc_hash).int()
weights = F.calc_ti_weights(z.indices, idx_query,
scale=x.stride[0]).t()
idx_query = idx_query.t()
new_values = F.spdevoxelize(x.values, idx_query, weights)
print(jt.grad(new_values, x.values))