mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[invoke_subgraph] make collect_meta_analysis fake prop cachable (#156347)"
This reverts commit f179b7198522e6d93bd103efba1a1ebd5a2cf891. Reverted https://github.com/pytorch/pytorch/pull/156347 on behalf of https://github.com/ydwu4 due to no signal, it breaks linter tests. ([comment](https://github.com/pytorch/pytorch/pull/156347#issuecomment-2997453729))
This commit is contained in:
@ -1091,60 +1091,6 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(exp_out, out)
|
||||
self.assertEqual(x_clone, x)
|
||||
|
||||
def test_input_mutation_mutiple_times_fake_tensor_cahche_hit(self):
|
||||
@mark_compile_region
|
||||
def gn(x, y):
|
||||
x.add_(1)
|
||||
return torch.mul(x, y)
|
||||
|
||||
def fn(x, y):
|
||||
z = gn(x, y)
|
||||
for _ in range(16):
|
||||
z += gn(x, y)
|
||||
return z
|
||||
|
||||
x = torch.randn(8, requires_grad=False)
|
||||
x_clone = x.clone()
|
||||
y = torch.randn(8, requires_grad=False)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
fake_prop_count = 0
|
||||
|
||||
def _mock_invoke_subgraph(mode, subgraph, identifer, *operands):
|
||||
nonlocal fake_prop_count
|
||||
fake_prop_count += 1
|
||||
return (operands[0].clone(),)
|
||||
|
||||
with mock.patch(
|
||||
"torch._higher_order_ops.utils.registered_hop_fake_fns",
|
||||
{torch.ops.higher_order.invoke_subgraph: _mock_invoke_subgraph},
|
||||
), mock.patch(
|
||||
"torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation",
|
||||
True,
|
||||
), torch.no_grad():
|
||||
out = opt_fn(x, y)
|
||||
|
||||
# Fake propagation occurs only twice, with subsequent calls using cached results.
|
||||
#
|
||||
# First fake propagation (in collect_metadata_analysis of AOT):
|
||||
# - Uses the original Dynamo graph
|
||||
# - Flow: functionalization -> fake tensor
|
||||
#
|
||||
# Second fake propagation (in _create_graph of AOT):
|
||||
# - Uses a materialized graph that includes epilogue operations
|
||||
# - Flow: functionalization -> proxy -> fake tensor
|
||||
#
|
||||
# The key difference: the second time we materialize the graph with epilogue
|
||||
# operations included in the proxy key. Since the dynamo graph module is not
|
||||
# in the functional + epilogue format, the cache key should be different,
|
||||
# preventing cache reuse between these two phases.
|
||||
self.assertEqual(fake_prop_count, 2)
|
||||
exp_out = fn(x_clone, y)
|
||||
self.assertEqual(exp_out, out)
|
||||
self.assertEqual(x_clone, x)
|
||||
|
||||
def test_input_mutation_inference_mode(self):
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
|
@ -3,7 +3,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, get_args, Optional, Union
|
||||
from typing import Any, get_args, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._library.utils as library_utils
|
||||
@ -571,28 +571,6 @@ def do_auto_functionalize(
|
||||
return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# Wrapper for GraphModule that applies functionalization during execution to enable
|
||||
# epilogue graph inlining and better fusion opportunities in subgraphs
|
||||
# When tracing this wrapper, we'll get a graph module with epilogue.
|
||||
#
|
||||
# We want to hash it according to the original graph module, so that when we go
|
||||
# from Functional mode -> fake mode for multiple invoke_subgraph calls that share,
|
||||
# the same inner graph module, we can hit the cache.
|
||||
class FunctionalCallableWithEpilogue:
|
||||
def __init__(self, orig_callable: Callable):
|
||||
self.orig_callable = orig_callable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# We call torch.func.functionalize. This allows us to inline the epilogue graph.
|
||||
# Inlining has the benefit of allowing easiser fusion inside subgraph.
|
||||
# Though the epilogue graph contains copy_, it is OK becuase inductor can handle it
|
||||
# and this is also how we have been supporting top-level graph input mutation.
|
||||
return tuple(torch.func.functionalize(self.orig_callable)(*args, **kwargs))
|
||||
|
||||
def __hash__(self):
|
||||
return id(self.orig_callable)
|
||||
|
||||
|
||||
def do_auto_functionalize_v2(
|
||||
mode: "torch._subclasses.functional_tensor.FunctionalTensorMode",
|
||||
op: Union[OpOverload, HopInstance],
|
||||
@ -613,7 +591,19 @@ def do_auto_functionalize_v2(
|
||||
|
||||
def _functionalize_callable(arg: Any):
|
||||
if callable(arg):
|
||||
return FunctionalCallableWithEpilogue(arg)
|
||||
|
||||
def functional_fn(*args, **kwargs):
|
||||
# We call torch.func.functionalize. This allows us to inline the epilogue graph.
|
||||
# Inlining has the benefit of allowing easiser fusion inside subgraph.
|
||||
# Though the epilogue graph contains copy_, it is OK becuase inductor can handle it
|
||||
# and this is also how we have been supporting top-level graph input mutation.
|
||||
return tuple(
|
||||
pytree.tree_leaves(torch.func.functionalize(arg)(*args, **kwargs))
|
||||
)
|
||||
|
||||
return torch._higher_order_ops.base_hop.FunctionWithNoFreeVars(
|
||||
functional_fn
|
||||
)
|
||||
return arg
|
||||
|
||||
args, kwargs = pytree.tree_map(_functionalize_callable, (args, kwargs))
|
||||
|
@ -6,7 +6,6 @@ import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._higher_order_ops.auto_functionalize import FunctionalCallableWithEpilogue
|
||||
from torch._higher_order_ops.utils import (
|
||||
check_input_alias_and_mutation_return_outputs,
|
||||
HopInstance,
|
||||
@ -68,14 +67,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
||||
)
|
||||
|
||||
def __call__(self, subgraph, *operands, **kwargs):
|
||||
if not isinstance(
|
||||
subgraph,
|
||||
(
|
||||
torch.fx.GraphModule,
|
||||
FunctionWithNoFreeVars,
|
||||
FunctionalCallableWithEpilogue,
|
||||
),
|
||||
):
|
||||
if not isinstance(subgraph, (torch.fx.GraphModule, FunctionWithNoFreeVars)):
|
||||
raise RuntimeError(
|
||||
f"{self._name}: when calling this API without torch.compile, "
|
||||
f"we require that the subgraph be a torch.fx.GraphModule (or "
|
||||
|
@ -1645,9 +1645,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
convert FakeTensors into metadata. Raises _BypassDispatchCache to signal
|
||||
unsupported cases that should bypass caching.
|
||||
"""
|
||||
from torch._higher_order_ops.auto_functionalize import (
|
||||
FunctionalCallableWithEpilogue,
|
||||
)
|
||||
from torch._higher_order_ops.utils import FunctionalizeCtxWrapper
|
||||
|
||||
if isinstance(args, dict):
|
||||
@ -1688,10 +1685,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# functional wrapper is destroyed after fake tensor prop. We
|
||||
# need to put the finalizer on the subgraph.
|
||||
id_hashed_objects.append(arg.subgraph)
|
||||
elif isinstance(arg, FunctionalCallableWithEpilogue):
|
||||
result.append(type(arg))
|
||||
result.append(hash(arg))
|
||||
id_hashed_objects.append(arg.orig_callable)
|
||||
else:
|
||||
# It's important to capture the type of the arg since, e.g., 1 and 1.0
|
||||
# hash to the same value, but can produce different dtypes for the
|
||||
|
Reference in New Issue
Block a user