427 lines
14 KiB
Python
427 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
from pytorch3d.renderer.compositing import (
|
|
alpha_composite,
|
|
norm_weighted_sum,
|
|
weighted_sum,
|
|
)
|
|
|
|
from .common_testing import get_random_cuda_device, TestCaseMixin
|
|
|
|
|
|
class TestAccumulatePoints(TestCaseMixin, unittest.TestCase):
|
|
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING)
|
|
@staticmethod
|
|
def accumulate_alphacomposite_python(points_idx, alphas, features):
|
|
"""
|
|
Naive pure PyTorch implementation of alpha_composite.
|
|
Inputs / Outputs: Same as function
|
|
"""
|
|
|
|
B, K, H, W = points_idx.size()
|
|
C = features.size(0)
|
|
|
|
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
|
|
|
|
for b in range(0, B):
|
|
for c in range(0, C):
|
|
for i in range(0, W):
|
|
for j in range(0, H):
|
|
t_alpha = 1
|
|
for k in range(0, K):
|
|
n_idx = points_idx[b, k, j, i]
|
|
|
|
if n_idx < 0:
|
|
continue
|
|
|
|
alpha = alphas[b, k, j, i]
|
|
output[b, c, j, i] += features[c, n_idx] * alpha * t_alpha
|
|
t_alpha = (1 - alpha) * t_alpha
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
def accumulate_weightedsum_python(points_idx, alphas, features):
|
|
"""
|
|
Naive pure PyTorch implementation of weighted_sum rasterization.
|
|
Inputs / Outputs: Same as function
|
|
"""
|
|
B, K, H, W = points_idx.size()
|
|
C = features.size(0)
|
|
|
|
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
|
|
|
|
for b in range(0, B):
|
|
for c in range(0, C):
|
|
for i in range(0, W):
|
|
for j in range(0, H):
|
|
for k in range(0, K):
|
|
n_idx = points_idx[b, k, j, i]
|
|
|
|
if n_idx < 0:
|
|
continue
|
|
|
|
alpha = alphas[b, k, j, i]
|
|
output[b, c, j, i] += features[c, n_idx] * alpha
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
def accumulate_weightedsumnorm_python(points_idx, alphas, features):
|
|
"""
|
|
Naive pure PyTorch implementation of norm_weighted_sum.
|
|
Inputs / Outputs: Same as function
|
|
"""
|
|
|
|
B, K, H, W = points_idx.size()
|
|
C = features.size(0)
|
|
|
|
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
|
|
|
|
for b in range(0, B):
|
|
for c in range(0, C):
|
|
for i in range(0, W):
|
|
for j in range(0, H):
|
|
t_alpha = 0
|
|
for k in range(0, K):
|
|
n_idx = points_idx[b, k, j, i]
|
|
|
|
if n_idx < 0:
|
|
continue
|
|
|
|
t_alpha += alphas[b, k, j, i]
|
|
|
|
t_alpha = max(t_alpha, 1e-4)
|
|
|
|
for k in range(0, K):
|
|
n_idx = points_idx[b, k, j, i]
|
|
|
|
if n_idx < 0:
|
|
continue
|
|
|
|
alpha = alphas[b, k, j, i]
|
|
output[b, c, j, i] += features[c, n_idx] * alpha / t_alpha
|
|
|
|
return output
|
|
|
|
def test_python(self):
|
|
device = torch.device("cpu")
|
|
self._simple_alphacomposite(self.accumulate_alphacomposite_python, device)
|
|
self._simple_wsum(self.accumulate_weightedsum_python, device)
|
|
self._simple_wsumnorm(self.accumulate_weightedsumnorm_python, device)
|
|
|
|
def test_cpu(self):
|
|
device = torch.device("cpu")
|
|
self._simple_alphacomposite(alpha_composite, device)
|
|
self._simple_wsum(weighted_sum, device)
|
|
self._simple_wsumnorm(norm_weighted_sum, device)
|
|
|
|
def test_cuda(self):
|
|
device = get_random_cuda_device()
|
|
self._simple_alphacomposite(alpha_composite, device)
|
|
self._simple_wsum(weighted_sum, device)
|
|
self._simple_wsumnorm(norm_weighted_sum, device)
|
|
|
|
def test_python_vs_cpu_vs_cuda(self):
|
|
self._python_vs_cpu_vs_cuda(
|
|
self.accumulate_alphacomposite_python, alpha_composite
|
|
)
|
|
self._python_vs_cpu_vs_cuda(
|
|
self.accumulate_weightedsumnorm_python, norm_weighted_sum
|
|
)
|
|
self._python_vs_cpu_vs_cuda(self.accumulate_weightedsum_python, weighted_sum)
|
|
|
|
def _python_vs_cpu_vs_cuda(self, accumulate_func_python, accumulate_func):
|
|
torch.manual_seed(231)
|
|
device = torch.device("cpu")
|
|
|
|
W = 8
|
|
C = 3
|
|
P = 32
|
|
|
|
for d in ["cpu", get_random_cuda_device()]:
|
|
# TODO(gkioxari) add torch.float64 to types after double precision
|
|
# support is added to atomicAdd
|
|
for t in [torch.float32]:
|
|
device = torch.device(d)
|
|
|
|
# Create values
|
|
alphas = torch.rand(2, 4, W, W, dtype=t).to(device)
|
|
alphas.requires_grad = True
|
|
alphas_cpu = alphas.detach().cpu()
|
|
alphas_cpu.requires_grad = True
|
|
|
|
features = torch.randn(C, P, dtype=t).to(device)
|
|
features.requires_grad = True
|
|
features_cpu = features.detach().cpu()
|
|
features_cpu.requires_grad = True
|
|
|
|
inds = torch.randint(P + 1, size=(2, 4, W, W)).to(device) - 1
|
|
inds_cpu = inds.detach().cpu()
|
|
|
|
args_cuda = (inds, alphas, features)
|
|
args_cpu = (inds_cpu, alphas_cpu, features_cpu)
|
|
|
|
self._compare_impls(
|
|
accumulate_func_python,
|
|
accumulate_func,
|
|
args_cpu,
|
|
args_cuda,
|
|
(alphas_cpu, features_cpu),
|
|
(alphas, features),
|
|
compare_grads=True,
|
|
)
|
|
|
|
def _compare_impls(
|
|
self, fn1, fn2, args1, args2, grads1, grads2, compare_grads=False
|
|
):
|
|
res1 = fn1(*args1)
|
|
res2 = fn2(*args2)
|
|
|
|
self.assertClose(res1.cpu(), res2.cpu(), atol=1e-6)
|
|
|
|
if not compare_grads:
|
|
return
|
|
|
|
# Compare gradients
|
|
torch.manual_seed(231)
|
|
grad_res = torch.randn_like(res1)
|
|
loss1 = (res1 * grad_res).sum()
|
|
loss1.backward()
|
|
|
|
grads1 = [gradsi.grad.data.clone().cpu() for gradsi in grads1]
|
|
grad_res = grad_res.to(res2)
|
|
|
|
loss2 = (res2 * grad_res).sum()
|
|
loss2.backward()
|
|
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2]
|
|
|
|
for i in range(0, len(grads1)):
|
|
self.assertClose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)
|
|
|
|
def _simple_wsum(self, accum_func, device):
|
|
# Initialise variables
|
|
features = torch.Tensor([[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]).to(device)
|
|
|
|
alphas = torch.Tensor(
|
|
[
|
|
[
|
|
[
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
],
|
|
[
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
],
|
|
]
|
|
]
|
|
).to(device)
|
|
|
|
points_idx = (
|
|
torch.Tensor(
|
|
[
|
|
[
|
|
# fmt: off
|
|
[
|
|
[0, 0, 0, 0], # noqa: E241, E201
|
|
[0, -1, -1, -1], # noqa: E241, E201
|
|
[0, 1, 1, 0], # noqa: E241, E201
|
|
[0, 0, 0, 0], # noqa: E241, E201
|
|
],
|
|
[
|
|
[2, 2, 2, 2], # noqa: E241, E201
|
|
[2, 3, 3, 2], # noqa: E241, E201
|
|
[2, 3, 3, 2], # noqa: E241, E201
|
|
[2, 2, -1, 2], # noqa: E241, E201
|
|
],
|
|
# fmt: on
|
|
]
|
|
]
|
|
)
|
|
.long()
|
|
.to(device)
|
|
)
|
|
|
|
result = accum_func(points_idx, alphas, features)
|
|
|
|
self.assertTrue(result.shape == (1, 2, 4, 4))
|
|
|
|
true_result = torch.Tensor(
|
|
[
|
|
[
|
|
[
|
|
[0.35, 0.35, 0.35, 0.35],
|
|
[0.35, 0.90, 0.90, 0.30],
|
|
[0.35, 1.30, 1.30, 0.35],
|
|
[0.35, 0.35, 0.05, 0.35],
|
|
],
|
|
[
|
|
[0.35, 0.35, 0.35, 0.35],
|
|
[0.35, 0.90, 0.90, 0.30],
|
|
[0.35, 1.30, 1.30, 0.35],
|
|
[0.35, 0.35, 0.05, 0.35],
|
|
],
|
|
]
|
|
]
|
|
).to(device)
|
|
|
|
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
|
|
|
def _simple_wsumnorm(self, accum_func, device):
|
|
# Initialise variables
|
|
features = torch.Tensor([[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]).to(device)
|
|
|
|
alphas = torch.Tensor(
|
|
[
|
|
[
|
|
[
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
],
|
|
[
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
],
|
|
]
|
|
]
|
|
).to(device)
|
|
|
|
# fmt: off
|
|
points_idx = (
|
|
torch.Tensor(
|
|
[
|
|
[
|
|
[
|
|
[0, 0, 0, 0], # noqa: E241, E201
|
|
[0, -1, -1, -1], # noqa: E241, E201
|
|
[0, 1, 1, 0], # noqa: E241, E201
|
|
[0, 0, 0, 0], # noqa: E241, E201
|
|
],
|
|
[
|
|
[2, 2, 2, 2], # noqa: E241, E201
|
|
[2, 3, 3, 2], # noqa: E241, E201
|
|
[2, 3, 3, 2], # noqa: E241, E201
|
|
[2, 2, -1, 2], # noqa: E241, E201
|
|
],
|
|
]
|
|
]
|
|
)
|
|
.long()
|
|
.to(device)
|
|
)
|
|
# fmt: on
|
|
|
|
result = accum_func(points_idx, alphas, features)
|
|
|
|
self.assertTrue(result.shape == (1, 2, 4, 4))
|
|
|
|
true_result = torch.Tensor(
|
|
[
|
|
[
|
|
[
|
|
[0.35, 0.35, 0.35, 0.35],
|
|
[0.35, 0.90, 0.90, 0.60],
|
|
[0.35, 0.65, 0.65, 0.35],
|
|
[0.35, 0.35, 0.10, 0.35],
|
|
],
|
|
[
|
|
[0.35, 0.35, 0.35, 0.35],
|
|
[0.35, 0.90, 0.90, 0.60],
|
|
[0.35, 0.65, 0.65, 0.35],
|
|
[0.35, 0.35, 0.10, 0.35],
|
|
],
|
|
]
|
|
]
|
|
).to(device)
|
|
|
|
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
|
|
|
def _simple_alphacomposite(self, accum_func, device):
|
|
# Initialise variables
|
|
features = torch.Tensor([[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]).to(device)
|
|
|
|
alphas = torch.Tensor(
|
|
[
|
|
[
|
|
[
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
],
|
|
[
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 1.0, 1.0, 0.5],
|
|
[0.5, 0.5, 0.5, 0.5],
|
|
],
|
|
]
|
|
]
|
|
).to(device)
|
|
|
|
# fmt: off
|
|
points_idx = (
|
|
torch.Tensor(
|
|
[
|
|
[
|
|
[
|
|
[0, 0, 0, 0], # noqa: E241, E201
|
|
[0, -1, -1, -1], # noqa: E241, E201
|
|
[0, 1, 1, 0], # noqa: E241, E201
|
|
[0, 0, 0, 0], # noqa: E241, E201
|
|
],
|
|
[
|
|
[2, 2, 2, 2], # noqa: E241, E201
|
|
[2, 3, 3, 2], # noqa: E241, E201
|
|
[2, 3, 3, 2], # noqa: E241, E201
|
|
[2, 2, -1, 2], # noqa: E241, E201
|
|
],
|
|
]
|
|
]
|
|
)
|
|
.long()
|
|
.to(device)
|
|
)
|
|
# fmt: on
|
|
|
|
result = accum_func(points_idx, alphas, features)
|
|
|
|
self.assertTrue(result.shape == (1, 2, 4, 4))
|
|
|
|
true_result = torch.Tensor(
|
|
[
|
|
[
|
|
[
|
|
[0.20, 0.20, 0.20, 0.20],
|
|
[0.20, 0.90, 0.90, 0.30],
|
|
[0.20, 0.40, 0.40, 0.20],
|
|
[0.20, 0.20, 0.05, 0.20],
|
|
],
|
|
[
|
|
[0.20, 0.20, 0.20, 0.20],
|
|
[0.20, 0.90, 0.90, 0.30],
|
|
[0.20, 0.40, 0.40, 0.20],
|
|
[0.20, 0.20, 0.05, 0.20],
|
|
],
|
|
]
|
|
]
|
|
).to(device)
|
|
|
|
self.assertTrue((result == true_result).all().item())
|