mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] _detect_attribute_assignment gives warning instead of raising ValueError (#163809)
Summary: LSTM was not exportable with non-strict export as it failed at `_detect_attribute_assignment` This is because the `_flat_weights` attribute in LSTM is a list of registered parameters and will be updated by the `_update_flat_weights` method in `forward`. However, in `_detect_attribute_assignment`, we manually restore the state of the module by `mod.__dict__.update(snapshot)`. Therefore, it should be fine to turn the `ValueError` into a warning so that RNN models are exportable with non-strict export. Added test to verify that there is no lifted tensor constant and no fake tensor leakage. Test Plan: buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_rnn_variants_with_warning Differential Revision: D83196971 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163809 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
b4be380480
commit
5c2f09d1f9
@ -2007,8 +2007,8 @@ class GraphModule(torch.nn.Module):
|
||||
# z = 3
|
||||
return x + y + z
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"The tensor attribute self.buf was assigned during export",
|
||||
):
|
||||
export(M(), (torch.randn(2, 3),), strict=False)
|
||||
@ -2065,8 +2065,8 @@ class GraphModule(torch.nn.Module):
|
||||
# z = 3 + 3
|
||||
return x + y + z
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"The tensor attributes self.tensors\\[0\\], self.tensors\\[1\\] were assigned during export",
|
||||
):
|
||||
export(M(), (torch.randn(2, 3),), strict=False)
|
||||
@ -15898,6 +15898,50 @@ class GraphModule(torch.nn.Module):
|
||||
]
|
||||
self.assertEqual(len(shift_op), 1)
|
||||
|
||||
def test_export_rnn_variants_with_warning(self):
|
||||
"""
|
||||
Test that when exporting RNN, LSTM, and GRU models in non-strict mode, it:
|
||||
|
||||
1. Produces expected warnings about tensor attributes being assigned during export
|
||||
2. Does not leak fake tensors in the model's flat weights
|
||||
3. Does not produce extra tensor constants in the graph signature
|
||||
"""
|
||||
rnn_types = [
|
||||
(torch.nn.RNN, "RNN"),
|
||||
(torch.nn.LSTM, "LSTM"),
|
||||
(torch.nn.GRU, "GRU"),
|
||||
]
|
||||
|
||||
for rnn_class, rnn_name in rnn_types:
|
||||
with self.subTest(rnn_type=rnn_name):
|
||||
m = rnn_class(
|
||||
input_size=2, hidden_size=4, num_layers=1, batch_first=True
|
||||
)
|
||||
sample_inputs = (torch.randn(1, 2, 2),)
|
||||
eager_out = m(*sample_inputs)
|
||||
|
||||
# Verify that export produces the expected warning about tensor attributes
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
r"The tensor attributes self\._flat_weights\[0\], self\._flat_weights\[1\], "
|
||||
r"self\._flat_weights\[2\], self\._flat_weights\[3\] were assigned during export.*",
|
||||
):
|
||||
ep = torch.export.export(m, sample_inputs, strict=False)
|
||||
|
||||
ep_out = ep.module()(*sample_inputs)
|
||||
self.assertEqual(eager_out, ep_out)
|
||||
|
||||
# Verify no fake tensor leakage: flat weights should be real tensors
|
||||
for flat_weight in m._flat_weights:
|
||||
self.assertTrue(
|
||||
not isinstance(
|
||||
flat_weight, torch._subclasses.fake_tensor.FakeTensor
|
||||
)
|
||||
)
|
||||
|
||||
# Verify no tensor constants in graph signature
|
||||
self.assertEqual(len(ep.graph_signature.lifted_tensor_constants), 0)
|
||||
|
||||
@contextmanager
|
||||
def distributed_env(self, world_size):
|
||||
try:
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import warnings
|
||||
from collections.abc import KeysView
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
@ -277,7 +278,7 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
||||
noun, verb = "attributes", "were"
|
||||
else:
|
||||
noun, verb = "attribute", "was"
|
||||
raise ValueError(
|
||||
warnings.warn(
|
||||
f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "
|
||||
"Such attributes must be registered as buffers using the `register_buffer` API "
|
||||
"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
|
Reference in New Issue
Block a user