227 lines
7.7 KiB
Python
227 lines
7.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
import unittest
|
|
|
|
import torch
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
|
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
|
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
|
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
|
from tests.common_testing import get_pytorch3d_dir
|
|
|
|
from .common_resources import provide_resnet34
|
|
|
|
IMPLICITRON_CONFIGS_DIR = (
|
|
get_pytorch3d_dir() / "projects" / "implicitron_trainer" / "configs"
|
|
)
|
|
|
|
|
|
class TestGenericModel(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
provide_resnet34()
|
|
|
|
def setUp(self):
|
|
torch.manual_seed(42)
|
|
|
|
def test_gm(self):
|
|
# Simple test of a forward and backward pass of the default GenericModel.
|
|
device = torch.device("cuda:0")
|
|
expand_args_fields(GenericModel)
|
|
model = GenericModel(render_image_height=80, render_image_width=80)
|
|
model.to(device)
|
|
self._one_model_test(model, device)
|
|
|
|
def test_all_gm_configs(self):
|
|
# Tests all model settings in the implicitron_trainer config folder.
|
|
device = torch.device("cuda:0")
|
|
config_files = []
|
|
|
|
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"):
|
|
config_files.extend(
|
|
[
|
|
f
|
|
for f in IMPLICITRON_CONFIGS_DIR.glob(pattern)
|
|
if not f.name.endswith("_base.yaml")
|
|
]
|
|
)
|
|
|
|
for config_file in config_files:
|
|
with self.subTest(name=config_file.stem):
|
|
cfg = _load_model_config_from_yaml(str(config_file))
|
|
cfg.render_image_height = 80
|
|
cfg.render_image_width = 80
|
|
model = GenericModel(**cfg)
|
|
model.to(device)
|
|
self._one_model_test(
|
|
model,
|
|
device,
|
|
eval_test=True,
|
|
bw_test=True,
|
|
)
|
|
|
|
def _one_model_test(
|
|
self,
|
|
model,
|
|
device,
|
|
n_train_cameras: int = 5,
|
|
eval_test: bool = True,
|
|
bw_test: bool = True,
|
|
):
|
|
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
|
cameras = PerspectiveCameras(R=R, T=T, device=device)
|
|
|
|
N, H, W = n_train_cameras, model.render_image_height, model.render_image_width
|
|
|
|
random_args = {
|
|
"camera": cameras,
|
|
"fg_probability": _random_input_tensor(N, 1, H, W, True, device),
|
|
"depth_map": _random_input_tensor(N, 1, H, W, False, device) + 0.1,
|
|
"mask_crop": _random_input_tensor(N, 1, H, W, True, device),
|
|
"sequence_name": ["sequence"] * N,
|
|
"image_rgb": _random_input_tensor(N, 3, H, W, False, device),
|
|
}
|
|
|
|
# training foward pass
|
|
model.train()
|
|
train_preds = model(
|
|
**random_args,
|
|
evaluation_mode=EvaluationMode.TRAINING,
|
|
)
|
|
self.assertTrue(
|
|
train_preds["objective"].isfinite().item()
|
|
) # check finiteness of the objective
|
|
|
|
if bw_test:
|
|
train_preds["objective"].backward()
|
|
|
|
if eval_test:
|
|
model.eval()
|
|
with torch.no_grad():
|
|
eval_preds = model(
|
|
**random_args,
|
|
evaluation_mode=EvaluationMode.EVALUATION,
|
|
)
|
|
self.assertEqual(
|
|
eval_preds["images_render"].shape,
|
|
(1, 3, model.render_image_height, model.render_image_width),
|
|
)
|
|
|
|
def test_idr(self):
|
|
# Forward pass of GenericModel with IDR.
|
|
device = torch.device("cuda:0")
|
|
args = get_default_args(GenericModel)
|
|
args.renderer_class_type = "SignedDistanceFunctionRenderer"
|
|
args.implicit_function_class_type = "IdrFeatureField"
|
|
args.implicit_function_IdrFeatureField_args.n_harmonic_functions_xyz = 6
|
|
|
|
model = GenericModel(**args)
|
|
model.to(device)
|
|
|
|
n_train_cameras = 2
|
|
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
|
cameras = PerspectiveCameras(R=R, T=T, device=device)
|
|
|
|
defaulted_args = {
|
|
"depth_map": None,
|
|
"mask_crop": None,
|
|
"sequence_name": None,
|
|
}
|
|
|
|
target_image_rgb = torch.rand(
|
|
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
|
|
device=device,
|
|
)
|
|
fg_probability = torch.rand(
|
|
(n_train_cameras, 1, model.render_image_height, model.render_image_width),
|
|
device=device,
|
|
)
|
|
train_preds = model(
|
|
camera=cameras,
|
|
evaluation_mode=EvaluationMode.TRAINING,
|
|
image_rgb=target_image_rgb,
|
|
fg_probability=fg_probability,
|
|
**defaulted_args,
|
|
)
|
|
self.assertGreater(train_preds["objective"].item(), 0)
|
|
|
|
def test_viewpool(self):
|
|
device = torch.device("cuda:0")
|
|
args = get_default_args(GenericModel)
|
|
args.view_pooler_enabled = True
|
|
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
|
|
args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False
|
|
model = GenericModel(**args)
|
|
model.to(device)
|
|
|
|
n_train_cameras = 2
|
|
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
|
cameras = PerspectiveCameras(R=R, T=T, device=device)
|
|
|
|
defaulted_args = {
|
|
"fg_probability": None,
|
|
"depth_map": None,
|
|
"mask_crop": None,
|
|
}
|
|
|
|
target_image_rgb = torch.rand(
|
|
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
|
|
device=device,
|
|
)
|
|
train_preds = model(
|
|
camera=cameras,
|
|
evaluation_mode=EvaluationMode.TRAINING,
|
|
image_rgb=target_image_rgb,
|
|
sequence_name=["a"] * n_train_cameras,
|
|
**defaulted_args,
|
|
)
|
|
self.assertGreater(train_preds["objective"].item(), 0)
|
|
|
|
|
|
def _random_input_tensor(
|
|
N: int,
|
|
C: int,
|
|
H: int,
|
|
W: int,
|
|
is_binary: bool,
|
|
device: torch.device,
|
|
) -> torch.Tensor:
|
|
T = torch.rand(N, C, H, W, device=device)
|
|
if is_binary:
|
|
T = (T > 0.5).float()
|
|
return T
|
|
|
|
|
|
def _load_model_config_from_yaml(config_path, strict=True) -> DictConfig:
|
|
default_cfg = get_default_args(GenericModel)
|
|
cfg = _load_model_config_from_yaml_rec(default_cfg, config_path)
|
|
return cfg
|
|
|
|
|
|
def _load_model_config_from_yaml_rec(cfg: DictConfig, config_path: str) -> DictConfig:
|
|
cfg_loaded = OmegaConf.load(config_path)
|
|
cfg_model_loaded = None
|
|
if "model_factory_ImplicitronModelFactory_args" in cfg_loaded:
|
|
factory_args = cfg_loaded.model_factory_ImplicitronModelFactory_args
|
|
if "model_GenericModel_args" in factory_args:
|
|
cfg_model_loaded = factory_args.model_GenericModel_args
|
|
defaults = cfg_loaded.pop("defaults", None)
|
|
if defaults is not None:
|
|
for default_name in defaults:
|
|
if default_name in ("_self_", "default_config"):
|
|
continue
|
|
default_name = os.path.splitext(default_name)[0]
|
|
defpath = os.path.join(os.path.dirname(config_path), default_name + ".yaml")
|
|
cfg = _load_model_config_from_yaml_rec(cfg, defpath)
|
|
if cfg_model_loaded is not None:
|
|
cfg = OmegaConf.merge(cfg, cfg_model_loaded)
|
|
elif cfg_model_loaded is not None:
|
|
cfg = OmegaConf.merge(cfg, cfg_model_loaded)
|
|
return cfg
|