Revert "Rewrite _reparametrize_module to use contextmanager (#138203)"

This reverts commit 7bf3b7cdc5631f9991eebcdd8ec09095339a9973.

Reverted https://github.com/pytorch/pytorch/pull/138203 on behalf of https://github.com/guilhermeleobas due to breaking one of the benchmarks (moco) ([comment](https://github.com/pytorch/pytorch/pull/138203#issuecomment-2569634001))
This commit is contained in:
PyTorch MergeBot
2025-01-03 18:17:31 +00:00
parent 60fe8a65af
commit 2409b49a33

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Any, Dict, Optional, Set, Tuple, Union
from typing_extensions import deprecated
@ -95,70 +94,89 @@ def _untie_named_tensors_map(
return untied_parameters_and_buffers
@contextlib.contextmanager
def _reparametrize_module(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
tie_weights: bool = False,
strict: bool = False,
stack_weights: bool = False,
):
parameters_and_buffers = parameters_and_buffers
stack_weights = stack_weights
class _ReparametrizeModule:
def __init__(
self,
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
tie_weights: bool = False,
strict: bool = False,
stack_weights: bool = False,
):
self.parameters_and_buffers = parameters_and_buffers
self.stack_weights = stack_weights
if tie_weights:
untied_parameters_and_buffers = _untie_named_tensors_map(
module, parameters_and_buffers
)
else:
untied_parameters_and_buffers = parameters_and_buffers
accessor = NamedMemberAccessor(module)
if strict:
missing_keys, unexpected_keys = accessor.check_keys(
untied_parameters_and_buffers
)
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append(
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
if tie_weights:
self.untied_parameters_and_buffers = _untie_named_tensors_map(
module, parameters_and_buffers
)
if len(missing_keys) > 0:
error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in reparametrizing for {}:\n\t{}".format(
module._get_name(), "\n\t".join(error_msgs)
else:
self.untied_parameters_and_buffers = parameters_and_buffers
self.accessor = NamedMemberAccessor(module)
if strict:
missing_keys, unexpected_keys = self.accessor.check_keys(
self.untied_parameters_and_buffers
)
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append(
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
)
if len(missing_keys) > 0:
error_msgs.append(
f"Missing key(s): {', '.join(map(repr, missing_keys))}."
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in reparametrizing for {}:\n\t{}".format(
module._get_name(), "\n\t".join(error_msgs)
)
)
)
orig_parameters_and_buffers: Dict[str, Tensor] = {}
try:
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
untied_parameters_and_buffers, allow_missing=True
def __enter__(self):
self.orig_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
self.untied_parameters_and_buffers, allow_missing=True
)
yield
finally:
if stack_weights:
def __exit__(self, exception_type, exception_value, traceback):
if self.stack_weights:
# When stacking is enabled, we will restore the weights in LIFO order.
orig_parameters_and_buffers = dict(
reversed(orig_parameters_and_buffers.items())
self.orig_parameters_and_buffers = dict(
reversed(self.orig_parameters_and_buffers.items())
)
new_parameters_and_buffers, _ = accessor.swap_tensors_dict(
orig_parameters_and_buffers, allow_missing=True
new_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
self.orig_parameters_and_buffers, allow_missing=True
)
# Sometimes the module is not completely stateless and has some in-place modifications on
# the _parameters and _buffers dictionaries.
# Write the changed parameters and buffers back to the original dict.
parameters_and_buffers.update(
self.parameters_and_buffers.update(
{
k: new_parameters_and_buffers[k]
for k in parameters_and_buffers
for k in self.parameters_and_buffers
if k in new_parameters_and_buffers
}
)
def _reparametrize_module(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
*,
tie_weights: bool = False,
strict: bool = False,
stack_weights: bool = False,
) -> _ReparametrizeModule:
return _ReparametrizeModule(
module,
parameters_and_buffers,
tie_weights=tie_weights,
strict=strict,
stack_weights=stack_weights,
)
@deprecated(
"`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 "
"and will be removed in a future version of PyTorch. "