mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Warn if AccumulateGrad stream does not match producer node stream (#165065)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165065 Approved by: https://github.com/ngimel ghstack dependencies: #162815
This commit is contained in:
committed by
PyTorch MergeBot
parent
01a2812f48
commit
a70ef954b9
@ -13712,6 +13712,53 @@ class TestAutogradStreamSynchronization(TestCase):
|
||||
populate_events()
|
||||
check_ordering()
|
||||
|
||||
# Fails on MPS
|
||||
@skipIfMPS
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_warn_on_accumulate_grad_stream_mismatch_flag(self):
|
||||
def do_test(suppress_warn, keep_grad_acc):
|
||||
def _test():
|
||||
with warnings.catch_warnings(record=True) as warns:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
with torch.Stream(0) as s0:
|
||||
a = torch.ones(8, 8, device="cuda", requires_grad=True)
|
||||
if keep_grad_acc:
|
||||
# create grad_acc under s1 and keep alive with b
|
||||
b = a.clone()
|
||||
|
||||
with torch.Stream(0) as s1:
|
||||
s1.wait_stream(s0)
|
||||
c = a.sum()
|
||||
|
||||
c.backward()
|
||||
|
||||
filter_str = "set_warn_on_accumulate_grad_stream_mismatch"
|
||||
return sum([filter_str in str(w.message) for w in warns]) > 0
|
||||
|
||||
if suppress_warn:
|
||||
try:
|
||||
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
|
||||
False
|
||||
)
|
||||
actual_warn = _test()
|
||||
finally:
|
||||
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
|
||||
True
|
||||
)
|
||||
else:
|
||||
actual_warn = _test()
|
||||
|
||||
expect_warn = not suppress_warn and keep_grad_acc
|
||||
self.assertEqual(actual_warn, expect_warn)
|
||||
|
||||
# Warn by default
|
||||
self.assertTrue(torch._C._warn_on_accumulate_grad_stream_mismatch())
|
||||
|
||||
for suppress_warn in (True, False):
|
||||
for keep_grad_acc in (True, False):
|
||||
do_test(suppress_warn=suppress_warn, keep_grad_acc=keep_grad_acc)
|
||||
|
||||
|
||||
class TestMultithreadAutograd(TestCase):
|
||||
def _run_py_multithread_fn(
|
||||
|
Reference in New Issue
Block a user