mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add some autograd producer consumer stream sync tests (#150952)
Thanks @ngimel and @albanD for some ideas on test cases Pull Request resolved: https://github.com/pytorch/pytorch/pull/150952 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
397b7f9b82
commit
32f0f414ab
@ -12926,6 +12926,455 @@ class TestAutogradInferenceMode(TestCase):
|
||||
run_test(lambda x: x.transpose_(0, 1))
|
||||
|
||||
|
||||
NUM_GPU_CYCLES_IN_ONE_SEC = 2_000_000_000
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_device_index(target_device):
|
||||
orig_device = torch.accelerator.current_device_index()
|
||||
try:
|
||||
torch.accelerator.set_device_index(target_device)
|
||||
yield
|
||||
finally:
|
||||
torch.accelerator.set_device_index(orig_device)
|
||||
|
||||
|
||||
def _sleep_if_cuda(cycles):
|
||||
if "cuda" == torch.accelerator.current_accelerator().type:
|
||||
return torch.cuda._sleep(cycles)
|
||||
else:
|
||||
# Update this if non-cuda accelerators support something like sleep
|
||||
return
|
||||
|
||||
|
||||
def _get_device_name(idx):
|
||||
return f"{torch.accelerator.current_accelerator().type}:{idx}"
|
||||
|
||||
|
||||
# Although this is written to be generic over all accelerators, non-cuda accelerators
|
||||
# are not fully tested since sleep is only supported on cuda.
|
||||
class TestAutogradStreamSynchronization(TestCase):
|
||||
def get_default_streams(self, num_devices=1):
|
||||
out = []
|
||||
for i in range(num_devices):
|
||||
with _set_device_index(i):
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
out.append(torch.get_device_module(acc).default_stream())
|
||||
return tuple(out)
|
||||
|
||||
def synchronize_all_devices(self, num_devices=1):
|
||||
for i in range(num_devices):
|
||||
torch.accelerator.synchronize(i)
|
||||
|
||||
def assert_all_streams_default(self, num_devices=1):
|
||||
# Sanity check
|
||||
default_streams = self.get_default_streams(num_devices)
|
||||
for i in range(num_devices):
|
||||
with _set_device_index(i):
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
# Do this instead of using torch.accelerator.current_stream(i)
|
||||
# Otherwise, e.g. in the case of cuda, we'd be trying to compare
|
||||
# torch.cuda.Stream with torch.Stream
|
||||
self.assertEqual(
|
||||
torch.get_device_module(acc).current_stream(), default_streams[i]
|
||||
)
|
||||
|
||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||
@skipIfMPS
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
||||
def test_consumer_to_single_producer_case_2_correctness(self):
|
||||
# Device Stream
|
||||
# Consumer (MulBackward): cuda:0 s0
|
||||
# Producer : cuda:0 s1
|
||||
# Gradient : cuda:0 s1
|
||||
class Producer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
out = gO.clone()
|
||||
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
|
||||
out.add_(1)
|
||||
return out
|
||||
|
||||
def test():
|
||||
self.synchronize_all_devices()
|
||||
self.assert_all_streams_default()
|
||||
|
||||
with torch.Stream(0) as s0:
|
||||
a = torch.ones(256, 256, requires_grad=True, device=_get_device_name(0))
|
||||
b = a * 2
|
||||
|
||||
with torch.Stream(0) as s1:
|
||||
s1.wait_stream(s0)
|
||||
out = Producer.apply(b)
|
||||
|
||||
with torch.autograd.grad_mode.set_multithreading_enabled(False):
|
||||
out.sum().backward()
|
||||
|
||||
self.synchronize_all_devices()
|
||||
|
||||
# Expected result: a.grad = (grad_out + 1) * 2 = 4
|
||||
self.assertEqual(a.grad, torch.full_like(a, 4))
|
||||
|
||||
# Run an extra time to warm up
|
||||
for _ in range(2):
|
||||
test()
|
||||
|
||||
def _test_consumer_to_single_producer_case_3_correctness(
|
||||
self, non_default_ambient_stream
|
||||
):
|
||||
# Device Stream
|
||||
# Consumer (MulBackward): cuda:0 s0
|
||||
# Producer : cuda:1 cuda:1 default
|
||||
# Gradient : cuda:0 cuda:0 default
|
||||
class Producer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
# The node's canonical stream is the current stream
|
||||
# of the device of the first output.
|
||||
ctx.node_stream = torch.accelerator.current_stream(1)
|
||||
return x.to(_get_device_name(1))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
out = gO.to(_get_device_name(0))
|
||||
with _set_device_index(0):
|
||||
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
|
||||
# It's the node's responsibility to sync back to its canonical stream.
|
||||
out.add_(1)
|
||||
ctx.node_stream.wait_stream(torch.accelerator.current_stream(0))
|
||||
return out
|
||||
|
||||
def test():
|
||||
self.synchronize_all_devices(2)
|
||||
self.assert_all_streams_default(2)
|
||||
|
||||
(default_stream_0,) = self.get_default_streams()
|
||||
|
||||
# Ensure consumer node happens on non-default stream so that
|
||||
# when FuncBackward produces a gradient on a default stream
|
||||
# a sync is necessary.
|
||||
with torch.Stream(0) as s0:
|
||||
a = torch.ones(256, 256, requires_grad=True, device="cuda")
|
||||
b = a * 2
|
||||
|
||||
default_stream_0.wait_stream(s0)
|
||||
out = Producer.apply(b)
|
||||
|
||||
def call_backward(x):
|
||||
with torch.autograd.grad_mode.set_multithreading_enabled(False):
|
||||
x.sum().backward()
|
||||
|
||||
if non_default_ambient_stream:
|
||||
with torch.Stream(0) as s1:
|
||||
s1.wait_stream(default_stream_0)
|
||||
call_backward(out)
|
||||
else:
|
||||
call_backward(out)
|
||||
|
||||
self.synchronize_all_devices(2)
|
||||
|
||||
# Expected result: a.grad = (grad_out + 1) * 2 = 4
|
||||
self.assertEqual(a.grad, torch.full_like(a, 4))
|
||||
|
||||
# Run an extra time to warm up
|
||||
for _ in range(2):
|
||||
test()
|
||||
|
||||
# This fails because we currently sync to the default stream
|
||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||
@skipIfMPS
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
||||
@unittest.skipIf(
|
||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||
)
|
||||
@unittest.expectedFailure
|
||||
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
|
||||
self,
|
||||
):
|
||||
self._test_consumer_to_single_producer_case_3_correctness(
|
||||
non_default_ambient_stream=True
|
||||
)
|
||||
|
||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||
@skipIfMPS
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
||||
@unittest.skipIf(
|
||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||
)
|
||||
def test_consumer_to_single_producer_case_3_correctness(self):
|
||||
self._test_consumer_to_single_producer_case_3_correctness(
|
||||
non_default_ambient_stream=False
|
||||
)
|
||||
|
||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||
@skipIfMPS
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
||||
@unittest.skipIf(
|
||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||
)
|
||||
def test_consumer_to_single_producer_case_4_correctness(self):
|
||||
# Device Stream
|
||||
# Consumer: cuda:0 cuda:0 default
|
||||
# Producer: cuda:1 s1
|
||||
# Gradient: cuda:1 s1
|
||||
class Producer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
out = gO.clone()
|
||||
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
|
||||
return out.add_(1)
|
||||
|
||||
class Consumer(torch.autograd.Function):
|
||||
# In the multi-output case, the node's canonical device and stream correspond to
|
||||
# that of its first output. This is required to induce cases 4/5.
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.clone(), x.to(_get_device_name(1))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO_0, gO_1):
|
||||
# gO_1 is on s1, but we're currently doing compute in cuda:1 default
|
||||
# It's the user's responsibility to sync to consumer (.to() should do this
|
||||
# already.)
|
||||
# Things would work out if the engine sync'd s1 with consumer.
|
||||
# Ignore grad wrt first arg because we don't use it.
|
||||
return gO_1.to(_get_device_name(0))
|
||||
|
||||
def test():
|
||||
self.synchronize_all_devices(2)
|
||||
self.assert_all_streams_default(2)
|
||||
|
||||
_, default_stream_1 = self.get_default_streams(2)
|
||||
a = torch.ones(256, 256, requires_grad=True, device=_get_device_name(0))
|
||||
_unused, b = Consumer.apply(a)
|
||||
|
||||
with torch.Stream(1) as s1:
|
||||
s1.wait_stream(default_stream_1)
|
||||
out = Producer.apply(b)
|
||||
|
||||
with torch.autograd.grad_mode.set_multithreading_enabled(False):
|
||||
out.sum().backward()
|
||||
|
||||
self.synchronize_all_devices(2)
|
||||
|
||||
# Expected result: a.grad = grad_out + 1 = 2
|
||||
self.assertEqual(a.grad, torch.full_like(a, 2))
|
||||
|
||||
# Run an extra time to warm up
|
||||
for _ in range(2):
|
||||
test()
|
||||
|
||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||
@skipIfMPS
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
||||
@unittest.skipIf(
|
||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||
)
|
||||
def test_consumer_to_multi_producer_case_4_correctness(self):
|
||||
# Device Stream
|
||||
# Consumer : cuda:0 cuda:0 default
|
||||
#
|
||||
# Producer 1: cuda:1 s1
|
||||
# Gradient 1: cuda:1 s1
|
||||
#
|
||||
# Producer 2: cuda:1 s2
|
||||
# Gradient 2: cuda:1 s2
|
||||
#
|
||||
# Accumulation stream: s2 since it is scheduled first
|
||||
class ProducerFast(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
out = gO.clone()
|
||||
return out * 2
|
||||
|
||||
class ProducerSlow(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
out = gO.clone()
|
||||
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
|
||||
return out.mul_(2)
|
||||
|
||||
class Consumer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.node_stream = torch.accelerator.current_stream(x.device)
|
||||
return x.clone(), x.to(_get_device_name(1))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO_0, gO_1):
|
||||
torch.accelerator.current_stream(gO_1.device).wait_stream(
|
||||
ctx.node_stream
|
||||
)
|
||||
return (gO_1 * 2).to(_get_device_name(0))
|
||||
|
||||
def test():
|
||||
self.synchronize_all_devices(2)
|
||||
self.assert_all_streams_default(2)
|
||||
|
||||
default_stream_0, default_stream_1 = self.get_default_streams(2)
|
||||
|
||||
a = torch.ones(256, 256, requires_grad=True, device=_get_device_name(0))
|
||||
_unused, b = Consumer.apply(a)
|
||||
|
||||
with torch.Stream(1) as s1:
|
||||
s1.wait_stream(default_stream_1)
|
||||
out1 = ProducerFast.apply(b)
|
||||
|
||||
with torch.Stream(1) as s2:
|
||||
s2.wait_stream(default_stream_1)
|
||||
out2 = ProducerSlow.apply(b)
|
||||
|
||||
default_stream_1.wait_stream(s1)
|
||||
default_stream_1.wait_stream(s2)
|
||||
|
||||
with torch.autograd.grad_mode.set_multithreading_enabled(False):
|
||||
(out1 + out2).sum().backward()
|
||||
|
||||
self.synchronize_all_devices(2)
|
||||
|
||||
# If the accumulation stream does not wait for the slow producer stream
|
||||
# the in-place mul-by-2 is performed on the accumulated buffer AFTER
|
||||
# ProducerFast has already accumulated!
|
||||
#
|
||||
# Correct: (1.mul_(2) + 2) * 2 = 8
|
||||
# Incorrect: (1 + 2).mul_(2) * 2 = 12
|
||||
self.assertEqual(a.grad, torch.full_like(a, 8))
|
||||
|
||||
# Run an extra time to warm up
|
||||
for _ in range(2):
|
||||
test()
|
||||
|
||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||
@skipIfMPS
|
||||
# This test may spuriously fail on non-cuda accelerators (since we won't
|
||||
# be calling sleep)
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
@unittest.expectedFailure
|
||||
def test_side_stream_backward_overlap(self):
|
||||
# In case 2/3, we would designate the consumer as the accumulation
|
||||
# stream and naively, one might have the consumer wait for the producer
|
||||
# as soon as we've added to the InputBuffer the first time.
|
||||
#
|
||||
# However, in the case where the stream of the consumer also happens to
|
||||
# be the stream of the producer, this is suboptimal because it would
|
||||
# prevent the computation of the two producers from being overlapped.
|
||||
# what you really want to do is to have that sync between the producer
|
||||
# and consumer to be delayed until right before the accumulation.
|
||||
# Note that this doesn't address N=3, but the side-stream N=2 case is
|
||||
# the common case.
|
||||
events = {
|
||||
"main_backward_start": None,
|
||||
"side_backward_start": None,
|
||||
"side_backward_end": None,
|
||||
}
|
||||
|
||||
class Main(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
# Record when main backward starts
|
||||
evt = torch.Event(enable_timing=True)
|
||||
evt.record()
|
||||
events["main_backward_start"] = evt
|
||||
return gO
|
||||
|
||||
class Side(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
evt = torch.Event(enable_timing=True)
|
||||
evt.record()
|
||||
events["side_backward_start"] = evt
|
||||
|
||||
_sleep_if_cuda(NUM_GPU_CYCLES_IN_ONE_SEC // 2)
|
||||
result = gO.clone()
|
||||
|
||||
evt = torch.Event(enable_timing=True)
|
||||
evt.record()
|
||||
events["side_backward_end"] = evt
|
||||
return result
|
||||
|
||||
def populate_events():
|
||||
self.synchronize_all_devices()
|
||||
self.assert_all_streams_default()
|
||||
|
||||
(default_stream_0,) = self.get_default_streams()
|
||||
|
||||
a = torch.ones(256, 256, requires_grad=True, device=_get_device_name(0))
|
||||
b = a.clone() # not a leaf, does it matter?
|
||||
|
||||
evt = torch.Event()
|
||||
evt.record()
|
||||
|
||||
# Overlap during forward
|
||||
c_main = Main.apply(b)
|
||||
|
||||
with torch.Stream(0) as s0:
|
||||
s0.wait_event(evt)
|
||||
c_side = Side.apply(b)
|
||||
|
||||
default_stream_0.wait_stream(s0)
|
||||
|
||||
with torch.autograd.grad_mode.set_multithreading_enabled(False):
|
||||
(c_main + c_side).sum().backward()
|
||||
|
||||
self.synchronize_all_devices()
|
||||
|
||||
def check_ordering():
|
||||
# Sanity check: side backward's end happens after start
|
||||
self.assertTrue(
|
||||
events["side_backward_start"].elapsed_time(events["side_backward_end"])
|
||||
> 0
|
||||
)
|
||||
# Sanity check: main's backward started after side's backward started
|
||||
self.assertTrue(
|
||||
events["side_backward_start"].elapsed_time(
|
||||
events["main_backward_start"]
|
||||
)
|
||||
> 0
|
||||
)
|
||||
# Overlap check: side's backward starts before side backward ends
|
||||
self.assertTrue(
|
||||
events["main_backward_start"].elapsed_time(events["side_backward_end"])
|
||||
> 0
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for _ in range(2):
|
||||
populate_events()
|
||||
|
||||
# Reset events (not really necessary but OK)
|
||||
events["side_backward_start"] = None
|
||||
events["side_backward_end"] = None
|
||||
events["main_backward_start"] = None
|
||||
|
||||
# Test
|
||||
populate_events()
|
||||
check_ordering()
|
||||
|
||||
|
||||
class TestMultithreadAutograd(TestCase):
|
||||
def _run_py_multithread_fn(
|
||||
self, fn, args=(), num_threads=10, kwargs=None, pass_idx=False
|
||||
|
Reference in New Issue
Block a user