mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[RELAND] Close some sources of fake tensor leakage (#161589)
Reland of https://github.com/pytorch/pytorch/pull/159923 Couple of fixes: 1. When we run into an operation we didn't proxy, we end up emitting fake constants. We detect this and warn using the FQN of the lifted constant. We warn because some internal users complained it was regressing their exportability. 2. Previous attribute mutation detection logic in non-strict didn't account for nested module structure. This fixes silent incorrectness issue of exporting esm and qwen in non-strict 3. We modify yolov3 to fix the previous silent incorrect behaviour 4. We use strict export for levit_128 because it errors in non-strict due to more strict side effect checking When upgrading torchbench pin, opacus_cifar10 seems to not run on eager anymore. I verified this by pushing a temporary PR on master with new pin. So i added it to expect_fail list. Differential Revision: [D81133908](https://our.internmc.facebook.com/intern/diff/D81133908) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161589 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
2e77a08b95
commit
5790b00975
@ -1 +1 @@
|
||||
e03a63be43e33596f7f0a43b0f530353785e4a59
|
||||
22bc29b4d503fc895ff73bc720ff396e9723465f
|
||||
|
@ -1427,13 +1427,23 @@ class AOTInductorModelCache:
|
||||
inductor_configs = {}
|
||||
if mode == "max-autotune":
|
||||
inductor_configs["max_autotune"] = True
|
||||
ep = torch.export.export(
|
||||
model_clone,
|
||||
example_args,
|
||||
example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
)
|
||||
# We can't support this in non-strict
|
||||
if hasattr(model_clone, "name") and model.name == "levit_128":
|
||||
ep = torch.export.export(
|
||||
model_clone,
|
||||
example_args,
|
||||
example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
else:
|
||||
ep = torch.export.export(
|
||||
model_clone,
|
||||
example_args,
|
||||
example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
with torch.no_grad():
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
ep, inductor_configs=inductor_configs
|
||||
@ -2317,6 +2327,7 @@ class BenchmarkRunner:
|
||||
# no need for n iterations
|
||||
# the logic should be the same to self.model_iter_fn (forward_pass)
|
||||
with self.autocast(**self.autocast_arg):
|
||||
model_copy.name = name
|
||||
optimized_model_iter_fn = optimize_ctx(
|
||||
model_copy, example_inputs
|
||||
)
|
||||
|
@ -420,6 +420,28 @@ graph():
|
||||
):
|
||||
ep.module()(torch.tensor([3]))
|
||||
|
||||
def test_container_leak(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._cache = {}
|
||||
|
||||
def forward(self, x):
|
||||
self._cache["leaky"] = x.sum()
|
||||
return x.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, bar):
|
||||
super().__init__()
|
||||
self.bar = bar
|
||||
|
||||
def forward(self, x):
|
||||
return self.bar(x)
|
||||
|
||||
foo = Foo(Bar())
|
||||
with self.assertRaisesRegex(ValueError, "self.bar._cache"):
|
||||
export(foo, (torch.randn(4, 4),), strict=False)
|
||||
|
||||
def test_export_assume_static_by_default(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
@ -4341,6 +4363,79 @@ def forward(self, x):
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertTrue(torch.allclose(mod(x), ep.module()(x)))
|
||||
|
||||
def test_nested_module_fake_tensor_leak(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._tensor_cache = None
|
||||
|
||||
def forward(self, x):
|
||||
if self._tensor_cache is None:
|
||||
self._tensor_cache = x + 2
|
||||
return self._tensor_cache.sum() + x.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, bar):
|
||||
super().__init__()
|
||||
self.bar = bar
|
||||
|
||||
def forward(self, x):
|
||||
return self.bar(x)
|
||||
|
||||
foo = Foo(Bar())
|
||||
_ = export(foo, (torch.ones(4, 4),), strict=False)
|
||||
self.assertTrue(foo.bar._tensor_cache is None)
|
||||
|
||||
def test_export_leak_compile(self):
|
||||
class BaseModule(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class CacheModule(BaseModule):
|
||||
def __init__(self, cache: torch.Tensor):
|
||||
super().__init__()
|
||||
assert cache.ndim == 3
|
||||
self.cache = torch.nn.Parameter(cache, requires_grad=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_tokens = x.size(1)
|
||||
rolled_cache = torch.roll(self.cache.data, -n_tokens, dims=1)
|
||||
rolled_cache[:, -n_tokens:, :] = x
|
||||
self.cache.data = rolled_cache
|
||||
return self.cache
|
||||
|
||||
class LinearBlock(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, activation=None):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return self.activation(x) if self.activation else x
|
||||
|
||||
class MyModel(BaseModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
default_cache = torch.zeros(1, 10, 5)
|
||||
self.cache_layer = CacheModule(default_cache)
|
||||
self.fc1 = LinearBlock(5, 10, activation=torch.nn.ReLU())
|
||||
self.fc2 = LinearBlock(10, 5)
|
||||
|
||||
def forward(self, x):
|
||||
cached = self.cache_layer(x)
|
||||
out = self.fc1(cached)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"cached = self.cache_layer\(x\)",
|
||||
):
|
||||
# Intentionally using training IR here because it will crash in inference IR
|
||||
# anyways.
|
||||
_ = torch.export.export(MyModel(), (torch.randn(1, 3, 5),), strict=False)
|
||||
|
||||
def test_export_for_training_with_container_type(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -221,10 +221,23 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
||||
# return any attributes of a module that are not standard attributes
|
||||
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
|
||||
|
||||
def _get_all_module_attributes(mod):
|
||||
# return attributes from all modules and submodules
|
||||
result = {}
|
||||
for name, submodule in mod.named_modules():
|
||||
result[name] = _get_attributes(submodule)
|
||||
return result
|
||||
|
||||
def _restore_all_module_attributes(mod, snapshot):
|
||||
# restore attributes to all modules and submodules
|
||||
for name, submodule in mod.named_modules():
|
||||
if name in snapshot:
|
||||
submodule.__dict__.update(snapshot[name])
|
||||
|
||||
# save state of attributes before enter
|
||||
snapshot = pytree.tree_map(
|
||||
lambda x: x,
|
||||
_get_attributes(mod),
|
||||
_get_all_module_attributes(mod),
|
||||
is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
|
||||
)
|
||||
try:
|
||||
@ -236,41 +249,75 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
||||
|
||||
def _collect_assigned_tensor_attributes(kp, v, _v):
|
||||
if _v is not v:
|
||||
attr, *rest = kp
|
||||
module_name, attr, *rest = kp
|
||||
if isinstance(v, torch.Tensor):
|
||||
module_prefix = f"{module_name.key}." if module_name.key else ""
|
||||
assigned_tensor_attributes.append(
|
||||
f"self.{attr.key}{pytree.keystr(rest)}"
|
||||
f"self.{module_prefix}{attr.key}{pytree.keystr(rest)}"
|
||||
)
|
||||
# TODO(avik): Assigning all other types are allowed right now.
|
||||
# Maybe in the future we want to limit this to primitive types?
|
||||
return v
|
||||
|
||||
new_attrs = _get_attributes(mod)
|
||||
if len(new_attrs) != len(snapshot):
|
||||
added_attrs = new_attrs.keys() - snapshot.keys()
|
||||
deleted_attrs = snapshot.keys() - new_attrs.keys()
|
||||
new_attrs = _get_all_module_attributes(mod)
|
||||
|
||||
if len(added_attrs) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
# Check for added/deleted attributes across all modules
|
||||
for module_name in snapshot.keys() | new_attrs.keys():
|
||||
old_module_attrs = snapshot.get(module_name, {})
|
||||
new_module_attrs = new_attrs.get(module_name, {})
|
||||
|
||||
if len(deleted_attrs) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
module_prefix = f"self.{module_name}." if module_name else "self."
|
||||
|
||||
if len(new_module_attrs) != len(old_module_attrs):
|
||||
added_attrs = new_module_attrs.keys() - old_module_attrs.keys()
|
||||
deleted_attrs = old_module_attrs.keys() - new_module_attrs.keys()
|
||||
|
||||
if len(added_attrs) > 0:
|
||||
formatted_attrs = [f"{module_prefix}{attr}" for attr in added_attrs]
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were created in the model.forward: {formatted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
|
||||
if len(deleted_attrs) > 0:
|
||||
formatted_attrs = [
|
||||
f"{module_prefix}{attr}" for attr in deleted_attrs
|
||||
]
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were deleted in the model.forward: {formatted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
|
||||
# Tensors could have leaked at container attributes
|
||||
for k, new_v in new_module_attrs.items():
|
||||
assert k in old_module_attrs
|
||||
if isinstance(new_v, (tuple, list, dict)):
|
||||
flat_new_v, _ = pytree.tree_flatten(new_v)
|
||||
flat_old_v, _ = pytree.tree_flatten(old_module_attrs[k])
|
||||
if len(flat_new_v) != len(flat_old_v):
|
||||
leaked_values = [
|
||||
v
|
||||
for v in flat_new_v
|
||||
if v not in flat_old_v and isinstance(v, torch.Tensor)
|
||||
]
|
||||
if len(leaked_values) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following tensors were leaked at {module_prefix}{k}: {leaked_values} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer). " # noqa: 950
|
||||
f"Alternatively, consider using `torch.export.export(strict=True)` to export the model."
|
||||
)
|
||||
|
||||
pytree.tree_map_with_path(
|
||||
_collect_assigned_tensor_attributes, snapshot, new_attrs
|
||||
)
|
||||
# restore state of all attributes (including, e.g., of primitive types)
|
||||
mod.__dict__.update(snapshot)
|
||||
_restore_all_module_attributes(mod, snapshot)
|
||||
|
||||
if assigned_tensor_attributes:
|
||||
if len(assigned_tensor_attributes) > 1:
|
||||
|
@ -1849,6 +1849,11 @@ def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
|
||||
return next(iter(node for node in gm.graph.nodes if node.name == name))
|
||||
|
||||
|
||||
def _is_invalid_const_name(name: str):
|
||||
splitted_names = name.split(".")
|
||||
return splitted_names[-1].startswith("lifted_tensor")
|
||||
|
||||
|
||||
def _non_strict_export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
@ -2024,6 +2029,43 @@ def _non_strict_export(
|
||||
)
|
||||
|
||||
|
||||
def emit_bogus_const_warning(constants, gs, gm):
|
||||
bogus_constants: set[str] = set()
|
||||
for const, val in constants.items():
|
||||
if isinstance(
|
||||
val, torch._subclasses.fake_tensor.FakeTensor
|
||||
) and _is_invalid_const_name(const):
|
||||
bogus_constants.add(const)
|
||||
|
||||
if len(bogus_constants) == 0:
|
||||
return
|
||||
|
||||
bogus_constant_names: set[str] = set()
|
||||
for inp in gs.input_specs:
|
||||
if inp.kind == InputKind.CONSTANT_TENSOR and inp.target in bogus_constants:
|
||||
bogus_constant_names.add(inp.arg.name)
|
||||
|
||||
placeholders = {
|
||||
node.name: node for node in gm.graph.nodes if node.op == "placeholder"
|
||||
}
|
||||
for name in bogus_constant_names:
|
||||
placeholder_node = placeholders[name]
|
||||
dependencies: list[str] = []
|
||||
for user in placeholder_node.users:
|
||||
if user.meta.get("stack_trace", None) is not None:
|
||||
dependencies.append(user.meta["stack_trace"])
|
||||
if len(placeholder_node.users) > 0:
|
||||
raise RuntimeError(
|
||||
f"We found a fake tensor in the exported program constant's list. "
|
||||
f"This typically means our tracing system encountered an op that "
|
||||
f"we can't trace through. For the potential source, you can refer to "
|
||||
f"following model attribute: {name}. We found following stacktrace might "
|
||||
f"be helpful: \n\n"
|
||||
f"{dependencies if dependencies else '<unknown>'} \n\n"
|
||||
f"Please file an issue on github if you need further help.\n"
|
||||
)
|
||||
|
||||
|
||||
@_log_export_wrapper
|
||||
@_disable_prexisiting_fake_mode
|
||||
def _export_for_training(
|
||||
@ -2049,6 +2091,11 @@ def _export_for_training(
|
||||
|
||||
original_state_dict = _get_original_state_dict(mod)
|
||||
|
||||
has_ambient_mode = False
|
||||
if not strict:
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs))
|
||||
has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
|
||||
|
||||
# Call the appropriate export function based on the strictness of tracing.
|
||||
export_func = _strict_export if strict else _non_strict_export
|
||||
|
||||
@ -2065,6 +2112,15 @@ def _export_for_training(
|
||||
|
||||
export_graph_signature = export_artifact.aten.sig
|
||||
|
||||
# If we are tracing with fake inputs, it is expected to
|
||||
# see fake tensor constants.
|
||||
if not strict and not has_ambient_mode:
|
||||
emit_bogus_const_warning(
|
||||
export_artifact.aten.constants,
|
||||
export_graph_signature,
|
||||
export_artifact.aten.gm,
|
||||
)
|
||||
|
||||
forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
|
||||
inline_constraints = _get_inline_constraints(export_artifact.fake_mode)
|
||||
# The unbacked symint symbols are updated in aot_export
|
||||
|
Reference in New Issue
Block a user