[RELAND] Close some sources of fake tensor leakage (#161589)

Reland of https://github.com/pytorch/pytorch/pull/159923

Couple of fixes:
1. When we run into an operation we didn't proxy, we end up emitting fake constants. We detect this and warn using the FQN of the lifted constant. We warn because some internal users complained it was regressing their exportability.

2. Previous attribute mutation detection logic in non-strict didn't account for nested module structure. This fixes silent incorrectness issue of exporting esm and qwen in non-strict

3. We modify yolov3 to fix the previous silent incorrect behaviour
4. We use strict export for levit_128 because it errors in non-strict due to more strict side effect checking

When upgrading torchbench pin, opacus_cifar10 seems to not run on eager anymore. I verified this by pushing a temporary PR on master with new pin. So i added it to expect_fail list.

Differential Revision: [D81133908](https://our.internmc.facebook.com/intern/diff/D81133908)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161589
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-08-27 07:55:44 -07:00
committed by PyTorch MergeBot
parent 2e77a08b95
commit 5790b00975
5 changed files with 239 additions and 30 deletions

View File

@ -1 +1 @@
e03a63be43e33596f7f0a43b0f530353785e4a59
22bc29b4d503fc895ff73bc720ff396e9723465f

View File

@ -1427,13 +1427,23 @@ class AOTInductorModelCache:
inductor_configs = {}
if mode == "max-autotune":
inductor_configs["max_autotune"] = True
ep = torch.export.export(
model_clone,
example_args,
example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=False,
)
# We can't support this in non-strict
if hasattr(model_clone, "name") and model.name == "levit_128":
ep = torch.export.export(
model_clone,
example_args,
example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=True,
)
else:
ep = torch.export.export(
model_clone,
example_args,
example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=True,
)
with torch.no_grad():
package_path = torch._inductor.aoti_compile_and_package(
ep, inductor_configs=inductor_configs
@ -2317,6 +2327,7 @@ class BenchmarkRunner:
# no need for n iterations
# the logic should be the same to self.model_iter_fn (forward_pass)
with self.autocast(**self.autocast_arg):
model_copy.name = name
optimized_model_iter_fn = optimize_ctx(
model_copy, example_inputs
)

View File

@ -420,6 +420,28 @@ graph():
):
ep.module()(torch.tensor([3]))
def test_container_leak(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self._cache = {}
def forward(self, x):
self._cache["leaky"] = x.sum()
return x.sum()
class Foo(torch.nn.Module):
def __init__(self, bar):
super().__init__()
self.bar = bar
def forward(self, x):
return self.bar(x)
foo = Foo(Bar())
with self.assertRaisesRegex(ValueError, "self.bar._cache"):
export(foo, (torch.randn(4, 4),), strict=False)
def test_export_assume_static_by_default(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor):
@ -4341,6 +4363,79 @@ def forward(self, x):
x = torch.tensor([1, 2])
self.assertTrue(torch.allclose(mod(x), ep.module()(x)))
def test_nested_module_fake_tensor_leak(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self._tensor_cache = None
def forward(self, x):
if self._tensor_cache is None:
self._tensor_cache = x + 2
return self._tensor_cache.sum() + x.sum()
class Foo(torch.nn.Module):
def __init__(self, bar):
super().__init__()
self.bar = bar
def forward(self, x):
return self.bar(x)
foo = Foo(Bar())
_ = export(foo, (torch.ones(4, 4),), strict=False)
self.assertTrue(foo.bar._tensor_cache is None)
def test_export_leak_compile(self):
class BaseModule(torch.nn.Module):
def forward(self, *args, **kwargs):
raise NotImplementedError
class CacheModule(BaseModule):
def __init__(self, cache: torch.Tensor):
super().__init__()
assert cache.ndim == 3
self.cache = torch.nn.Parameter(cache, requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
n_tokens = x.size(1)
rolled_cache = torch.roll(self.cache.data, -n_tokens, dims=1)
rolled_cache[:, -n_tokens:, :] = x
self.cache.data = rolled_cache
return self.cache
class LinearBlock(torch.nn.Module):
def __init__(self, in_features, out_features, activation=None):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)
self.activation = activation
def forward(self, x):
x = self.linear(x)
return self.activation(x) if self.activation else x
class MyModel(BaseModule):
def __init__(self):
super().__init__()
default_cache = torch.zeros(1, 10, 5)
self.cache_layer = CacheModule(default_cache)
self.fc1 = LinearBlock(5, 10, activation=torch.nn.ReLU())
self.fc2 = LinearBlock(10, 5)
def forward(self, x):
cached = self.cache_layer(x)
out = self.fc1(cached)
out = self.fc2(out)
return out
with self.assertRaisesRegex(
RuntimeError,
"cached = self.cache_layer\(x\)",
):
# Intentionally using training IR here because it will crash in inference IR
# anyways.
_ = torch.export.export(MyModel(), (torch.randn(1, 3, 5),), strict=False)
def test_export_for_training_with_container_type(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:

View File

@ -221,10 +221,23 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
# return any attributes of a module that are not standard attributes
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
def _get_all_module_attributes(mod):
# return attributes from all modules and submodules
result = {}
for name, submodule in mod.named_modules():
result[name] = _get_attributes(submodule)
return result
def _restore_all_module_attributes(mod, snapshot):
# restore attributes to all modules and submodules
for name, submodule in mod.named_modules():
if name in snapshot:
submodule.__dict__.update(snapshot[name])
# save state of attributes before enter
snapshot = pytree.tree_map(
lambda x: x,
_get_attributes(mod),
_get_all_module_attributes(mod),
is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
)
try:
@ -236,41 +249,75 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
def _collect_assigned_tensor_attributes(kp, v, _v):
if _v is not v:
attr, *rest = kp
module_name, attr, *rest = kp
if isinstance(v, torch.Tensor):
module_prefix = f"{module_name.key}." if module_name.key else ""
assigned_tensor_attributes.append(
f"self.{attr.key}{pytree.keystr(rest)}"
f"self.{module_prefix}{attr.key}{pytree.keystr(rest)}"
)
# TODO(avik): Assigning all other types are allowed right now.
# Maybe in the future we want to limit this to primitive types?
return v
new_attrs = _get_attributes(mod)
if len(new_attrs) != len(snapshot):
added_attrs = new_attrs.keys() - snapshot.keys()
deleted_attrs = snapshot.keys() - new_attrs.keys()
new_attrs = _get_all_module_attributes(mod)
if len(added_attrs) > 0:
raise ValueError(
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
# Check for added/deleted attributes across all modules
for module_name in snapshot.keys() | new_attrs.keys():
old_module_attrs = snapshot.get(module_name, {})
new_module_attrs = new_attrs.get(module_name, {})
if len(deleted_attrs) > 0:
raise ValueError(
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
module_prefix = f"self.{module_name}." if module_name else "self."
if len(new_module_attrs) != len(old_module_attrs):
added_attrs = new_module_attrs.keys() - old_module_attrs.keys()
deleted_attrs = old_module_attrs.keys() - new_module_attrs.keys()
if len(added_attrs) > 0:
formatted_attrs = [f"{module_prefix}{attr}" for attr in added_attrs]
raise ValueError(
f"During torch.export, following attrs were created in the model.forward: {formatted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
if len(deleted_attrs) > 0:
formatted_attrs = [
f"{module_prefix}{attr}" for attr in deleted_attrs
]
raise ValueError(
f"During torch.export, following attrs were deleted in the model.forward: {formatted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
# Tensors could have leaked at container attributes
for k, new_v in new_module_attrs.items():
assert k in old_module_attrs
if isinstance(new_v, (tuple, list, dict)):
flat_new_v, _ = pytree.tree_flatten(new_v)
flat_old_v, _ = pytree.tree_flatten(old_module_attrs[k])
if len(flat_new_v) != len(flat_old_v):
leaked_values = [
v
for v in flat_new_v
if v not in flat_old_v and isinstance(v, torch.Tensor)
]
if len(leaked_values) > 0:
raise ValueError(
f"During torch.export, following tensors were leaked at {module_prefix}{k}: {leaked_values} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer). " # noqa: 950
f"Alternatively, consider using `torch.export.export(strict=True)` to export the model."
)
pytree.tree_map_with_path(
_collect_assigned_tensor_attributes, snapshot, new_attrs
)
# restore state of all attributes (including, e.g., of primitive types)
mod.__dict__.update(snapshot)
_restore_all_module_attributes(mod, snapshot)
if assigned_tensor_attributes:
if len(assigned_tensor_attributes) > 1:

View File

@ -1849,6 +1849,11 @@ def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
return next(iter(node for node in gm.graph.nodes if node.name == name))
def _is_invalid_const_name(name: str):
splitted_names = name.split(".")
return splitted_names[-1].startswith("lifted_tensor")
def _non_strict_export(
mod: torch.nn.Module,
args: tuple[Any, ...],
@ -2024,6 +2029,43 @@ def _non_strict_export(
)
def emit_bogus_const_warning(constants, gs, gm):
bogus_constants: set[str] = set()
for const, val in constants.items():
if isinstance(
val, torch._subclasses.fake_tensor.FakeTensor
) and _is_invalid_const_name(const):
bogus_constants.add(const)
if len(bogus_constants) == 0:
return
bogus_constant_names: set[str] = set()
for inp in gs.input_specs:
if inp.kind == InputKind.CONSTANT_TENSOR and inp.target in bogus_constants:
bogus_constant_names.add(inp.arg.name)
placeholders = {
node.name: node for node in gm.graph.nodes if node.op == "placeholder"
}
for name in bogus_constant_names:
placeholder_node = placeholders[name]
dependencies: list[str] = []
for user in placeholder_node.users:
if user.meta.get("stack_trace", None) is not None:
dependencies.append(user.meta["stack_trace"])
if len(placeholder_node.users) > 0:
raise RuntimeError(
f"We found a fake tensor in the exported program constant's list. "
f"This typically means our tracing system encountered an op that "
f"we can't trace through. For the potential source, you can refer to "
f"following model attribute: {name}. We found following stacktrace might "
f"be helpful: \n\n"
f"{dependencies if dependencies else '<unknown>'} \n\n"
f"Please file an issue on github if you need further help.\n"
)
@_log_export_wrapper
@_disable_prexisiting_fake_mode
def _export_for_training(
@ -2049,6 +2091,11 @@ def _export_for_training(
original_state_dict = _get_original_state_dict(mod)
has_ambient_mode = False
if not strict:
flat_args, _ = pytree.tree_flatten((args, kwargs))
has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
# Call the appropriate export function based on the strictness of tracing.
export_func = _strict_export if strict else _non_strict_export
@ -2065,6 +2112,15 @@ def _export_for_training(
export_graph_signature = export_artifact.aten.sig
# If we are tracing with fake inputs, it is expected to
# see fake tensor constants.
if not strict and not has_ambient_mode:
emit_bogus_const_warning(
export_artifact.aten.constants,
export_graph_signature,
export_artifact.aten.gm,
)
forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
inline_constraints = _get_inline_constraints(export_artifact.fake_mode)
# The unbacked symint symbols are updated in aot_export