mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
serde unbacked bindings (#144894)
Adds unbacked bindings during deserialization. These are carried by a node's metadata, and map pending fresh unbacked symbols to paths to such symbols inside the corresponding example value carried by the node's metadata. Since it is awkward to serialize paths, we only serialize the names of these symbols and reconstruct the paths on deserialization, using a shape env util. We also need to bump counters for unbacked symbols here, because the shape env util we use to create these symbols (when deserializing example values) don't do so, and not doing so makes later passes (like `run_decompositions`) crash because new unbacked symbols don't get new names. This is enough for non-strict. For strict, the unbacked bindings and example values in node metadata can get out of sync, because of running AOTAutograd as an additional step after Dynamo. So we have to sync those back. Differential Revision: [D68232274](https://our.internmc.facebook.com/intern/diff/D68232274/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144894 Approved by: https://github.com/pianpwk
This commit is contained in:
committed by
PyTorch MergeBot
parent
5725462cd8
commit
42b8e233d9
@ -3710,9 +3710,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
dynamic_shapes = ({"k": {"k2": [(dim,)], "k1": [(dim,)]}},) # ok
|
||||
export(N(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
@testing.expectedFailureSerDer # no unbacked bindings after deserialization?
|
||||
@testing.expectedFailureCppSerDes # no unbacked bindings after deserialization?
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
def test_unbacked_bindings_for_divisible_u_symint(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
|
@ -42,6 +42,7 @@ from torch.fx.experimental import symbolic_shapes
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._pytree import treespec_dumps, treespec_loads
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from ..utils import remove_proxy_from_state_dict
|
||||
@ -599,6 +600,13 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
|
||||
def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]:
|
||||
ret = {}
|
||||
if unbacked_bindings := node.meta.get("unbacked_bindings"):
|
||||
# serialize the symbol names of unbacked bindings;
|
||||
# reconstruct the key paths to those symbols when deserializing
|
||||
ret["unbacked_bindings"] = ",".join(
|
||||
u.name for u in unbacked_bindings.keys()
|
||||
)
|
||||
|
||||
if stack_trace := node.meta.get("stack_trace"):
|
||||
ret["stack_trace"] = stack_trace
|
||||
|
||||
@ -1878,6 +1886,24 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
fx_node.kwargs,
|
||||
fx_node.meta.get("val"),
|
||||
)
|
||||
if "unbacked_bindings" in serialized_node.metadata:
|
||||
for u_name in serialized_node.metadata["unbacked_bindings"].split(","):
|
||||
u = self.symbol_name_to_symbol[u_name]
|
||||
# these are pending fresh unbacked symbols, so update shape env
|
||||
if symbol_is_type(u, SymT.UNBACKED_FLOAT):
|
||||
suffix = str(next(self.shape_env.unbacked_symfloat_counter))
|
||||
assert u.name.endswith(suffix)
|
||||
elif symbol_is_type(u, SymT.UNBACKED_INT):
|
||||
suffix = str(next(self.shape_env.unbacked_symint_counter))
|
||||
assert u.name.endswith(suffix)
|
||||
else:
|
||||
raise AssertionError(f"Illegal unbacked symbol {u}")
|
||||
self.shape_env.pending_fresh_unbacked_symbols.append(u)
|
||||
# consume pending fresh unbacked symbols and reconstruct key paths to them
|
||||
unbacked_bindings = symbolic_shapes.compute_unbacked_bindings(
|
||||
self.shape_env, fx_node.meta["val"]
|
||||
)
|
||||
fx_node.meta["unbacked_bindings"] = unbacked_bindings
|
||||
if fx_node.op not in ["placeholder", "output"] and "nn_module_stack" not in fx_node.meta:
|
||||
fx_node.meta["nn_module_stack"] = {} # serialization throws away empty dicts
|
||||
|
||||
|
@ -1395,6 +1395,22 @@ def _strict_export_lower_to_aten_ir(
|
||||
export_graph_signature = aten_export_artifact.sig
|
||||
constants = aten_export_artifact.constants
|
||||
|
||||
# update unbacked bindings that might have gone out of sync
|
||||
# between Dynamo and AOTAutograd
|
||||
for node in gm.graph.nodes:
|
||||
if "unbacked_bindings" in node.meta:
|
||||
old_unbacked_bindings = node.meta["unbacked_bindings"]
|
||||
val = node.meta["val"]
|
||||
new_unbacked_bindings = {}
|
||||
for key in old_unbacked_bindings.values():
|
||||
expr = pytree.key_get(val, key).node.expr
|
||||
if expr.is_symbol:
|
||||
new_unbacked_bindings[expr] = key
|
||||
if new_unbacked_bindings:
|
||||
node.meta["unbacked_bindings"] = new_unbacked_bindings
|
||||
else:
|
||||
del node.meta["unbacked_bindings"]
|
||||
|
||||
_populate_param_buffer_metadata_to_new_gm(
|
||||
params_buffers_to_node_meta, gm, export_graph_signature
|
||||
)
|
||||
|
Reference in New Issue
Block a user