Revert "[invoke_subgraph] make collect_meta_analysis fake prop cachable (#156347)"

This reverts commit f179b7198522e6d93bd103efba1a1ebd5a2cf891.

Reverted https://github.com/pytorch/pytorch/pull/156347 on behalf of https://github.com/ydwu4 due to no signal, it breaks linter tests. ([comment](https://github.com/pytorch/pytorch/pull/156347#issuecomment-2997453729))
This commit is contained in:
PyTorch MergeBot
2025-06-23 18:19:29 +00:00
parent 98a34e8d4b
commit 35d03398e5
4 changed files with 15 additions and 94 deletions

View File

@ -1091,60 +1091,6 @@ class GraphModule(torch.nn.Module):
self.assertEqual(exp_out, out)
self.assertEqual(x_clone, x)
def test_input_mutation_mutiple_times_fake_tensor_cahche_hit(self):
@mark_compile_region
def gn(x, y):
x.add_(1)
return torch.mul(x, y)
def fn(x, y):
z = gn(x, y)
for _ in range(16):
z += gn(x, y)
return z
x = torch.randn(8, requires_grad=False)
x_clone = x.clone()
y = torch.randn(8, requires_grad=False)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
fake_prop_count = 0
def _mock_invoke_subgraph(mode, subgraph, identifer, *operands):
nonlocal fake_prop_count
fake_prop_count += 1
return (operands[0].clone(),)
with mock.patch(
"torch._higher_order_ops.utils.registered_hop_fake_fns",
{torch.ops.higher_order.invoke_subgraph: _mock_invoke_subgraph},
), mock.patch(
"torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation",
True,
), torch.no_grad():
out = opt_fn(x, y)
# Fake propagation occurs only twice, with subsequent calls using cached results.
#
# First fake propagation (in collect_metadata_analysis of AOT):
# - Uses the original Dynamo graph
# - Flow: functionalization -> fake tensor
#
# Second fake propagation (in _create_graph of AOT):
# - Uses a materialized graph that includes epilogue operations
# - Flow: functionalization -> proxy -> fake tensor
#
# The key difference: the second time we materialize the graph with epilogue
# operations included in the proxy key. Since the dynamo graph module is not
# in the functional + epilogue format, the cache key should be different,
# preventing cache reuse between these two phases.
self.assertEqual(fake_prop_count, 2)
exp_out = fn(x_clone, y)
self.assertEqual(exp_out, out)
self.assertEqual(x_clone, x)
def test_input_mutation_inference_mode(self):
@nested_compile_region
def gn(x, y):

View File

@ -3,7 +3,7 @@ import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Callable, get_args, Optional, Union
from typing import Any, get_args, Optional, Union
import torch
import torch._library.utils as library_utils
@ -571,28 +571,6 @@ def do_auto_functionalize(
return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]
# Wrapper for GraphModule that applies functionalization during execution to enable
# epilogue graph inlining and better fusion opportunities in subgraphs
# When tracing this wrapper, we'll get a graph module with epilogue.
#
# We want to hash it according to the original graph module, so that when we go
# from Functional mode -> fake mode for multiple invoke_subgraph calls that share,
# the same inner graph module, we can hit the cache.
class FunctionalCallableWithEpilogue:
def __init__(self, orig_callable: Callable):
self.orig_callable = orig_callable
def __call__(self, *args, **kwargs):
# We call torch.func.functionalize. This allows us to inline the epilogue graph.
# Inlining has the benefit of allowing easiser fusion inside subgraph.
# Though the epilogue graph contains copy_, it is OK becuase inductor can handle it
# and this is also how we have been supporting top-level graph input mutation.
return tuple(torch.func.functionalize(self.orig_callable)(*args, **kwargs))
def __hash__(self):
return id(self.orig_callable)
def do_auto_functionalize_v2(
mode: "torch._subclasses.functional_tensor.FunctionalTensorMode",
op: Union[OpOverload, HopInstance],
@ -613,7 +591,19 @@ def do_auto_functionalize_v2(
def _functionalize_callable(arg: Any):
if callable(arg):
return FunctionalCallableWithEpilogue(arg)
def functional_fn(*args, **kwargs):
# We call torch.func.functionalize. This allows us to inline the epilogue graph.
# Inlining has the benefit of allowing easiser fusion inside subgraph.
# Though the epilogue graph contains copy_, it is OK becuase inductor can handle it
# and this is also how we have been supporting top-level graph input mutation.
return tuple(
pytree.tree_leaves(torch.func.functionalize(arg)(*args, **kwargs))
)
return torch._higher_order_ops.base_hop.FunctionWithNoFreeVars(
functional_fn
)
return arg
args, kwargs = pytree.tree_map(_functionalize_callable, (args, kwargs))

View File

@ -6,7 +6,6 @@ import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._higher_order_ops.auto_functionalize import FunctionalCallableWithEpilogue
from torch._higher_order_ops.utils import (
check_input_alias_and_mutation_return_outputs,
HopInstance,
@ -68,14 +67,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
)
def __call__(self, subgraph, *operands, **kwargs):
if not isinstance(
subgraph,
(
torch.fx.GraphModule,
FunctionWithNoFreeVars,
FunctionalCallableWithEpilogue,
),
):
if not isinstance(subgraph, (torch.fx.GraphModule, FunctionWithNoFreeVars)):
raise RuntimeError(
f"{self._name}: when calling this API without torch.compile, "
f"we require that the subgraph be a torch.fx.GraphModule (or "

View File

@ -1645,9 +1645,6 @@ class FakeTensorMode(TorchDispatchMode):
convert FakeTensors into metadata. Raises _BypassDispatchCache to signal
unsupported cases that should bypass caching.
"""
from torch._higher_order_ops.auto_functionalize import (
FunctionalCallableWithEpilogue,
)
from torch._higher_order_ops.utils import FunctionalizeCtxWrapper
if isinstance(args, dict):
@ -1688,10 +1685,6 @@ class FakeTensorMode(TorchDispatchMode):
# functional wrapper is destroyed after fake tensor prop. We
# need to put the finalizer on the subgraph.
id_hashed_objects.append(arg.subgraph)
elif isinstance(arg, FunctionalCallableWithEpilogue):
result.append(type(arg))
result.append(hash(arg))
id_hashed_objects.append(arg.orig_callable)
else:
# It's important to capture the type of the arg since, e.g., 1 and 1.0
# hash to the same value, but can produce different dtypes for the