mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] _detect_attribute_assignment gives warning instead of raising ValueError (#163809)
Summary: LSTM was not exportable with non-strict export as it failed at `_detect_attribute_assignment` This is because the `_flat_weights` attribute in LSTM is a list of registered parameters and will be updated by the `_update_flat_weights` method in `forward`. However, in `_detect_attribute_assignment`, we manually restore the state of the module by `mod.__dict__.update(snapshot)`. Therefore, it should be fine to turn the `ValueError` into a warning so that RNN models are exportable with non-strict export. Added test to verify that there is no lifted tensor constant and no fake tensor leakage. Test Plan: buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_rnn_variants_with_warning Differential Revision: D83196971 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163809 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
b4be380480
commit
5c2f09d1f9
@ -2007,8 +2007,8 @@ class GraphModule(torch.nn.Module):
|
|||||||
# z = 3
|
# z = 3
|
||||||
return x + y + z
|
return x + y + z
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertWarnsRegex(
|
||||||
ValueError,
|
UserWarning,
|
||||||
"The tensor attribute self.buf was assigned during export",
|
"The tensor attribute self.buf was assigned during export",
|
||||||
):
|
):
|
||||||
export(M(), (torch.randn(2, 3),), strict=False)
|
export(M(), (torch.randn(2, 3),), strict=False)
|
||||||
@ -2065,8 +2065,8 @@ class GraphModule(torch.nn.Module):
|
|||||||
# z = 3 + 3
|
# z = 3 + 3
|
||||||
return x + y + z
|
return x + y + z
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertWarnsRegex(
|
||||||
ValueError,
|
UserWarning,
|
||||||
"The tensor attributes self.tensors\\[0\\], self.tensors\\[1\\] were assigned during export",
|
"The tensor attributes self.tensors\\[0\\], self.tensors\\[1\\] were assigned during export",
|
||||||
):
|
):
|
||||||
export(M(), (torch.randn(2, 3),), strict=False)
|
export(M(), (torch.randn(2, 3),), strict=False)
|
||||||
@ -15898,6 +15898,50 @@ class GraphModule(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
self.assertEqual(len(shift_op), 1)
|
self.assertEqual(len(shift_op), 1)
|
||||||
|
|
||||||
|
def test_export_rnn_variants_with_warning(self):
|
||||||
|
"""
|
||||||
|
Test that when exporting RNN, LSTM, and GRU models in non-strict mode, it:
|
||||||
|
|
||||||
|
1. Produces expected warnings about tensor attributes being assigned during export
|
||||||
|
2. Does not leak fake tensors in the model's flat weights
|
||||||
|
3. Does not produce extra tensor constants in the graph signature
|
||||||
|
"""
|
||||||
|
rnn_types = [
|
||||||
|
(torch.nn.RNN, "RNN"),
|
||||||
|
(torch.nn.LSTM, "LSTM"),
|
||||||
|
(torch.nn.GRU, "GRU"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for rnn_class, rnn_name in rnn_types:
|
||||||
|
with self.subTest(rnn_type=rnn_name):
|
||||||
|
m = rnn_class(
|
||||||
|
input_size=2, hidden_size=4, num_layers=1, batch_first=True
|
||||||
|
)
|
||||||
|
sample_inputs = (torch.randn(1, 2, 2),)
|
||||||
|
eager_out = m(*sample_inputs)
|
||||||
|
|
||||||
|
# Verify that export produces the expected warning about tensor attributes
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
UserWarning,
|
||||||
|
r"The tensor attributes self\._flat_weights\[0\], self\._flat_weights\[1\], "
|
||||||
|
r"self\._flat_weights\[2\], self\._flat_weights\[3\] were assigned during export.*",
|
||||||
|
):
|
||||||
|
ep = torch.export.export(m, sample_inputs, strict=False)
|
||||||
|
|
||||||
|
ep_out = ep.module()(*sample_inputs)
|
||||||
|
self.assertEqual(eager_out, ep_out)
|
||||||
|
|
||||||
|
# Verify no fake tensor leakage: flat weights should be real tensors
|
||||||
|
for flat_weight in m._flat_weights:
|
||||||
|
self.assertTrue(
|
||||||
|
not isinstance(
|
||||||
|
flat_weight, torch._subclasses.fake_tensor.FakeTensor
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify no tensor constants in graph signature
|
||||||
|
self.assertEqual(len(ep.graph_signature.lifted_tensor_constants), 0)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def distributed_env(self, world_size):
|
def distributed_env(self, world_size):
|
||||||
try:
|
try:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
|
|
||||||
|
import warnings
|
||||||
from collections.abc import KeysView
|
from collections.abc import KeysView
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@ -277,7 +278,7 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
|||||||
noun, verb = "attributes", "were"
|
noun, verb = "attributes", "were"
|
||||||
else:
|
else:
|
||||||
noun, verb = "attribute", "was"
|
noun, verb = "attribute", "was"
|
||||||
raise ValueError(
|
warnings.warn(
|
||||||
f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "
|
f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "
|
||||||
"Such attributes must be registered as buffers using the `register_buffer` API "
|
"Such attributes must be registered as buffers using the `register_buffer` API "
|
||||||
"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||||
|
Reference in New Issue
Block a user