JSparse/python/jsparse/nn/functional/voxelize.py

114 lines
4.1 KiB
Python

import jittor as jt
from jittor import Function
from python.jsparse import SparseTensor
__all__ = ['spvoxelize']
class Voxelize(Function):
def execute(
self,
values: jt.Var,
idx_query: jt.Var,
counts: jt.Var
) -> jt.Var:
# N = values_shape0
# c = values_shape1
# N1 = counts_shape0
# out: N1 x c
output = jt.code((counts.shape[0], values.shape[1]), "float32", [values, idx_query, counts],
cuda_header="""
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
""",
cuda_src="""
__global__ void voxelize_forward_kernel(@ARGS_DEF) {
@PRECALC
@alias(values, in0)
@alias(idx_query, in1)
@alias(counts, in2)
int index = blockDim.x * blockIdx.x + threadIdx.x;
int c = values_shape1;
int i = index / c;
int j = index % c;
if (i < values_shape0) {
int pos = @idx_query(i);
if (pos < 0 || pos >= counts_shape0 || @counts(pos) == 0) return;
atomicAdd(&@out(pos, j), @values(i, j) / (float)(@counts(pos)));
}
}
@alias(values, in0)
voxelize_forward_kernel<<< values_shape0, values_shape1 >>>(@ARGS);
""",
cpu_src="""
@alias(values, in0)
@alias(idx_query, in1)
@alias(counts, in2)
for (int i = 0; i < values_shape0; ++ i ) {
int pos = @idx_query(i);
if (@counts(pos) == 0)
continue;
#pragma omp parallel for
for (int j = 0; j < values_shape1; ++ j ) {
#pragma omp atomic
@out(pos, j) += @values(i, j) / (float)@counts(pos);
}
}
"""
)
self.save_vars = idx_query, counts, values.shape[0]
return output
def grad(self, grad_output: jt.Var):
idx_query, counts, input_size = self.save_vars
grad_values = jt.code((input_size, grad_output.shape[1]), jt.float32, [idx_query, counts, grad_output],
cuda_header="""
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
""",
cuda_src="""
__global__ void voxelize_backward_kernel(@ARGS_DEF) {
@PRECALC
@alias(idx_query, in0)
@alias(counts, in1)
@alias(grad_output, in2)
int index = blockDim.x * blockIdx.x + threadIdx.x;
int i = index / grad_output_shape1;
int j = index % grad_output_shape1;
if (i < out_shape0) {
int pos = @idx_query(i);
if (pos < 0 || pos >= counts_shape0 || @counts(pos) == 0) return;
atomicAdd(&@out(pos, j), @grad_output(pos, j) / @counts(pos));
}
}
voxelize_backward_kernel<<<out_shape0, out_shape1>>>(@ARGS);
""",
cpu_src="""
@alias(idx_query, in0)
@alias(counts, in1)
@alias(grad_output, in2)
for (int i = 0; i < out_shape0; ++ i ) {
int pos = @idx_query(i);
if (@counts(pos) == 0) continue;
#pragma omp parallel for
for (int j = 0; j < grad_output_shape1; ++ j ) {
@out(i, j) = @grad_output(pos, j) / (float)@counts(pos);
}
}
"""
)
return grad_values, None, None
def spvoxelize(
values: jt.Var,
idx_query: jt.Var,
counts: jt.Var
) -> jt.Var:
return Voxelize.apply(values, idx_query, counts)