mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add deprecation warning to nn stateless functional_call (#87367)
Same as the release version but just for master Pull Request resolved: https://github.com/pytorch/pytorch/pull/87367 Approved by: https://github.com/albanD, https://github.com/atalman
This commit is contained in:
@ -176,6 +176,37 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
self.assertEqual(orig_sn_weight, module.l1.weight)
|
||||
|
||||
|
||||
def test_tied_weights_warns(self):
|
||||
module = MockModule()
|
||||
module.tied_bias = module.l1.bias
|
||||
module.register_buffer("tied_buffer", module.buffer)
|
||||
weight = torch.tensor([[1.0]],)
|
||||
bias = torch.tensor([0.0])
|
||||
buffer = torch.tensor([0.0])
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x))
|
||||
|
||||
# if tied values are the same tensors, shouldn't warn
|
||||
parameters['tied_bias'] = bias
|
||||
parameters['tied_buffer'] = buffer
|
||||
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x))
|
||||
del parameters['tied_bias']
|
||||
del parameters['tied_buffer']
|
||||
|
||||
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
|
||||
parameters['tied_bias'] = torch.tensor([5.0])
|
||||
stateless.functional_call(module, parameters, x)
|
||||
del parameters['tied_bias']
|
||||
|
||||
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
|
||||
parameters['tied_buffer'] = torch.tensor([5.0])
|
||||
stateless.functional_call(module, parameters, x)
|
||||
|
||||
|
||||
def test_setattr(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, Iterator, List, Tuple
|
||||
|
||||
@ -35,10 +36,22 @@ def _change_class(module, params_and_buffers) -> None:
|
||||
module.__class__ = param_cls
|
||||
module._orig_class = cls
|
||||
|
||||
def _create_swap_params(params_and_buffers):
|
||||
|
||||
def _check_tied_val_already_replaced(old_val, new_val, replaced_tensors_map):
|
||||
if old_val not in replaced_tensors_map:
|
||||
replaced_tensors_map[old_val] = new_val
|
||||
elif replaced_tensors_map[old_val] is not new_val:
|
||||
warnings.warn("functional_call was passed multiple values for tied weights. "
|
||||
"This behavior is deprecated and will be an error in future versions")
|
||||
|
||||
|
||||
def _create_swap_params(params_and_buffers, replaced_tensors_map):
|
||||
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
|
||||
# Changes the module class to get a new __getattr__ dunder method
|
||||
# that looks for the reparametrized tensor
|
||||
if hasattr(module, tensor_name):
|
||||
old_val = getattr(module, tensor_name)
|
||||
_check_tied_val_already_replaced(old_val, tensor, replaced_tensors_map)
|
||||
if hasattr(module, "_attr_to_path"):
|
||||
module._attr_to_path[tensor_name] = full_path
|
||||
else:
|
||||
@ -60,9 +73,10 @@ def _reparametrize_module(
|
||||
module: 'torch.nn.Module',
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
) -> Iterator[None]:
|
||||
orig_tensors_to_replacements: Dict[Tensor, Tensor] = {}
|
||||
for name, tensor in parameters_and_buffers.items():
|
||||
_apply_func_submodules(
|
||||
_create_swap_params(parameters_and_buffers),
|
||||
_create_swap_params(parameters_and_buffers, orig_tensors_to_replacements),
|
||||
module, name.split("."), name, (tensor,))
|
||||
try:
|
||||
yield
|
||||
|
Reference in New Issue
Block a user