mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add pinned memory support to sparse COO/CSR/CSC/BSR/BSC tensors (#129645)
As in the title: To register indices/values of a sparse XYZ tensor with CUDA, the following methods are supported - `sparse_xyz_tensor(indices, values, pin_memory=True)` - `sparse_xyz_tensor(indices, values).pin_memory()` - `sparse_xyz_tensor(indices.pin_memory(), values.pin_memory())` Fixes https://github.com/pytorch/pytorch/issues/115330 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129645 Approved by: https://github.com/amjames, https://github.com/cpuhrsch, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
babb249a89
commit
a4ea776881
@ -3330,6 +3330,8 @@ class TestCase(expecttest.TestCase):
|
||||
device=None,
|
||||
dtype=None,
|
||||
index_dtype=None,
|
||||
pin_memory=None,
|
||||
members_pin_memory=None,
|
||||
enable_batch=True,
|
||||
enable_hybrid=True,
|
||||
enable_zero_sized=True,
|
||||
@ -3353,10 +3355,11 @@ class TestCase(expecttest.TestCase):
|
||||
constructors:
|
||||
|
||||
- sparse compressed input is defined as
|
||||
(compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype)
|
||||
(compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
- sparse COO input is defined as
|
||||
(indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype)
|
||||
(indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, pin_memory=pin_memory)
|
||||
|
||||
- strided input is defined as
|
||||
(values,), dict(device=device, dtype=dtype)
|
||||
@ -3368,17 +3371,23 @@ class TestCase(expecttest.TestCase):
|
||||
|
||||
if output_tensor:
|
||||
for args, kwargs in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype,
|
||||
pin_memory=pin_memory,
|
||||
enable_batch=enable_batch, enable_hybrid=enable_hybrid,
|
||||
enable_zero_sized=enable_zero_sized,
|
||||
enable_non_contiguous_indices=enable_non_contiguous_indices,
|
||||
enable_non_contiguous_values=enable_non_contiguous_values,
|
||||
enable_batch_variable_nse=enable_batch_variable_nse,
|
||||
output_tensor=False):
|
||||
if members_pin_memory:
|
||||
args = tuple(a.pin_memory() for a in args)
|
||||
if layout is torch.strided:
|
||||
assert len(args) == 1
|
||||
size = kwargs.pop('size', None) # to ensure that a zero-sized tensor has the desired shape
|
||||
assert size is not None
|
||||
yield args[0].reshape(size)
|
||||
if pin_memory:
|
||||
yield args[0].reshape(size).pin_memory()
|
||||
else:
|
||||
yield args[0].reshape(size)
|
||||
elif layout is torch.sparse_coo:
|
||||
yield torch.sparse_coo_tensor(*args, **kwargs)
|
||||
elif is_compressed_sparse_layout:
|
||||
@ -3594,25 +3603,24 @@ class TestCase(expecttest.TestCase):
|
||||
for densesize in densesizes:
|
||||
indices = [a.to(device=device, dtype=index_dtype) for a in data[:-1]]
|
||||
values = generate_values(data[-1], densesize).to(device=device, dtype=dtype)
|
||||
yield (*indices, values), dict(device=device, dtype=dtype,
|
||||
size=pattern.shape + densesize)
|
||||
kwargs = dict(device=device, dtype=dtype, size=pattern.shape + densesize)
|
||||
if pin_memory is not None:
|
||||
kwargs.update(pin_memory=pin_memory)
|
||||
|
||||
yield (*indices, values), kwargs.copy()
|
||||
if enable_non_contiguous_indices and pattern.ndim > 2:
|
||||
# sparse compressed indices can be sliced only along batch dimensions
|
||||
for (dim, offset) in {(0, 1), (-2, 0)}:
|
||||
indices_copy = [non_contiguous_copy(a, dim=dim, offset=offset) for a in indices]
|
||||
yield (*indices_copy, values), dict(device=device, dtype=dtype,
|
||||
size=pattern.shape + densesize)
|
||||
yield (*indices_copy, values), kwargs.copy()
|
||||
|
||||
if enable_non_contiguous_values:
|
||||
values_copy = non_contiguous_copy(values, dim=-1, offset=1)
|
||||
yield (*indices_copy, values_copy), dict(device=device, dtype=dtype,
|
||||
size=pattern.shape + densesize)
|
||||
yield (*indices_copy, values_copy), kwargs.copy()
|
||||
|
||||
if enable_non_contiguous_values:
|
||||
values_copy = non_contiguous_copy(values, dim=-1, offset=1)
|
||||
yield (*indices, values_copy), dict(device=device, dtype=dtype,
|
||||
size=pattern.shape + densesize)
|
||||
yield (*indices, values_copy), kwargs.copy()
|
||||
|
||||
# zero-sized tensor inputs, non-batch, non-hybrid/hybrid
|
||||
if enable_zero_sized:
|
||||
@ -3651,7 +3659,10 @@ class TestCase(expecttest.TestCase):
|
||||
values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
|
||||
else:
|
||||
assert 0 # unreachable
|
||||
yield (*indices, values), dict(device=device, dtype=dtype, size=basesize + densesize)
|
||||
kwargs = dict(device=device, dtype=dtype, size=basesize + densesize)
|
||||
if pin_memory is not None:
|
||||
kwargs.update(pin_memory=pin_memory)
|
||||
yield (*indices, values), kwargs
|
||||
|
||||
def safeToDense(self, t):
|
||||
# coalesce is only implemented for COO
|
||||
|
||||
Reference in New Issue
Block a user