mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Better handling of restore_state_dict (#164401)
After lean export, we might want to be able to restore the original fqn. This PR refactors one util function in export that sort of does this. Note that strict_export has some complicated logic of updating the graph signature as well which we don't want. I think we can gradually make this util more refined by handling constants, non persistent buffers etc and change how strict_export does it today. Differential Revision: [D83687844](https://www.internalfb.com/diff/D83687844) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164401 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
47956196d9
commit
a57a14868d
@ -10907,17 +10907,16 @@ graph():
|
||||
test_inp = torch.randn(2, 3)
|
||||
|
||||
torch_gm = _export_to_torch_ir(orig_eager, (torch.rand(2, 3),), {})
|
||||
torch_gm.state_dict().keys()
|
||||
for k, v in orig_eager.state_dict().items():
|
||||
normalized_k = k.replace(".", "_")
|
||||
self.assertIn(normalized_k, torch_gm.state_dict())
|
||||
self.assertEqual(v, torch_gm.state_dict()[normalized_k])
|
||||
self.assertIn(k, torch_gm.state_dict())
|
||||
self.assertEqual(v, torch_gm.state_dict()[k])
|
||||
self.assertTrue(torch.allclose(torch_gm(test_inp), orig_eager(test_inp)))
|
||||
|
||||
pre_autograd_gm = torch.export._trace._export(
|
||||
orig_eager, (torch.rand(2, 3),), {}, pre_dispatch=True
|
||||
).module()
|
||||
for k, v in orig_eager.state_dict().items():
|
||||
normalized_k = k.replace(".", "_")
|
||||
self.assertIn(k, pre_autograd_gm.state_dict())
|
||||
self.assertEqual(v, pre_autograd_gm.state_dict()[k])
|
||||
self.assertTrue(torch.allclose(pre_autograd_gm(test_inp), orig_eager(test_inp)))
|
||||
@ -10929,6 +10928,7 @@ graph():
|
||||
self.assertIn(k, ep.state_dict)
|
||||
self.assertEqual(v, ep.state_dict[k])
|
||||
self.assertTrue(torch.allclose(ep.module()(test_inp), orig_eager(test_inp)))
|
||||
self.assertTrue(torch_gm.state_dict().keys(), orig_eager.state_dict().keys())
|
||||
|
||||
def test_nn_module_stack(self):
|
||||
class Leaf(torch.nn.Module):
|
||||
|
@ -10,6 +10,7 @@ import time
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from itertools import chain
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
|
||||
|
||||
|
||||
@ -693,24 +694,19 @@ def _restore_state_dict(
|
||||
Restores the state dict of the traced module to that of the original module.
|
||||
"""
|
||||
param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
|
||||
# Since the graph module is flattened (no module hierarchy), we
|
||||
# need to normalize the module by replacing "." with "_". If we
|
||||
# don't, it will try to save the weight to a submodule which no
|
||||
# longer exists.
|
||||
for name, fqn in param_buffer_table.items():
|
||||
param_buffer_table[name] = fqn.replace(".", "_")
|
||||
# Don't want to change the convention of previous call.
|
||||
param_buffer_table_reverse = {v: k for k, v in param_buffer_table.items()}
|
||||
|
||||
# Replace state dict attr names with the fqn
|
||||
for name, fqn in param_buffer_table.items():
|
||||
if not hasattr(traced_module, name):
|
||||
continue
|
||||
|
||||
attr = getattr(traced_module, name)
|
||||
if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter):
|
||||
traced_module.register_buffer(fqn, attr)
|
||||
else:
|
||||
setattr(traced_module, fqn, attr)
|
||||
delattr(traced_module, name)
|
||||
for name, _ in chain(
|
||||
original_module.named_parameters(remove_duplicate=False),
|
||||
original_module.named_buffers(remove_duplicate=False),
|
||||
):
|
||||
if name in param_buffer_table_reverse:
|
||||
dynamo_name = param_buffer_table_reverse[name]
|
||||
param = torch.fx.graph_module._get_attr(traced_module, dynamo_name)
|
||||
torch.fx.graph_module._assign_attr(param, traced_module, name)
|
||||
torch.fx.graph_module._del_attr(traced_module, dynamo_name)
|
||||
|
||||
# Replace graph getattr nodes with the correct name
|
||||
for node in traced_module.graph.nodes:
|
||||
|
Reference in New Issue
Block a user