[invoke_subgraph] make collect_meta_analysis fake prop cachable (#156347)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156347
Approved by: https://github.com/anijain2305, https://github.com/zou3519
ghstack dependencies: #156260
This commit is contained in:
Yidi Wu
2025-06-24 10:21:57 -07:00
committed by PyTorch MergeBot
parent 558d7f7db0
commit f5f4beaf56
4 changed files with 98 additions and 15 deletions

View File

@ -1645,6 +1645,9 @@ class FakeTensorMode(TorchDispatchMode):
convert FakeTensors into metadata. Raises _BypassDispatchCache to signal
unsupported cases that should bypass caching.
"""
from torch._higher_order_ops.auto_functionalize import (
FunctionalCallableWithEpilogue,
)
from torch._higher_order_ops.utils import FunctionalizeCtxWrapper
if isinstance(args, dict):
@ -1685,6 +1688,10 @@ class FakeTensorMode(TorchDispatchMode):
# functional wrapper is destroyed after fake tensor prop. We
# need to put the finalizer on the subgraph.
id_hashed_objects.append(arg.subgraph)
elif isinstance(arg, FunctionalCallableWithEpilogue):
result.append(type(arg))
result.append(hash(arg))
id_hashed_objects.append(arg.orig_callable)
else:
# It's important to capture the type of the arg since, e.g., 1 and 1.0
# hash to the same value, but can produce different dtypes for the