serde unbacked bindings (#144894)

Adds unbacked bindings during deserialization. These are carried by a node's metadata, and map pending fresh unbacked symbols to paths to such symbols inside the corresponding example value carried by the node's metadata.

Since it is awkward to serialize paths, we only serialize the names of these symbols and reconstruct the paths on deserialization, using a shape env util. We also need to bump counters for unbacked symbols here, because the shape env util we use to create these symbols (when deserializing example values) don't do so, and not doing so makes later passes (like `run_decompositions`) crash because new unbacked symbols don't get new names.

This is enough for non-strict. For strict, the unbacked bindings and example values in node metadata can get out of sync, because of running AOTAutograd as an additional step after Dynamo. So we have to sync those back.

Differential Revision: [D68232274](https://our.internmc.facebook.com/intern/diff/D68232274/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144894
Approved by: https://github.com/pianpwk
This commit is contained in:
Avik Chaudhuri
2025-01-24 08:35:42 -08:00
committed by PyTorch MergeBot
parent 5725462cd8
commit 42b8e233d9
3 changed files with 42 additions and 3 deletions

View File

@ -3710,9 +3710,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
dynamic_shapes = ({"k": {"k2": [(dim,)], "k1": [(dim,)]}},) # ok
export(N(), inputs, dynamic_shapes=dynamic_shapes)
@testing.expectedFailureSerDer # no unbacked bindings after deserialization?
@testing.expectedFailureCppSerDes # no unbacked bindings after deserialization?
@testing.expectedFailureSerDerNonStrict
def test_unbacked_bindings_for_divisible_u_symint(self):
class M(torch.nn.Module):
def forward(self, a, b):

View File

@ -42,6 +42,7 @@ from torch.fx.experimental import symbolic_shapes
from torch.utils import _pytree as pytree
from torch.utils._pytree import treespec_dumps, treespec_loads
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
from ..utils import remove_proxy_from_state_dict
@ -599,6 +600,13 @@ class GraphModuleSerializer(metaclass=Final):
def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]:
ret = {}
if unbacked_bindings := node.meta.get("unbacked_bindings"):
# serialize the symbol names of unbacked bindings;
# reconstruct the key paths to those symbols when deserializing
ret["unbacked_bindings"] = ",".join(
u.name for u in unbacked_bindings.keys()
)
if stack_trace := node.meta.get("stack_trace"):
ret["stack_trace"] = stack_trace
@ -1878,6 +1886,24 @@ class GraphModuleDeserializer(metaclass=Final):
fx_node.kwargs,
fx_node.meta.get("val"),
)
if "unbacked_bindings" in serialized_node.metadata:
for u_name in serialized_node.metadata["unbacked_bindings"].split(","):
u = self.symbol_name_to_symbol[u_name]
# these are pending fresh unbacked symbols, so update shape env
if symbol_is_type(u, SymT.UNBACKED_FLOAT):
suffix = str(next(self.shape_env.unbacked_symfloat_counter))
assert u.name.endswith(suffix)
elif symbol_is_type(u, SymT.UNBACKED_INT):
suffix = str(next(self.shape_env.unbacked_symint_counter))
assert u.name.endswith(suffix)
else:
raise AssertionError(f"Illegal unbacked symbol {u}")
self.shape_env.pending_fresh_unbacked_symbols.append(u)
# consume pending fresh unbacked symbols and reconstruct key paths to them
unbacked_bindings = symbolic_shapes.compute_unbacked_bindings(
self.shape_env, fx_node.meta["val"]
)
fx_node.meta["unbacked_bindings"] = unbacked_bindings
if fx_node.op not in ["placeholder", "output"] and "nn_module_stack" not in fx_node.meta:
fx_node.meta["nn_module_stack"] = {} # serialization throws away empty dicts

View File

@ -1395,6 +1395,22 @@ def _strict_export_lower_to_aten_ir(
export_graph_signature = aten_export_artifact.sig
constants = aten_export_artifact.constants
# update unbacked bindings that might have gone out of sync
# between Dynamo and AOTAutograd
for node in gm.graph.nodes:
if "unbacked_bindings" in node.meta:
old_unbacked_bindings = node.meta["unbacked_bindings"]
val = node.meta["val"]
new_unbacked_bindings = {}
for key in old_unbacked_bindings.values():
expr = pytree.key_get(val, key).node.expr
if expr.is_symbol:
new_unbacked_bindings[expr] = key
if new_unbacked_bindings:
node.meta["unbacked_bindings"] = new_unbacked_bindings
else:
del node.meta["unbacked_bindings"]
_populate_param_buffer_metadata_to_new_gm(
params_buffers_to_node_meta, gm, export_graph_signature
)