[ghstack-poisoned]
This commit is contained in:
Animesh Jain
2025-11-09 20:33:49 -08:00
parent 877d854d8c
commit 96f19e8a9e

View File

@ -322,6 +322,7 @@ def _call_function_and_unflatten_output_wrap_semantics(
orig_vt.proxy = subgraph_vt.proxy
return body_r
def _call_function_and_unflatten_output(
tx, fn, args, kwargs, flat_example_value, ret_spec, body_r
):
@ -1021,7 +1022,9 @@ def speculate_subgraph_with_wrap_semantics(
source_target: Optional[HigherOrderOperator] = None,
enable_grad: Optional[bool] = None,
# TODO - We can probably just make everyone use automatic for wrap_semantics
set_subgraph_inputs: Literal["automatic", "semi_automatic", "flatten_manual", "manual"] = "automatic",
set_subgraph_inputs: Literal[
"automatic", "semi_automatic", "flatten_manual", "manual"
] = "automatic",
# Make default False
restore_side_effects: bool = True,
under_activation_checkpoint: bool = False,
@ -1034,8 +1037,13 @@ def speculate_subgraph_with_wrap_semantics(
) -> tuple[
VariableTracker, # output: The VT that Dynamo continues tracing with
torch.fx.Graph, # graph: The FX graph representing the subgraph computation
dict[torch.fx.Proxy, torch.fx.Proxy], # lifted_freevars: Free variables lifted as inputs
VariableTracker | tuple[VariableTracker, ...], # graph_output_vts: Tensor/symint VTs that are actual FX graph outputs
dict[
torch.fx.Proxy, torch.fx.Proxy
], # lifted_freevars: Free variables lifted as inputs
VariableTracker
| tuple[
VariableTracker, ...
], # graph_output_vts: Tensor/symint VTs that are actual FX graph outputs
]:
"""
Speculate subgraph for Higher-Order Operators (HOPs) with wrap semantics.
@ -4570,7 +4578,13 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
# Step 5: Install local_map subgraph
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
out = _call_function_and_unflatten_output_wrap_semantics(
tx, self.value, p_args, p_kwargs, example_value, body_r, body_graph_output_vts
tx,
self.value,
p_args,
p_kwargs,
example_value,
body_r,
body_graph_output_vts,
)
# Step 6: Restore inputs and outputs to global shapes