mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Add pinned memory support to sparse COO/CSR/CSC/BSR/BSC tensors (#129645)
As in the title: To register indices/values of a sparse XYZ tensor with CUDA, the following methods are supported - `sparse_xyz_tensor(indices, values, pin_memory=True)` - `sparse_xyz_tensor(indices, values).pin_memory()` - `sparse_xyz_tensor(indices.pin_memory(), values.pin_memory())` Fixes https://github.com/pytorch/pytorch/issues/115330 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129645 Approved by: https://github.com/amjames, https://github.com/cpuhrsch, https://github.com/eqy
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							babb249a89
						
					
				
				
					commit
					a4ea776881
				
			| @ -3330,6 +3330,8 @@ class TestCase(expecttest.TestCase): | ||||
|                                device=None, | ||||
|                                dtype=None, | ||||
|                                index_dtype=None, | ||||
|                                pin_memory=None, | ||||
|                                members_pin_memory=None, | ||||
|                                enable_batch=True, | ||||
|                                enable_hybrid=True, | ||||
|                                enable_zero_sized=True, | ||||
| @ -3353,10 +3355,11 @@ class TestCase(expecttest.TestCase): | ||||
|         constructors: | ||||
|  | ||||
|           - sparse compressed input is defined as | ||||
|             (compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype) | ||||
|             (compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, | ||||
|                                                               pin_memory=pin_memory) | ||||
|  | ||||
|           - sparse COO input is defined as | ||||
|             (indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype) | ||||
|             (indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, pin_memory=pin_memory) | ||||
|  | ||||
|           - strided input is defined as | ||||
|             (values,), dict(device=device, dtype=dtype) | ||||
| @ -3368,17 +3371,23 @@ class TestCase(expecttest.TestCase): | ||||
|  | ||||
|         if output_tensor: | ||||
|             for args, kwargs in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype, | ||||
|                                                             pin_memory=pin_memory, | ||||
|                                                             enable_batch=enable_batch, enable_hybrid=enable_hybrid, | ||||
|                                                             enable_zero_sized=enable_zero_sized, | ||||
|                                                             enable_non_contiguous_indices=enable_non_contiguous_indices, | ||||
|                                                             enable_non_contiguous_values=enable_non_contiguous_values, | ||||
|                                                             enable_batch_variable_nse=enable_batch_variable_nse, | ||||
|                                                             output_tensor=False): | ||||
|                 if members_pin_memory: | ||||
|                     args = tuple(a.pin_memory() for a in args) | ||||
|                 if layout is torch.strided: | ||||
|                     assert len(args) == 1 | ||||
|                     size = kwargs.pop('size', None)  # to ensure that a zero-sized tensor has the desired shape | ||||
|                     assert size is not None | ||||
|                     yield args[0].reshape(size) | ||||
|                     if pin_memory: | ||||
|                         yield args[0].reshape(size).pin_memory() | ||||
|                     else: | ||||
|                         yield args[0].reshape(size) | ||||
|                 elif layout is torch.sparse_coo: | ||||
|                     yield torch.sparse_coo_tensor(*args, **kwargs) | ||||
|                 elif is_compressed_sparse_layout: | ||||
| @ -3594,25 +3603,24 @@ class TestCase(expecttest.TestCase): | ||||
|                 for densesize in densesizes: | ||||
|                     indices = [a.to(device=device, dtype=index_dtype) for a in data[:-1]] | ||||
|                     values = generate_values(data[-1], densesize).to(device=device, dtype=dtype) | ||||
|                     yield (*indices, values), dict(device=device, dtype=dtype, | ||||
|                                                    size=pattern.shape + densesize) | ||||
|                     kwargs = dict(device=device, dtype=dtype, size=pattern.shape + densesize) | ||||
|                     if pin_memory is not None: | ||||
|                         kwargs.update(pin_memory=pin_memory) | ||||
|  | ||||
|                     yield (*indices, values), kwargs.copy() | ||||
|                     if enable_non_contiguous_indices and pattern.ndim > 2: | ||||
|                         # sparse compressed indices can be sliced only along batch dimensions | ||||
|                         for (dim, offset) in {(0, 1), (-2, 0)}: | ||||
|                             indices_copy = [non_contiguous_copy(a, dim=dim, offset=offset) for a in indices] | ||||
|                             yield (*indices_copy, values), dict(device=device, dtype=dtype, | ||||
|                                                                 size=pattern.shape + densesize) | ||||
|                             yield (*indices_copy, values), kwargs.copy() | ||||
|  | ||||
|                             if enable_non_contiguous_values: | ||||
|                                 values_copy = non_contiguous_copy(values, dim=-1, offset=1) | ||||
|                                 yield (*indices_copy, values_copy), dict(device=device, dtype=dtype, | ||||
|                                                                          size=pattern.shape + densesize) | ||||
|                                 yield (*indices_copy, values_copy), kwargs.copy() | ||||
|  | ||||
|                     if enable_non_contiguous_values: | ||||
|                         values_copy = non_contiguous_copy(values, dim=-1, offset=1) | ||||
|                         yield (*indices, values_copy), dict(device=device, dtype=dtype, | ||||
|                                                             size=pattern.shape + densesize) | ||||
|                         yield (*indices, values_copy), kwargs.copy() | ||||
|  | ||||
|         # zero-sized tensor inputs, non-batch, non-hybrid/hybrid | ||||
|         if enable_zero_sized: | ||||
| @ -3651,7 +3659,10 @@ class TestCase(expecttest.TestCase): | ||||
|                             values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype) | ||||
|                         else: | ||||
|                             assert 0  # unreachable | ||||
|                         yield (*indices, values), dict(device=device, dtype=dtype, size=basesize + densesize) | ||||
|                         kwargs = dict(device=device, dtype=dtype, size=basesize + densesize) | ||||
|                         if pin_memory is not None: | ||||
|                             kwargs.update(pin_memory=pin_memory) | ||||
|                         yield (*indices, values), kwargs | ||||
|  | ||||
|     def safeToDense(self, t): | ||||
|         # coalesce is only implemented for COO | ||||
|  | ||||
		Reference in New Issue
	
	Block a user