Warn if AccumulateGrad stream does not match producer node stream (#165065)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165065
Approved by: https://github.com/ngimel
ghstack dependencies: #162815
This commit is contained in:
soulitzer
2025-10-10 06:50:06 -07:00
committed by PyTorch MergeBot
parent 01a2812f48
commit a70ef954b9
11 changed files with 141 additions and 7 deletions

View File

@ -13712,6 +13712,53 @@ class TestAutogradStreamSynchronization(TestCase):
populate_events()
check_ordering()
# Fails on MPS
@skipIfMPS
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_warn_on_accumulate_grad_stream_mismatch_flag(self):
def do_test(suppress_warn, keep_grad_acc):
def _test():
with warnings.catch_warnings(record=True) as warns:
warnings.simplefilter("always")
with torch.Stream(0) as s0:
a = torch.ones(8, 8, device="cuda", requires_grad=True)
if keep_grad_acc:
# create grad_acc under s1 and keep alive with b
b = a.clone()
with torch.Stream(0) as s1:
s1.wait_stream(s0)
c = a.sum()
c.backward()
filter_str = "set_warn_on_accumulate_grad_stream_mismatch"
return sum([filter_str in str(w.message) for w in warns]) > 0
if suppress_warn:
try:
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
False
)
actual_warn = _test()
finally:
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
True
)
else:
actual_warn = _test()
expect_warn = not suppress_warn and keep_grad_acc
self.assertEqual(actual_warn, expect_warn)
# Warn by default
self.assertTrue(torch._C._warn_on_accumulate_grad_stream_mismatch())
for suppress_warn in (True, False):
for keep_grad_acc in (True, False):
do_test(suppress_warn=suppress_warn, keep_grad_acc=keep_grad_acc)
class TestMultithreadAutograd(TestCase):
def _run_py_multithread_fn(