mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[RELAND] Close some sources of fake tensor leakage (#161589)"
This reverts commit 5790b009751e6ebba35d3e6d05e7c1b135553eee.
Reverted https://github.com/pytorch/pytorch/pull/161589 on behalf of https://github.com/atalman due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/17305150611/job/49128381649) [HUD commit link](5790b00975
) ([comment](https://github.com/pytorch/pytorch/pull/161589#issuecomment-3235224249))
This commit is contained in:
@ -1 +1 @@
|
||||
22bc29b4d503fc895ff73bc720ff396e9723465f
|
||||
e03a63be43e33596f7f0a43b0f530353785e4a59
|
||||
|
@ -1427,23 +1427,13 @@ class AOTInductorModelCache:
|
||||
inductor_configs = {}
|
||||
if mode == "max-autotune":
|
||||
inductor_configs["max_autotune"] = True
|
||||
# We can't support this in non-strict
|
||||
if hasattr(model_clone, "name") and model.name == "levit_128":
|
||||
ep = torch.export.export(
|
||||
model_clone,
|
||||
example_args,
|
||||
example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
else:
|
||||
ep = torch.export.export(
|
||||
model_clone,
|
||||
example_args,
|
||||
example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
ep = torch.export.export(
|
||||
model_clone,
|
||||
example_args,
|
||||
example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
)
|
||||
with torch.no_grad():
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
ep, inductor_configs=inductor_configs
|
||||
@ -2327,7 +2317,6 @@ class BenchmarkRunner:
|
||||
# no need for n iterations
|
||||
# the logic should be the same to self.model_iter_fn (forward_pass)
|
||||
with self.autocast(**self.autocast_arg):
|
||||
model_copy.name = name
|
||||
optimized_model_iter_fn = optimize_ctx(
|
||||
model_copy, example_inputs
|
||||
)
|
||||
|
@ -420,28 +420,6 @@ graph():
|
||||
):
|
||||
ep.module()(torch.tensor([3]))
|
||||
|
||||
def test_container_leak(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._cache = {}
|
||||
|
||||
def forward(self, x):
|
||||
self._cache["leaky"] = x.sum()
|
||||
return x.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, bar):
|
||||
super().__init__()
|
||||
self.bar = bar
|
||||
|
||||
def forward(self, x):
|
||||
return self.bar(x)
|
||||
|
||||
foo = Foo(Bar())
|
||||
with self.assertRaisesRegex(ValueError, "self.bar._cache"):
|
||||
export(foo, (torch.randn(4, 4),), strict=False)
|
||||
|
||||
def test_export_assume_static_by_default(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
@ -4363,79 +4341,6 @@ def forward(self, x):
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertTrue(torch.allclose(mod(x), ep.module()(x)))
|
||||
|
||||
def test_nested_module_fake_tensor_leak(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._tensor_cache = None
|
||||
|
||||
def forward(self, x):
|
||||
if self._tensor_cache is None:
|
||||
self._tensor_cache = x + 2
|
||||
return self._tensor_cache.sum() + x.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, bar):
|
||||
super().__init__()
|
||||
self.bar = bar
|
||||
|
||||
def forward(self, x):
|
||||
return self.bar(x)
|
||||
|
||||
foo = Foo(Bar())
|
||||
_ = export(foo, (torch.ones(4, 4),), strict=False)
|
||||
self.assertTrue(foo.bar._tensor_cache is None)
|
||||
|
||||
def test_export_leak_compile(self):
|
||||
class BaseModule(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class CacheModule(BaseModule):
|
||||
def __init__(self, cache: torch.Tensor):
|
||||
super().__init__()
|
||||
assert cache.ndim == 3
|
||||
self.cache = torch.nn.Parameter(cache, requires_grad=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_tokens = x.size(1)
|
||||
rolled_cache = torch.roll(self.cache.data, -n_tokens, dims=1)
|
||||
rolled_cache[:, -n_tokens:, :] = x
|
||||
self.cache.data = rolled_cache
|
||||
return self.cache
|
||||
|
||||
class LinearBlock(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, activation=None):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return self.activation(x) if self.activation else x
|
||||
|
||||
class MyModel(BaseModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
default_cache = torch.zeros(1, 10, 5)
|
||||
self.cache_layer = CacheModule(default_cache)
|
||||
self.fc1 = LinearBlock(5, 10, activation=torch.nn.ReLU())
|
||||
self.fc2 = LinearBlock(10, 5)
|
||||
|
||||
def forward(self, x):
|
||||
cached = self.cache_layer(x)
|
||||
out = self.fc1(cached)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"cached = self.cache_layer\(x\)",
|
||||
):
|
||||
# Intentionally using training IR here because it will crash in inference IR
|
||||
# anyways.
|
||||
_ = torch.export.export(MyModel(), (torch.randn(1, 3, 5),), strict=False)
|
||||
|
||||
def test_export_for_training_with_container_type(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -221,23 +221,10 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
||||
# return any attributes of a module that are not standard attributes
|
||||
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
|
||||
|
||||
def _get_all_module_attributes(mod):
|
||||
# return attributes from all modules and submodules
|
||||
result = {}
|
||||
for name, submodule in mod.named_modules():
|
||||
result[name] = _get_attributes(submodule)
|
||||
return result
|
||||
|
||||
def _restore_all_module_attributes(mod, snapshot):
|
||||
# restore attributes to all modules and submodules
|
||||
for name, submodule in mod.named_modules():
|
||||
if name in snapshot:
|
||||
submodule.__dict__.update(snapshot[name])
|
||||
|
||||
# save state of attributes before enter
|
||||
snapshot = pytree.tree_map(
|
||||
lambda x: x,
|
||||
_get_all_module_attributes(mod),
|
||||
_get_attributes(mod),
|
||||
is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
|
||||
)
|
||||
try:
|
||||
@ -249,75 +236,41 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
||||
|
||||
def _collect_assigned_tensor_attributes(kp, v, _v):
|
||||
if _v is not v:
|
||||
module_name, attr, *rest = kp
|
||||
attr, *rest = kp
|
||||
if isinstance(v, torch.Tensor):
|
||||
module_prefix = f"{module_name.key}." if module_name.key else ""
|
||||
assigned_tensor_attributes.append(
|
||||
f"self.{module_prefix}{attr.key}{pytree.keystr(rest)}"
|
||||
f"self.{attr.key}{pytree.keystr(rest)}"
|
||||
)
|
||||
# TODO(avik): Assigning all other types are allowed right now.
|
||||
# Maybe in the future we want to limit this to primitive types?
|
||||
return v
|
||||
|
||||
new_attrs = _get_all_module_attributes(mod)
|
||||
new_attrs = _get_attributes(mod)
|
||||
if len(new_attrs) != len(snapshot):
|
||||
added_attrs = new_attrs.keys() - snapshot.keys()
|
||||
deleted_attrs = snapshot.keys() - new_attrs.keys()
|
||||
|
||||
# Check for added/deleted attributes across all modules
|
||||
for module_name in snapshot.keys() | new_attrs.keys():
|
||||
old_module_attrs = snapshot.get(module_name, {})
|
||||
new_module_attrs = new_attrs.get(module_name, {})
|
||||
if len(added_attrs) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
|
||||
module_prefix = f"self.{module_name}." if module_name else "self."
|
||||
|
||||
if len(new_module_attrs) != len(old_module_attrs):
|
||||
added_attrs = new_module_attrs.keys() - old_module_attrs.keys()
|
||||
deleted_attrs = old_module_attrs.keys() - new_module_attrs.keys()
|
||||
|
||||
if len(added_attrs) > 0:
|
||||
formatted_attrs = [f"{module_prefix}{attr}" for attr in added_attrs]
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were created in the model.forward: {formatted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
|
||||
if len(deleted_attrs) > 0:
|
||||
formatted_attrs = [
|
||||
f"{module_prefix}{attr}" for attr in deleted_attrs
|
||||
]
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were deleted in the model.forward: {formatted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
|
||||
# Tensors could have leaked at container attributes
|
||||
for k, new_v in new_module_attrs.items():
|
||||
assert k in old_module_attrs
|
||||
if isinstance(new_v, (tuple, list, dict)):
|
||||
flat_new_v, _ = pytree.tree_flatten(new_v)
|
||||
flat_old_v, _ = pytree.tree_flatten(old_module_attrs[k])
|
||||
if len(flat_new_v) != len(flat_old_v):
|
||||
leaked_values = [
|
||||
v
|
||||
for v in flat_new_v
|
||||
if v not in flat_old_v and isinstance(v, torch.Tensor)
|
||||
]
|
||||
if len(leaked_values) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following tensors were leaked at {module_prefix}{k}: {leaked_values} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer). " # noqa: 950
|
||||
f"Alternatively, consider using `torch.export.export(strict=True)` to export the model."
|
||||
)
|
||||
if len(deleted_attrs) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
|
||||
pytree.tree_map_with_path(
|
||||
_collect_assigned_tensor_attributes, snapshot, new_attrs
|
||||
)
|
||||
# restore state of all attributes (including, e.g., of primitive types)
|
||||
_restore_all_module_attributes(mod, snapshot)
|
||||
mod.__dict__.update(snapshot)
|
||||
|
||||
if assigned_tensor_attributes:
|
||||
if len(assigned_tensor_attributes) > 1:
|
||||
|
@ -1852,11 +1852,6 @@ def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
|
||||
return next(iter(node for node in gm.graph.nodes if node.name == name))
|
||||
|
||||
|
||||
def _is_invalid_const_name(name: str):
|
||||
splitted_names = name.split(".")
|
||||
return splitted_names[-1].startswith("lifted_tensor")
|
||||
|
||||
|
||||
def _non_strict_export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
@ -2032,43 +2027,6 @@ def _non_strict_export(
|
||||
)
|
||||
|
||||
|
||||
def emit_bogus_const_warning(constants, gs, gm):
|
||||
bogus_constants: set[str] = set()
|
||||
for const, val in constants.items():
|
||||
if isinstance(
|
||||
val, torch._subclasses.fake_tensor.FakeTensor
|
||||
) and _is_invalid_const_name(const):
|
||||
bogus_constants.add(const)
|
||||
|
||||
if len(bogus_constants) == 0:
|
||||
return
|
||||
|
||||
bogus_constant_names: set[str] = set()
|
||||
for inp in gs.input_specs:
|
||||
if inp.kind == InputKind.CONSTANT_TENSOR and inp.target in bogus_constants:
|
||||
bogus_constant_names.add(inp.arg.name)
|
||||
|
||||
placeholders = {
|
||||
node.name: node for node in gm.graph.nodes if node.op == "placeholder"
|
||||
}
|
||||
for name in bogus_constant_names:
|
||||
placeholder_node = placeholders[name]
|
||||
dependencies: list[str] = []
|
||||
for user in placeholder_node.users:
|
||||
if user.meta.get("stack_trace", None) is not None:
|
||||
dependencies.append(user.meta["stack_trace"])
|
||||
if len(placeholder_node.users) > 0:
|
||||
raise RuntimeError(
|
||||
f"We found a fake tensor in the exported program constant's list. "
|
||||
f"This typically means our tracing system encountered an op that "
|
||||
f"we can't trace through. For the potential source, you can refer to "
|
||||
f"following model attribute: {name}. We found following stacktrace might "
|
||||
f"be helpful: \n\n"
|
||||
f"{dependencies if dependencies else '<unknown>'} \n\n"
|
||||
f"Please file an issue on github if you need further help.\n"
|
||||
)
|
||||
|
||||
|
||||
@_log_export_wrapper
|
||||
@_disable_prexisiting_fake_mode
|
||||
def _export_for_training(
|
||||
@ -2094,11 +2052,6 @@ def _export_for_training(
|
||||
|
||||
original_state_dict = _get_original_state_dict(mod)
|
||||
|
||||
has_ambient_mode = False
|
||||
if not strict:
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs))
|
||||
has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
|
||||
|
||||
# Call the appropriate export function based on the strictness of tracing.
|
||||
export_func = _strict_export if strict else _non_strict_export
|
||||
|
||||
@ -2115,15 +2068,6 @@ def _export_for_training(
|
||||
|
||||
export_graph_signature = export_artifact.aten.sig
|
||||
|
||||
# If we are tracing with fake inputs, it is expected to
|
||||
# see fake tensor constants.
|
||||
if not strict and not has_ambient_mode:
|
||||
emit_bogus_const_warning(
|
||||
export_artifact.aten.constants,
|
||||
export_graph_signature,
|
||||
export_artifact.aten.gm,
|
||||
)
|
||||
|
||||
forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
|
||||
inline_constraints = _get_inline_constraints(export_artifact.fake_mode)
|
||||
# The unbacked symint symbols are updated in aot_export
|
||||
|
Reference in New Issue
Block a user