mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Better handling of restore_state_dict (#164401)
After lean export, we might want to be able to restore the original fqn. This PR refactors one util function in export that sort of does this. Note that strict_export has some complicated logic of updating the graph signature as well which we don't want. I think we can gradually make this util more refined by handling constants, non persistent buffers etc and change how strict_export does it today. Differential Revision: [D83687844](https://www.internalfb.com/diff/D83687844) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164401 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
47956196d9
commit
a57a14868d
@ -10907,17 +10907,16 @@ graph():
|
|||||||
test_inp = torch.randn(2, 3)
|
test_inp = torch.randn(2, 3)
|
||||||
|
|
||||||
torch_gm = _export_to_torch_ir(orig_eager, (torch.rand(2, 3),), {})
|
torch_gm = _export_to_torch_ir(orig_eager, (torch.rand(2, 3),), {})
|
||||||
|
torch_gm.state_dict().keys()
|
||||||
for k, v in orig_eager.state_dict().items():
|
for k, v in orig_eager.state_dict().items():
|
||||||
normalized_k = k.replace(".", "_")
|
self.assertIn(k, torch_gm.state_dict())
|
||||||
self.assertIn(normalized_k, torch_gm.state_dict())
|
self.assertEqual(v, torch_gm.state_dict()[k])
|
||||||
self.assertEqual(v, torch_gm.state_dict()[normalized_k])
|
|
||||||
self.assertTrue(torch.allclose(torch_gm(test_inp), orig_eager(test_inp)))
|
self.assertTrue(torch.allclose(torch_gm(test_inp), orig_eager(test_inp)))
|
||||||
|
|
||||||
pre_autograd_gm = torch.export._trace._export(
|
pre_autograd_gm = torch.export._trace._export(
|
||||||
orig_eager, (torch.rand(2, 3),), {}, pre_dispatch=True
|
orig_eager, (torch.rand(2, 3),), {}, pre_dispatch=True
|
||||||
).module()
|
).module()
|
||||||
for k, v in orig_eager.state_dict().items():
|
for k, v in orig_eager.state_dict().items():
|
||||||
normalized_k = k.replace(".", "_")
|
|
||||||
self.assertIn(k, pre_autograd_gm.state_dict())
|
self.assertIn(k, pre_autograd_gm.state_dict())
|
||||||
self.assertEqual(v, pre_autograd_gm.state_dict()[k])
|
self.assertEqual(v, pre_autograd_gm.state_dict()[k])
|
||||||
self.assertTrue(torch.allclose(pre_autograd_gm(test_inp), orig_eager(test_inp)))
|
self.assertTrue(torch.allclose(pre_autograd_gm(test_inp), orig_eager(test_inp)))
|
||||||
@ -10929,6 +10928,7 @@ graph():
|
|||||||
self.assertIn(k, ep.state_dict)
|
self.assertIn(k, ep.state_dict)
|
||||||
self.assertEqual(v, ep.state_dict[k])
|
self.assertEqual(v, ep.state_dict[k])
|
||||||
self.assertTrue(torch.allclose(ep.module()(test_inp), orig_eager(test_inp)))
|
self.assertTrue(torch.allclose(ep.module()(test_inp), orig_eager(test_inp)))
|
||||||
|
self.assertTrue(torch_gm.state_dict().keys(), orig_eager.state_dict().keys())
|
||||||
|
|
||||||
def test_nn_module_stack(self):
|
def test_nn_module_stack(self):
|
||||||
class Leaf(torch.nn.Module):
|
class Leaf(torch.nn.Module):
|
||||||
|
@ -10,6 +10,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
|
from itertools import chain
|
||||||
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
|
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
|
||||||
|
|
||||||
|
|
||||||
@ -693,24 +694,19 @@ def _restore_state_dict(
|
|||||||
Restores the state dict of the traced module to that of the original module.
|
Restores the state dict of the traced module to that of the original module.
|
||||||
"""
|
"""
|
||||||
param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
|
param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
|
||||||
# Since the graph module is flattened (no module hierarchy), we
|
# Don't want to change the convention of previous call.
|
||||||
# need to normalize the module by replacing "." with "_". If we
|
param_buffer_table_reverse = {v: k for k, v in param_buffer_table.items()}
|
||||||
# don't, it will try to save the weight to a submodule which no
|
|
||||||
# longer exists.
|
|
||||||
for name, fqn in param_buffer_table.items():
|
|
||||||
param_buffer_table[name] = fqn.replace(".", "_")
|
|
||||||
|
|
||||||
# Replace state dict attr names with the fqn
|
# Replace state dict attr names with the fqn
|
||||||
for name, fqn in param_buffer_table.items():
|
for name, _ in chain(
|
||||||
if not hasattr(traced_module, name):
|
original_module.named_parameters(remove_duplicate=False),
|
||||||
continue
|
original_module.named_buffers(remove_duplicate=False),
|
||||||
|
):
|
||||||
attr = getattr(traced_module, name)
|
if name in param_buffer_table_reverse:
|
||||||
if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter):
|
dynamo_name = param_buffer_table_reverse[name]
|
||||||
traced_module.register_buffer(fqn, attr)
|
param = torch.fx.graph_module._get_attr(traced_module, dynamo_name)
|
||||||
else:
|
torch.fx.graph_module._assign_attr(param, traced_module, name)
|
||||||
setattr(traced_module, fqn, attr)
|
torch.fx.graph_module._del_attr(traced_module, dynamo_name)
|
||||||
delattr(traced_module, name)
|
|
||||||
|
|
||||||
# Replace graph getattr nodes with the correct name
|
# Replace graph getattr nodes with the correct name
|
||||||
for node in traced_module.graph.nodes:
|
for node in traced_module.graph.nodes:
|
||||||
|
Reference in New Issue
Block a user