mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[RELAND v2] Close some sources of fake tensors (#164372)
Changelog: 1. When we run into an operation we didn't proxy, we end up emitting fake constants. We error under a config and we disable the config for some internal users. The reason we want to error is this signals a coverage problem we need to address but at the same time, we don't wnat to be disruptive to already working flows. 2. Previous attribute mutation detection logic in non-strict didn't account for nested module structure. This fixes silent incorrectness issue of exporting esm and qwen in non-strict and some torchbench models like levit_128 and demucs. 3. Previous logic didn't work on the cases where we mutate a container attribute as the previous approach used to pytree over old and new attributes resulting in length mismatch. We gracefully handle this now. Differential Revision: [D83673054](https://our.internmc.facebook.com/intern/diff/D83673054) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164372 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
6a31f42da4
commit
4661200125
@ -786,6 +786,94 @@ graph():
|
||||
# instead of the scripted function, so we get x.sin()
|
||||
self.assertEqual(res, x.sin())
|
||||
|
||||
def test_nested_module_fake_tensor_leak(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._tensor_cache = None
|
||||
|
||||
def forward(self, x):
|
||||
if self._tensor_cache is None:
|
||||
self._tensor_cache = x + 2
|
||||
return self._tensor_cache.sum() + x.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, bar):
|
||||
super().__init__()
|
||||
self.bar = bar
|
||||
|
||||
def forward(self, x):
|
||||
return self.bar(x)
|
||||
|
||||
foo = Foo(Bar())
|
||||
_ = export(foo, (torch.ones(4, 4),), strict=False)
|
||||
self.assertTrue(foo.bar._tensor_cache is None)
|
||||
|
||||
def test_export_leak_compile(self):
|
||||
class BaseModule(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class CacheModule(BaseModule):
|
||||
def __init__(self, cache: torch.Tensor):
|
||||
super().__init__()
|
||||
assert cache.ndim == 3
|
||||
self.cache = torch.nn.Parameter(cache, requires_grad=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_tokens = x.size(1)
|
||||
rolled_cache = torch.roll(self.cache.data, -n_tokens, dims=1)
|
||||
rolled_cache[:, -n_tokens:, :] = x
|
||||
self.cache.data = rolled_cache
|
||||
return self.cache
|
||||
|
||||
class LinearBlock(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, activation=None):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return self.activation(x) if self.activation else x
|
||||
|
||||
class MyModel(BaseModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
default_cache = torch.zeros(1, 10, 5)
|
||||
self.cache_layer = CacheModule(default_cache)
|
||||
self.fc1 = LinearBlock(5, 10, activation=torch.nn.ReLU())
|
||||
self.fc2 = LinearBlock(10, 5)
|
||||
|
||||
def forward(self, x):
|
||||
cached = self.cache_layer(x)
|
||||
out = self.fc1(cached)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"We found a fake tensor in the exported program constant's list. "
|
||||
"This typically means our tracing system encountered an op that we can't trace through. "
|
||||
"For the potential source, you can refer to following model attribute: cache_layer.lifted_tensor_0. "
|
||||
"Please file an issue on github.",
|
||||
):
|
||||
_ = export(MyModel(), (torch.randn(1, 3, 5),), strict=False)
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"We found a fake tensor in the exported program constant's list. "
|
||||
"This typically means our tracing system encountered an op that we can't trace through. "
|
||||
"For the potential source, you can refer to following model attribute: cache_layer.lifted_tensor_0. "
|
||||
"Please file an issue on github.",
|
||||
):
|
||||
# can't trigger all variant of export because later on it will crash
|
||||
# and it is good because we warned :).
|
||||
with torch._export.config.patch(error_on_lifted_constant_tensors=False):
|
||||
_ = torch.export.export(
|
||||
MyModel(), (torch.randn(1, 3, 5),), strict=False
|
||||
)
|
||||
|
||||
def test_inline_script_class_method(self):
|
||||
class M(torch.nn.Module):
|
||||
@staticmethod
|
||||
@ -13843,9 +13931,9 @@ def forward(self, x):
|
||||
self.bar = x.sum()
|
||||
return x + 2
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"During torch.export, following attrs were created in the model.forward:",
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"The tensor attribute self.bar was assigned during export",
|
||||
):
|
||||
_ = export(Foo(), (torch.randn(4, 4),), strict=False)
|
||||
|
||||
|
@ -411,7 +411,7 @@ def forward(self, token, x, cc):
|
||||
F1(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
|
||||
)
|
||||
|
||||
def test_torchbind_register_attr_at_runtime_error(self):
|
||||
def test_torchbind_register_attr_at_runtime_get_restored(self):
|
||||
# alias as model attribute
|
||||
class F3(torch.nn.Module):
|
||||
def forward(self, x, foo):
|
||||
@ -419,10 +419,8 @@ def forward(self, token, x, cc):
|
||||
return x + self.foo.add_tensor(x)
|
||||
|
||||
foo = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "following attrs were created in the model"
|
||||
):
|
||||
torch.export.export(F3(), (torch.ones(2, 3), foo))
|
||||
torch.export.export(F3(), (torch.ones(2, 3), foo), strict=False)
|
||||
self.assertFalse(hasattr(foo, "foo"))
|
||||
|
||||
@parametrize("pre_dispatch", [True, False])
|
||||
def test_torchbind_input_and_alias(self, pre_dispatch):
|
||||
|
@ -22,6 +22,11 @@ use_new_tracer_experimental = False
|
||||
# by default, but user can turn it on to debug leaks.
|
||||
detect_non_strict_fake_tensor_leaks = False
|
||||
|
||||
# error on potentially pre-dispatch/non-strict tracing limitation
|
||||
# this type of error usually happens when we encounter an op
|
||||
# that we don't know how to proxy, resulting in untracked fake tensors
|
||||
error_on_lifted_constant_tensors = True
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
@ -224,10 +224,23 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
||||
# return any attributes of a module that are not standard attributes
|
||||
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
|
||||
|
||||
def _get_all_module_attributes(mod):
|
||||
# return attributes from all modules and submodules
|
||||
result = {}
|
||||
for name, submodule in mod.named_modules():
|
||||
result[name] = _get_attributes(submodule)
|
||||
return result
|
||||
|
||||
def _restore_all_module_attributes(mod, snapshot):
|
||||
# restore attributes to all modules and submodules
|
||||
for name, submodule in mod.named_modules():
|
||||
if name in snapshot:
|
||||
submodule.__dict__.update(snapshot[name])
|
||||
|
||||
# save state of attributes before enter
|
||||
snapshot = pytree.tree_map(
|
||||
lambda x: x,
|
||||
_get_attributes(mod),
|
||||
_get_all_module_attributes(mod),
|
||||
is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
|
||||
)
|
||||
try:
|
||||
@ -235,45 +248,70 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
||||
finally:
|
||||
# after exit, compare state of attributes with snapshot
|
||||
# to detect which tensor attributes were assigned
|
||||
assigned_tensor_attributes = []
|
||||
|
||||
def _collect_assigned_tensor_attributes(kp, v, _v):
|
||||
if _v is not v:
|
||||
attr, *rest = kp
|
||||
if isinstance(v, torch.Tensor):
|
||||
assigned_tensor_attributes.append(
|
||||
f"self.{attr.key}{pytree.keystr(rest)}"
|
||||
)
|
||||
# TODO(avik): Assigning all other types are allowed right now.
|
||||
# Maybe in the future we want to limit this to primitive types?
|
||||
return v
|
||||
def _collect_assigned_tensor_attributes(snapshot, new_attrs):
|
||||
assigned_tensor_attributes = []
|
||||
|
||||
new_attrs = _get_attributes(mod)
|
||||
if len(new_attrs) != len(snapshot):
|
||||
added_attrs = new_attrs.keys() - snapshot.keys()
|
||||
deleted_attrs = snapshot.keys() - new_attrs.keys()
|
||||
def _compare_values(path, old_val, new_val):
|
||||
"""Recursively compare values, handling containers."""
|
||||
# Same object, no change
|
||||
if old_val is new_val:
|
||||
return
|
||||
|
||||
if len(added_attrs) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
if old_val is None or new_val is None:
|
||||
if isinstance(new_val, torch.Tensor):
|
||||
assigned_tensor_attributes.append(path)
|
||||
return
|
||||
|
||||
if len(deleted_attrs) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
# Check if it's a tensor that was reassigned
|
||||
if isinstance(new_val, torch.Tensor):
|
||||
assigned_tensor_attributes.append(path)
|
||||
return
|
||||
|
||||
pytree.tree_map_with_path(
|
||||
_collect_assigned_tensor_attributes, snapshot, new_attrs
|
||||
# Handle dict containers
|
||||
if isinstance(old_val, dict) and isinstance(new_val, dict):
|
||||
all_keys = set(old_val.keys()) | set(new_val.keys())
|
||||
for key in all_keys:
|
||||
old_item = old_val.get(key)
|
||||
new_item = new_val.get(key)
|
||||
_compare_values(f"{path}[{key!r}]", old_item, new_item)
|
||||
return
|
||||
|
||||
# Handle list/tuple containers
|
||||
if isinstance(old_val, (list, tuple)) and isinstance(
|
||||
new_val, (list, tuple)
|
||||
):
|
||||
# Different lengths = mutation happened
|
||||
max_len = max(len(old_val), len(new_val))
|
||||
for i in range(max_len):
|
||||
old_item = old_val[i] if i < len(old_val) else None
|
||||
new_item = new_val[i] if i < len(new_val) else None
|
||||
_compare_values(f"{path}[{i}]", old_item, new_item)
|
||||
return
|
||||
|
||||
# For other types, just check if they're different objects
|
||||
# (we don't care about non-tensor mutations)
|
||||
|
||||
for module_name in snapshot.keys() | new_attrs.keys():
|
||||
old_module_attrs = snapshot.get(module_name, {})
|
||||
new_module_attrs = new_attrs.get(module_name, {})
|
||||
|
||||
for attr_name in old_module_attrs.keys() | new_module_attrs.keys():
|
||||
module_prefix = f"self.{module_name}." if module_name else "self."
|
||||
full_path = f"{module_prefix}{attr_name}"
|
||||
|
||||
old_val = old_module_attrs.get(attr_name)
|
||||
new_val = new_module_attrs.get(attr_name)
|
||||
_compare_values(full_path, old_val, new_val)
|
||||
|
||||
return assigned_tensor_attributes
|
||||
|
||||
new_attrs = _get_all_module_attributes(mod)
|
||||
assigned_tensor_attributes = _collect_assigned_tensor_attributes(
|
||||
snapshot, new_attrs
|
||||
)
|
||||
# restore state of all attributes (including, e.g., of primitive types)
|
||||
mod.__dict__.update(snapshot)
|
||||
_restore_all_module_attributes(mod, snapshot)
|
||||
|
||||
if assigned_tensor_attributes:
|
||||
if len(assigned_tensor_attributes) > 1:
|
||||
|
@ -206,6 +206,14 @@ def _strip_root(x):
|
||||
return x
|
||||
|
||||
|
||||
def _is_bogus_const_name(name: str):
|
||||
splitted_names = name.split(".")
|
||||
if len(splitted_names) < 1:
|
||||
return True
|
||||
|
||||
return splitted_names[-1].startswith("lifted_tensor")
|
||||
|
||||
|
||||
def _rewrite_tracepoint_node(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
In-place modify input graph module by replacing the export tracepoint with a new node
|
||||
@ -2069,6 +2077,11 @@ def _export_for_training(
|
||||
|
||||
original_state_dict = _get_original_state_dict(mod)
|
||||
|
||||
has_ambient_mode = False
|
||||
if not strict:
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs))
|
||||
has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
|
||||
|
||||
# Call the appropriate export function based on the strictness of tracing.
|
||||
export_func = _strict_export if strict else _non_strict_export
|
||||
|
||||
@ -2088,6 +2101,25 @@ def _export_for_training(
|
||||
_to_aten_func=_export_to_aten_ir_make_fx,
|
||||
)
|
||||
|
||||
# If we are tracing with fake inputs, it is expected to
|
||||
# see fake tensor constants.
|
||||
if not strict and not has_ambient_mode:
|
||||
for const, val in export_artifact.aten.constants.items():
|
||||
if isinstance(
|
||||
val, torch._subclasses.fake_tensor.FakeTensor
|
||||
) and _is_bogus_const_name(const):
|
||||
error_msg = (
|
||||
f"We found a fake tensor in the exported program constant's list. "
|
||||
f"This typically means our tracing system encountered an op that "
|
||||
f"we can't trace through. For the potential source, you can refer to "
|
||||
f"following model attribute: {const}. "
|
||||
f"Please file an issue on github. "
|
||||
)
|
||||
if torch._export.config.error_on_lifted_constant_tensors:
|
||||
raise RuntimeError(error_msg)
|
||||
else:
|
||||
warnings.warn(error_msg)
|
||||
|
||||
export_graph_signature = export_artifact.aten.sig
|
||||
|
||||
forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
|
||||
|
Reference in New Issue
Block a user