pytorch3d/tests/test_raysampling.py

640 lines
22 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from typing import Callable
import torch
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.ops import eyes
from pytorch3d.renderer import (
MonteCarloRaysampler,
MultinomialRaysampler,
NDCGridRaysampler,
NDCMultinomialRaysampler,
)
from pytorch3d.renderer.cameras import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
)
from pytorch3d.renderer.implicit.raysampling import (
_jiggle_within_stratas,
_safe_multinomial,
)
from pytorch3d.renderer.implicit.utils import (
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
)
from pytorch3d.transforms import Rotate
from .common_testing import TestCaseMixin
from .test_cameras import init_random_cameras
class TestNDCRaysamplerConvention(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
def test_ndc_convention(
self,
h=428,
w=760,
):
device = torch.device("cuda")
camera = init_random_cameras(PerspectiveCameras, 1, random_z=True).to(device)
depth_map = torch.ones((1, 1, h, w)).to(device)
xyz = ray_bundle_to_ray_points(
NDCGridRaysampler(
image_width=w,
image_height=h,
n_pts_per_ray=1,
min_depth=1.0,
max_depth=1.0,
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
).view(1, -1, 3)
# project pointcloud
xy = camera.transform_points(xyz)[:, :, :2].squeeze()
xy_grid = self._get_ndc_grid(h, w, device)
self.assertClose(
xy,
xy_grid,
atol=1e-4,
)
def _get_ndc_grid(self, h, w, device):
if w >= h:
range_x = w / h
range_y = 1.0
else:
range_x = 1.0
range_y = h / w
half_pix_width = range_x / w
half_pix_height = range_y / h
min_x = range_x - half_pix_width
max_x = -range_x + half_pix_width
min_y = range_y - half_pix_height
max_y = -range_y + half_pix_height
y_grid, x_grid = meshgrid_ij(
torch.linspace(min_y, max_y, h, dtype=torch.float32),
torch.linspace(min_x, max_x, w, dtype=torch.float32),
)
x_points = x_grid.contiguous().view(-1).to(device)
y_points = y_grid.contiguous().view(-1).to(device)
xy = torch.stack((x_points, y_points), dim=1)
return xy
class TestRaysampling(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
@staticmethod
def raysampler(
raysampler_type,
camera_type,
n_pts_per_ray: int,
batch_size: int,
image_width: int,
image_height: int,
) -> Callable[[], None]:
"""
Used for benchmarks.
"""
device = torch.device("cuda")
# init raysamplers
raysampler = TestRaysampling.init_raysampler(
raysampler_type=raysampler_type,
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
max_y=1.0,
image_width=image_width,
image_height=image_height,
min_depth=1.0,
max_depth=10.0,
n_pts_per_ray=n_pts_per_ray,
).to(device)
# init a batch of random cameras
cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device)
def run_raysampler() -> None:
raysampler(cameras=cameras)
torch.cuda.synchronize()
return run_raysampler
@staticmethod
def init_raysampler(
raysampler_type,
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
max_y=1.0,
image_width=10,
image_height=20,
min_depth=1.0,
max_depth=10.0,
n_pts_per_ray=10,
n_rays_total=None,
n_rays_per_image=None,
):
raysampler_params = {
"min_x": min_x,
"max_x": max_x,
"min_y": min_y,
"max_y": max_y,
"n_pts_per_ray": n_pts_per_ray,
"min_depth": min_depth,
"max_depth": max_depth,
"n_rays_total": n_rays_total,
"n_rays_per_image": n_rays_per_image,
}
if issubclass(raysampler_type, MultinomialRaysampler):
raysampler_params.update(
{"image_width": image_width, "image_height": image_height}
)
elif issubclass(raysampler_type, MonteCarloRaysampler):
raysampler_params["n_rays_per_image"] = (
image_width * image_height
if (n_rays_total is None) and (n_rays_per_image is None)
else n_rays_per_image
)
else:
raise ValueError(str(raysampler_type))
if issubclass(raysampler_type, NDCMultinomialRaysampler):
# NDCGridRaysampler does not use min/max_x/y
for k in ("min_x", "max_x", "min_y", "max_y"):
del raysampler_params[k]
# instantiate the raysampler
raysampler = raysampler_type(**raysampler_params)
return raysampler
def test_raysamplers(
self,
batch_size=25,
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
max_y=1.0,
image_width=10,
image_height=20,
min_depth=1.0,
max_depth=10.0,
):
"""
Tests the shapes and outputs of MC and GridRaysamplers for randomly
generated cameras and different numbers of points per ray.
"""
device = torch.device("cuda")
for n_pts_per_ray in (100, 1):
for raysampler_type in (
MonteCarloRaysampler,
MultinomialRaysampler,
NDCMultinomialRaysampler,
):
raysampler = TestRaysampling.init_raysampler(
raysampler_type=raysampler_type,
min_x=min_x,
max_x=max_x,
min_y=min_y,
max_y=max_y,
image_width=image_width,
image_height=image_height,
min_depth=min_depth,
max_depth=max_depth,
n_pts_per_ray=n_pts_per_ray,
)
if issubclass(raysampler_type, NDCMultinomialRaysampler):
# adjust the gt bounds for NDCGridRaysampler
if image_width >= image_height:
range_x = image_width / image_height
range_y = 1.0
else:
range_x = 1.0
range_y = image_height / image_width
half_pix_width = range_x / image_width
half_pix_height = range_y / image_height
min_x_ = range_x - half_pix_width
max_x_ = -range_x + half_pix_width
min_y_ = range_y - half_pix_height
max_y_ = -range_y + half_pix_height
else:
min_x_ = min_x
max_x_ = max_x
min_y_ = min_y
max_y_ = max_y
# carry out the test over several camera types
for cam_type in (
FoVPerspectiveCameras,
FoVOrthographicCameras,
OrthographicCameras,
PerspectiveCameras,
):
# init a batch of random cameras
cameras = init_random_cameras(
cam_type, batch_size, random_z=True
).to(device)
# call the raysampler
ray_bundle = raysampler(cameras=cameras)
# check the shapes of the raysampler outputs
self._check_raysampler_output_shapes(
raysampler,
ray_bundle,
batch_size,
image_width,
image_height,
n_pts_per_ray,
)
# check the points sampled along each ray
self._check_raysampler_ray_points(
raysampler,
cameras,
ray_bundle,
min_x_,
max_x_,
min_y_,
max_y_,
image_width,
image_height,
min_depth,
max_depth,
)
# check the output direction vectors
self._check_raysampler_ray_directions(
cameras, raysampler, ray_bundle
)
def _check_grid_shape(self, grid, batch_size, spatial_size, n_pts_per_ray, dim):
"""
A helper for checking the desired size of a variable output by a raysampler.
"""
tgt_shape = [
x for x in [batch_size, *spatial_size, n_pts_per_ray, dim] if x > 0
]
self.assertTrue(all(sz1 == sz2 for sz1, sz2 in zip(grid.shape, tgt_shape)))
def _check_raysampler_output_shapes(
self,
raysampler,
ray_bundle,
batch_size,
image_width,
image_height,
n_pts_per_ray,
):
"""
Checks the shapes of raysampler outputs.
"""
if isinstance(raysampler, MultinomialRaysampler):
spatial_size = [image_height, image_width]
elif isinstance(raysampler, MonteCarloRaysampler):
spatial_size = [image_height * image_width]
else:
raise ValueError(str(type(raysampler)))
self._check_grid_shape(ray_bundle.xys, batch_size, spatial_size, 0, 2)
self._check_grid_shape(ray_bundle.origins, batch_size, spatial_size, 0, 3)
self._check_grid_shape(ray_bundle.directions, batch_size, spatial_size, 0, 3)
self._check_grid_shape(
ray_bundle.lengths, batch_size, spatial_size, n_pts_per_ray, 0
)
def _check_raysampler_ray_points(
self,
raysampler,
cameras,
ray_bundle,
min_x,
max_x,
min_y,
max_y,
image_width,
image_height,
min_depth,
max_depth,
):
"""
Check rays_points_world and rays_zs outputs of raysamplers.
"""
batch_size = cameras.R.shape[0]
# convert to ray points
rays_points_world = ray_bundle_variables_to_ray_points(
ray_bundle.origins, ray_bundle.directions, ray_bundle.lengths
)
n_pts_per_ray = rays_points_world.shape[-2]
# check that the outputs if ray_bundle_variables_to_ray_points and
# ray_bundle_to_ray_points match
rays_points_world_ = ray_bundle_to_ray_points(ray_bundle)
self.assertClose(rays_points_world, rays_points_world_)
# check that the depth of each ray point in camera coords
# matches the expected linearly-spaced depth
depth_expected = torch.linspace(
min_depth,
max_depth,
n_pts_per_ray,
dtype=torch.float32,
device=rays_points_world.device,
)
ray_points_camera = (
cameras.get_world_to_view_transform()
.transform_points(rays_points_world.view(batch_size, -1, 3))
.view(batch_size, -1, n_pts_per_ray, 3)
)
self.assertClose(
ray_points_camera[..., 2],
depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]),
atol=1e-4,
)
# check also that rays_zs is consistent with depth_expected
self.assertClose(
ray_bundle.lengths.view(batch_size, -1, n_pts_per_ray),
depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]),
atol=1e-6,
)
# project the world ray points back to screen space
ray_points_projected = cameras.transform_points(
rays_points_world.view(batch_size, -1, 3)
).view(rays_points_world.shape)
# check that ray_xy matches rays_points_projected xy
rays_xy_projected = ray_points_projected[..., :2].view(
batch_size, -1, n_pts_per_ray, 2
)
self.assertClose(
ray_bundle.xys.view(batch_size, -1, 1, 2).expand_as(rays_xy_projected),
rays_xy_projected,
atol=1e-4,
)
# check that projected world points' xy coordinates
# range correctly between [minx/y, max/y]
if isinstance(raysampler, MultinomialRaysampler):
# get the expected coordinates along each grid axis
ys, xs = [
torch.linspace(
low, high, sz, dtype=torch.float32, device=rays_points_world.device
)
for low, high, sz in (
(min_y, max_y, image_height),
(min_x, max_x, image_width),
)
]
# compare expected xy with the output xy
for dim, gt_axis in zip(
(0, 1), (xs[None, None, :, None], ys[None, :, None, None])
):
self.assertClose(
ray_points_projected[..., dim],
gt_axis.expand_as(ray_points_projected[..., dim]),
atol=1e-4,
)
elif isinstance(raysampler, MonteCarloRaysampler):
# check that the randomly sampled locations
# are within the allowed bounds for both x and y axes
for dim, axis_bounds in zip((0, 1), ((min_x, max_x), (min_y, max_y))):
self.assertTrue(
(
(ray_points_projected[..., dim] <= axis_bounds[1])
& (ray_points_projected[..., dim] >= axis_bounds[0])
).all()
)
# also check that x,y along each ray is constant
if n_pts_per_ray > 1:
self.assertClose(
ray_points_projected[..., :2].std(dim=-2),
torch.zeros_like(ray_points_projected[..., 0, :2]),
atol=1e-5,
)
else:
raise ValueError(str(type(raysampler)))
def _check_raysampler_ray_directions(self, cameras, raysampler, ray_bundle):
"""
Check the rays_directions_world output of raysamplers.
"""
batch_size = cameras.R.shape[0]
n_pts_per_ray = ray_bundle.lengths.shape[-1]
spatial_size = ray_bundle.xys.shape[1:-1]
n_rays_per_image = spatial_size.numel()
# obtain the ray points in world coords
rays_points_world = cameras.unproject_points(
torch.cat(
(
ray_bundle.xys.view(batch_size, n_rays_per_image, 1, 2).expand(
batch_size, n_rays_per_image, n_pts_per_ray, 2
),
ray_bundle.lengths.view(
batch_size, n_rays_per_image, n_pts_per_ray, 1
),
),
dim=-1,
).view(batch_size, -1, 3)
).view(batch_size, -1, n_pts_per_ray, 3)
# reshape to common testing size
rays_directions_world_normed = torch.nn.functional.normalize(
ray_bundle.directions.view(batch_size, -1, 3), dim=-1
)
# check that the l2-normed difference of all consecutive planes
# of points in world coords matches ray_directions_world
rays_directions_world_ = torch.nn.functional.normalize(
rays_points_world[:, :, -1:] - rays_points_world[:, :, :-1], dim=-1
)
self.assertClose(
rays_directions_world_normed[:, :, None].expand_as(rays_directions_world_),
rays_directions_world_,
atol=1e-4,
)
# check the ray directions rotated using camera rotation matrix
# match the ray directions of a camera with trivial extrinsics
cameras_trivial_extrinsic = cameras.clone()
cameras_trivial_extrinsic.R = eyes(
N=batch_size, dim=3, dtype=cameras.R.dtype, device=cameras.device
)
cameras_trivial_extrinsic.T = torch.zeros_like(cameras.T)
# make sure we get the same random rays in case we call the
# MonteCarloRaysampler twice below
with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
torch.random.manual_seed(42)
ray_bundle_world_fix_seed = raysampler(cameras=cameras)
torch.random.manual_seed(42)
ray_bundle_camera_fix_seed = raysampler(cameras=cameras_trivial_extrinsic)
rays_directions_camera_fix_seed_ = Rotate(
cameras.R, device=cameras.R.device
).transform_points(ray_bundle_world_fix_seed.directions.view(batch_size, -1, 3))
self.assertClose(
rays_directions_camera_fix_seed_,
ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3),
atol=1e-5,
)
@unittest.skipIf(
torch.__version__[:4] == "1.5.", "non persistent buffer needs PyTorch 1.6"
)
def test_load_state_different_resolution(self):
# check that we can load the state of one ray sampler into
# another with different image size.
module1 = NDCGridRaysampler(
image_width=20,
image_height=30,
n_pts_per_ray=40,
min_depth=1.2,
max_depth=2.3,
)
module2 = NDCGridRaysampler(
image_width=22,
image_height=32,
n_pts_per_ray=42,
min_depth=1.2,
max_depth=2.3,
)
state = module1.state_dict()
module2.load_state_dict(state)
def test_jiggle(self):
# random data which is in ascending order along the last dimension
scale = 180
data = scale * torch.cumsum(torch.rand(8, 3, 4, 20), dim=-1)
out = _jiggle_within_stratas(data)
self.assertTupleEqual(out.shape, data.shape)
# Check `out` is in ascending order
self.assertGreater((out[..., 1:] - out[..., :-1]).min(), 0)
self.assertConstant(out[..., :-1] < data[..., 1:], True)
self.assertConstant(data[..., :-1] < out[..., 1:], True)
jiggles = out - data
# jiggles is random between -scale/2 and scale/2
self.assertLess(jiggles.min(), -0.4 * scale)
self.assertGreater(jiggles.min(), -0.5 * scale)
self.assertGreater(jiggles.max(), 0.4 * scale)
self.assertLess(jiggles.max(), 0.5 * scale)
def test_safe_multinomial(self):
mask = [
[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 1, 0],
]
tmask = torch.tensor(mask, dtype=torch.float32)
for _ in range(5):
random_scalar = torch.rand(1)
samples = _safe_multinomial(tmask * random_scalar, 3)
self.assertTupleEqual(samples.shape, (4, 3))
# samples[0] is exactly determined
self.assertConstant(samples[0], 0)
self.assertGreaterEqual(samples[1].min(), 0)
self.assertLessEqual(samples[1].max(), 1)
# samples[2] is exactly determined
self.assertSetEqual(set(samples[2].tolist()), {0, 1, 2})
# samples[3] has enough sources, so must contain 3 distinct values.
self.assertLessEqual(samples[3].max(), 3)
self.assertEqual(len(set(samples[3].tolist())), 3)
def test_heterogeneous_sampling(self, batch_size=8):
"""
Test that the output of heterogeneous sampling has the first dimension equal
to n_rays_total and second to 1 and that ray_bundle elements from different
sampled cameras are different and equal for same sampled cameras.
"""
cameras = init_random_cameras(PerspectiveCameras, batch_size, random_z=True)
for n_rays_total in [2, 3, 17, 21, 32]:
for cls in (MultinomialRaysampler, MonteCarloRaysampler):
with self.subTest(cls.__name__ + ", n_rays_total=" + str(n_rays_total)):
raysampler = self.init_raysampler(
cls, n_rays_total=n_rays_total, n_rays_per_image=None
)
ray_bundle = raysampler(cameras)
# test weather they are of the correct shape
for attr in ("origins", "directions", "lengths", "xys"):
tensor = getattr(ray_bundle, attr)
assert tensor.shape[:2] == torch.Size(
(n_rays_total, 1)
), tensor.shape
# if two camera ids are same than origins should also be the same
# directions and xys are always different and lengths equal
for i1, (origin1, dir1, len1, id1) in enumerate(
zip(
ray_bundle.origins,
ray_bundle.directions,
ray_bundle.lengths,
torch.repeat_interleave(
ray_bundle.camera_ids, ray_bundle.camera_counts
),
)
):
for i2, (origin2, dir2, len2, id2) in enumerate(
zip(
ray_bundle.origins,
ray_bundle.directions,
ray_bundle.lengths,
torch.repeat_interleave(
ray_bundle.camera_ids, ray_bundle.camera_counts
),
)
):
if i1 == i2:
continue
assert torch.allclose(
origin1, origin2, rtol=1e-4, atol=1e-4
) == (id1 == id2), (origin1, origin2, id1, id2)
assert not torch.allclose(dir1, dir2), (dir1, dir2)
self.assertClose(len1, len2)