[export] _detect_attribute_assignment gives warning instead of raising ValueError (#163809)

Summary:
LSTM was not exportable with non-strict export as it failed at `_detect_attribute_assignment`

This is because the `_flat_weights` attribute in LSTM is a list of registered parameters and will be updated by the `_update_flat_weights` method in `forward`.

However, in `_detect_attribute_assignment`, we manually restore the state of the module by `mod.__dict__.update(snapshot)`. Therefore, it should be fine to turn the `ValueError` into a warning so that RNN models are exportable with non-strict export.

Added test to verify that there is no lifted tensor constant and no fake tensor leakage.

Test Plan: buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_rnn_variants_with_warning

Differential Revision: D83196971

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163809
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Yiming Zhou
2025-09-26 00:43:26 +00:00
committed by PyTorch MergeBot
parent b4be380480
commit 5c2f09d1f9
2 changed files with 50 additions and 5 deletions

View File

@ -2007,8 +2007,8 @@ class GraphModule(torch.nn.Module):
# z = 3
return x + y + z
with self.assertRaisesRegex(
ValueError,
with self.assertWarnsRegex(
UserWarning,
"The tensor attribute self.buf was assigned during export",
):
export(M(), (torch.randn(2, 3),), strict=False)
@ -2065,8 +2065,8 @@ class GraphModule(torch.nn.Module):
# z = 3 + 3
return x + y + z
with self.assertRaisesRegex(
ValueError,
with self.assertWarnsRegex(
UserWarning,
"The tensor attributes self.tensors\\[0\\], self.tensors\\[1\\] were assigned during export",
):
export(M(), (torch.randn(2, 3),), strict=False)
@ -15898,6 +15898,50 @@ class GraphModule(torch.nn.Module):
]
self.assertEqual(len(shift_op), 1)
def test_export_rnn_variants_with_warning(self):
"""
Test that when exporting RNN, LSTM, and GRU models in non-strict mode, it:
1. Produces expected warnings about tensor attributes being assigned during export
2. Does not leak fake tensors in the model's flat weights
3. Does not produce extra tensor constants in the graph signature
"""
rnn_types = [
(torch.nn.RNN, "RNN"),
(torch.nn.LSTM, "LSTM"),
(torch.nn.GRU, "GRU"),
]
for rnn_class, rnn_name in rnn_types:
with self.subTest(rnn_type=rnn_name):
m = rnn_class(
input_size=2, hidden_size=4, num_layers=1, batch_first=True
)
sample_inputs = (torch.randn(1, 2, 2),)
eager_out = m(*sample_inputs)
# Verify that export produces the expected warning about tensor attributes
with self.assertWarnsRegex(
UserWarning,
r"The tensor attributes self\._flat_weights\[0\], self\._flat_weights\[1\], "
r"self\._flat_weights\[2\], self\._flat_weights\[3\] were assigned during export.*",
):
ep = torch.export.export(m, sample_inputs, strict=False)
ep_out = ep.module()(*sample_inputs)
self.assertEqual(eager_out, ep_out)
# Verify no fake tensor leakage: flat weights should be real tensors
for flat_weight in m._flat_weights:
self.assertTrue(
not isinstance(
flat_weight, torch._subclasses.fake_tensor.FakeTensor
)
)
# Verify no tensor constants in graph signature
self.assertEqual(len(ep.graph_signature.lifted_tensor_constants), 0)
@contextmanager
def distributed_env(self, world_size):
try:

View File

@ -1,5 +1,6 @@
# mypy: ignore-errors
import warnings
from collections.abc import KeysView
from contextlib import contextmanager
from typing import Any, Optional
@ -277,7 +278,7 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
noun, verb = "attributes", "were"
else:
noun, verb = "attribute", "was"
raise ValueError(
warnings.warn(
f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "
"Such attributes must be registered as buffers using the `register_buffer` API "
"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."