add deprecation warning to nn stateless functional_call (#87367)

Same as the release version but just for master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87367
Approved by: https://github.com/albanD, https://github.com/atalman
This commit is contained in:
samdow
2022-10-20 13:45:20 -04:00
committed by PyTorch MergeBot
parent 9b88dcf248
commit bc8cf33244
2 changed files with 47 additions and 2 deletions

View File

@ -176,6 +176,37 @@ class TestStatelessFunctionalAPI(TestCase):
self.assertEqual(orig_sn_weight, module.l1.weight)
def test_tied_weights_warns(self):
module = MockModule()
module.tied_bias = module.l1.bias
module.register_buffer("tied_buffer", module.buffer)
weight = torch.tensor([[1.0]],)
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x))
# if tied values are the same tensors, shouldn't warn
parameters['tied_bias'] = bias
parameters['tied_buffer'] = buffer
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x))
del parameters['tied_bias']
del parameters['tied_buffer']
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
parameters['tied_bias'] = torch.tensor([5.0])
stateless.functional_call(module, parameters, x)
del parameters['tied_bias']
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
parameters['tied_buffer'] = torch.tensor([5.0])
stateless.functional_call(module, parameters, x)
def test_setattr(self):
class Foo(torch.nn.Module):
def __init__(self):

View File

@ -1,3 +1,4 @@
import warnings
import contextlib
from typing import Any, Callable, Dict, Iterator, List, Tuple
@ -35,10 +36,22 @@ def _change_class(module, params_and_buffers) -> None:
module.__class__ = param_cls
module._orig_class = cls
def _create_swap_params(params_and_buffers):
def _check_tied_val_already_replaced(old_val, new_val, replaced_tensors_map):
if old_val not in replaced_tensors_map:
replaced_tensors_map[old_val] = new_val
elif replaced_tensors_map[old_val] is not new_val:
warnings.warn("functional_call was passed multiple values for tied weights. "
"This behavior is deprecated and will be an error in future versions")
def _create_swap_params(params_and_buffers, replaced_tensors_map):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, tensor_name):
old_val = getattr(module, tensor_name)
_check_tied_val_already_replaced(old_val, tensor, replaced_tensors_map)
if hasattr(module, "_attr_to_path"):
module._attr_to_path[tensor_name] = full_path
else:
@ -60,9 +73,10 @@ def _reparametrize_module(
module: 'torch.nn.Module',
parameters_and_buffers: Dict[str, Tensor],
) -> Iterator[None]:
orig_tensors_to_replacements: Dict[Tensor, Tensor] = {}
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers),
_create_swap_params(parameters_and_buffers, orig_tensors_to_replacements),
module, name.split("."), name, (tensor,))
try:
yield