Better handling of restore_state_dict (#164401)

After lean export, we might want to be able to restore the original fqn. This PR refactors one util function in export that sort of does this. Note that strict_export has some complicated logic of updating the graph signature as well which we don't want. I think we can gradually make this util more refined by handling constants, non persistent buffers etc and change how strict_export does it today.

Differential Revision: [D83687844](https://www.internalfb.com/diff/D83687844)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164401
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-09 11:06:03 -07:00
committed by PyTorch MergeBot
parent 47956196d9
commit a57a14868d
2 changed files with 16 additions and 20 deletions

View File

@ -10907,17 +10907,16 @@ graph():
test_inp = torch.randn(2, 3)
torch_gm = _export_to_torch_ir(orig_eager, (torch.rand(2, 3),), {})
torch_gm.state_dict().keys()
for k, v in orig_eager.state_dict().items():
normalized_k = k.replace(".", "_")
self.assertIn(normalized_k, torch_gm.state_dict())
self.assertEqual(v, torch_gm.state_dict()[normalized_k])
self.assertIn(k, torch_gm.state_dict())
self.assertEqual(v, torch_gm.state_dict()[k])
self.assertTrue(torch.allclose(torch_gm(test_inp), orig_eager(test_inp)))
pre_autograd_gm = torch.export._trace._export(
orig_eager, (torch.rand(2, 3),), {}, pre_dispatch=True
).module()
for k, v in orig_eager.state_dict().items():
normalized_k = k.replace(".", "_")
self.assertIn(k, pre_autograd_gm.state_dict())
self.assertEqual(v, pre_autograd_gm.state_dict()[k])
self.assertTrue(torch.allclose(pre_autograd_gm(test_inp), orig_eager(test_inp)))
@ -10929,6 +10928,7 @@ graph():
self.assertIn(k, ep.state_dict)
self.assertEqual(v, ep.state_dict[k])
self.assertTrue(torch.allclose(ep.module()(test_inp), orig_eager(test_inp)))
self.assertTrue(torch_gm.state_dict().keys(), orig_eager.state_dict().keys())
def test_nn_module_stack(self):
class Leaf(torch.nn.Module):

View File

@ -10,6 +10,7 @@ import time
import warnings
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
from itertools import chain
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
@ -693,24 +694,19 @@ def _restore_state_dict(
Restores the state dict of the traced module to that of the original module.
"""
param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
# Since the graph module is flattened (no module hierarchy), we
# need to normalize the module by replacing "." with "_". If we
# don't, it will try to save the weight to a submodule which no
# longer exists.
for name, fqn in param_buffer_table.items():
param_buffer_table[name] = fqn.replace(".", "_")
# Don't want to change the convention of previous call.
param_buffer_table_reverse = {v: k for k, v in param_buffer_table.items()}
# Replace state dict attr names with the fqn
for name, fqn in param_buffer_table.items():
if not hasattr(traced_module, name):
continue
attr = getattr(traced_module, name)
if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter):
traced_module.register_buffer(fqn, attr)
else:
setattr(traced_module, fqn, attr)
delattr(traced_module, name)
for name, _ in chain(
original_module.named_parameters(remove_duplicate=False),
original_module.named_buffers(remove_duplicate=False),
):
if name in param_buffer_table_reverse:
dynamo_name = param_buffer_table_reverse[name]
param = torch.fx.graph_module._get_attr(traced_module, dynamo_name)
torch.fx.graph_module._assign_attr(param, traced_module, name)
torch.fx.graph_module._del_attr(traced_module, dynamo_name)
# Replace graph getattr nodes with the correct name
for node in traced_module.graph.nodes: