loading llff and blender datasets

Summary: Copy code from NeRF for loading LLFF data and blender synthetic data, and create dataset objects for them

Reviewed By: shapovalov

Differential Revision: D35581039

fbshipit-source-id: af7a6f3e9a42499700693381b5b147c991f57e5d
This commit is contained in:
Jeremy Reizenstein 2022-06-16 03:09:15 -07:00 committed by Facebook GitHub Bot
parent 7978ffd1e4
commit 65f667fd2e
16 changed files with 992 additions and 67 deletions

View File

@ -46,3 +46,26 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
NeRF https://github.com/bmild/nerf/
Copyright (c) 2020 bmild
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -5,7 +5,7 @@ Implicitron is a PyTorch3D-based framework for new-view synthesis via modeling t
# License
Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE).
It includes code from [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
It includes code from the [NeRF](https://github.com/bmild/nerf), [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses.

View File

@ -315,7 +315,7 @@ def trainvalidate(
epoch,
loader,
optimizer,
validation,
validation: bool,
bp_var: str = "objective",
metric_print_interval: int = 5,
visualize_interval: int = 100,

View File

@ -95,13 +95,6 @@ generic_model_args:
append_coarse_samples_to_fine: true
density_noise_std_train: 0.0
return_weights: false
raymarcher_EmissionAbsorptionRaymarcher_args:
surface_thickness: 1
bg_color:
- 0.0
background_opacity: 10000000000.0
density_relu: true
blend_output: false
raymarcher_CumsumRaymarcher_args:
surface_thickness: 1
bg_color:
@ -109,6 +102,13 @@ generic_model_args:
background_opacity: 0.0
density_relu: true
blend_output: false
raymarcher_EmissionAbsorptionRaymarcher_args:
surface_thickness: 1
bg_color:
- 0.0
background_opacity: 10000000000.0
density_relu: true
blend_output: false
renderer_SignedDistanceFunctionRenderer_args:
render_features_dimensions: 3
ray_tracer_args:
@ -157,6 +157,21 @@ generic_model_args:
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
reduction_functions:
- AVG
- STD
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
feature_aggregator_IdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
@ -168,21 +183,6 @@ generic_model_args:
reduction_functions:
- AVG
- STD
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
reduction_functions:
- AVG
- STD
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3
@ -203,19 +203,6 @@ generic_model_args:
n_harmonic_functions_xyz: 0
pooled_feature_dim: 0
encoding_dim: 0
implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
latent_dim: 0
input_xyz: true
xyz_ray_dir_in_camera_coords: false
color_dim: 3
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
append_xyz:
- 5
implicit_function_NeRFormerImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
@ -229,24 +216,19 @@ generic_model_args:
n_layers_xyz: 2
append_xyz:
- 1
implicit_function_SRNImplicitFunction_args:
raymarch_function_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
in_features: 3
out_features: 256
latent_dim: 0
xyz_in_camera_coords: false
raymarch_function: null
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
latent_dim: 0
input_xyz: true
xyz_ray_dir_in_camera_coords: false
color_dim: 3
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
append_xyz:
- 5
implicit_function_SRNHyperNetImplicitFunction_args:
hypernet_args:
n_harmonic_functions: 3
@ -267,6 +249,24 @@ generic_model_args:
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
implicit_function_SRNImplicitFunction_args:
raymarch_function_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
in_features: 3
out_features: 256
latent_dim: 0
xyz_in_camera_coords: false
raymarch_function: null
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
solver_args:
breed: adam
weight_decay: 0.0
@ -282,6 +282,13 @@ solver_args:
data_source_args:
dataset_map_provider_class_type: ???
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
dataset_map_provider_BlenderDatasetMapProvider_args:
base_dir: ???
object_name: ???
path_manager_factory_class_type: PathManagerFactory
n_known_frames_for_test: null
path_manager_factory_PathManagerFactory_args:
silence_logs: true
dataset_map_provider_JsonIndexDatasetMapProvider_args:
category: ???
task_str: singlesequence
@ -317,6 +324,13 @@ data_source_args:
sort_frames: false
path_manager_factory_PathManagerFactory_args:
silence_logs: true
dataset_map_provider_LlffDatasetMapProvider_args:
base_dir: ???
object_name: ???
path_manager_factory_class_type: PathManagerFactory
n_known_frames_for_test: null
path_manager_factory_PathManagerFactory_args:
silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0

View File

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from pytorch3d.implicitron.tools.config import registry
from .load_blender import load_blender_data
from .single_sequence_dataset import (
_interpret_blender_cameras,
SingleSceneDatasetMapProviderBase,
)
@registry.register
class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
"""
Provides data for one scene from Blender synthetic dataset.
Uses the code in load_blender.py
Members:
base_dir: directory holding the data for the scene.
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
It will typically be equal to the name of the directory self.base_dir.
path_manager_factory: Creates path manager which may be used for
interpreting paths.
n_known_frames_for_test: If set, training frames are included in the val
and test datasets, and this many random training frames are added to
each test batch. If not set, test batches each contain just a single
testing frame.
"""
def _load_data(self) -> None:
path_manager = self.path_manager_factory.get()
images, poses, _, hwf, i_split = load_blender_data(
self.base_dir,
testskip=1,
path_manager=path_manager,
)
H, W, focal = hwf
H, W = int(H), int(W)
images = torch.from_numpy(images)
# pyre-ignore[16]
self.poses = _interpret_blender_cameras(poses, H, W, focal)
# pyre-ignore[16]
self.images = images
# pyre-ignore[16]
self.i_split = i_split

View File

@ -8,9 +8,11 @@ from typing import Tuple
from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation
from . import json_index_dataset_map_provider # noqa
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
class DataSourceBase(ReplaceableBase):

View File

@ -36,10 +36,11 @@ class FrameData(Mapping[str, Any]):
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
frame_timestamp: The time elapsed since the start of a sequence in sec.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
image_size_hw: The size of the image in pixels; (height, width) tuple.
frame_timestamp: The time elapsed since the start of a sequence in sec.
image_size_hw: The size of the image in pixels; (height, width) tensor
of shape (2,).
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
@ -81,9 +82,9 @@ class FrameData(Mapping[str, Any]):
"""
frame_number: Optional[torch.LongTensor]
frame_timestamp: Optional[torch.Tensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
frame_timestamp: Optional[torch.Tensor] = None
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
@ -101,7 +102,7 @@ class FrameData(Mapping[str, Any]):
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # seen | unseen
frame_type: Union[str, List[str], None] = None # known | unseen
meta: dict = field(default_factory=lambda: {})
def to(self, *args, **kwargs):

View File

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from pytorch3d.implicitron.tools.config import registry
from .load_llff import load_llff_data
from .single_sequence_dataset import (
_interpret_blender_cameras,
SingleSceneDatasetMapProviderBase,
)
@registry.register
class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
"""
Provides data for one scene from the LLFF dataset.
Members:
base_dir: directory holding the data for the scene.
object_name: The name of the scene (e.g. "fern"). This is just used as a label.
It will typically be equal to the name of the directory self.base_dir.
path_manager_factory: Creates path manager which may be used for
interpreting paths.
n_known_frames_for_test: If set, training frames are included in the val
and test datasets, and this many random training frames are added to
each test batch. If not set, test batches each contain just a single
testing frame.
"""
def _load_data(self) -> None:
path_manager = self.path_manager_factory.get()
images, poses, _ = load_llff_data(
self.base_dir, factor=8, path_manager=path_manager
)
hwf = poses[0, :3, -1]
poses = poses[:, :3, :4]
i_test = np.arange(images.shape[0])[::8]
i_test_index = set(i_test.tolist())
i_train = np.array(
[i for i in np.arange(images.shape[0]) if i not in i_test_index]
)
i_split = (i_train, i_test, i_test)
H, W, focal = hwf
H, W = int(H), int(W)
images = torch.from_numpy(images)
poses = torch.from_numpy(poses)
# pyre-ignore[16]
self.poses = _interpret_blender_cameras(poses, H, W, focal)
# pyre-ignore[16]
self.images = images
# pyre-ignore[16]
self.i_split = i_split

View File

@ -0,0 +1,131 @@
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/bmild/nerf/blob/master/load_blender.py
# Copyright (c) 2020 bmild
import json
import os
import numpy as np
import torch
from PIL import Image
def translate_by_t_along_z(t):
tform = np.eye(4).astype(np.float32)
tform[2][3] = t
return tform
def rotate_by_phi_along_x(phi):
tform = np.eye(4).astype(np.float32)
tform[1, 1] = tform[2, 2] = np.cos(phi)
tform[1, 2] = -np.sin(phi)
tform[2, 1] = -tform[1, 2]
return tform
def rotate_by_theta_along_y(theta):
tform = np.eye(4).astype(np.float32)
tform[0, 0] = tform[2, 2] = np.cos(theta)
tform[0, 2] = -np.sin(theta)
tform[2, 0] = -tform[0, 2]
return tform
def pose_spherical(theta, phi, radius):
c2w = translate_by_t_along_z(radius)
c2w = rotate_by_phi_along_x(phi / 180.0 * np.pi) @ c2w
c2w = rotate_by_theta_along_y(theta / 180 * np.pi) @ c2w
c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
return c2w
def _local_path(path_manager, path):
if path_manager is None:
return path
return path_manager.get_local_path(path)
def load_blender_data(
basedir, half_res=False, testskip=1, debug=False, path_manager=None
):
splits = ["train", "val", "test"]
metas = {}
for s in splits:
path = os.path.join(basedir, f"transforms_{s}.json")
with open(_local_path(path_manager, path)) as fp:
metas[s] = json.load(fp)
all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s == "train" or testskip == 0:
skip = 1
else:
skip = testskip
for frame in meta["frames"][::skip]:
fname = os.path.join(basedir, frame["file_path"] + ".png")
imgs.append(np.array(Image.open(_local_path(path_manager, fname))))
poses.append(np.array(frame["transform_matrix"]))
imgs = (np.array(imgs) / 255.0).astype(np.float32)
poses = np.array(poses).astype(np.float32)
counts.append(counts[-1] + imgs.shape[0])
all_imgs.append(imgs)
all_poses.append(poses)
i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate(all_poses, 0)
H, W = imgs[0].shape[:2]
camera_angle_x = float(meta["camera_angle_x"])
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
render_poses = torch.stack(
[
torch.from_numpy(pose_spherical(angle, -30.0, 4.0))
for angle in np.linspace(-180, 180, 40 + 1)[:-1]
],
0,
)
# In debug mode, return extremely tiny images
if debug:
import cv2
H = H // 32
W = W // 32
focal = focal / 32.0
imgs = [
torch.from_numpy(
cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA)
)
for i in range(imgs.shape[0])
]
imgs = torch.stack(imgs, 0)
poses = torch.from_numpy(poses)
return imgs, poses, render_poses, [H, W, focal], i_split
if half_res:
import cv2
# TODO: resize images using INTER_AREA (cv2)
H = H // 2
W = W // 2
focal = focal / 2.0
imgs = [
torch.from_numpy(
cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA)
)
for i in range(imgs.shape[0])
]
imgs = torch.stack(imgs, 0)
poses = torch.from_numpy(poses)
return imgs, poses, render_poses, [H, W, focal], i_split

View File

@ -0,0 +1,343 @@
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/bmild/nerf/blob/master/load_llff.py
# Copyright (c) 2020 bmild
import logging
import os
import warnings
import numpy as np
from PIL import Image
# Slightly modified version of LLFF data loading code
# see https://github.com/Fyusion/LLFF for original
logger = logging.getLogger(__name__)
def _minify(basedir, path_manager, factors=(), resolutions=()):
needtoload = False
for r in factors:
imgdir = os.path.join(basedir, "images_{}".format(r))
if not _exists(path_manager, imgdir):
needtoload = True
for r in resolutions:
imgdir = os.path.join(basedir, "images_{}x{}".format(r[1], r[0]))
if not _exists(path_manager, imgdir):
needtoload = True
if not needtoload:
return
assert path_manager is None
from subprocess import check_output
imgdir = os.path.join(basedir, "images")
imgs = [os.path.join(imgdir, f) for f in sorted(_ls(path_manager, imgdir))]
imgs = [
f
for f in imgs
if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]])
]
imgdir_orig = imgdir
wd = os.getcwd()
for r in factors + resolutions:
if isinstance(r, int):
name = "images_{}".format(r)
resizearg = "{}%".format(100.0 / r)
else:
name = "images_{}x{}".format(r[1], r[0])
resizearg = "{}x{}".format(r[1], r[0])
imgdir = os.path.join(basedir, name)
if os.path.exists(imgdir):
continue
logger.info(f"Minifying {r}, {basedir}")
os.makedirs(imgdir)
check_output("cp {}/* {}".format(imgdir_orig, imgdir), shell=True)
ext = imgs[0].split(".")[-1]
args = " ".join(
["mogrify", "-resize", resizearg, "-format", "png", "*.{}".format(ext)]
)
logger.info(args)
os.chdir(imgdir)
check_output(args, shell=True)
os.chdir(wd)
if ext != "png":
check_output("rm {}/*.{}".format(imgdir, ext), shell=True)
logger.info("Removed duplicates")
logger.info("Done")
def _load_data(
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
):
poses_arr = np.load(
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
)
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
bds = poses_arr[:, -2:].transpose([1, 0])
img0 = [
os.path.join(basedir, "images", f)
for f in sorted(_ls(path_manager, os.path.join(basedir, "images")))
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
][0]
def imread(f):
return np.array(Image.open(f))
sh = imread(_local_path(path_manager, img0)).shape
sfx = ""
if factor is not None:
sfx = "_{}".format(factor)
_minify(basedir, path_manager, factors=[factor])
factor = factor
elif height is not None:
factor = sh[0] / float(height)
width = int(sh[1] / factor)
_minify(basedir, path_manager, resolutions=[[height, width]])
sfx = "_{}x{}".format(width, height)
elif width is not None:
factor = sh[1] / float(width)
height = int(sh[0] / factor)
_minify(basedir, path_manager, resolutions=[[height, width]])
sfx = "_{}x{}".format(width, height)
else:
factor = 1
imgdir = os.path.join(basedir, "images" + sfx)
if not _exists(path_manager, imgdir):
raise ValueError(f"{imgdir} does not exist, returning")
imgfiles = [
_local_path(path_manager, os.path.join(imgdir, f))
for f in sorted(_ls(path_manager, imgdir))
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
]
if poses.shape[-1] != len(imgfiles):
raise ValueError(
"Mismatch between imgs {} and poses {} !!!!".format(
len(imgfiles), poses.shape[-1]
)
)
sh = imread(imgfiles[0]).shape
poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
poses[2, 4, :] = poses[2, 4, :] * 1.0 / factor
if not load_imgs:
return poses, bds
imgs = imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles]
imgs = np.stack(imgs, -1)
logger.info(f"Loaded image data, shape {imgs.shape}")
return poses, bds, imgs
def normalize(x):
denom = np.linalg.norm(x)
if denom < 0.001:
warnings.warn("unsafe normalize()")
return x / denom
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, pos], 1)
return m
def ptstocam(pts, c2w):
tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]
return tt
def poses_avg(poses):
hwf = poses[0, :3, -1:]
center = poses[:, :3, 3].mean(0)
vec2 = normalize(poses[:, :3, 2].sum(0))
up = poses[:, :3, 1].sum(0)
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
return c2w
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
render_poses = []
rads = np.array(list(rads) + [1.0])
hwf = c2w[:, 4:5]
for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]:
c = np.dot(
c2w[:3, :4],
np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0])
* rads,
)
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
return render_poses
def recenter_poses(poses):
poses_ = poses + 0
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
c2w = poses_avg(poses)
c2w = np.concatenate([c2w[:3, :4], bottom], -2)
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
poses = np.concatenate([poses[:, :3, :4], bottom], -2)
poses = np.linalg.inv(c2w) @ poses
poses_[:, :3, :4] = poses[:, :3, :4]
poses = poses_
return poses
def spherify_poses(poses, bds):
def add_row_to_homogenize_transform(p):
r"""Add the last row to homogenize 3 x 4 transformation matrices."""
return np.concatenate(
[p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
)
# p34_to_44 = lambda p: np.concatenate(
# [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
# )
p34_to_44 = add_row_to_homogenize_transform
rays_d = poses[:, :3, 2:3]
rays_o = poses[:, :3, 3:4]
def min_line_dist(rays_o, rays_d):
A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
b_i = -A_i @ rays_o
pt_mindist = np.squeeze(
-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)
)
return pt_mindist
pt_mindist = min_line_dist(rays_o, rays_d)
center = pt_mindist
up = (poses[:, :3, 3] - center).mean(0)
vec0 = normalize(up)
vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0))
vec2 = normalize(np.cross(vec0, vec1))
pos = center
c2w = np.stack([vec1, vec2, vec0, pos], 1)
poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
sc = 1.0 / rad
poses_reset[:, :3, 3] *= sc
bds *= sc
rad *= sc
centroid = np.mean(poses_reset[:, :3, 3], 0)
zh = centroid[2]
radcircle = np.sqrt(rad**2 - zh**2)
new_poses = []
for th in np.linspace(0.0, 2.0 * np.pi, 120):
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
up = np.array([0, 0, -1.0])
vec2 = normalize(camorigin)
vec0 = normalize(np.cross(vec2, up))
vec1 = normalize(np.cross(vec2, vec0))
pos = camorigin
p = np.stack([vec0, vec1, vec2, pos], 1)
new_poses.append(p)
new_poses = np.stack(new_poses, 0)
new_poses = np.concatenate(
[new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1
)
poses_reset = np.concatenate(
[
poses_reset[:, :3, :4],
np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape),
],
-1,
)
return poses_reset, new_poses, bds
def _local_path(path_manager, path):
if path_manager is None:
return path
return path_manager.get_local_path(path)
def _ls(path_manager, path):
if path_manager is None:
return os.path.listdir(path)
return path_manager.ls(path)
def _exists(path_manager, path):
if path_manager is None:
return os.path.exists(path)
return path_manager.exists(path)
def load_llff_data(
basedir,
factor=8,
recenter=True,
bd_factor=0.75,
spherify=False,
path_zflat=False,
path_manager=None,
):
poses, bds, imgs = _load_data(
basedir, factor=factor, path_manager=path_manager
) # factor=8 downsamples original imgs by 8x
logger.info(f"Loaded {basedir}, {bds.min()}, {bds.max()}")
# Correct rotation matrix ordering and move variable dim to axis 0
poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
images = imgs
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
# Rescale if bd_factor is provided
sc = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor)
poses[:, :3, 3] *= sc
bds *= sc
if recenter:
poses = recenter_poses(poses)
if spherify:
poses, render_poses, bds = spherify_poses(poses, bds)
images = images.astype(np.float32)
poses = poses.astype(np.float32)
return images, poses, bds

View File

@ -0,0 +1,181 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This file defines a base class for dataset map providers which
# provide data for a single scene.
from dataclasses import field
from typing import Iterable, List, Optional
import numpy as np
import torch
from pytorch3d.implicitron.tools.config import (
Configurable,
expand_args_fields,
run_auto_creation,
)
from pytorch3d.renderer import PerspectiveCameras
from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
class SingleSceneDataset(DatasetBase, Configurable):
"""
A dataset from images from a single scene.
"""
images: List[torch.Tensor] = field()
poses: List[PerspectiveCameras] = field()
object_name: str = field()
frame_types: List[str] = field()
eval_batches: Optional[List[List[int]]] = field()
def sequence_names(self) -> Iterable[str]:
return [_SINGLE_SEQUENCE_NAME]
def __len__(self) -> int:
return len(self.poses)
def __getitem__(self, index) -> FrameData:
if index >= len(self):
raise IndexError(f"index {index} out of range {len(self)}")
image = self.images[index]
pose = self.poses[index]
frame_type = self.frame_types[index]
frame_data = FrameData(
frame_number=index,
sequence_name=_SINGLE_SEQUENCE_NAME,
sequence_category=self.object_name,
camera=pose,
image_size_hw=torch.tensor(image.shape[1:]),
image_rgb=image,
frame_type=frame_type,
)
return frame_data
def get_eval_batches(self) -> Optional[List[List[int]]]:
return self.eval_batches
# pyre-fixme[13]: Uninitialized attribute
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
"""
Base for provider of data for one scene from LLFF or blender datasets.
Members:
base_dir: directory holding the data for the scene.
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
It will typically be equal to the name of the directory self.base_dir.
path_manager_factory: Creates path manager which may be used for
interpreting paths.
n_known_frames_for_test: If set, training frames are included in the val
and test datasets, and this many random training frames are added to
each test batch. If not set, test batches each contain just a single
testing frame.
"""
base_dir: str
object_name: str
path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
n_known_frames_for_test: Optional[int] = None
def __post_init__(self) -> None:
run_auto_creation(self)
self._load_data()
def _load_data(self) -> None:
# This must be defined by each subclass,
# and should set poses, images and i_split on self.
raise NotImplementedError
def _get_dataset(
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
) -> SingleSceneDataset:
expand_args_fields(SingleSceneDataset)
# pyre-ignore[16]
split = self.i_split[split_idx]
frame_types = [frame_type] * len(split)
eval_batches = [[i] for i in range(len(split))]
if split_idx != 0 and self.n_known_frames_for_test is not None:
train_split = self.i_split[0]
if set_eval_batches:
generator = np.random.default_rng(seed=0)
for batch in eval_batches:
to_add = generator.choice(
len(train_split), self.n_known_frames_for_test
)
batch.extend((to_add + len(split)).tolist())
split = np.concatenate([split, train_split])
frame_types.extend([DATASET_TYPE_KNOWN] * len(train_split))
# pyre-ignore[28]
return SingleSceneDataset(
object_name=self.object_name,
# pyre-ignore[16]
images=self.images[split],
# pyre-ignore[16]
poses=[self.poses[i] for i in split],
frame_types=frame_types,
eval_batches=eval_batches if set_eval_batches else None,
)
def get_dataset_map(self) -> DatasetMap:
return DatasetMap(
train=self._get_dataset(0, DATASET_TYPE_KNOWN),
val=self._get_dataset(1, DATASET_TYPE_UNKNOWN),
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
)
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def _interpret_blender_cameras(
poses: torch.Tensor, H: int, W: int, focal: float
) -> List[PerspectiveCameras]:
"""
Convert 4x4 matrices representing cameras in blender format
to PyTorch3D format.
Args:
poses: N x 3 x 4 camera matrices
"""
pose_target_cameras = []
for pose_target in poses:
pose_target = pose_target[:3, :4]
mtx = torch.eye(4, dtype=pose_target.dtype)
mtx[:3, :3] = pose_target[:3, :3].t()
mtx[3, :3] = pose_target[:, 3]
mtx = mtx.inverse()
# flip the XZ coordinates.
mtx[:, [0, 2]] *= -1.0
Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0)
focal_length_pt3 = torch.FloatTensor([[-focal, focal]])
principal_point_pt3 = torch.FloatTensor([[W / 2, H / 2]])
cameras = PerspectiveCameras(
focal_length=focal_length_pt3,
principal_point=principal_point_pt3,
R=Rpt3[None],
T=Tpt3,
)
pose_target_cameras.append(cameras)
return pose_target_cameras

View File

@ -220,6 +220,7 @@ class Configurable:
_X = TypeVar("X", bound=ReplaceableBase)
_Y = TypeVar("Y", bound=Union[ReplaceableBase, Configurable])
class _Registry:
@ -307,20 +308,23 @@ class _Registry:
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
Returns:
list of class types
list of class types in alphabetical order of registered name.
"""
if self._is_base_class(base_class_wanted):
return list(self._mapping[base_class_wanted].values())
source = self._mapping[base_class_wanted]
return [source[key] for key in sorted(source)]
base_class = self._base_class_from_class(base_class_wanted)
if base_class is None:
raise ValueError(
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
)
source = self._mapping[base_class]
return [
class_
for class_ in self._mapping[base_class].values()
if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
source[key]
for key in sorted(source)
if issubclass(source[key], base_class_wanted)
and source[key] is not base_class_wanted
]
@staticmethod
@ -647,8 +651,8 @@ def _is_actually_dataclass(some_class) -> bool:
def expand_args_fields(
some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
) -> Type[_X]:
some_class: Type[_Y], *, _do_not_process: Tuple[type, ...] = ()
) -> Type[_Y]:
"""
This expands a class which inherits Configurable or ReplaceableBase classes,
including dataclass processing. some_class is modified in place by this function.

View File

@ -13,6 +13,7 @@ from .blending import (
from .camera_utils import join_cameras_as_batch, rotate_on_spot
from .cameras import ( # deprecated # deprecated # deprecated # deprecated
camera_position_from_spherical_angles,
CamerasBase,
FoVOrthographicCameras,
FoVPerspectiveCameras,
get_world_to_view_transform,

View File

@ -1,5 +1,12 @@
dataset_map_provider_class_type: ???
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
dataset_map_provider_BlenderDatasetMapProvider_args:
base_dir: ???
object_name: ???
path_manager_factory_class_type: PathManagerFactory
n_known_frames_for_test: null
path_manager_factory_PathManagerFactory_args:
silence_logs: true
dataset_map_provider_JsonIndexDatasetMapProvider_args:
category: ???
task_str: singlesequence
@ -35,6 +42,13 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
sort_frames: false
path_manager_factory_PathManagerFactory_args:
silence_logs: true
dataset_map_provider_LlffDatasetMapProvider_args:
base_dir: ???
object_name: ???
path_manager_factory_class_type: PathManagerFactory
n_known_frames_for_test: null
path_manager_factory_PathManagerFactory_args:
silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0

View File

@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
BlenderDatasetMapProvider,
)
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
LlffDatasetMapProvider,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from tests.common_testing import TestCaseMixin
# These tests are only run internally, where the data is available.
internal = os.environ.get("FB_TEST", False)
inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
skip_tests = not internal or inside_re_worker
@unittest.skipIf(skip_tests, "no data")
class TestDataLlff(TestCaseMixin, unittest.TestCase):
def test_synthetic(self):
expand_args_fields(BlenderDatasetMapProvider)
provider = BlenderDatasetMapProvider(
base_dir="manifold://co3d/tree/nerf_data/nerf_synthetic/lego",
object_name="lego",
)
dataset_map = provider.get_dataset_map()
for name, length in [("train", 100), ("val", 100), ("test", 200)]:
dataset = getattr(dataset_map, name)
self.assertEqual(len(dataset), length)
# try getting a value
value = dataset[0]
self.assertIsInstance(value, FrameData)
def test_llff(self):
expand_args_fields(LlffDatasetMapProvider)
provider = LlffDatasetMapProvider(
base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern",
object_name="fern",
)
dataset_map = provider.get_dataset_map()
for name, length, frame_type in [
("train", 17, "known"),
("test", 3, "unseen"),
("val", 3, "unseen"),
]:
dataset = getattr(dataset_map, name)
self.assertEqual(len(dataset), length)
# try getting a value
value = dataset[0]
self.assertIsInstance(value, FrameData)
self.assertEqual(value.frame_type, frame_type)
self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
for batch in dataset_map.test.get_eval_batches():
self.assertEqual(len(batch), 1)
self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
def test_include_known_frames(self):
expand_args_fields(LlffDatasetMapProvider)
provider = LlffDatasetMapProvider(
base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern",
object_name="fern",
n_known_frames_for_test=2,
)
dataset_map = provider.get_dataset_map()
for name, types in [
("train", ["known"] * 17),
("val", ["unseen"] * 3 + ["known"] * 17),
("test", ["unseen"] * 3 + ["known"] * 17),
]:
dataset = getattr(dataset_map, name)
self.assertEqual(len(dataset), len(types))
for i, frame_type in enumerate(types):
value = dataset[i]
self.assertEqual(value.frame_type, frame_type)
self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
for batch in dataset_map.test.get_eval_batches():
self.assertEqual(len(batch), 3)
self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
for i in batch[1:]:
self.assertEqual(dataset_map.test[i].frame_type, "known")

View File

@ -6,6 +6,7 @@
import os
import unittest
import unittest.mock
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource