mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Customized pin_memory for PackedSequence (#18079)
Summary: fixes https://github.com/pytorch/pytorch/issues/18078 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18079 Reviewed By: ezyang Differential Revision: D14521192 Pulled By: zou3519 fbshipit-source-id: cec773a3a6f2c405a0d9701e213b7caf81649181
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							916a670828
						
					
				
				
					commit
					f212fd9fd6
				
			| @ -1004,7 +1004,7 @@ class TestNamedTupleDataLoader(TestCase): | ||||
|             self.assertIsInstance(batch.data, NamedTupleDataset.Data) | ||||
|  | ||||
|  | ||||
| class SimpleCustomBatch: | ||||
| class SimpleCustomBatch(object): | ||||
|     def __init__(self, data): | ||||
|         transposed_data = list(zip(*data)) | ||||
|         self.inp = torch.stack(transposed_data[0], 0) | ||||
| @ -1015,11 +1015,28 @@ class SimpleCustomBatch: | ||||
|         self.tgt = self.tgt.pin_memory() | ||||
|         return self | ||||
|  | ||||
|     def is_pinned(self): | ||||
|         return self.inp.is_pinned() and self.tgt.is_pinned() | ||||
|  | ||||
|  | ||||
| def collate_wrapper(batch): | ||||
|     return SimpleCustomBatch(batch) | ||||
|  | ||||
|  | ||||
| def collate_into_packed_sequence(batch): | ||||
|     data = torch.stack([sample[0] for sample in batch], 1) | ||||
|     t, b = data.size() | ||||
|     lengths = torch.randint(1, t, size=(b,), dtype=torch.int64) | ||||
|     return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False) | ||||
|  | ||||
|  | ||||
| def collate_into_packed_sequence_batch_first(batch): | ||||
|     data = torch.stack([sample[0] for sample in batch], 0) | ||||
|     b, t = data.size() | ||||
|     lengths = torch.randint(1, t, size=(b,), dtype=torch.int64) | ||||
|     return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, batch_first=True, enforce_sorted=False) | ||||
|  | ||||
|  | ||||
| class TestCustomPinFn(TestCase): | ||||
|     def setUp(self): | ||||
|         inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) | ||||
| @ -1029,20 +1046,32 @@ class TestCustomPinFn(TestCase): | ||||
|     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | ||||
|     @skipIfRocm | ||||
|     def test_custom_batch_pin(self): | ||||
|         loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper, | ||||
|                             pin_memory=True) | ||||
|         for sample in loader: | ||||
|             self.assertTrue(sample.inp.is_pinned()) | ||||
|             self.assertTrue(sample.tgt.is_pinned()) | ||||
|         test_cases = [ | ||||
|             (collate_wrapper, SimpleCustomBatch), | ||||
|             (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence), | ||||
|             (collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence), | ||||
|         ] | ||||
|         for collate_fn, elem_cls in test_cases: | ||||
|             loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn, | ||||
|                                 pin_memory=True) | ||||
|             for sample in loader: | ||||
|                 self.assertIsInstance(sample, elem_cls) | ||||
|                 self.assertTrue(sample.is_pinned()) | ||||
|  | ||||
|     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | ||||
|     @skipIfRocm | ||||
|     def test_custom_batch_pin_worker(self): | ||||
|         loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper, | ||||
|                             pin_memory=True, num_workers=1) | ||||
|         for sample in loader: | ||||
|             self.assertTrue(sample.inp.is_pinned()) | ||||
|             self.assertTrue(sample.tgt.is_pinned()) | ||||
|         test_cases = [ | ||||
|             (collate_wrapper, SimpleCustomBatch), | ||||
|             (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence), | ||||
|             (collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence), | ||||
|         ] | ||||
|         for collate_fn, elem_cls in test_cases: | ||||
|             loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn, | ||||
|                                 pin_memory=True, num_workers=1) | ||||
|             for sample in loader: | ||||
|                 self.assertIsInstance(sample, elem_cls) | ||||
|                 self.assertTrue(sample.is_pinned()) | ||||
|  | ||||
|  | ||||
| class TestWorkerQueueDataset(Dataset): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user