Add pinned memory support to sparse COO/CSR/CSC/BSR/BSC tensors (#129645)

As in the title:

To register indices/values of a sparse XYZ tensor with CUDA, the following methods are supported
- `sparse_xyz_tensor(indices, values, pin_memory=True)`
- `sparse_xyz_tensor(indices, values).pin_memory()`
- `sparse_xyz_tensor(indices.pin_memory(), values.pin_memory())`

Fixes https://github.com/pytorch/pytorch/issues/115330

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129645
Approved by: https://github.com/amjames, https://github.com/cpuhrsch, https://github.com/eqy
This commit is contained in:
Pearu Peterson
2024-08-01 19:56:27 +03:00
committed by PyTorch MergeBot
parent babb249a89
commit a4ea776881
13 changed files with 327 additions and 58 deletions

View File

@ -3330,6 +3330,8 @@ class TestCase(expecttest.TestCase):
device=None,
dtype=None,
index_dtype=None,
pin_memory=None,
members_pin_memory=None,
enable_batch=True,
enable_hybrid=True,
enable_zero_sized=True,
@ -3353,10 +3355,11 @@ class TestCase(expecttest.TestCase):
constructors:
- sparse compressed input is defined as
(compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype)
(compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype,
pin_memory=pin_memory)
- sparse COO input is defined as
(indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype)
(indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, pin_memory=pin_memory)
- strided input is defined as
(values,), dict(device=device, dtype=dtype)
@ -3368,17 +3371,23 @@ class TestCase(expecttest.TestCase):
if output_tensor:
for args, kwargs in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype,
pin_memory=pin_memory,
enable_batch=enable_batch, enable_hybrid=enable_hybrid,
enable_zero_sized=enable_zero_sized,
enable_non_contiguous_indices=enable_non_contiguous_indices,
enable_non_contiguous_values=enable_non_contiguous_values,
enable_batch_variable_nse=enable_batch_variable_nse,
output_tensor=False):
if members_pin_memory:
args = tuple(a.pin_memory() for a in args)
if layout is torch.strided:
assert len(args) == 1
size = kwargs.pop('size', None) # to ensure that a zero-sized tensor has the desired shape
assert size is not None
yield args[0].reshape(size)
if pin_memory:
yield args[0].reshape(size).pin_memory()
else:
yield args[0].reshape(size)
elif layout is torch.sparse_coo:
yield torch.sparse_coo_tensor(*args, **kwargs)
elif is_compressed_sparse_layout:
@ -3594,25 +3603,24 @@ class TestCase(expecttest.TestCase):
for densesize in densesizes:
indices = [a.to(device=device, dtype=index_dtype) for a in data[:-1]]
values = generate_values(data[-1], densesize).to(device=device, dtype=dtype)
yield (*indices, values), dict(device=device, dtype=dtype,
size=pattern.shape + densesize)
kwargs = dict(device=device, dtype=dtype, size=pattern.shape + densesize)
if pin_memory is not None:
kwargs.update(pin_memory=pin_memory)
yield (*indices, values), kwargs.copy()
if enable_non_contiguous_indices and pattern.ndim > 2:
# sparse compressed indices can be sliced only along batch dimensions
for (dim, offset) in {(0, 1), (-2, 0)}:
indices_copy = [non_contiguous_copy(a, dim=dim, offset=offset) for a in indices]
yield (*indices_copy, values), dict(device=device, dtype=dtype,
size=pattern.shape + densesize)
yield (*indices_copy, values), kwargs.copy()
if enable_non_contiguous_values:
values_copy = non_contiguous_copy(values, dim=-1, offset=1)
yield (*indices_copy, values_copy), dict(device=device, dtype=dtype,
size=pattern.shape + densesize)
yield (*indices_copy, values_copy), kwargs.copy()
if enable_non_contiguous_values:
values_copy = non_contiguous_copy(values, dim=-1, offset=1)
yield (*indices, values_copy), dict(device=device, dtype=dtype,
size=pattern.shape + densesize)
yield (*indices, values_copy), kwargs.copy()
# zero-sized tensor inputs, non-batch, non-hybrid/hybrid
if enable_zero_sized:
@ -3651,7 +3659,10 @@ class TestCase(expecttest.TestCase):
values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
else:
assert 0 # unreachable
yield (*indices, values), dict(device=device, dtype=dtype, size=basesize + densesize)
kwargs = dict(device=device, dtype=dtype, size=basesize + densesize)
if pin_memory is not None:
kwargs.update(pin_memory=pin_memory)
yield (*indices, values), kwargs
def safeToDense(self, t):
# coalesce is only implemented for COO