Customized pin_memory for PackedSequence (#18079)

Summary:
fixes https://github.com/pytorch/pytorch/issues/18078
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18079

Reviewed By: ezyang

Differential Revision: D14521192

Pulled By: zou3519

fbshipit-source-id: cec773a3a6f2c405a0d9701e213b7caf81649181
This commit is contained in:
Tongzhou Wang
2019-03-19 13:35:55 -07:00
committed by Facebook Github Bot
parent 916a670828
commit f212fd9fd6
3 changed files with 116 additions and 16 deletions

View File

@ -1004,7 +1004,7 @@ class TestNamedTupleDataLoader(TestCase):
self.assertIsInstance(batch.data, NamedTupleDataset.Data)
class SimpleCustomBatch:
class SimpleCustomBatch(object):
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
@ -1015,11 +1015,28 @@ class SimpleCustomBatch:
self.tgt = self.tgt.pin_memory()
return self
def is_pinned(self):
return self.inp.is_pinned() and self.tgt.is_pinned()
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
def collate_into_packed_sequence(batch):
data = torch.stack([sample[0] for sample in batch], 1)
t, b = data.size()
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
def collate_into_packed_sequence_batch_first(batch):
data = torch.stack([sample[0] for sample in batch], 0)
b, t = data.size()
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, batch_first=True, enforce_sorted=False)
class TestCustomPinFn(TestCase):
def setUp(self):
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
@ -1029,20 +1046,32 @@ class TestCustomPinFn(TestCase):
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm
def test_custom_batch_pin(self):
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for sample in loader:
self.assertTrue(sample.inp.is_pinned())
self.assertTrue(sample.tgt.is_pinned())
test_cases = [
(collate_wrapper, SimpleCustomBatch),
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
(collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
]
for collate_fn, elem_cls in test_cases:
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
pin_memory=True)
for sample in loader:
self.assertIsInstance(sample, elem_cls)
self.assertTrue(sample.is_pinned())
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm
def test_custom_batch_pin_worker(self):
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True, num_workers=1)
for sample in loader:
self.assertTrue(sample.inp.is_pinned())
self.assertTrue(sample.tgt.is_pinned())
test_cases = [
(collate_wrapper, SimpleCustomBatch),
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
(collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
]
for collate_fn, elem_cls in test_cases:
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
pin_memory=True, num_workers=1)
for sample in loader:
self.assertIsInstance(sample, elem_cls)
self.assertTrue(sample.is_pinned())
class TestWorkerQueueDataset(Dataset):