mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Our strategy for detecting fake tensor leakage in non-strict for outside scope (side effects happening outside of model.forward) is: 1. We do gc.collect() before export and get the alive fake tensors 2. We dump the proxy to fake tensor map from make_fx tracer 3. We query gc again to get alive fake tensors 4. We take the delta between (1) and (3) 5. Filter out fake tensors that are: 1. Associated with `TrackedFake` (input tracking thing in symbolic_shapes) 2. Associated with `gm.meta` 6. Do ID match with the proxies and emit their stacktraces. We rely on (https://github.com/pytorch/pytorch/pull/159923) for other sources of leakages such as: 1. We failed to proxy an operator (like param.data) 2. We cache some tensor in model.forward (https://github.com/pytorch/pytorch/issues/155114) In general, we notice `gc.collect()` and query-ing gc for live objects are kinda slow. So we turn on this feature under env variable. We should document on export public facing documents that if you run into weird errors regarding fake tensors, they should look into turning on this env variable for further analysis. Test Plan: Test plan Rollback Plan: Differential Revision: D80003204 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160456 Approved by: https://github.com/pianpwk
113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
import gc
|
|
import types
|
|
import typing
|
|
import weakref
|
|
|
|
import torch
|
|
|
|
|
|
"""
|
|
These functions are used to detect potential fake tensor leakage when using PT2 export.
|
|
See NOTE [export non-strict fake tensor leak detection]
|
|
|
|
There are some complications that made this logic overly complicated:
|
|
1) Python 3.10 and Python 3.12 have different ways of implementing referrer so
|
|
we need to account for whether it is ref.__dict__ or the real ref object
|
|
|
|
2) There are some internal PT2 references to fake tensors like `TrackedFake`
|
|
3) closures, generators, and bound methods can hold fake tensors.
|
|
4) global object can hold onto a fake tensor
|
|
|
|
In general, these utils are our last resort to detect fake tensors. if the leak happens
|
|
within the model attributes, we have a separate mechanism to detect. This tool relies a bit
|
|
on garbage collector internal details, so I think it is unsafe to turn on by default, hence
|
|
this tool should be used as debugging tool.
|
|
"""
|
|
|
|
|
|
# Things we never want to flag as leaks
|
|
_SKIP_TYPES = (
|
|
types.FrameType,
|
|
types.ModuleType,
|
|
)
|
|
|
|
|
|
def _is_globals_or_locals(obj: typing.Any) -> bool:
|
|
# These comparisons only make sense within this frame; still cheap to check.
|
|
return obj is globals() or obj is locals()
|
|
|
|
|
|
def _is_tracked_fake(obj: typing.Any) -> bool:
|
|
return isinstance(obj, torch.fx.experimental.symbolic_shapes.TrackedFake)
|
|
|
|
|
|
def _is_gm_meta_like_dict(d: dict, o: typing.Any) -> bool:
|
|
# Hope gm.meta was a custom dict we can assert on
|
|
return d.get("val", None) is o
|
|
|
|
|
|
def _dict_is_attr_of_tracked_fake(d: dict) -> bool:
|
|
"""
|
|
Python 3.10 quirk: sometimes the referrer is obj.__dict__ instead of obj.
|
|
Check if this dict is exactly the __dict__ of a TrackedFake.
|
|
"""
|
|
for parent in gc.get_referrers(d):
|
|
if (
|
|
hasattr(parent, "__dict__")
|
|
and parent.__dict__ is d
|
|
and _is_tracked_fake(parent)
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
def find_legit_leaks_from_referrers(active_fakes: weakref.WeakSet) -> weakref.WeakSet:
|
|
legit_leak: weakref.WeakSet = weakref.WeakSet()
|
|
|
|
# This is so that we don't falsely flag generator to be holding fake tensor
|
|
fake_list = list(active_fakes)
|
|
fake_list_id = id(fake_list)
|
|
|
|
for act in fake_list:
|
|
# Track by id to avoid processing duplicate referrers
|
|
seen = set()
|
|
# Assume it's a leak unless we find only ignorable referrers
|
|
flagged = False
|
|
|
|
for r in gc.get_referrers(act):
|
|
rid = id(r)
|
|
if rid in seen:
|
|
continue
|
|
seen.add(rid)
|
|
|
|
# Skip our own fake_list
|
|
if rid == fake_list_id:
|
|
continue
|
|
|
|
# Fast-path: skip obvious non-owners
|
|
if _is_globals_or_locals(r):
|
|
continue
|
|
if isinstance(r, _SKIP_TYPES):
|
|
continue
|
|
if _is_tracked_fake(r):
|
|
# TrackedFake should be ignored
|
|
continue
|
|
|
|
# Handle dicts carefully (Python 3.10 sometimes shows __dict__)
|
|
if isinstance(r, dict):
|
|
if _is_gm_meta_like_dict(r, act):
|
|
continue
|
|
if _dict_is_attr_of_tracked_fake(r):
|
|
continue
|
|
flagged = True
|
|
break
|
|
|
|
# Any other referrer we don't explicitly whitelist counts as a leak
|
|
flagged = True
|
|
break
|
|
|
|
if flagged:
|
|
legit_leak.add(act)
|
|
|
|
return legit_leak
|