Files
pytorch/torch/export/_leakage_detection_utils.py
Yuanyuan Chen a029675f6f More ruff SIM fixes (#164695)
This PR applies ruff `SIM` rules to more files. Most changes are about simplifying `dict.get` because `None` is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164695
Approved by: https://github.com/ezyang
2025-10-09 03:24:50 +00:00

113 lines
3.4 KiB
Python

import gc
import types
import typing
import weakref
import torch
"""
These functions are used to detect potential fake tensor leakage when using PT2 export.
See NOTE [export non-strict fake tensor leak detection]
There are some complications that made this logic overly complicated:
1) Python 3.10 and Python 3.12 have different ways of implementing referrer so
we need to account for whether it is ref.__dict__ or the real ref object
2) There are some internal PT2 references to fake tensors like `TrackedFake`
3) closures, generators, and bound methods can hold fake tensors.
4) global object can hold onto a fake tensor
In general, these utils are our last resort to detect fake tensors. if the leak happens
within the model attributes, we have a separate mechanism to detect. This tool relies a bit
on garbage collector internal details, so I think it is unsafe to turn on by default, hence
this tool should be used as debugging tool.
"""
# Things we never want to flag as leaks
_SKIP_TYPES = (
types.FrameType,
types.ModuleType,
)
def _is_globals_or_locals(obj: typing.Any) -> bool:
# These comparisons only make sense within this frame; still cheap to check.
return obj is globals() or obj is locals()
def _is_tracked_fake(obj: typing.Any) -> bool:
return isinstance(obj, torch.fx.experimental.symbolic_shapes.TrackedFake)
def _is_gm_meta_like_dict(d: dict, o: typing.Any) -> bool:
# Hope gm.meta was a custom dict we can assert on
return d.get("val") is o
def _dict_is_attr_of_tracked_fake(d: dict) -> bool:
"""
Python 3.10 quirk: sometimes the referrer is obj.__dict__ instead of obj.
Check if this dict is exactly the __dict__ of a TrackedFake.
"""
for parent in gc.get_referrers(d):
if (
hasattr(parent, "__dict__")
and parent.__dict__ is d
and _is_tracked_fake(parent)
):
return True
return False
def find_legit_leaks_from_referrers(active_fakes: weakref.WeakSet) -> weakref.WeakSet:
legit_leak: weakref.WeakSet = weakref.WeakSet()
# This is so that we don't falsely flag generator to be holding fake tensor
fake_list = list(active_fakes)
fake_list_id = id(fake_list)
for act in fake_list:
# Track by id to avoid processing duplicate referrers
seen = set()
# Assume it's a leak unless we find only ignorable referrers
flagged = False
for r in gc.get_referrers(act):
rid = id(r)
if rid in seen:
continue
seen.add(rid)
# Skip our own fake_list
if rid == fake_list_id:
continue
# Fast-path: skip obvious non-owners
if _is_globals_or_locals(r):
continue
if isinstance(r, _SKIP_TYPES):
continue
if _is_tracked_fake(r):
# TrackedFake should be ignored
continue
# Handle dicts carefully (Python 3.10 sometimes shows __dict__)
if isinstance(r, dict):
if _is_gm_meta_like_dict(r, act):
continue
if _dict_is_attr_of_tracked_fake(r):
continue
flagged = True
break
# Any other referrer we don't explicitly whitelist counts as a leak
flagged = True
break
if flagged:
legit_leak.add(act)
return legit_leak