mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Rewrite _reparametrize_module to use contextmanager
(#138203)"
This reverts commit 7bf3b7cdc5631f9991eebcdd8ec09095339a9973. Reverted https://github.com/pytorch/pytorch/pull/138203 on behalf of https://github.com/guilhermeleobas due to breaking one of the benchmarks (moco) ([comment](https://github.com/pytorch/pytorch/pull/138203#issuecomment-2569634001))
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from typing import Any, Dict, Optional, Set, Tuple, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@ -95,70 +94,89 @@ def _untie_named_tensors_map(
|
||||
return untied_parameters_and_buffers
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _reparametrize_module(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
tie_weights: bool = False,
|
||||
strict: bool = False,
|
||||
stack_weights: bool = False,
|
||||
):
|
||||
parameters_and_buffers = parameters_and_buffers
|
||||
stack_weights = stack_weights
|
||||
class _ReparametrizeModule:
|
||||
def __init__(
|
||||
self,
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
tie_weights: bool = False,
|
||||
strict: bool = False,
|
||||
stack_weights: bool = False,
|
||||
):
|
||||
self.parameters_and_buffers = parameters_and_buffers
|
||||
self.stack_weights = stack_weights
|
||||
|
||||
if tie_weights:
|
||||
untied_parameters_and_buffers = _untie_named_tensors_map(
|
||||
module, parameters_and_buffers
|
||||
)
|
||||
else:
|
||||
untied_parameters_and_buffers = parameters_and_buffers
|
||||
|
||||
accessor = NamedMemberAccessor(module)
|
||||
if strict:
|
||||
missing_keys, unexpected_keys = accessor.check_keys(
|
||||
untied_parameters_and_buffers
|
||||
)
|
||||
error_msgs = []
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.append(
|
||||
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
|
||||
if tie_weights:
|
||||
self.untied_parameters_and_buffers = _untie_named_tensors_map(
|
||||
module, parameters_and_buffers
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
"Error(s) in reparametrizing for {}:\n\t{}".format(
|
||||
module._get_name(), "\n\t".join(error_msgs)
|
||||
else:
|
||||
self.untied_parameters_and_buffers = parameters_and_buffers
|
||||
|
||||
self.accessor = NamedMemberAccessor(module)
|
||||
if strict:
|
||||
missing_keys, unexpected_keys = self.accessor.check_keys(
|
||||
self.untied_parameters_and_buffers
|
||||
)
|
||||
error_msgs = []
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.append(
|
||||
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.append(
|
||||
f"Missing key(s): {', '.join(map(repr, missing_keys))}."
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
"Error(s) in reparametrizing for {}:\n\t{}".format(
|
||||
module._get_name(), "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
orig_parameters_and_buffers: Dict[str, Tensor] = {}
|
||||
try:
|
||||
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
|
||||
untied_parameters_and_buffers, allow_missing=True
|
||||
def __enter__(self):
|
||||
self.orig_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
|
||||
self.untied_parameters_and_buffers, allow_missing=True
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
if stack_weights:
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
if self.stack_weights:
|
||||
# When stacking is enabled, we will restore the weights in LIFO order.
|
||||
orig_parameters_and_buffers = dict(
|
||||
reversed(orig_parameters_and_buffers.items())
|
||||
self.orig_parameters_and_buffers = dict(
|
||||
reversed(self.orig_parameters_and_buffers.items())
|
||||
)
|
||||
new_parameters_and_buffers, _ = accessor.swap_tensors_dict(
|
||||
orig_parameters_and_buffers, allow_missing=True
|
||||
new_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
|
||||
self.orig_parameters_and_buffers, allow_missing=True
|
||||
)
|
||||
# Sometimes the module is not completely stateless and has some in-place modifications on
|
||||
# the _parameters and _buffers dictionaries.
|
||||
# Write the changed parameters and buffers back to the original dict.
|
||||
parameters_and_buffers.update(
|
||||
self.parameters_and_buffers.update(
|
||||
{
|
||||
k: new_parameters_and_buffers[k]
|
||||
for k in parameters_and_buffers
|
||||
for k in self.parameters_and_buffers
|
||||
if k in new_parameters_and_buffers
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _reparametrize_module(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
*,
|
||||
tie_weights: bool = False,
|
||||
strict: bool = False,
|
||||
stack_weights: bool = False,
|
||||
) -> _ReparametrizeModule:
|
||||
return _ReparametrizeModule(
|
||||
module,
|
||||
parameters_and_buffers,
|
||||
tie_weights=tie_weights,
|
||||
strict=strict,
|
||||
stack_weights=stack_weights,
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 "
|
||||
"and will be removed in a future version of PyTorch. "
|
||||
|
Reference in New Issue
Block a user