Files
pytorch/test/test_modules.py
Zsolt Dollenstein b004307252 [codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle

Reviewed By: zertosh

Differential Revision: D30279364

fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
2021-08-12 10:58:35 -07:00

163 lines
6.9 KiB
Python

from unittest.mock import patch
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
freeze_rng_state,
mock_wrapper,
get_tensors_from,
)
class TestModule(TestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
precision = 1e-5
rel_tol = 1e-5
@modules(module_db)
def test_forward(self, device, dtype, module_info):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(
module_info, device=device, dtype=dtype, requires_grad=False
)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
with freeze_rng_state():
# === Instantiate the module. ===
args, kwargs = (
module_input.constructor_input.args,
module_input.constructor_input.kwargs,
)
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
# === Do forward pass. ===
args, kwargs = (
module_input.forward_input.args,
module_input.forward_input.kwargs,
)
outputs = m(*args, **kwargs)
# === Compare outputs to a reference if one is specified. ===
# TODO: Handle precision
reference_fn = module_input.reference_fn
if reference_fn is not None:
ref_outputs = reference_fn(m, *args, **kwargs)
self.assertEqual(outputs, ref_outputs)
# Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
# They should be applied to any created parameters and buffers.
@modules(module_db)
def test_factory_kwargs(self, device, dtype, module_info):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(
module_info, device=device, dtype=dtype, requires_grad=False
)
for module_input in module_inputs:
args, kwargs = (
module_input.constructor_input.args,
module_input.constructor_input.kwargs,
)
# Check if this module creates parameters or registers buffers.
# The mock magic here passes through to the real Parameter / register_buffer
# logic and is only used to check call inputs.
module_creates_params_or_buffers = False
parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
with patch.object(torch.nn.Parameter, "__new__", parameter_new):
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
with patch.object(torch.nn.Module, "register_buffer", register_buffer):
m = module_cls(*args, **kwargs)
# Check if a parameter or buffer was created with a tensor not passed to the constructor.
constructor_tensors = get_tensors_from(args, kwargs)
for mock in [parameter_new.mock, register_buffer.mock]:
for call_args, call_kwargs in mock.call_args_list:
call_tensors = get_tensors_from(call_args, call_kwargs)
if len(
call_tensors
) > 0 and not constructor_tensors.intersection(
call_tensors
):
module_creates_params_or_buffers = True
break
if not module_creates_params_or_buffers:
continue
# Instantiate module with the factory kwargs.
kwargs.update(
{
"device": device,
"dtype": dtype,
}
)
if issubclass(
module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin
):
# Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
with patch.object(
torch.nn.UninitializedParameter, "__new__", uninit_param_new
):
uninit_buffer_new = mock_wrapper(
torch.nn.UninitializedBuffer.__new__
)
with patch.object(
torch.nn.UninitializedBuffer, "__new__", uninit_buffer_new
):
m = module_cls(*args, **kwargs)
uninit_param_new.mock.assert_has_calls(
[
mock.call(device=device, dtype=dtype)
for _ in uninit_param_new.mock.mock_calls
]
)
uninit_buffer_new.mock.assert_has_calls(
[
mock.call(device=device, dtype=dtype)
for _ in uninit_buffer_new.mock.mock_calls
]
)
else:
# Check device placement and dtype for created parameters and buffers.
# Only verify floating point dtypes since that's what the kwarg applies to.
m = module_cls(*args, **kwargs)
for name, param in m.named_parameters():
self.assertEqual(
str(param.device),
device,
f"Parameter {name} is on {param.device.type} instead of the expected device {device}",
)
if param.dtype.is_floating_point:
self.assertEqual(
param.dtype,
dtype,
f"Parameter {name} is of dtype {param.dtype} instead of the expected dtype {dtype}",
)
for name, buffer in m.named_buffers():
self.assertEqual(
str(buffer.device),
device,
f"Buffer {name} is on {buffer.device.type} instead of the expected device {device}",
)
if buffer.dtype.is_floating_point:
self.assertEqual(
buffer.dtype,
dtype,
f"Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}",
)
instantiate_device_type_tests(TestModule, globals())
if __name__ == "__main__":
run_tests()