Compare commits

...

1 Commits

Author SHA1 Message Date
f1753fc368 [export] Disable side effects on dynamo_graph_capture_for_export and warn user.
Summary:
as title.

Test Plan:
test_dynamo_graph_capture_side_effects

Reviewers:

Subscribers:

Tasks:

Tags:
2025-11-13 12:48:22 -08:00
2 changed files with 34 additions and 0 deletions

View File

@ -3,6 +3,7 @@
import copy
import types
import unittest
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple
@ -18,6 +19,9 @@ from torch.testing import FileCheck
from torch.testing._internal.common_utils import TEST_CUDA
GLOBAL_LIST = []
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
def test_joint_basic(self) -> None:
@ -611,6 +615,34 @@ def forward(self, args_0):
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
def test_dynamo_graph_capture_side_effects(self):
GLOBAL_LIST.clear()
def foo(x):
z = x + 1
GLOBAL_LIST.append(z)
return z
def make_inputs():
return (torch.randn(2, 3),)
trace_inputs = make_inputs()
with warnings.catch_warnings(record=True) as w:
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
cnt = 0
for entry in w:
if "While compiling, we found certain side effects happened" in str(
entry.message
):
cnt += 1
self.assertEqual(cnt, 1)
self.assertEqual(len(GLOBAL_LIST), 0)
test_inputs = make_inputs()
gm_results = gm(*test_inputs)
self.assertEqual(len(GLOBAL_LIST), 0)
self.assertEqual(gm_results, foo(*test_inputs))
self.assertEqual(len(GLOBAL_LIST), 1)
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
class DummyOp(torch.autograd.Function):

View File

@ -611,6 +611,8 @@ def dynamo_graph_capture_for_export(
def inner(*args: Any, **kwargs: Any) -> Any:
assert not torch._dynamo.config.install_free_tensors
with (
torch._dynamo.config.patch(replay_side_effects=False),
torch._dynamo.config.patch(side_effect_replay_policy="warn"),
get_metrics_context(),
dynamo_timed("fullgraph_capture"),
):