34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
|
|
|
from .common_testing import TestCaseMixin
|
|
|
|
|
|
class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
torch.manual_seed(42)
|
|
|
|
def test_simple(self):
|
|
x = torch.rand(4, 6, 7, 3)
|
|
y = torch.rand(4, 6, 4)
|
|
|
|
linear = torch.nn.Linear(7, 8)
|
|
torch.nn.init.xavier_uniform_(linear.weight.data)
|
|
linear.bias.data.uniform_()
|
|
equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1)
|
|
expected = linear.forward(equivalent)
|
|
|
|
linear_with_repeat = LinearWithRepeat(7, 8)
|
|
linear_with_repeat.load_state_dict(linear.state_dict())
|
|
actual = linear_with_repeat.forward((x, y))
|
|
self.assertClose(actual, expected, rtol=1e-4)
|