Fix persistent buffer bug (#162190)

For non-persistent buffers, we should properly register them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162190
Approved by: https://github.com/zhxchen17
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-09-10 20:21:06 -07:00
committed by PyTorch MergeBot
parent c3f30eca9e
commit c924c675d0
3 changed files with 1 additions and 9 deletions

View File

@ -12073,8 +12073,6 @@ graph():
test(export(M(), inp))
# Preserving signature hook is messing with dynamo tracing
@testing.expectedFailureStrictV2
def test_unflatten_multiple_graphs_state(self):
class N(torch.nn.Module):
def __init__(self):

View File

@ -61,7 +61,7 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
# Move the parameter to the new name
if hasattr(graph_module, old_target):
param = torch.fx.graph_module._get_attr(graph_module, old_target)
torch.fx.graph_module._set_attr(graph_module, new_target, param)
torch.fx.graph_module._assign_attr(param, graph_module, new_target)
torch.fx.graph_module._del_attr(graph_module, old_target)

View File

@ -302,12 +302,6 @@ def _has_attr(model: torch.nn.Module, attr_name: str):
return hasattr(t, field)
def _set_attr(model: torch.nn.Module, attr_name: str, value):
attr_names = attr_name.split(".")
t = _get_attr_via_attr_list(model, attr_names[:-1])
setattr(t, attr_names[-1], value)
def _print_readable(
module,
module_name,