Add some autograd producer consumer stream sync tests (#150952)

Thanks @ngimel and @albanD for some ideas on test cases

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150952
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2025-04-11 16:30:46 -07:00
committed by PyTorch MergeBot
parent 397b7f9b82
commit 32f0f414ab

View File

@ -12926,6 +12926,455 @@ class TestAutogradInferenceMode(TestCase):
run_test(lambda x: x.transpose_(0, 1))
NUM_GPU_CYCLES_IN_ONE_SEC = 2_000_000_000
@contextlib.contextmanager
def _set_device_index(target_device):
orig_device = torch.accelerator.current_device_index()
try:
torch.accelerator.set_device_index(target_device)
yield
finally:
torch.accelerator.set_device_index(orig_device)
def _sleep_if_cuda(cycles):
if "cuda" == torch.accelerator.current_accelerator().type:
return torch.cuda._sleep(cycles)
else:
# Update this if non-cuda accelerators support something like sleep
return
def _get_device_name(idx):
return f"{torch.accelerator.current_accelerator().type}:{idx}"
# Although this is written to be generic over all accelerators, non-cuda accelerators
# are not fully tested since sleep is only supported on cuda.
class TestAutogradStreamSynchronization(TestCase):
def get_default_streams(self, num_devices=1):
out = []
for i in range(num_devices):
with _set_device_index(i):
acc = torch.accelerator.current_accelerator()
out.append(torch.get_device_module(acc).default_stream())
return tuple(out)
def synchronize_all_devices(self, num_devices=1):
for i in range(num_devices):
torch.accelerator.synchronize(i)
def assert_all_streams_default(self, num_devices=1):
# Sanity check
default_streams = self.get_default_streams(num_devices)
for i in range(num_devices):
with _set_device_index(i):
acc = torch.accelerator.current_accelerator()
# Do this instead of using torch.accelerator.current_stream(i)
# Otherwise, e.g. in the case of cuda, we'd be trying to compare
# torch.cuda.Stream with torch.Stream
self.assertEqual(
torch.get_device_module(acc).current_stream(), default_streams[i]
)
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
def test_consumer_to_single_producer_case_2_correctness(self):
# Device Stream
# Consumer (MulBackward): cuda:0 s0
# Producer : cuda:0 s1
# Gradient : cuda:0 s1
class Producer(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, gO):
out = gO.clone()
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
out.add_(1)
return out
def test():
self.synchronize_all_devices()
self.assert_all_streams_default()
with torch.Stream(0) as s0:
a = torch.ones(256, 256, requires_grad=True, device=_get_device_name(0))
b = a * 2
with torch.Stream(0) as s1:
s1.wait_stream(s0)
out = Producer.apply(b)
with torch.autograd.grad_mode.set_multithreading_enabled(False):
out.sum().backward()
self.synchronize_all_devices()
# Expected result: a.grad = (grad_out + 1) * 2 = 4
self.assertEqual(a.grad, torch.full_like(a, 4))
# Run an extra time to warm up
for _ in range(2):
test()
def _test_consumer_to_single_producer_case_3_correctness(
self, non_default_ambient_stream
):
# Device Stream
# Consumer (MulBackward): cuda:0 s0
# Producer : cuda:1 cuda:1 default
# Gradient : cuda:0 cuda:0 default
class Producer(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# The node's canonical stream is the current stream
# of the device of the first output.
ctx.node_stream = torch.accelerator.current_stream(1)
return x.to(_get_device_name(1))
@staticmethod
def backward(ctx, gO):
out = gO.to(_get_device_name(0))
with _set_device_index(0):
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
# It's the node's responsibility to sync back to its canonical stream.
out.add_(1)
ctx.node_stream.wait_stream(torch.accelerator.current_stream(0))
return out
def test():
self.synchronize_all_devices(2)
self.assert_all_streams_default(2)
(default_stream_0,) = self.get_default_streams()
# Ensure consumer node happens on non-default stream so that
# when FuncBackward produces a gradient on a default stream
# a sync is necessary.
with torch.Stream(0) as s0:
a = torch.ones(256, 256, requires_grad=True, device="cuda")
b = a * 2
default_stream_0.wait_stream(s0)
out = Producer.apply(b)
def call_backward(x):
with torch.autograd.grad_mode.set_multithreading_enabled(False):
x.sum().backward()
if non_default_ambient_stream:
with torch.Stream(0) as s1:
s1.wait_stream(default_stream_0)
call_backward(out)
else:
call_backward(out)
self.synchronize_all_devices(2)
# Expected result: a.grad = (grad_out + 1) * 2 = 4
self.assertEqual(a.grad, torch.full_like(a, 4))
# Run an extra time to warm up
for _ in range(2):
test()
# This fails because we currently sync to the default stream
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
@unittest.skipIf(
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
)
@unittest.expectedFailure
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
self,
):
self._test_consumer_to_single_producer_case_3_correctness(
non_default_ambient_stream=True
)
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
@unittest.skipIf(
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
)
def test_consumer_to_single_producer_case_3_correctness(self):
self._test_consumer_to_single_producer_case_3_correctness(
non_default_ambient_stream=False
)
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
@unittest.skipIf(
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
)
def test_consumer_to_single_producer_case_4_correctness(self):
# Device Stream
# Consumer: cuda:0 cuda:0 default
# Producer: cuda:1 s1
# Gradient: cuda:1 s1
class Producer(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, gO):
out = gO.clone()
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
return out.add_(1)
class Consumer(torch.autograd.Function):
# In the multi-output case, the node's canonical device and stream correspond to
# that of its first output. This is required to induce cases 4/5.
@staticmethod
def forward(ctx, x):
return x.clone(), x.to(_get_device_name(1))
@staticmethod
def backward(ctx, gO_0, gO_1):
# gO_1 is on s1, but we're currently doing compute in cuda:1 default
# It's the user's responsibility to sync to consumer (.to() should do this
# already.)
# Things would work out if the engine sync'd s1 with consumer.
# Ignore grad wrt first arg because we don't use it.
return gO_1.to(_get_device_name(0))
def test():
self.synchronize_all_devices(2)
self.assert_all_streams_default(2)
_, default_stream_1 = self.get_default_streams(2)
a = torch.ones(256, 256, requires_grad=True, device=_get_device_name(0))
_unused, b = Consumer.apply(a)
with torch.Stream(1) as s1:
s1.wait_stream(default_stream_1)
out = Producer.apply(b)
with torch.autograd.grad_mode.set_multithreading_enabled(False):
out.sum().backward()
self.synchronize_all_devices(2)
# Expected result: a.grad = grad_out + 1 = 2
self.assertEqual(a.grad, torch.full_like(a, 2))
# Run an extra time to warm up
for _ in range(2):
test()
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
@unittest.skipIf(
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
)
def test_consumer_to_multi_producer_case_4_correctness(self):
# Device Stream
# Consumer : cuda:0 cuda:0 default
#
# Producer 1: cuda:1 s1
# Gradient 1: cuda:1 s1
#
# Producer 2: cuda:1 s2
# Gradient 2: cuda:1 s2
#
# Accumulation stream: s2 since it is scheduled first
class ProducerFast(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, gO):
out = gO.clone()
return out * 2
class ProducerSlow(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, gO):
out = gO.clone()
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
return out.mul_(2)
class Consumer(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.node_stream = torch.accelerator.current_stream(x.device)
return x.clone(), x.to(_get_device_name(1))
@staticmethod
def backward(ctx, gO_0, gO_1):
torch.accelerator.current_stream(gO_1.device).wait_stream(
ctx.node_stream
)
return (gO_1 * 2).to(_get_device_name(0))
def test():
self.synchronize_all_devices(2)
self.assert_all_streams_default(2)
default_stream_0, default_stream_1 = self.get_default_streams(2)
a = torch.ones(256, 256, requires_grad=True, device=_get_device_name(0))
_unused, b = Consumer.apply(a)
with torch.Stream(1) as s1:
s1.wait_stream(default_stream_1)
out1 = ProducerFast.apply(b)
with torch.Stream(1) as s2:
s2.wait_stream(default_stream_1)
out2 = ProducerSlow.apply(b)
default_stream_1.wait_stream(s1)
default_stream_1.wait_stream(s2)
with torch.autograd.grad_mode.set_multithreading_enabled(False):
(out1 + out2).sum().backward()
self.synchronize_all_devices(2)
# If the accumulation stream does not wait for the slow producer stream
# the in-place mul-by-2 is performed on the accumulated buffer AFTER
# ProducerFast has already accumulated!
#
# Correct: (1.mul_(2) + 2) * 2 = 8
# Incorrect: (1 + 2).mul_(2) * 2 = 12
self.assertEqual(a.grad, torch.full_like(a, 8))
# Run an extra time to warm up
for _ in range(2):
test()
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS
# This test may spuriously fail on non-cuda accelerators (since we won't
# be calling sleep)
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
@unittest.expectedFailure
def test_side_stream_backward_overlap(self):
# In case 2/3, we would designate the consumer as the accumulation
# stream and naively, one might have the consumer wait for the producer
# as soon as we've added to the InputBuffer the first time.
#
# However, in the case where the stream of the consumer also happens to
# be the stream of the producer, this is suboptimal because it would
# prevent the computation of the two producers from being overlapped.
# what you really want to do is to have that sync between the producer
# and consumer to be delayed until right before the accumulation.
# Note that this doesn't address N=3, but the side-stream N=2 case is
# the common case.
events = {
"main_backward_start": None,
"side_backward_start": None,
"side_backward_end": None,
}
class Main(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, gO):
# Record when main backward starts
evt = torch.Event(enable_timing=True)
evt.record()
events["main_backward_start"] = evt
return gO
class Side(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, gO):
evt = torch.Event(enable_timing=True)
evt.record()
events["side_backward_start"] = evt
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
result = gO.clone()
evt = torch.Event(enable_timing=True)
evt.record()
events["side_backward_end"] = evt
return result
def populate_events():
self.synchronize_all_devices()
self.assert_all_streams_default()
(default_stream_0,) = self.get_default_streams()
a = torch.ones(256, 256, requires_grad=True, device=_get_device_name(0))
b = a.clone() # not a leaf, does it matter?
evt = torch.Event()
evt.record()
# Overlap during forward
c_main = Main.apply(b)
with torch.Stream(0) as s0:
s0.wait_event(evt)
c_side = Side.apply(b)
default_stream_0.wait_stream(s0)
with torch.autograd.grad_mode.set_multithreading_enabled(False):
(c_main + c_side).sum().backward()
self.synchronize_all_devices()
def check_ordering():
# Sanity check: side backward's end happens after start
self.assertTrue(
events["side_backward_start"].elapsed_time(events["side_backward_end"])
> 0
)
# Sanity check: main's backward started after side's backward started
self.assertTrue(
events["side_backward_start"].elapsed_time(
events["main_backward_start"]
)
> 0
)
# Overlap check: side's backward starts before side backward ends
self.assertTrue(
events["main_backward_start"].elapsed_time(events["side_backward_end"])
> 0
)
# Warmup
for _ in range(2):
populate_events()
# Reset events (not really necessary but OK)
events["side_backward_start"] = None
events["side_backward_end"] = None
events["main_backward_start"] = None
# Test
populate_events()
check_ordering()
class TestMultithreadAutograd(TestCase):
def _run_py_multithread_fn(
self, fn, args=(), num_threads=10, kwargs=None, pass_idx=False