JSparse/python/jsparse/nn/functional/devoxelize.py

167 lines
5.8 KiB
Python

import jittor as jt
from jittor import Function
from python.jsparse import SparseTensor
__all__ = ['calc_ti_weights', 'spdevoxelize']
def calc_ti_weights(
indices: jt.Var,
idx_query: jt.Var,
scale: float = 1
) -> jt.Var:
with jt.no_grad():
p = indices
if scale != 1:
pf = jt.floor(indices / scale) * scale
else:
pf = jt.floor(indices)
pc = pf + scale
x = p[:, 1].view(-1, 1)
y = p[:, 2].view(-1, 1)
z = p[:, 3].view(-1, 1)
xf = pf[:, 1].view(-1, 1).float()
yf = pf[:, 2].view(-1, 1).float()
zf = pf[:, 3].view(-1, 1).float()
xc = pc[:, 1].view(-1, 1).float()
yc = pc[:, 2].view(-1, 1).float()
zc = pc[:, 3].view(-1, 1).float()
w0 = (xc - x) * (yc - y) * (zc - z)
w1 = (xc - x) * (yc - y) * (z - zf)
w2 = (xc - x) * (y - yf) * (zc - z)
w3 = (xc - x) * (y - yf) * (z - zf)
w4 = (x - xf) * (yc - y) * (zc - z)
w5 = (x - xf) * (yc - y) * (z - zf)
w6 = (x - xf) * (y - yf) * (zc - z)
w7 = (x - xf) * (y - yf) * (z - zf)
w = jt.concat([w0, w1, w2, w3, w4, w5, w6, w7], dim=1).t()
if scale != 1:
w /= scale ** 3
w[idx_query == -1] = 0
w /= jt.sum(w, dim=0) + 1e-8
return w
class Devoxelize(Function):
def execute(
self,
values: jt.Var,
idx_query: jt.Var,
weights: jt.Var
) -> jt.Var:
# c = values_shape1
# N = idx_query_shape0
output = jt.code((idx_query.shape[0], values.shape[1]), jt.float32, [values, idx_query, weights],
cuda_src="""
__global__ void devoxelize_forward_kernel(@ARGS_DEF) {
@PRECALC
@alias(values, in0)
@alias(idx_query, in1)
@alias(weights, in2)
int index = blockIdx.x * blockDim.x + threadIdx.x;
int i = index / values_shape1;
int j = index % values_shape1;
if (i < idx_query_shape0) {
float cur_values = 0;
for (int k = 0; k < 8; ++ k ) {
int idx = @idx_query(i, k);
cur_values = (idx >= 0) ? @values(idx, j) : 0;
@out(i, j) += @weights(i, k) * cur_values;
}
}
}
devoxelize_forward_kernel<<<out_shape0, out_shape1>>>(@ARGS);
""",
cpu_src="""
@alias(values, in0)
@alias(idx_query, in1)
@alias(weights, in2)
#pragma omp parallel for
for (int i = 0; i < idx_query_shape0; ++ i ) {
for (int j = 0; j < values_shape1; ++ j ) {
float cur_values = 0;
for (int k = 0; k < 8; ++ k ) {
int idx = @idx_query(i, k);
cur_values = (idx >= 0) ? @values(idx, j) : 0;
#pragma omp atomic
@out(i ,j) += @weights(i, k) * cur_values;
}
}
}
"""
)
self.save_vars = (idx_query, weights, values.shape[0])
return output
def grad(self, grad_output: jt.Var):
idx_query, weights, input_size = self.save_vars
grad_values = jt.code((input_size, grad_output.shape[0]), jt.float, [idx_query, weights, grad_output],
cuda_header="""
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
""",
cuda_src="""
__global__ void devoxelize_backward_kernel(@ARGS_DEF) {
@PRECALC
@alias(idx_query, in0)
@alias(weights, in1)
@alias(grad_output, in2)
int index = blockIdx.x * blockDim.x + threadIdx.x;
int c = grad_output_shape1;
int i = index / c;
int j = index % c;
if (i < grad_output_shape0) {
float cur_grad_output = @grad_output(i, j);
#pragma unroll
for (int k = 0; k < 8; ++ k ) {
int idx = @idx_query(i, k);
if (idx >= 0) {
atomicAdd(&@out(idx, j), @weights(i, k) * cur_grad_output);
}
}
}
}
@alias(grad_output, in2)
devoxelize_backward_kernel<<<grad_output_shape0, grad_output_shape1>>>(@ARGS);
""",
cpu_src="""
@alias(idx_query, in0)
@alias(weights, in1)
@alias(grad_output, in2)
for (int i = 0; i < grad_output_shape0; ++ i ) {
#pragma omp parallel for
for (int j = 0; j < grad_output_shape1; ++ j ) {
float cur_grad_output = 0;
for (int k = 0; k < 8; ++ k ) {
int idx = @idx_query(i, k);
cur_grad_output = (idx >= 0) ? @grad_output(i, j) : 0;
#pragma omp atomic
@out(idx, j) += @weights(i, k) * cur_grad_output;
}
}
}
"""
)
return grad_values, None, None
def spdevoxelize(
values: jt.Var,
idx_query: jt.Var,
weights: jt.Var
) -> jt.Var:
return Devoxelize.apply(values, idx_query, weights)