mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix dynamic shapes repordering bug (#149528)
WHen we create constraints, we look at the ordering of kwargs according to model signature. But when we trace, we use the ordering that is created based on how user passes in their kwargs. As a result, constraints and dynamic shapes end up having a different order causing issues when they have different dynamic tensor specs. Differential Revision: [D71478578](https://our.internmc.facebook.com/intern/diff/D71478578) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149528 Approved by: https://github.com/ydwu4
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e30192b19
commit
3b7bd6c63d
@ -10717,6 +10717,30 @@ def forward(self, x):
|
||||
# check that graph input names are as expected
|
||||
self.assertEqual(ep.graph_signature.user_inputs, ("x1", False, "x2"))
|
||||
|
||||
def test_kwarg_dynamic_shapes_diff_order(self):
|
||||
class DummyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = torch.ones(4, 4)
|
||||
|
||||
def forward(self, baba, *, start, end):
|
||||
return baba.sum() + start.sum() + end.sum()
|
||||
|
||||
f = DummyModel()
|
||||
kwargs = {
|
||||
"end": torch.ones(4, 4, 4),
|
||||
"start": torch.ones(4, 4),
|
||||
}
|
||||
dynamic_shapes = {
|
||||
"baba": {0: torch.export.Dim("end_dim")},
|
||||
"end": {0: torch.export.Dim("end_dim")},
|
||||
"start": {0: torch.export.Dim("end_dim"), 1: torch.export.Dim("end_dim")},
|
||||
}
|
||||
ep = torch.export.export(
|
||||
f, (torch.ones(4, 4),), kwargs, dynamic_shapes=dynamic_shapes
|
||||
).run_decompositions()
|
||||
ep.module()(torch.ones(4, 4), **kwargs)
|
||||
|
||||
def test_placeholder_naming_order_variadic(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, a, b, c, **kwargs):
|
||||
|
@ -1194,7 +1194,12 @@ def _get_module_call_graph(
|
||||
|
||||
|
||||
def _get_range_constraints(
|
||||
export_artifact: ExportArtifact, combined_args: dict[str, Any], dynamic_shapes
|
||||
mod: torch.nn.Module,
|
||||
export_artifact: ExportArtifact,
|
||||
args,
|
||||
kwargs,
|
||||
dynamic_shapes,
|
||||
_is_torch_jit_trace=False,
|
||||
):
|
||||
gm: torch.fx.GraphModule = export_artifact.aten.gm
|
||||
export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
|
||||
@ -1207,6 +1212,25 @@ def _get_range_constraints(
|
||||
),
|
||||
len(export_graph_signature.input_specs),
|
||||
)
|
||||
combined_args = _combine_args(
|
||||
mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace
|
||||
)
|
||||
|
||||
# This is because we trace based on the kewargs passed in from user
|
||||
# not based on the signature. I feel it would be better to just enforce
|
||||
# one ordering at the start of tracing to avoid confusions, but that is
|
||||
# bigger refactor, so do this to unblock for now.
|
||||
if not _is_torch_jit_trace:
|
||||
combined_args_traced_order = {}
|
||||
for arg in combined_args:
|
||||
if arg not in kwargs:
|
||||
combined_args_traced_order[arg] = combined_args[arg]
|
||||
|
||||
for key in kwargs:
|
||||
combined_args_traced_order[key] = kwargs[key]
|
||||
|
||||
combined_args = combined_args_traced_order
|
||||
|
||||
range_constraints = make_constraints(
|
||||
fake_mode,
|
||||
gm,
|
||||
@ -1950,8 +1974,10 @@ def _export_for_training(
|
||||
# Note: _get_range_constraints depends on "inline_constraints" to be set.
|
||||
export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints
|
||||
range_constraints = _get_range_constraints(
|
||||
mod,
|
||||
export_artifact,
|
||||
_combine_args(mod, args, kwargs, _is_torch_jit_trace=False),
|
||||
args,
|
||||
kwargs,
|
||||
dynamic_shapes,
|
||||
)
|
||||
# The returned the gm is in-place modified
|
||||
@ -2114,9 +2140,12 @@ def _export(
|
||||
# Note: this step must be before _get_range_constraints.
|
||||
export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints
|
||||
range_constraints = _get_range_constraints(
|
||||
mod,
|
||||
export_artifact,
|
||||
_combine_args(mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace),
|
||||
args,
|
||||
kwargs,
|
||||
dynamic_shapes,
|
||||
_is_torch_jit_trace=_is_torch_jit_trace,
|
||||
)
|
||||
gm, module_call_graph = _get_module_call_graph(
|
||||
export_artifact,
|
||||
|
Reference in New Issue
Block a user