pytorch3d/tests/implicitron/test_voxel_grid_implicit_fu...

228 lines
8.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (
VoxelGridImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from pytorch3d.renderer import ray_bundle_to_ray_points
from tests.common_testing import TestCaseMixin
class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
expand_args_fields(VoxelGridImplicitFunction)
def _get_simple_implicit_function(self, scaffold_res=16):
default_cfg = get_default_args(VoxelGridImplicitFunction)
custom_cfg = DictConfig(
{
"voxel_grid_density_args": {
"voxel_grid_FullResolutionVoxelGrid_args": {"n_features": 7}
},
"decoder_density_class_type": "ElementwiseDecoder",
"decoder_color_class_type": "MLPDecoder",
"decoder_color_MLPDecoder_args": {
"network_args": {
"n_layers": 2,
"output_dim": 3,
"hidden_dim": 128,
}
},
"scaffold_resolution": (scaffold_res, scaffold_res, scaffold_res),
}
)
cfg = OmegaConf.merge(default_cfg, custom_cfg)
return VoxelGridImplicitFunction(**cfg)
def test_forward(self) -> None:
"""
Test one forward of VoxelGridImplicitFunction.
"""
func = self._get_simple_implicit_function()
n_grids, n_points = 10, 9
raybundle = ImplicitronRayBundle(
origins=torch.randn(n_grids, 2, 3, 3),
directions=torch.randn(n_grids, 2, 3, 3),
lengths=torch.randn(n_grids, 2, 3, n_points),
xys=0,
)
func(raybundle)
def test_scaffold_formation(self):
"""
Test calculating the scaffold.
We define a custom density function and make the implicit function use it
After calculating the scaffold we compare the density of our custom
density function with densities from the scaffold.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
func = self._get_simple_implicit_function().to(device)
func.scaffold_max_pool_kernel_size = 1
def new_density(points):
"""
Density function which returns 1 if p>(0.5, 0.5, 0.5) or
p < (-0.5, -0.5, -0.5) else 0
"""
inshape = points.shape
points = points.view(-1, 3)
out = []
for p in points:
if torch.all(p > 0.5) or torch.all(p < -0.5):
out.append(torch.tensor([[1.0]]))
else:
out.append(torch.tensor([[0.0]]))
return torch.cat(out).view(*inshape[:-1], 1).to(device)
func._get_density = new_density
func._get_scaffold(0)
points = torch.tensor(
[
[0, 0, 0],
[1, 1, 1],
[1, 0, 0],
[0.1, 0, 0],
[10, 1, -1],
[-0.8, -0.7, -0.9],
]
).to(device)
expected = new_density(points).float().to(device)
assert torch.allclose(func.voxel_grid_scaffold(points), expected), (
func.voxel_grid_scaffold(points),
expected,
)
def test_scaffold_filtering(self, n_test_points=100):
"""
Test that filtering points with scaffold works.
We define a scaffold and make the implicit function use it. We also
define new density and color functions which check that all passed
points are not in empty space (with scaffold function). In the end
we compare the result from the implicit function with one calculated
simple python, this checks that the points were merged correectly.
"""
device = "cuda"
func = self._get_simple_implicit_function().to(device)
def scaffold(points):
"""'
Function to deterministically and randomly enough assign a point
to empty or occupied space.
Return 1 if second digit of sum after 0 is odd else 0
"""
return (
((points.sum(dim=-1, keepdim=True) * 10**2 % 10).long() % 2) == 1
).float()
def new_density(points):
# check if all passed points should be passed here
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
return points.sum(dim=-1, keepdim=True)
def new_color(points, camera, directions, non_empty_points, num_points_per_ray):
# check if all passed points should be passed here
assert torch.all(scaffold(points)) # , (scaffold(points), points)
return points * 2
# check both computation paths that they contain only points
# which are not in empty space
func._get_density = new_density
func._get_color = new_color
func.voxel_grid_scaffold.forward = scaffold
func._scaffold_ready = True
bundle = ImplicitronRayBundle(
origins=torch.rand((n_test_points, 2, 1, 3), device=device),
directions=torch.rand((n_test_points, 2, 1, 3), device=device),
lengths=torch.rand((n_test_points, 2, 1, 4), device=device),
xys=None,
)
points = ray_bundle_to_ray_points(bundle)
result_density, result_color, _ = func(bundle)
# construct the wanted result 'by hand'
flat_points = points.view(-1, 3)
expected_result_density, expected_result_color = [], []
for point in flat_points:
if scaffold(point) == 1:
expected_result_density.append(point.sum(dim=-1, keepdim=True))
expected_result_color.append(point * 2)
else:
expected_result_density.append(point.new_zeros((1,)))
expected_result_color.append(point.new_zeros((3,)))
expected_result_density = torch.stack(expected_result_density, dim=0).view(
*points.shape[:-1], 1
)
expected_result_color = torch.stack(expected_result_color, dim=0).view(
*points.shape[:-1], 3
)
# check that thre result is expected
assert torch.allclose(result_density, expected_result_density), (
result_density,
expected_result_density,
)
assert torch.allclose(result_color, expected_result_color), (
result_color,
expected_result_color,
)
def test_cropping(self, scaffold_res=9):
"""
Tests whether implicit function finds the bounding box of the object and sends
correct min and max points to voxel grids for rescaling.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
func = self._get_simple_implicit_function(scaffold_res=scaffold_res).to(device)
assert scaffold_res >= 8
div = (scaffold_res - 1) / 2
true_min_point = torch.tensor(
[-3 / div, 0 / div, -3 / div],
device=device,
)
true_max_point = torch.tensor(
[1 / div, 2 / div, 3 / div],
device=device,
)
def new_scaffold(points):
# 1 if between true_min and true_max point else 0
# return points.new_ones((*points.shape[:-1], 1))
return (
torch.logical_and(true_min_point <= points, points <= true_max_point)
.all(dim=-1)
.float()[..., None]
)
called_crop = []
def assert_min_max_points(min_point, max_point):
called_crop.append(1)
self.assertClose(min_point, true_min_point)
self.assertClose(max_point, true_max_point)
func.voxel_grid_density.crop_self = assert_min_max_points
func.voxel_grid_color.crop_self = assert_min_max_points
func.voxel_grid_scaffold.forward = new_scaffold
func._scaffold_ready = True
func._crop(epoch=0)
assert len(called_crop) == 2