mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix persistent buffer bug (#162190)
For non-persistent buffers, we should properly register them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162190 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
c3f30eca9e
commit
c924c675d0
@ -12073,8 +12073,6 @@ graph():
|
||||
|
||||
test(export(M(), inp))
|
||||
|
||||
# Preserving signature hook is messing with dynamo tracing
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_unflatten_multiple_graphs_state(self):
|
||||
class N(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -61,7 +61,7 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
|
||||
# Move the parameter to the new name
|
||||
if hasattr(graph_module, old_target):
|
||||
param = torch.fx.graph_module._get_attr(graph_module, old_target)
|
||||
torch.fx.graph_module._set_attr(graph_module, new_target, param)
|
||||
torch.fx.graph_module._assign_attr(param, graph_module, new_target)
|
||||
torch.fx.graph_module._del_attr(graph_module, old_target)
|
||||
|
||||
|
||||
|
@ -302,12 +302,6 @@ def _has_attr(model: torch.nn.Module, attr_name: str):
|
||||
return hasattr(t, field)
|
||||
|
||||
|
||||
def _set_attr(model: torch.nn.Module, attr_name: str, value):
|
||||
attr_names = attr_name.split(".")
|
||||
t = _get_attr_via_attr_list(model, attr_names[:-1])
|
||||
setattr(t, attr_names[-1], value)
|
||||
|
||||
|
||||
def _print_readable(
|
||||
module,
|
||||
module_name,
|
||||
|
Reference in New Issue
Block a user