Make CUDACachingAllocator::recordStream() a no-op on null ptrs (#20658)

Summary:
Fixes #20651

Communication collectives in `torch.distributed` call `CUDACachingAllocator::recordStream()` on input and output tensors to prevent their memory blocks being freed too early. `CUDACachingAllocator` uses tensor's data pointer to track memory blocks, which does not accept null pointers. However, empty tensor's `storage().data()` might be null. In this case, as there is no associated memory block for the empty tensor, it should be fine to make `recordStream()` a no-op.

Tests only cover `broadcast` empty tensors for GLOO backend, because GLOO does not support empty inputs (facebookincubator/gloo/issues/179). It can be addressed in either `ProcessGroupGloo` or GLOO itself. Will add more tests when that gap is filled.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20658

Differential Revision: D15399371

Pulled By: mrshenli

fbshipit-source-id: d29ebd1c72fddae49531f32695f81b89e42e5a4d
This commit is contained in:
Shen Li
2019-05-20 07:11:12 -07:00
committed by Facebook Github Bot
parent 071971476d
commit 8acaa286b7
2 changed files with 46 additions and 10 deletions

View File

@ -378,17 +378,21 @@ struct THCCachingAllocator
void recordStream(void* ptr, cuda::CUDAStream stream) void recordStream(void* ptr, cuda::CUDAStream stream)
{ {
std::lock_guard<std::recursive_mutex> lock(mutex); // Empty tensor's storage().data() might be a null ptr. As there is no
Block* block = find_allocated_block(ptr); // blocks associated with those tensors, it is fine to do nothing here.
if (!block) { if (ptr) {
AT_ERROR("invalid device pointer: ", ptr); std::lock_guard<std::recursive_mutex> lock(mutex);
Block* block = find_allocated_block(ptr);
if (!block) {
AT_ERROR("invalid device pointer: ", ptr);
}
if (stream.stream() == block->stream) {
// ignore uses on the allocation stream, since those don't require any
// special synchronization
return;
}
block->stream_uses.insert(stream);
} }
if (stream.stream() == block->stream) {
// ignore uses on the allocation stream, since those don't require any
// special synchronization
return;
}
block->stream_uses.insert(stream);
} }
/** moves a block into a pool of cached free blocks */ /** moves a block into a pool of cached free blocks */

View File

@ -579,6 +579,14 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
opts.threads = threads opts.threads = threads
return opts return opts
def test_empty_tensors(self):
store = c10d.FileStore(self.file.name, self.world_size)
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
xs = [torch.FloatTensor([])]
pg.broadcast(xs).wait()
self.assertEqual(0, xs[0].numel())
def test_broadcast_checks(self): def test_broadcast_checks(self):
store = c10d.FileStore(self.file.name, self.world_size) store = c10d.FileStore(self.file.name, self.world_size)
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
@ -1344,6 +1352,30 @@ class ProcessGroupNCCLTest(TestCase):
def tearDown(self): def tearDown(self):
pass pass
def test_empty_tensors(self):
store = c10d.FileStore(self.file.name, self.world_size)
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
xs = [torch.cuda.FloatTensor([])]
pg.broadcast(xs).wait()
self.assertEqual(0, xs[0].numel())
pg.allreduce(xs).wait()
self.assertEqual(0, xs[0].numel())
pg.reduce(xs).wait()
self.assertEqual(0, xs[0].numel())
ys = [[torch.cuda.FloatTensor([]) for _ in range(self.world_size)]]
pg.allgather(ys, xs).wait()
for y in ys[0]:
self.assertEqual(0, y.numel())
ys = [torch.cuda.FloatTensor([])]
xs = [[torch.cuda.FloatTensor([]) for _ in range(self.world_size)]]
pg.reduce_scatter(ys, xs).wait()
self.assertEqual(0, ys[0].numel())
def test_broadcast_ops(self): def test_broadcast_ops(self):
store = c10d.FileStore(self.file.name, self.world_size) store = c10d.FileStore(self.file.name, self.world_size)
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)