pytorch3d/tests/test_splatter_blend.py

628 lines
23 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.renderer.cameras import FoVPerspectiveCameras
from pytorch3d.renderer.splatter_blend import (
_compute_occlusion_layers,
_compute_splatted_colors_and_weights,
_compute_splatting_colors_and_weights,
_get_splat_kernel_normalization,
_normalize_and_compose_all_layers,
_offset_splats,
_precompute,
_prepare_pixels_and_colors,
)
from .common_testing import TestCaseMixin
offsets = torch.tensor(
[
[-1, -1],
[-1, 0],
[-1, 1],
[0, -1],
[0, 0],
[0, 1],
[1, -1],
[1, 0],
[1, 1],
],
device=torch.device("cpu"),
)
def compute_splatting_colors_and_weights_naive(pixel_coords_screen, colors, sigma):
normalizer = float(_get_splat_kernel_normalization(offsets))
N, H, W, K, _ = colors.shape
splat_weights_and_colors = torch.zeros((N, H, W, K, 9, 5))
for n in range(N):
for h in range(H):
for w in range(W):
for k in range(K):
q_xy = pixel_coords_screen[n, h, w, k]
q_to_px_center = torch.floor(q_xy) - q_xy + 0.5
color = colors[n, h, w, k]
alpha = colors[n, h, w, k, 3:4]
for d in range(9):
dist_p_q = torch.sum((q_to_px_center + offsets[d]) ** 2)
splat_weight = (
alpha * torch.exp(-dist_p_q / (2 * sigma**2)) * normalizer
)
splat_color = splat_weight * color
splat_weights_and_colors[n, h, w, k, d, :4] = splat_color
splat_weights_and_colors[n, h, w, k, d, 4:5] = splat_weight
return splat_weights_and_colors
class TestPrecompute(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.results_cpu = _precompute((2, 3, 4, 5), torch.device("cpu"))
self.results1_cpu = _precompute((1, 1, 1, 1), torch.device("cpu"))
def test_offsets(self):
self.assertClose(self.results_cpu[2].shape, offsets.shape, atol=0)
self.assertClose(self.results_cpu[2], offsets, atol=0)
# Offsets should be independent of input_size.
self.assertClose(self.results_cpu[2], self.results1_cpu[2], atol=0)
def test_crops_h(self):
target_crops_h1 = torch.tensor(
[
# chennels being offset:
# R G B A W(eight)
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
]
* 3, # 3 because we're aiming at (N, H, W+2, K, 9, 5) with W=1.
device=torch.device("cpu"),
).reshape(1, 1, 3, 1, 9, 5)
self.assertClose(self.results1_cpu[0], target_crops_h1, atol=0)
target_crops_h_base = target_crops_h1[0, 0, 0]
target_crops_h = torch.cat(
[target_crops_h_base, target_crops_h_base + 1, target_crops_h_base + 2],
dim=0,
)
# Check that we have the right shape, and (after broadcasting) it has the right
# values. These should be repeated (tiled) for each n and k.
self.assertClose(
self.results_cpu[0].shape, torch.tensor([2, 3, 6, 5, 9, 5]), atol=0
)
for n in range(2):
for w in range(6):
for k in range(5):
self.assertClose(
self.results_cpu[0][n, :, w, k],
target_crops_h,
)
def test_crops_w(self):
target_crops_w1 = torch.tensor(
[
# chennels being offset:
# R G B A W(eight)
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
],
device=torch.device("cpu"),
).reshape(1, 1, 1, 1, 9, 5)
self.assertClose(self.results1_cpu[1], target_crops_w1)
target_crops_w_base = target_crops_w1[0, 0, 0]
target_crops_w = torch.cat(
[
target_crops_w_base,
target_crops_w_base + 1,
target_crops_w_base + 2,
target_crops_w_base + 3,
],
dim=0,
) # Each w value needs an increment.
# Check that we have the right shape, and (after broadcasting) it has the right
# values. These should be repeated (tiled) for each n and k.
self.assertClose(self.results_cpu[1].shape, torch.tensor([2, 3, 4, 5, 9, 5]))
for n in range(2):
for h in range(3):
for k in range(5):
self.assertClose(
self.results_cpu[1][n, h, :, k],
target_crops_w,
atol=0,
)
class TestPreparPixelsAndColors(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.device = torch.device("cpu")
N, H, W, K = 2, 3, 4, 5
self.pixel_coords_cameras = torch.randn(
(N, H, W, K, 3), device=self.device, requires_grad=True
)
self.colors_before = torch.rand((N, H, W, K, 3), device=self.device)
self.cameras = FoVPerspectiveCameras(device=self.device)
self.background_mask = torch.rand((N, H, W, K), device=self.device) < 0.5
self.pixel_coords_screen, self.colors_after = _prepare_pixels_and_colors(
self.pixel_coords_cameras,
self.colors_before,
self.cameras,
self.background_mask,
)
def test_background_z(self):
self.assertTrue(
torch.all(self.pixel_coords_screen[..., 2][self.background_mask] == 1.0)
)
def test_background_alpha(self):
self.assertTrue(
torch.all(self.colors_after[..., 3][self.background_mask] == 0.0)
)
class TestGetSplatKernelNormalization(TestCaseMixin, unittest.TestCase):
def test_splat_kernel_normalization(self):
self.assertAlmostEqual(
float(_get_splat_kernel_normalization(offsets)), 0.6503, places=3
)
self.assertAlmostEqual(
float(_get_splat_kernel_normalization(offsets, 0.01)), 1.05, places=3
)
with self.assertRaisesRegex(ValueError, "Only positive standard deviations"):
_get_splat_kernel_normalization(offsets, 0)
class TestComputeOcclusionLayers(TestCaseMixin, unittest.TestCase):
def test_single_layer(self):
# If there's only one layer, all splats must be on the surface level.
N, H, W, K = 2, 3, 4, 1
q_depth = torch.rand(N, H, W, K)
occlusion_layers = _compute_occlusion_layers(q_depth)
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
def test_all_equal(self):
# If all q-vals are equal, then all splats must be on the surface level.
N, H, W, K = 2, 3, 4, 5
q_depth = torch.ones((N, H, W, K)) * 0.1234
occlusion_layers = _compute_occlusion_layers(q_depth)
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
def test_mid_to_top_level_splatting(self):
# Check that occlusion buffers get accumulated as expected when the splatting
# and splatted pixels are co-surface on different intersection layers.
# This test will make best sense with accompanying Fig. 4 from "Differentiable
# Surface Rendering via Non-differentiable Sampling" by Cole et al.
for direction, offset in enumerate(offsets):
if direction == 4:
continue # Skip self-splatting which is always co-surface.
depths = torch.zeros(1, 3, 3, 3)
# This is our q, the pixel splatted onto, in the center of the image.
depths[0, 1, 1] = torch.tensor([0.71, 0.8, 1.0])
# This is our p, the splatting pixel.
depths[0, offset[0] + 1, offset[1] + 1] = torch.tensor([0.5, 0.7, 0.9])
occlusion_layers = _compute_occlusion_layers(depths)
# Check that we computed that it is the middle layer of p that is co-
# surface with q. (1, 1) is the id of q in the depth array, and offset_id
# is the id of p's direction w.r.t. q.
psurfaceid_onto_q = occlusion_layers[0, 1, 1, direction]
self.assertEqual(int(psurfaceid_onto_q), 1)
# Conversely, if we swap p and q, we have a top-level splatting onto
# mid-level. offset + 1 is the id of p, and 8-offset_id is the id of
# q's direction w.r.t. p (e.g. if p is [-1, -1] w.r.t. q, then q is
# [1, 1] w.r.t. p; we use the ids of these two directions in the offsets
# array).
qsurfaceid_onto_p = occlusion_layers[
0, offset[0] + 1, offset[1] + 1, 8 - direction
]
self.assertEqual(int(qsurfaceid_onto_p), -1)
class TestComputeSplattingColorsAndWeights(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.N, self.H, self.W, self.K = 2, 3, 4, 5
self.pixel_coords_screen = (
torch.stack(
meshgrid_ij(torch.arange(self.H), torch.arange(self.W)),
dim=-1,
)
.reshape(1, self.H, self.W, 1, 2)
.expand(self.N, self.H, self.W, self.K, 2)
.float()
+ 0.5
)
self.colors = torch.ones((self.N, self.H, self.W, self.K, 4))
def test_all_equal(self):
# If all colors are equal and on a regular grid, all weights and reweighted
# colors should be equal given a specific splatting direction.
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
self.pixel_coords_screen, self.colors * 0.2345, sigma=0.5, offsets=offsets
)
# Splatting directly to the top/bottom/left/right should have the same strenght.
non_diag_splats = splatting_colors_and_weights[
:, :, :, :, torch.tensor([1, 3, 5, 7])
]
# Same for diagonal splats.
diag_splats = splatting_colors_and_weights[
:, :, :, :, torch.tensor([0, 2, 6, 8])
]
# And for self-splats.
self_splats = splatting_colors_and_weights[:, :, :, :, torch.tensor([4])]
for splats in non_diag_splats, diag_splats, self_splats:
# Colors should be equal.
self.assertTrue(torch.all(splats[..., :4] == splats[0, 0, 0, 0, 0, 0]))
# Weights should be equal.
self.assertTrue(torch.all(splats[..., 4] == splats[0, 0, 0, 0, 0, 4]))
# Non-diagonal weights should be greater than diagonal weights.
self.assertGreater(
non_diag_splats[0, 0, 0, 0, 0, 0], diag_splats[0, 0, 0, 0, 0, 0]
)
# Self-splats should be strongest of all.
self.assertGreater(
self_splats[0, 0, 0, 0, 0, 0], non_diag_splats[0, 0, 0, 0, 0, 0]
)
# Splatting colors should be reweighted proportionally to their splat weights.
diag_self_color_ratio = (
diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
)
diag_self_weight_ratio = (
diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
)
self.assertEqual(diag_self_color_ratio, diag_self_weight_ratio)
non_diag_self_color_ratio = (
non_diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
)
non_diag_self_weight_ratio = (
non_diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
)
self.assertEqual(non_diag_self_color_ratio, non_diag_self_weight_ratio)
def test_zero_alpha_zero_weight(self):
# Pixels with zero alpha do no splatting, but should still be splatted on.
colors = self.colors.clone()
colors[0, 1, 1, 0, 3] = 0.0
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
self.pixel_coords_screen, colors, sigma=0.5, offsets=offsets
)
# The transparent pixel should do no splatting.
self.assertTrue(torch.all(splatting_colors_and_weights[0, 1, 1, 0] == 0.0))
# Splatting *onto* the transparent pixel should be unaffected.
reference_weights_colors = splatting_colors_and_weights[0, 1, 1, 1]
for direction, offset in enumerate(offsets):
if direction == 4:
continue # Ignore self-splats
# We invert the direction to get the right (h, w, d) coordinate of each
# pixel splatting *onto* the pixel with zero alpha.
self.assertClose(
splatting_colors_and_weights[
0, 1 + offset[0], 1 + offset[1], 0, 8 - direction
],
reference_weights_colors[8 - direction],
atol=0.001,
)
def test_random_inputs(self):
pixel_coords_screen = (
self.pixel_coords_screen
+ torch.randn((self.N, self.H, self.W, self.K, 2)) * 0.1
)
colors = torch.rand((self.N, self.H, self.W, self.K, 4))
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
pixel_coords_screen, colors, sigma=0.5, offsets=offsets
)
naive_colors_and_weights = compute_splatting_colors_and_weights_naive(
pixel_coords_screen, colors, sigma=0.5
)
self.assertClose(
splatting_colors_and_weights, naive_colors_and_weights, atol=0.01
)
class TestOffsetSplats(TestCaseMixin, unittest.TestCase):
def test_offset(self):
device = torch.device("cuda:0")
N, H, W, K = 2, 3, 4, 5
colors_and_weights = torch.rand((N, H, W, K, 9, 5), device=device)
crop_ids_h, crop_ids_w, _ = _precompute((N, H, W, K), device=device)
offset_colors_and_weights = _offset_splats(
colors_and_weights, crop_ids_h, crop_ids_w
)
# Check each splatting direction individually, for clarity.
# offset_x, offset_y = (-1, -1)
direction = 0
self.assertClose(
offset_colors_and_weights[:, 1:, 1:, :, direction],
colors_and_weights[:, :-1, :-1, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
)
# offset_x, offset_y = (-1, 0)
direction = 1
self.assertClose(
offset_colors_and_weights[:, :, 1:, :, direction],
colors_and_weights[:, :, :-1, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
)
# offset_x, offset_y = (-1, 1)
direction = 2
self.assertClose(
offset_colors_and_weights[:, :-1, 1:, :, direction],
colors_and_weights[:, 1:, :-1, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
)
# offset_x, offset_y = (0, -1)
direction = 3
self.assertClose(
offset_colors_and_weights[:, 1:, :, :, direction],
colors_and_weights[:, :-1, :, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
)
# self-splat
direction = 4
self.assertClose(
offset_colors_and_weights[..., direction, :],
colors_and_weights[..., direction, :],
atol=0.001,
)
# offset_x, offset_y = (0, 1)
direction = 5
self.assertClose(
offset_colors_and_weights[:, :-1, :, :, direction],
colors_and_weights[:, 1:, :, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
)
# offset_x, offset_y = (1, -1)
direction = 6
self.assertClose(
offset_colors_and_weights[:, 1:, :-1, :, direction],
colors_and_weights[:, :-1, 1:, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
)
# offset_x, offset_y = (1, 0)
direction = 7
self.assertClose(
offset_colors_and_weights[:, :, :-1, :, direction],
colors_and_weights[:, :, 1:, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
)
# offset_x, offset_y = (1, 1)
direction = 8
self.assertClose(
offset_colors_and_weights[:, :-1, :-1, :, direction],
colors_and_weights[:, 1:, 1:, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
)
class TestComputeSplattedColorsAndWeights(TestCaseMixin, unittest.TestCase):
def test_accumulation_background(self):
# Set occlusion_layers to all -1, so all splats are background splats.
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
occlusion_layers = torch.zeros((1, 1, 1, 9)) - 1
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
occlusion_layers, splat_colors_and_weights
)
# Foreground splats (there are none).
self.assertClose(
splatted_colors[0, 0, 0, :, 0],
torch.zeros((4)),
atol=0.001,
)
# Surface splats (there are none).
self.assertClose(
splatted_colors[0, 0, 0, :, 1],
torch.zeros((4)),
atol=0.001,
)
# Background splats.
self.assertClose(
splatted_colors[0, 0, 0, :, 2],
splat_colors_and_weights[0, 0, 0, :, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
def test_accumulation_middle(self):
# Set occlusion_layers to all 0, so top splats are co-surface with splatted
# pixels. Thus, the top splatting layer should be accumulated to surface, and
# all other layers to background.
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
occlusion_layers = torch.zeros((1, 1, 1, 9))
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
occlusion_layers, splat_colors_and_weights
)
# Foreground splats (there are none).
self.assertClose(
splatted_colors[0, 0, 0, :, 0],
torch.zeros((4)),
atol=0.001,
)
# Surface splats
self.assertClose(
splatted_colors[0, 0, 0, :, 1],
splat_colors_and_weights[0, 0, 0, 0, :, :4].sum(dim=0),
atol=0.001,
)
# Background splats
self.assertClose(
splatted_colors[0, 0, 0, :, 2],
splat_colors_and_weights[0, 0, 0, 1:, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
def test_accumulation_foreground(self):
# Set occlusion_layers to all 1. Then the top splatter is a foreground
# splatter, mid splatter is surface, and bottom splatter is background.
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
occlusion_layers = torch.zeros((1, 1, 1, 9)) + 1
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
occlusion_layers, splat_colors_and_weights
)
# Foreground splats
self.assertClose(
splatted_colors[0, 0, 0, :, 0],
splat_colors_and_weights[0, 0, 0, 0:1, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
# Surface splats
self.assertClose(
splatted_colors[0, 0, 0, :, 1],
splat_colors_and_weights[0, 0, 0, 1:2, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
# Background splats
self.assertClose(
splatted_colors[0, 0, 0, :, 2],
splat_colors_and_weights[0, 0, 0, 2:3, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
class TestNormalizeAndComposeAllLayers(TestCaseMixin, unittest.TestCase):
def test_background_color(self):
# Background should always have alpha=0, and the chosen RGB.
N, H, W = 2, 3, 4
# Make a mask with background in the zeroth row of the first image.
bg_mask = torch.zeros([N, H, W, 1, 1])
bg_mask[0, :, 0] = 1
bg_color = torch.tensor([0.2, 0.3, 0.4])
color_layers = torch.rand((N, H, W, 4, 3)) * (1 - bg_mask)
color_weights = torch.rand((N, H, W, 1, 3)) * (1 - bg_mask)
colors = _normalize_and_compose_all_layers(
bg_color, color_layers, color_weights
)
# Background RGB should be .2, .3, .4, and alpha should be 0.
self.assertClose(
torch.masked_select(colors, bg_mask.bool()[..., 0]),
torch.tensor([0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0.0]),
atol=0.001,
)
def test_compositing_opaque(self):
# When all colors are opaque, only the foreground layer should be visible.
N, H, W = 2, 3, 4
color_layers = torch.rand((N, H, W, 4, 3))
color_layers[..., 3, :] = 1.0
color_weights = torch.ones((N, H, W, 1, 3))
out_colors = _normalize_and_compose_all_layers(
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
)
self.assertClose(out_colors, color_layers[..., 0], atol=0.001)
def test_compositing_transparencies(self):
# When foreground layer is transparent and surface and bg are semi-transparent,
# we should return a mix of the two latter.
N, H, W = 2, 3, 4
color_layers = torch.rand((N, H, W, 4, 3))
color_layers[..., 3, 0] = 0.1 # fg
color_layers[..., 3, 1] = 0.2 # surface
color_layers[..., 3, 2] = 0.3 # bg
color_weights = torch.ones((N, H, W, 1, 3))
out_colors = _normalize_and_compose_all_layers(
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
)
self.assertClose(
out_colors,
color_layers[..., 0]
+ 0.9 * (color_layers[..., 1] + 0.8 * color_layers[..., 2]),
)