628 lines
23 KiB
Python
628 lines
23 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
from pytorch3d.common.compat import meshgrid_ij
|
|
from pytorch3d.renderer.cameras import FoVPerspectiveCameras
|
|
from pytorch3d.renderer.splatter_blend import (
|
|
_compute_occlusion_layers,
|
|
_compute_splatted_colors_and_weights,
|
|
_compute_splatting_colors_and_weights,
|
|
_get_splat_kernel_normalization,
|
|
_normalize_and_compose_all_layers,
|
|
_offset_splats,
|
|
_precompute,
|
|
_prepare_pixels_and_colors,
|
|
)
|
|
|
|
from .common_testing import TestCaseMixin
|
|
|
|
offsets = torch.tensor(
|
|
[
|
|
[-1, -1],
|
|
[-1, 0],
|
|
[-1, 1],
|
|
[0, -1],
|
|
[0, 0],
|
|
[0, 1],
|
|
[1, -1],
|
|
[1, 0],
|
|
[1, 1],
|
|
],
|
|
device=torch.device("cpu"),
|
|
)
|
|
|
|
|
|
def compute_splatting_colors_and_weights_naive(pixel_coords_screen, colors, sigma):
|
|
normalizer = float(_get_splat_kernel_normalization(offsets))
|
|
N, H, W, K, _ = colors.shape
|
|
splat_weights_and_colors = torch.zeros((N, H, W, K, 9, 5))
|
|
for n in range(N):
|
|
for h in range(H):
|
|
for w in range(W):
|
|
for k in range(K):
|
|
q_xy = pixel_coords_screen[n, h, w, k]
|
|
q_to_px_center = torch.floor(q_xy) - q_xy + 0.5
|
|
color = colors[n, h, w, k]
|
|
alpha = colors[n, h, w, k, 3:4]
|
|
for d in range(9):
|
|
dist_p_q = torch.sum((q_to_px_center + offsets[d]) ** 2)
|
|
splat_weight = (
|
|
alpha * torch.exp(-dist_p_q / (2 * sigma**2)) * normalizer
|
|
)
|
|
splat_color = splat_weight * color
|
|
splat_weights_and_colors[n, h, w, k, d, :4] = splat_color
|
|
splat_weights_and_colors[n, h, w, k, d, 4:5] = splat_weight
|
|
return splat_weights_and_colors
|
|
|
|
|
|
class TestPrecompute(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self):
|
|
self.results_cpu = _precompute((2, 3, 4, 5), torch.device("cpu"))
|
|
self.results1_cpu = _precompute((1, 1, 1, 1), torch.device("cpu"))
|
|
|
|
def test_offsets(self):
|
|
self.assertClose(self.results_cpu[2].shape, offsets.shape, atol=0)
|
|
self.assertClose(self.results_cpu[2], offsets, atol=0)
|
|
|
|
# Offsets should be independent of input_size.
|
|
self.assertClose(self.results_cpu[2], self.results1_cpu[2], atol=0)
|
|
|
|
def test_crops_h(self):
|
|
target_crops_h1 = torch.tensor(
|
|
[
|
|
# chennels being offset:
|
|
# R G B A W(eight)
|
|
[0, 0, 0, 0, 0],
|
|
[1, 1, 1, 1, 1],
|
|
[2, 2, 2, 2, 2],
|
|
[0, 0, 0, 0, 0],
|
|
[1, 1, 1, 1, 1],
|
|
[2, 2, 2, 2, 2],
|
|
[0, 0, 0, 0, 0],
|
|
[1, 1, 1, 1, 1],
|
|
[2, 2, 2, 2, 2],
|
|
]
|
|
* 3, # 3 because we're aiming at (N, H, W+2, K, 9, 5) with W=1.
|
|
device=torch.device("cpu"),
|
|
).reshape(1, 1, 3, 1, 9, 5)
|
|
self.assertClose(self.results1_cpu[0], target_crops_h1, atol=0)
|
|
|
|
target_crops_h_base = target_crops_h1[0, 0, 0]
|
|
target_crops_h = torch.cat(
|
|
[target_crops_h_base, target_crops_h_base + 1, target_crops_h_base + 2],
|
|
dim=0,
|
|
)
|
|
|
|
# Check that we have the right shape, and (after broadcasting) it has the right
|
|
# values. These should be repeated (tiled) for each n and k.
|
|
self.assertClose(
|
|
self.results_cpu[0].shape, torch.tensor([2, 3, 6, 5, 9, 5]), atol=0
|
|
)
|
|
for n in range(2):
|
|
for w in range(6):
|
|
for k in range(5):
|
|
self.assertClose(
|
|
self.results_cpu[0][n, :, w, k],
|
|
target_crops_h,
|
|
)
|
|
|
|
def test_crops_w(self):
|
|
target_crops_w1 = torch.tensor(
|
|
[
|
|
# chennels being offset:
|
|
# R G B A W(eight)
|
|
[0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0],
|
|
[1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1],
|
|
[2, 2, 2, 2, 2],
|
|
[2, 2, 2, 2, 2],
|
|
[2, 2, 2, 2, 2],
|
|
],
|
|
device=torch.device("cpu"),
|
|
).reshape(1, 1, 1, 1, 9, 5)
|
|
self.assertClose(self.results1_cpu[1], target_crops_w1)
|
|
|
|
target_crops_w_base = target_crops_w1[0, 0, 0]
|
|
target_crops_w = torch.cat(
|
|
[
|
|
target_crops_w_base,
|
|
target_crops_w_base + 1,
|
|
target_crops_w_base + 2,
|
|
target_crops_w_base + 3,
|
|
],
|
|
dim=0,
|
|
) # Each w value needs an increment.
|
|
|
|
# Check that we have the right shape, and (after broadcasting) it has the right
|
|
# values. These should be repeated (tiled) for each n and k.
|
|
self.assertClose(self.results_cpu[1].shape, torch.tensor([2, 3, 4, 5, 9, 5]))
|
|
for n in range(2):
|
|
for h in range(3):
|
|
for k in range(5):
|
|
self.assertClose(
|
|
self.results_cpu[1][n, h, :, k],
|
|
target_crops_w,
|
|
atol=0,
|
|
)
|
|
|
|
|
|
class TestPreparPixelsAndColors(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self):
|
|
self.device = torch.device("cpu")
|
|
N, H, W, K = 2, 3, 4, 5
|
|
self.pixel_coords_cameras = torch.randn(
|
|
(N, H, W, K, 3), device=self.device, requires_grad=True
|
|
)
|
|
self.colors_before = torch.rand((N, H, W, K, 3), device=self.device)
|
|
self.cameras = FoVPerspectiveCameras(device=self.device)
|
|
self.background_mask = torch.rand((N, H, W, K), device=self.device) < 0.5
|
|
self.pixel_coords_screen, self.colors_after = _prepare_pixels_and_colors(
|
|
self.pixel_coords_cameras,
|
|
self.colors_before,
|
|
self.cameras,
|
|
self.background_mask,
|
|
)
|
|
|
|
def test_background_z(self):
|
|
self.assertTrue(
|
|
torch.all(self.pixel_coords_screen[..., 2][self.background_mask] == 1.0)
|
|
)
|
|
|
|
def test_background_alpha(self):
|
|
self.assertTrue(
|
|
torch.all(self.colors_after[..., 3][self.background_mask] == 0.0)
|
|
)
|
|
|
|
|
|
class TestGetSplatKernelNormalization(TestCaseMixin, unittest.TestCase):
|
|
def test_splat_kernel_normalization(self):
|
|
self.assertAlmostEqual(
|
|
float(_get_splat_kernel_normalization(offsets)), 0.6503, places=3
|
|
)
|
|
self.assertAlmostEqual(
|
|
float(_get_splat_kernel_normalization(offsets, 0.01)), 1.05, places=3
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Only positive standard deviations"):
|
|
_get_splat_kernel_normalization(offsets, 0)
|
|
|
|
|
|
class TestComputeOcclusionLayers(TestCaseMixin, unittest.TestCase):
|
|
def test_single_layer(self):
|
|
# If there's only one layer, all splats must be on the surface level.
|
|
N, H, W, K = 2, 3, 4, 1
|
|
q_depth = torch.rand(N, H, W, K)
|
|
occlusion_layers = _compute_occlusion_layers(q_depth)
|
|
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
|
|
|
|
def test_all_equal(self):
|
|
# If all q-vals are equal, then all splats must be on the surface level.
|
|
N, H, W, K = 2, 3, 4, 5
|
|
q_depth = torch.ones((N, H, W, K)) * 0.1234
|
|
occlusion_layers = _compute_occlusion_layers(q_depth)
|
|
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
|
|
|
|
def test_mid_to_top_level_splatting(self):
|
|
# Check that occlusion buffers get accumulated as expected when the splatting
|
|
# and splatted pixels are co-surface on different intersection layers.
|
|
# This test will make best sense with accompanying Fig. 4 from "Differentiable
|
|
# Surface Rendering via Non-differentiable Sampling" by Cole et al.
|
|
for direction, offset in enumerate(offsets):
|
|
if direction == 4:
|
|
continue # Skip self-splatting which is always co-surface.
|
|
|
|
depths = torch.zeros(1, 3, 3, 3)
|
|
|
|
# This is our q, the pixel splatted onto, in the center of the image.
|
|
depths[0, 1, 1] = torch.tensor([0.71, 0.8, 1.0])
|
|
|
|
# This is our p, the splatting pixel.
|
|
depths[0, offset[0] + 1, offset[1] + 1] = torch.tensor([0.5, 0.7, 0.9])
|
|
|
|
occlusion_layers = _compute_occlusion_layers(depths)
|
|
|
|
# Check that we computed that it is the middle layer of p that is co-
|
|
# surface with q. (1, 1) is the id of q in the depth array, and offset_id
|
|
# is the id of p's direction w.r.t. q.
|
|
psurfaceid_onto_q = occlusion_layers[0, 1, 1, direction]
|
|
self.assertEqual(int(psurfaceid_onto_q), 1)
|
|
|
|
# Conversely, if we swap p and q, we have a top-level splatting onto
|
|
# mid-level. offset + 1 is the id of p, and 8-offset_id is the id of
|
|
# q's direction w.r.t. p (e.g. if p is [-1, -1] w.r.t. q, then q is
|
|
# [1, 1] w.r.t. p; we use the ids of these two directions in the offsets
|
|
# array).
|
|
qsurfaceid_onto_p = occlusion_layers[
|
|
0, offset[0] + 1, offset[1] + 1, 8 - direction
|
|
]
|
|
self.assertEqual(int(qsurfaceid_onto_p), -1)
|
|
|
|
|
|
class TestComputeSplattingColorsAndWeights(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self):
|
|
self.N, self.H, self.W, self.K = 2, 3, 4, 5
|
|
self.pixel_coords_screen = (
|
|
torch.stack(
|
|
meshgrid_ij(torch.arange(self.H), torch.arange(self.W)),
|
|
dim=-1,
|
|
)
|
|
.reshape(1, self.H, self.W, 1, 2)
|
|
.expand(self.N, self.H, self.W, self.K, 2)
|
|
.float()
|
|
+ 0.5
|
|
)
|
|
self.colors = torch.ones((self.N, self.H, self.W, self.K, 4))
|
|
|
|
def test_all_equal(self):
|
|
# If all colors are equal and on a regular grid, all weights and reweighted
|
|
# colors should be equal given a specific splatting direction.
|
|
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
|
|
self.pixel_coords_screen, self.colors * 0.2345, sigma=0.5, offsets=offsets
|
|
)
|
|
|
|
# Splatting directly to the top/bottom/left/right should have the same strenght.
|
|
non_diag_splats = splatting_colors_and_weights[
|
|
:, :, :, :, torch.tensor([1, 3, 5, 7])
|
|
]
|
|
|
|
# Same for diagonal splats.
|
|
diag_splats = splatting_colors_and_weights[
|
|
:, :, :, :, torch.tensor([0, 2, 6, 8])
|
|
]
|
|
|
|
# And for self-splats.
|
|
self_splats = splatting_colors_and_weights[:, :, :, :, torch.tensor([4])]
|
|
|
|
for splats in non_diag_splats, diag_splats, self_splats:
|
|
# Colors should be equal.
|
|
self.assertTrue(torch.all(splats[..., :4] == splats[0, 0, 0, 0, 0, 0]))
|
|
|
|
# Weights should be equal.
|
|
self.assertTrue(torch.all(splats[..., 4] == splats[0, 0, 0, 0, 0, 4]))
|
|
|
|
# Non-diagonal weights should be greater than diagonal weights.
|
|
self.assertGreater(
|
|
non_diag_splats[0, 0, 0, 0, 0, 0], diag_splats[0, 0, 0, 0, 0, 0]
|
|
)
|
|
|
|
# Self-splats should be strongest of all.
|
|
self.assertGreater(
|
|
self_splats[0, 0, 0, 0, 0, 0], non_diag_splats[0, 0, 0, 0, 0, 0]
|
|
)
|
|
|
|
# Splatting colors should be reweighted proportionally to their splat weights.
|
|
diag_self_color_ratio = (
|
|
diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
|
|
)
|
|
diag_self_weight_ratio = (
|
|
diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
|
|
)
|
|
self.assertEqual(diag_self_color_ratio, diag_self_weight_ratio)
|
|
|
|
non_diag_self_color_ratio = (
|
|
non_diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
|
|
)
|
|
non_diag_self_weight_ratio = (
|
|
non_diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
|
|
)
|
|
self.assertEqual(non_diag_self_color_ratio, non_diag_self_weight_ratio)
|
|
|
|
def test_zero_alpha_zero_weight(self):
|
|
# Pixels with zero alpha do no splatting, but should still be splatted on.
|
|
colors = self.colors.clone()
|
|
colors[0, 1, 1, 0, 3] = 0.0
|
|
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
|
|
self.pixel_coords_screen, colors, sigma=0.5, offsets=offsets
|
|
)
|
|
|
|
# The transparent pixel should do no splatting.
|
|
self.assertTrue(torch.all(splatting_colors_and_weights[0, 1, 1, 0] == 0.0))
|
|
|
|
# Splatting *onto* the transparent pixel should be unaffected.
|
|
reference_weights_colors = splatting_colors_and_weights[0, 1, 1, 1]
|
|
for direction, offset in enumerate(offsets):
|
|
if direction == 4:
|
|
continue # Ignore self-splats
|
|
# We invert the direction to get the right (h, w, d) coordinate of each
|
|
# pixel splatting *onto* the pixel with zero alpha.
|
|
self.assertClose(
|
|
splatting_colors_and_weights[
|
|
0, 1 + offset[0], 1 + offset[1], 0, 8 - direction
|
|
],
|
|
reference_weights_colors[8 - direction],
|
|
atol=0.001,
|
|
)
|
|
|
|
def test_random_inputs(self):
|
|
pixel_coords_screen = (
|
|
self.pixel_coords_screen
|
|
+ torch.randn((self.N, self.H, self.W, self.K, 2)) * 0.1
|
|
)
|
|
colors = torch.rand((self.N, self.H, self.W, self.K, 4))
|
|
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
|
|
pixel_coords_screen, colors, sigma=0.5, offsets=offsets
|
|
)
|
|
naive_colors_and_weights = compute_splatting_colors_and_weights_naive(
|
|
pixel_coords_screen, colors, sigma=0.5
|
|
)
|
|
|
|
self.assertClose(
|
|
splatting_colors_and_weights, naive_colors_and_weights, atol=0.01
|
|
)
|
|
|
|
|
|
class TestOffsetSplats(TestCaseMixin, unittest.TestCase):
|
|
def test_offset(self):
|
|
device = torch.device("cuda:0")
|
|
N, H, W, K = 2, 3, 4, 5
|
|
colors_and_weights = torch.rand((N, H, W, K, 9, 5), device=device)
|
|
crop_ids_h, crop_ids_w, _ = _precompute((N, H, W, K), device=device)
|
|
offset_colors_and_weights = _offset_splats(
|
|
colors_and_weights, crop_ids_h, crop_ids_w
|
|
)
|
|
|
|
# Check each splatting direction individually, for clarity.
|
|
# offset_x, offset_y = (-1, -1)
|
|
direction = 0
|
|
self.assertClose(
|
|
offset_colors_and_weights[:, 1:, 1:, :, direction],
|
|
colors_and_weights[:, :-1, :-1, :, direction],
|
|
atol=0.001,
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
|
|
)
|
|
|
|
# offset_x, offset_y = (-1, 0)
|
|
direction = 1
|
|
self.assertClose(
|
|
offset_colors_and_weights[:, :, 1:, :, direction],
|
|
colors_and_weights[:, :, :-1, :, direction],
|
|
atol=0.001,
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
|
|
)
|
|
|
|
# offset_x, offset_y = (-1, 1)
|
|
direction = 2
|
|
self.assertClose(
|
|
offset_colors_and_weights[:, :-1, 1:, :, direction],
|
|
colors_and_weights[:, 1:, :-1, :, direction],
|
|
atol=0.001,
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
|
|
)
|
|
|
|
# offset_x, offset_y = (0, -1)
|
|
direction = 3
|
|
self.assertClose(
|
|
offset_colors_and_weights[:, 1:, :, :, direction],
|
|
colors_and_weights[:, :-1, :, :, direction],
|
|
atol=0.001,
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
|
|
)
|
|
|
|
# self-splat
|
|
direction = 4
|
|
self.assertClose(
|
|
offset_colors_and_weights[..., direction, :],
|
|
colors_and_weights[..., direction, :],
|
|
atol=0.001,
|
|
)
|
|
|
|
# offset_x, offset_y = (0, 1)
|
|
direction = 5
|
|
self.assertClose(
|
|
offset_colors_and_weights[:, :-1, :, :, direction],
|
|
colors_and_weights[:, 1:, :, :, direction],
|
|
atol=0.001,
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
|
|
)
|
|
|
|
# offset_x, offset_y = (1, -1)
|
|
direction = 6
|
|
self.assertClose(
|
|
offset_colors_and_weights[:, 1:, :-1, :, direction],
|
|
colors_and_weights[:, :-1, 1:, :, direction],
|
|
atol=0.001,
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
|
|
)
|
|
|
|
# offset_x, offset_y = (1, 0)
|
|
direction = 7
|
|
self.assertClose(
|
|
offset_colors_and_weights[:, :, :-1, :, direction],
|
|
colors_and_weights[:, :, 1:, :, direction],
|
|
atol=0.001,
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
|
|
)
|
|
|
|
# offset_x, offset_y = (1, 1)
|
|
direction = 8
|
|
self.assertClose(
|
|
offset_colors_and_weights[:, :-1, :-1, :, direction],
|
|
colors_and_weights[:, 1:, 1:, :, direction],
|
|
atol=0.001,
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
|
|
)
|
|
self.assertTrue(
|
|
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
|
|
)
|
|
|
|
|
|
class TestComputeSplattedColorsAndWeights(TestCaseMixin, unittest.TestCase):
|
|
def test_accumulation_background(self):
|
|
# Set occlusion_layers to all -1, so all splats are background splats.
|
|
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
|
|
occlusion_layers = torch.zeros((1, 1, 1, 9)) - 1
|
|
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
|
|
occlusion_layers, splat_colors_and_weights
|
|
)
|
|
|
|
# Foreground splats (there are none).
|
|
self.assertClose(
|
|
splatted_colors[0, 0, 0, :, 0],
|
|
torch.zeros((4)),
|
|
atol=0.001,
|
|
)
|
|
|
|
# Surface splats (there are none).
|
|
self.assertClose(
|
|
splatted_colors[0, 0, 0, :, 1],
|
|
torch.zeros((4)),
|
|
atol=0.001,
|
|
)
|
|
|
|
# Background splats.
|
|
self.assertClose(
|
|
splatted_colors[0, 0, 0, :, 2],
|
|
splat_colors_and_weights[0, 0, 0, :, :, :4].sum(dim=0).sum(dim=0),
|
|
atol=0.001,
|
|
)
|
|
|
|
def test_accumulation_middle(self):
|
|
# Set occlusion_layers to all 0, so top splats are co-surface with splatted
|
|
# pixels. Thus, the top splatting layer should be accumulated to surface, and
|
|
# all other layers to background.
|
|
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
|
|
occlusion_layers = torch.zeros((1, 1, 1, 9))
|
|
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
|
|
occlusion_layers, splat_colors_and_weights
|
|
)
|
|
|
|
# Foreground splats (there are none).
|
|
self.assertClose(
|
|
splatted_colors[0, 0, 0, :, 0],
|
|
torch.zeros((4)),
|
|
atol=0.001,
|
|
)
|
|
|
|
# Surface splats
|
|
self.assertClose(
|
|
splatted_colors[0, 0, 0, :, 1],
|
|
splat_colors_and_weights[0, 0, 0, 0, :, :4].sum(dim=0),
|
|
atol=0.001,
|
|
)
|
|
|
|
# Background splats
|
|
self.assertClose(
|
|
splatted_colors[0, 0, 0, :, 2],
|
|
splat_colors_and_weights[0, 0, 0, 1:, :, :4].sum(dim=0).sum(dim=0),
|
|
atol=0.001,
|
|
)
|
|
|
|
def test_accumulation_foreground(self):
|
|
# Set occlusion_layers to all 1. Then the top splatter is a foreground
|
|
# splatter, mid splatter is surface, and bottom splatter is background.
|
|
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
|
|
occlusion_layers = torch.zeros((1, 1, 1, 9)) + 1
|
|
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
|
|
occlusion_layers, splat_colors_and_weights
|
|
)
|
|
|
|
# Foreground splats
|
|
self.assertClose(
|
|
splatted_colors[0, 0, 0, :, 0],
|
|
splat_colors_and_weights[0, 0, 0, 0:1, :, :4].sum(dim=0).sum(dim=0),
|
|
atol=0.001,
|
|
)
|
|
|
|
# Surface splats
|
|
self.assertClose(
|
|
splatted_colors[0, 0, 0, :, 1],
|
|
splat_colors_and_weights[0, 0, 0, 1:2, :, :4].sum(dim=0).sum(dim=0),
|
|
atol=0.001,
|
|
)
|
|
|
|
# Background splats
|
|
self.assertClose(
|
|
splatted_colors[0, 0, 0, :, 2],
|
|
splat_colors_and_weights[0, 0, 0, 2:3, :, :4].sum(dim=0).sum(dim=0),
|
|
atol=0.001,
|
|
)
|
|
|
|
|
|
class TestNormalizeAndComposeAllLayers(TestCaseMixin, unittest.TestCase):
|
|
def test_background_color(self):
|
|
# Background should always have alpha=0, and the chosen RGB.
|
|
N, H, W = 2, 3, 4
|
|
# Make a mask with background in the zeroth row of the first image.
|
|
bg_mask = torch.zeros([N, H, W, 1, 1])
|
|
bg_mask[0, :, 0] = 1
|
|
|
|
bg_color = torch.tensor([0.2, 0.3, 0.4])
|
|
|
|
color_layers = torch.rand((N, H, W, 4, 3)) * (1 - bg_mask)
|
|
color_weights = torch.rand((N, H, W, 1, 3)) * (1 - bg_mask)
|
|
|
|
colors = _normalize_and_compose_all_layers(
|
|
bg_color, color_layers, color_weights
|
|
)
|
|
|
|
# Background RGB should be .2, .3, .4, and alpha should be 0.
|
|
self.assertClose(
|
|
torch.masked_select(colors, bg_mask.bool()[..., 0]),
|
|
torch.tensor([0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0.0]),
|
|
atol=0.001,
|
|
)
|
|
|
|
def test_compositing_opaque(self):
|
|
# When all colors are opaque, only the foreground layer should be visible.
|
|
N, H, W = 2, 3, 4
|
|
color_layers = torch.rand((N, H, W, 4, 3))
|
|
color_layers[..., 3, :] = 1.0
|
|
color_weights = torch.ones((N, H, W, 1, 3))
|
|
|
|
out_colors = _normalize_and_compose_all_layers(
|
|
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
|
|
)
|
|
self.assertClose(out_colors, color_layers[..., 0], atol=0.001)
|
|
|
|
def test_compositing_transparencies(self):
|
|
# When foreground layer is transparent and surface and bg are semi-transparent,
|
|
# we should return a mix of the two latter.
|
|
N, H, W = 2, 3, 4
|
|
color_layers = torch.rand((N, H, W, 4, 3))
|
|
color_layers[..., 3, 0] = 0.1 # fg
|
|
color_layers[..., 3, 1] = 0.2 # surface
|
|
color_layers[..., 3, 2] = 0.3 # bg
|
|
color_weights = torch.ones((N, H, W, 1, 3))
|
|
|
|
out_colors = _normalize_and_compose_all_layers(
|
|
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
|
|
)
|
|
self.assertClose(
|
|
out_colors,
|
|
color_layers[..., 0]
|
|
+ 0.9 * (color_layers[..., 1] + 0.8 * color_layers[..., 2]),
|
|
)
|