mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[invoke_subgraph] make collect_meta_analysis fake prop cachable (#156347)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156347 Approved by: https://github.com/anijain2305, https://github.com/zou3519 ghstack dependencies: #156260
This commit is contained in:
committed by
PyTorch MergeBot
parent
558d7f7db0
commit
f5f4beaf56
@ -1645,6 +1645,9 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
convert FakeTensors into metadata. Raises _BypassDispatchCache to signal
|
||||
unsupported cases that should bypass caching.
|
||||
"""
|
||||
from torch._higher_order_ops.auto_functionalize import (
|
||||
FunctionalCallableWithEpilogue,
|
||||
)
|
||||
from torch._higher_order_ops.utils import FunctionalizeCtxWrapper
|
||||
|
||||
if isinstance(args, dict):
|
||||
@ -1685,6 +1688,10 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# functional wrapper is destroyed after fake tensor prop. We
|
||||
# need to put the finalizer on the subgraph.
|
||||
id_hashed_objects.append(arg.subgraph)
|
||||
elif isinstance(arg, FunctionalCallableWithEpilogue):
|
||||
result.append(type(arg))
|
||||
result.append(hash(arg))
|
||||
id_hashed_objects.append(arg.orig_callable)
|
||||
else:
|
||||
# It's important to capture the type of the arg since, e.g., 1 and 1.0
|
||||
# hash to the same value, but can produce different dtypes for the
|
||||
|
||||
Reference in New Issue
Block a user