mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 03:14:45 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
@ -322,6 +322,7 @@ def _call_function_and_unflatten_output_wrap_semantics(
|
||||
orig_vt.proxy = subgraph_vt.proxy
|
||||
return body_r
|
||||
|
||||
|
||||
def _call_function_and_unflatten_output(
|
||||
tx, fn, args, kwargs, flat_example_value, ret_spec, body_r
|
||||
):
|
||||
@ -1021,7 +1022,9 @@ def speculate_subgraph_with_wrap_semantics(
|
||||
source_target: Optional[HigherOrderOperator] = None,
|
||||
enable_grad: Optional[bool] = None,
|
||||
# TODO - We can probably just make everyone use automatic for wrap_semantics
|
||||
set_subgraph_inputs: Literal["automatic", "semi_automatic", "flatten_manual", "manual"] = "automatic",
|
||||
set_subgraph_inputs: Literal[
|
||||
"automatic", "semi_automatic", "flatten_manual", "manual"
|
||||
] = "automatic",
|
||||
# Make default False
|
||||
restore_side_effects: bool = True,
|
||||
under_activation_checkpoint: bool = False,
|
||||
@ -1034,8 +1037,13 @@ def speculate_subgraph_with_wrap_semantics(
|
||||
) -> tuple[
|
||||
VariableTracker, # output: The VT that Dynamo continues tracing with
|
||||
torch.fx.Graph, # graph: The FX graph representing the subgraph computation
|
||||
dict[torch.fx.Proxy, torch.fx.Proxy], # lifted_freevars: Free variables lifted as inputs
|
||||
VariableTracker | tuple[VariableTracker, ...], # graph_output_vts: Tensor/symint VTs that are actual FX graph outputs
|
||||
dict[
|
||||
torch.fx.Proxy, torch.fx.Proxy
|
||||
], # lifted_freevars: Free variables lifted as inputs
|
||||
VariableTracker
|
||||
| tuple[
|
||||
VariableTracker, ...
|
||||
], # graph_output_vts: Tensor/symint VTs that are actual FX graph outputs
|
||||
]:
|
||||
"""
|
||||
Speculate subgraph for Higher-Order Operators (HOPs) with wrap semantics.
|
||||
@ -4570,7 +4578,13 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
|
||||
# Step 5: Install local_map subgraph
|
||||
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
|
||||
out = _call_function_and_unflatten_output_wrap_semantics(
|
||||
tx, self.value, p_args, p_kwargs, example_value, body_r, body_graph_output_vts
|
||||
tx,
|
||||
self.value,
|
||||
p_args,
|
||||
p_kwargs,
|
||||
example_value,
|
||||
body_r,
|
||||
body_graph_output_vts,
|
||||
)
|
||||
|
||||
# Step 6: Restore inputs and outputs to global shapes
|
||||
|
||||
Reference in New Issue
Block a user