[RELAND v2] Close some sources of fake tensors (#164372)

Changelog:

1. When we run into an operation we didn't proxy, we end up emitting fake constants. We error under a config and we disable the config for some internal users. The reason we want to error is this signals a coverage problem we need to address but at the same time, we don't wnat to be disruptive to already working flows.

2. Previous attribute mutation detection logic in non-strict didn't account for nested module structure. This fixes silent incorrectness issue of exporting esm and qwen in non-strict and some torchbench models like levit_128 and demucs.

3. Previous logic didn't work on the cases where we mutate a container attribute as the previous approach used to pytree over old and new attributes resulting in length mismatch. We gracefully handle this now.

Differential Revision: [D83673054](https://our.internmc.facebook.com/intern/diff/D83673054)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164372
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-02 07:20:08 -07:00
committed by PyTorch MergeBot
parent 6a31f42da4
commit 4661200125
5 changed files with 202 additions and 41 deletions

View File

@ -786,6 +786,94 @@ graph():
# instead of the scripted function, so we get x.sin()
self.assertEqual(res, x.sin())
def test_nested_module_fake_tensor_leak(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self._tensor_cache = None
def forward(self, x):
if self._tensor_cache is None:
self._tensor_cache = x + 2
return self._tensor_cache.sum() + x.sum()
class Foo(torch.nn.Module):
def __init__(self, bar):
super().__init__()
self.bar = bar
def forward(self, x):
return self.bar(x)
foo = Foo(Bar())
_ = export(foo, (torch.ones(4, 4),), strict=False)
self.assertTrue(foo.bar._tensor_cache is None)
def test_export_leak_compile(self):
class BaseModule(torch.nn.Module):
def forward(self, *args, **kwargs):
raise NotImplementedError
class CacheModule(BaseModule):
def __init__(self, cache: torch.Tensor):
super().__init__()
assert cache.ndim == 3
self.cache = torch.nn.Parameter(cache, requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
n_tokens = x.size(1)
rolled_cache = torch.roll(self.cache.data, -n_tokens, dims=1)
rolled_cache[:, -n_tokens:, :] = x
self.cache.data = rolled_cache
return self.cache
class LinearBlock(torch.nn.Module):
def __init__(self, in_features, out_features, activation=None):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)
self.activation = activation
def forward(self, x):
x = self.linear(x)
return self.activation(x) if self.activation else x
class MyModel(BaseModule):
def __init__(self):
super().__init__()
default_cache = torch.zeros(1, 10, 5)
self.cache_layer = CacheModule(default_cache)
self.fc1 = LinearBlock(5, 10, activation=torch.nn.ReLU())
self.fc2 = LinearBlock(10, 5)
def forward(self, x):
cached = self.cache_layer(x)
out = self.fc1(cached)
out = self.fc2(out)
return out
with self.assertRaisesRegex(
RuntimeError,
"We found a fake tensor in the exported program constant's list. "
"This typically means our tracing system encountered an op that we can't trace through. "
"For the potential source, you can refer to following model attribute: cache_layer.lifted_tensor_0. "
"Please file an issue on github.",
):
_ = export(MyModel(), (torch.randn(1, 3, 5),), strict=False)
with self.assertWarnsRegex(
UserWarning,
"We found a fake tensor in the exported program constant's list. "
"This typically means our tracing system encountered an op that we can't trace through. "
"For the potential source, you can refer to following model attribute: cache_layer.lifted_tensor_0. "
"Please file an issue on github.",
):
# can't trigger all variant of export because later on it will crash
# and it is good because we warned :).
with torch._export.config.patch(error_on_lifted_constant_tensors=False):
_ = torch.export.export(
MyModel(), (torch.randn(1, 3, 5),), strict=False
)
def test_inline_script_class_method(self):
class M(torch.nn.Module):
@staticmethod
@ -13843,9 +13931,9 @@ def forward(self, x):
self.bar = x.sum()
return x + 2
with self.assertRaisesRegex(
ValueError,
"During torch.export, following attrs were created in the model.forward:",
with self.assertWarnsRegex(
UserWarning,
"The tensor attribute self.bar was assigned during export",
):
_ = export(Foo(), (torch.randn(4, 4),), strict=False)

View File

@ -411,7 +411,7 @@ def forward(self, token, x, cc):
F1(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
)
def test_torchbind_register_attr_at_runtime_error(self):
def test_torchbind_register_attr_at_runtime_get_restored(self):
# alias as model attribute
class F3(torch.nn.Module):
def forward(self, x, foo):
@ -419,10 +419,8 @@ def forward(self, token, x, cc):
return x + self.foo.add_tensor(x)
foo = torch.classes._TorchScriptTesting._Foo(10, 20)
with self.assertRaisesRegex(
ValueError, "following attrs were created in the model"
):
torch.export.export(F3(), (torch.ones(2, 3), foo))
torch.export.export(F3(), (torch.ones(2, 3), foo), strict=False)
self.assertFalse(hasattr(foo, "foo"))
@parametrize("pre_dispatch", [True, False])
def test_torchbind_input_and_alias(self, pre_dispatch):

View File

@ -22,6 +22,11 @@ use_new_tracer_experimental = False
# by default, but user can turn it on to debug leaks.
detect_non_strict_fake_tensor_leaks = False
# error on potentially pre-dispatch/non-strict tracing limitation
# this type of error usually happens when we encounter an op
# that we don't know how to proxy, resulting in untracked fake tensors
error_on_lifted_constant_tensors = True
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -224,10 +224,23 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
# return any attributes of a module that are not standard attributes
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
def _get_all_module_attributes(mod):
# return attributes from all modules and submodules
result = {}
for name, submodule in mod.named_modules():
result[name] = _get_attributes(submodule)
return result
def _restore_all_module_attributes(mod, snapshot):
# restore attributes to all modules and submodules
for name, submodule in mod.named_modules():
if name in snapshot:
submodule.__dict__.update(snapshot[name])
# save state of attributes before enter
snapshot = pytree.tree_map(
lambda x: x,
_get_attributes(mod),
_get_all_module_attributes(mod),
is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
)
try:
@ -235,45 +248,70 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
finally:
# after exit, compare state of attributes with snapshot
# to detect which tensor attributes were assigned
assigned_tensor_attributes = []
def _collect_assigned_tensor_attributes(kp, v, _v):
if _v is not v:
attr, *rest = kp
if isinstance(v, torch.Tensor):
assigned_tensor_attributes.append(
f"self.{attr.key}{pytree.keystr(rest)}"
)
# TODO(avik): Assigning all other types are allowed right now.
# Maybe in the future we want to limit this to primitive types?
return v
def _collect_assigned_tensor_attributes(snapshot, new_attrs):
assigned_tensor_attributes = []
new_attrs = _get_attributes(mod)
if len(new_attrs) != len(snapshot):
added_attrs = new_attrs.keys() - snapshot.keys()
deleted_attrs = snapshot.keys() - new_attrs.keys()
def _compare_values(path, old_val, new_val):
"""Recursively compare values, handling containers."""
# Same object, no change
if old_val is new_val:
return
if len(added_attrs) > 0:
raise ValueError(
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
if old_val is None or new_val is None:
if isinstance(new_val, torch.Tensor):
assigned_tensor_attributes.append(path)
return
if len(deleted_attrs) > 0:
raise ValueError(
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
# Check if it's a tensor that was reassigned
if isinstance(new_val, torch.Tensor):
assigned_tensor_attributes.append(path)
return
pytree.tree_map_with_path(
_collect_assigned_tensor_attributes, snapshot, new_attrs
# Handle dict containers
if isinstance(old_val, dict) and isinstance(new_val, dict):
all_keys = set(old_val.keys()) | set(new_val.keys())
for key in all_keys:
old_item = old_val.get(key)
new_item = new_val.get(key)
_compare_values(f"{path}[{key!r}]", old_item, new_item)
return
# Handle list/tuple containers
if isinstance(old_val, (list, tuple)) and isinstance(
new_val, (list, tuple)
):
# Different lengths = mutation happened
max_len = max(len(old_val), len(new_val))
for i in range(max_len):
old_item = old_val[i] if i < len(old_val) else None
new_item = new_val[i] if i < len(new_val) else None
_compare_values(f"{path}[{i}]", old_item, new_item)
return
# For other types, just check if they're different objects
# (we don't care about non-tensor mutations)
for module_name in snapshot.keys() | new_attrs.keys():
old_module_attrs = snapshot.get(module_name, {})
new_module_attrs = new_attrs.get(module_name, {})
for attr_name in old_module_attrs.keys() | new_module_attrs.keys():
module_prefix = f"self.{module_name}." if module_name else "self."
full_path = f"{module_prefix}{attr_name}"
old_val = old_module_attrs.get(attr_name)
new_val = new_module_attrs.get(attr_name)
_compare_values(full_path, old_val, new_val)
return assigned_tensor_attributes
new_attrs = _get_all_module_attributes(mod)
assigned_tensor_attributes = _collect_assigned_tensor_attributes(
snapshot, new_attrs
)
# restore state of all attributes (including, e.g., of primitive types)
mod.__dict__.update(snapshot)
_restore_all_module_attributes(mod, snapshot)
if assigned_tensor_attributes:
if len(assigned_tensor_attributes) > 1:

View File

@ -206,6 +206,14 @@ def _strip_root(x):
return x
def _is_bogus_const_name(name: str):
splitted_names = name.split(".")
if len(splitted_names) < 1:
return True
return splitted_names[-1].startswith("lifted_tensor")
def _rewrite_tracepoint_node(gm: torch.fx.GraphModule):
"""
In-place modify input graph module by replacing the export tracepoint with a new node
@ -2069,6 +2077,11 @@ def _export_for_training(
original_state_dict = _get_original_state_dict(mod)
has_ambient_mode = False
if not strict:
flat_args, _ = pytree.tree_flatten((args, kwargs))
has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
# Call the appropriate export function based on the strictness of tracing.
export_func = _strict_export if strict else _non_strict_export
@ -2088,6 +2101,25 @@ def _export_for_training(
_to_aten_func=_export_to_aten_ir_make_fx,
)
# If we are tracing with fake inputs, it is expected to
# see fake tensor constants.
if not strict and not has_ambient_mode:
for const, val in export_artifact.aten.constants.items():
if isinstance(
val, torch._subclasses.fake_tensor.FakeTensor
) and _is_bogus_const_name(const):
error_msg = (
f"We found a fake tensor in the exported program constant's list. "
f"This typically means our tracing system encountered an op that "
f"we can't trace through. For the potential source, you can refer to "
f"following model attribute: {const}. "
f"Please file an issue on github. "
)
if torch._export.config.error_on_lifted_constant_tensors:
raise RuntimeError(error_msg)
else:
warnings.warn(error_msg)
export_graph_signature = export_artifact.aten.sig
forward_arg_names = _get_forward_arg_names(mod, args, kwargs)