Better handling of restore_state_dict (#164401)

After lean export, we might want to be able to restore the original fqn. This PR refactors one util function in export that sort of does this. Note that strict_export has some complicated logic of updating the graph signature as well which we don't want. I think we can gradually make this util more refined by handling constants, non persistent buffers etc and change how strict_export does it today.

Differential Revision: [D83687844](https://www.internalfb.com/diff/D83687844)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164401
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-09 11:06:03 -07:00
committed by PyTorch MergeBot
parent 47956196d9
commit a57a14868d
2 changed files with 16 additions and 20 deletions

View File

@ -10907,17 +10907,16 @@ graph():
test_inp = torch.randn(2, 3) test_inp = torch.randn(2, 3)
torch_gm = _export_to_torch_ir(orig_eager, (torch.rand(2, 3),), {}) torch_gm = _export_to_torch_ir(orig_eager, (torch.rand(2, 3),), {})
torch_gm.state_dict().keys()
for k, v in orig_eager.state_dict().items(): for k, v in orig_eager.state_dict().items():
normalized_k = k.replace(".", "_") self.assertIn(k, torch_gm.state_dict())
self.assertIn(normalized_k, torch_gm.state_dict()) self.assertEqual(v, torch_gm.state_dict()[k])
self.assertEqual(v, torch_gm.state_dict()[normalized_k])
self.assertTrue(torch.allclose(torch_gm(test_inp), orig_eager(test_inp))) self.assertTrue(torch.allclose(torch_gm(test_inp), orig_eager(test_inp)))
pre_autograd_gm = torch.export._trace._export( pre_autograd_gm = torch.export._trace._export(
orig_eager, (torch.rand(2, 3),), {}, pre_dispatch=True orig_eager, (torch.rand(2, 3),), {}, pre_dispatch=True
).module() ).module()
for k, v in orig_eager.state_dict().items(): for k, v in orig_eager.state_dict().items():
normalized_k = k.replace(".", "_")
self.assertIn(k, pre_autograd_gm.state_dict()) self.assertIn(k, pre_autograd_gm.state_dict())
self.assertEqual(v, pre_autograd_gm.state_dict()[k]) self.assertEqual(v, pre_autograd_gm.state_dict()[k])
self.assertTrue(torch.allclose(pre_autograd_gm(test_inp), orig_eager(test_inp))) self.assertTrue(torch.allclose(pre_autograd_gm(test_inp), orig_eager(test_inp)))
@ -10929,6 +10928,7 @@ graph():
self.assertIn(k, ep.state_dict) self.assertIn(k, ep.state_dict)
self.assertEqual(v, ep.state_dict[k]) self.assertEqual(v, ep.state_dict[k])
self.assertTrue(torch.allclose(ep.module()(test_inp), orig_eager(test_inp))) self.assertTrue(torch.allclose(ep.module()(test_inp), orig_eager(test_inp)))
self.assertTrue(torch_gm.state_dict().keys(), orig_eager.state_dict().keys())
def test_nn_module_stack(self): def test_nn_module_stack(self):
class Leaf(torch.nn.Module): class Leaf(torch.nn.Module):

View File

@ -10,6 +10,7 @@ import time
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from itertools import chain
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
@ -693,24 +694,19 @@ def _restore_state_dict(
Restores the state dict of the traced module to that of the original module. Restores the state dict of the traced module to that of the original module.
""" """
param_buffer_table = _get_param_buffer_mapping(original_module, traced_module) param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
# Since the graph module is flattened (no module hierarchy), we # Don't want to change the convention of previous call.
# need to normalize the module by replacing "." with "_". If we param_buffer_table_reverse = {v: k for k, v in param_buffer_table.items()}
# don't, it will try to save the weight to a submodule which no
# longer exists.
for name, fqn in param_buffer_table.items():
param_buffer_table[name] = fqn.replace(".", "_")
# Replace state dict attr names with the fqn # Replace state dict attr names with the fqn
for name, fqn in param_buffer_table.items(): for name, _ in chain(
if not hasattr(traced_module, name): original_module.named_parameters(remove_duplicate=False),
continue original_module.named_buffers(remove_duplicate=False),
):
attr = getattr(traced_module, name) if name in param_buffer_table_reverse:
if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter): dynamo_name = param_buffer_table_reverse[name]
traced_module.register_buffer(fqn, attr) param = torch.fx.graph_module._get_attr(traced_module, dynamo_name)
else: torch.fx.graph_module._assign_attr(param, traced_module, name)
setattr(traced_module, fqn, attr) torch.fx.graph_module._del_attr(traced_module, dynamo_name)
delattr(traced_module, name)
# Replace graph getattr nodes with the correct name # Replace graph getattr nodes with the correct name
for node in traced_module.graph.nodes: for node in traced_module.graph.nodes: