mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Sparse half embeddings on cuda (#19695)
Summary: ``` import torch a = torch.nn.Embedding(3, 4, sparse=True).half().cuda() a(torch.LongTensor([1, 0]).cuda()).sum().backward() ``` gave: `RuntimeError: torch.cuda.sparse.HalfTensor is not enabled` This PR enables sparse.HalfTensor on cuda. Still won't work for CPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19695 Differential Revision: D15281162 Pulled By: nairbv fbshipit-source-id: 0d83d946a059393bd53d8b8102e2daa9b4c02588
This commit is contained in:
committed by
Facebook Github Bot
parent
148e90ba2a
commit
d68802ba47
@ -149,6 +149,7 @@
|
||||
[[
|
||||
name: _th_nonzero
|
||||
cname: nonzero
|
||||
cpu_half: True
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
variants:
|
||||
|
@ -390,9 +390,6 @@ def legacy_iterate_types():
|
||||
for scalar_type in (scalar_types + quantized_scalar_types):
|
||||
if density == 'Mkldnn' and (backend != 'CPU' or scalar_type[0] != 'Float'):
|
||||
continue
|
||||
if density == 'Sparse' and scalar_type[0] == 'Half':
|
||||
# THS does not do half type yet.
|
||||
continue
|
||||
else:
|
||||
yield (backend, density, scalar_type)
|
||||
for backend in quantized_backends:
|
||||
|
@ -2523,12 +2523,14 @@
|
||||
variants: function, method
|
||||
|
||||
- func: to_sparse(Tensor self, int sparse_dim) -> Tensor
|
||||
cpu_half: True
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU: dense_to_sparse
|
||||
CUDA: dense_to_sparse
|
||||
|
||||
- func: to_sparse(Tensor self) -> Tensor
|
||||
cpu_half: True
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU: dense_to_sparse
|
||||
|
@ -325,6 +325,9 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
|
||||
// NB: Dropped the resizeNd variants
|
||||
|
||||
Tensor sparse_to_dense(const SparseTensor& self) {
|
||||
if(self.scalar_type() == ScalarType::Half && self.options().device().is_cpu()) {
|
||||
AT_ERROR("to_dense() not supported for float16 on CPU");
|
||||
}
|
||||
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
|
||||
return dst.add_(self);
|
||||
}
|
||||
|
@ -31,6 +31,9 @@
|
||||
#include <TH/generic/THTensorMath.h>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <TH/generic/THTensorMath.h>
|
||||
#include <TH/THGenerateHalfType.h>
|
||||
|
||||
/* fill and zero*/
|
||||
#include <TH/generic/THTensorFill.h>
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
@ -8,3 +8,6 @@
|
||||
|
||||
#include <TH/generic/THTensorEvenMoreMath.cpp>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <TH/generic/THTensorEvenMoreMath.cpp>
|
||||
#include <TH/THGenerateHalfType.h>
|
||||
|
@ -11,7 +11,7 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
|
||||
int64_t *subscript_data;
|
||||
int64_t i = 0;
|
||||
#ifdef TH_REAL_IS_HALF
|
||||
#define IS_NONZERO(val) ((val.x & 0x7fff) != 0)
|
||||
#define IS_NONZERO(val) (c10::Half(0)!=val)
|
||||
#else
|
||||
#define IS_NONZERO(val) ((val)!=0)
|
||||
#endif
|
||||
@ -65,8 +65,12 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
|
||||
);
|
||||
delete [] sizes;
|
||||
delete [] idx;
|
||||
|
||||
#undef IS_NONZERO
|
||||
}
|
||||
|
||||
#if !defined(TH_REAL_IS_HALF) /* non half only part */
|
||||
|
||||
accreal THTensor_(sumall)(THTensor *tensor)
|
||||
{
|
||||
accreal sum = 0;
|
||||
@ -74,8 +78,7 @@ accreal THTensor_(sumall)(THTensor *tensor)
|
||||
scalar_t, tensor, *tensor_data, sum, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
|
||||
return sum;
|
||||
}
|
||||
|
||||
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
|
||||
#if !defined(TH_REAL_IS_BOOL)
|
||||
|
||||
void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value)
|
||||
{
|
||||
@ -906,4 +909,6 @@ void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value)
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#endif /* TH_GENERIC_FILE */
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
|
||||
|
||||
#ifndef TH_REAL_IS_HALF
|
||||
|
||||
TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, scalar_t value);
|
||||
TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, scalar_t value);
|
||||
TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, scalar_t value);
|
||||
@ -183,3 +185,4 @@ TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alp
|
||||
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
@ -2070,24 +2070,61 @@ class TestNN(NNTestCase):
|
||||
def test_embedding_dense_grad_cuda(self):
|
||||
self._test_embedding_dense_grad("cuda")
|
||||
|
||||
def test_embedding_sparse_backward(self):
|
||||
def test_move_sparse_half_embedding(self):
|
||||
embedding = nn.Embedding(10, 3, sparse=True)
|
||||
embedding.zero_grad()
|
||||
embedding(torch.LongTensor([7, 1, 3])).sum().backward()
|
||||
self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3]]))
|
||||
self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(3, 3))
|
||||
self.assertEqual(embedding.weight.device.type, 'cpu')
|
||||
self.assertEqual(embedding.weight.dtype, torch.float64)
|
||||
embedding.to(torch.float16)
|
||||
self.assertEqual(embedding.weight.dtype, torch.float16)
|
||||
self.assertEqual(embedding.embedding_dim, 3)
|
||||
self.assertEqual(embedding.num_embeddings, 10)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
embedding.to('cuda')
|
||||
self.assertEqual(embedding.weight.device.type, 'cuda')
|
||||
embedding.to('cpu')
|
||||
self.assertEqual(embedding.weight.device.type, 'cpu')
|
||||
|
||||
def test_embedding_sparse_backward(self):
|
||||
self._test_embedding_backward()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_embedding_sparse_half_backward(self):
|
||||
# same as test_embedding_sparse_backward above but testing half types in
|
||||
# cuda. cpu sum not supported for half types.
|
||||
self._test_embedding_backward('cuda', torch.float16)
|
||||
|
||||
def _test_embedding_backward(self, device='cpu', dtype=torch.float64):
|
||||
embedding = nn.Embedding(10, 3, sparse=True)
|
||||
tensor = torch.tensor([[7, 1, 3]])
|
||||
ones = torch.tensor(1.).expand(3, 3)
|
||||
tensorTwice = tensor.repeat(1, 2)
|
||||
onesTwice = torch.cat((ones, ones))
|
||||
|
||||
embedding = embedding.to(dtype=dtype).to(device)
|
||||
tensor = tensor.to(device)
|
||||
ones = ones.to(device)
|
||||
tensorTwice = tensorTwice.to(device)
|
||||
onesTwice = onesTwice.to(device)
|
||||
|
||||
embedding.zero_grad()
|
||||
embedding(torch.LongTensor([7, 1, 3])).sum().backward()
|
||||
embedding(torch.LongTensor([7, 1, 3])).sum().backward()
|
||||
self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 7, 1, 3]]))
|
||||
self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3))
|
||||
embedding(tensor[0]).sum().backward()
|
||||
self.assertEqual(embedding.weight.grad._indices(), tensor)
|
||||
self.assertEqual(embedding.weight.grad._values(), ones)
|
||||
|
||||
embedding.zero_grad()
|
||||
embedding(torch.LongTensor([7, 1, 3])).sum().backward()
|
||||
embedding(torch.LongTensor([8, 1, 3])).sum().backward()
|
||||
self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 8, 1, 3]]))
|
||||
self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3))
|
||||
embedding(tensor[0]).sum().backward()
|
||||
embedding(tensor[0]).sum().backward()
|
||||
self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
|
||||
self.assertEqual(embedding.weight.grad._values(), onesTwice)
|
||||
|
||||
embedding.zero_grad()
|
||||
embedding(tensor[0]).sum().backward()
|
||||
tensor[0, 0] = 8
|
||||
embedding(tensor[0]).sum().backward()
|
||||
tensorTwice[0, 3] = 8
|
||||
self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
|
||||
self.assertEqual(embedding.weight.grad._values(), onesTwice)
|
||||
|
||||
def test_embedding_padding_idx(self):
|
||||
embedding = nn.Embedding(10, 20, padding_idx=0)
|
||||
@ -2377,6 +2414,7 @@ class TestNN(NNTestCase):
|
||||
needed_prec = dtype2prec[dtype] * 2
|
||||
else:
|
||||
needed_prec = backward_prec
|
||||
|
||||
self.assertEqual(es_weight_grad, e.weight.grad, needed_prec)
|
||||
|
||||
if test_per_sample_weights and trainable_per_sample_weights:
|
||||
@ -2564,12 +2602,13 @@ class TestNN(NNTestCase):
|
||||
|
||||
def test_embedding_bag(self):
|
||||
for dtype in [torch.double, torch.float]:
|
||||
# TODO: figure out why backward on float breaks
|
||||
test_backward = dtype is not torch.float
|
||||
self._test_EmbeddingBag(False, 'sum', False, test_backward=test_backward, dtype=dtype)
|
||||
self._test_EmbeddingBag(False, 'mean', False, test_backward=test_backward, dtype=dtype)
|
||||
self._test_EmbeddingBag(False, 'max', False, test_backward=test_backward, dtype=dtype)
|
||||
self._test_EmbeddingBag(False, 'sum', False, dtype=dtype)
|
||||
self._test_EmbeddingBag(False, 'mean', False, dtype=dtype)
|
||||
self._test_EmbeddingBag(False, 'max', False, dtype=dtype)
|
||||
|
||||
# TODO: figure out why precision on sparse embeddings isn't the
|
||||
# same as for dense.
|
||||
test_backward = dtype is not torch.float
|
||||
self._test_EmbeddingBag(False, 'sum', True, test_backward=test_backward, dtype=dtype)
|
||||
self._test_EmbeddingBag(False, 'mean', True, test_backward=test_backward, dtype=dtype)
|
||||
|
||||
@ -2733,10 +2772,11 @@ class TestNN(NNTestCase):
|
||||
self._test_EmbeddingBag(True, 'sum', False, dtype)
|
||||
self._test_EmbeddingBag(True, 'mean', False, dtype)
|
||||
self._test_EmbeddingBag(True, 'max', False, dtype)
|
||||
if dtype != torch.half:
|
||||
# torch.cuda.sparse.HalfTensor is not enabled.
|
||||
self._test_EmbeddingBag(True, 'sum', True, dtype)
|
||||
self._test_EmbeddingBag(True, 'mean', True, dtype)
|
||||
|
||||
# see 'todo' in test_embedding_bag.
|
||||
test_backward = dtype is not torch.float16
|
||||
self._test_EmbeddingBag(True, 'sum', True, dtype, test_backward=test_backward)
|
||||
self._test_EmbeddingBag(True, 'mean', True, dtype, test_backward=test_backward)
|
||||
|
||||
def test_fractional_max_pool2d(self):
|
||||
x = torch.randn(1, 2, 7, 7, requires_grad=True)
|
||||
|
@ -234,33 +234,44 @@ class TestSparse(TestCase):
|
||||
[0, 0, 0, 3],
|
||||
[0, 0, 1, 4],
|
||||
])
|
||||
v = self.value_tensor([2, 1, 3, 4])
|
||||
x = self.sparse_tensor(i, v, torch.Size([3, 4, 5]))
|
||||
res = self.value_tensor([
|
||||
[[2, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
[[1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
[[0, 3, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 4]],
|
||||
])
|
||||
test_tensor(x, res)
|
||||
# we don't have to_dense for half types on CPU because it is implemented
|
||||
# with a slower add_ operation
|
||||
for dtype in [torch.float16, torch.float64] if self.device != 'cpu' else [torch.float64]:
|
||||
v = self.value_tensor([2, 1, 3, 4]).to(dtype=dtype)
|
||||
x = self.sparse_tensor(i, v, torch.Size([3, 4, 5]))
|
||||
res = self.value_tensor([
|
||||
[[2, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
[[1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
[[0, 3, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 4]],
|
||||
]).to(dtype=dtype)
|
||||
|
||||
i = self.index_tensor([
|
||||
[0, 1, 2, 2],
|
||||
[0, 0, 0, 3],
|
||||
[0, 0, 1, 4],
|
||||
])
|
||||
v = self.value_empty(4, 0)
|
||||
x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0]))
|
||||
res = self.value_empty(3, 4, 5, 0)
|
||||
test_tensor(x, res)
|
||||
test_tensor(x, res)
|
||||
|
||||
i = self.index_tensor([
|
||||
[0, 1, 2, 2],
|
||||
[0, 0, 0, 3],
|
||||
[0, 0, 1, 4],
|
||||
])
|
||||
v = self.value_empty(4, 0).to(dtype=dtype)
|
||||
x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0]))
|
||||
res = self.value_empty(3, 4, 5, 0).to(dtype=dtype)
|
||||
test_tensor(x, res)
|
||||
|
||||
# half tensors on cpu don't implement to_dense, so need to convert to float
|
||||
def _to_dense_half_safe(self, tensor):
|
||||
if(tensor.dtype == torch.half and tensor.device.type == 'cpu'):
|
||||
return tensor.to(torch.float).to_dense().to(torch.half)
|
||||
else:
|
||||
return tensor.to_dense()
|
||||
|
||||
def test_to_sparse(self):
|
||||
shape = [10, 5, 19, 8]
|
||||
@ -269,12 +280,15 @@ class TestSparse(TestCase):
|
||||
max_nnz *= dim_sz
|
||||
rnnz = torch.randint(2, max_nnz, (1,)).item()
|
||||
for nnz in [0, 1, rnnz]:
|
||||
expected, _, _ = self._gen_sparse(dim, nnz, shape)
|
||||
d = expected.to_dense()
|
||||
result = d.to_sparse(dim)
|
||||
self.assertEqual(d, result.to_dense()) # == not implemented for sparse tensors yet
|
||||
self.assertEqual(expected.size(), result.size())
|
||||
self.assertEqual(dim, result.sparse_dim())
|
||||
for dtype in [torch.float16, torch.float64, torch.int]:
|
||||
expected, _, _ = self._gen_sparse(dim, nnz, shape)
|
||||
expected = expected.to(dtype)
|
||||
|
||||
d = self._to_dense_half_safe(expected)
|
||||
result = d.to_sparse(dim)
|
||||
self.assertEqual(d, self._to_dense_half_safe(result)) # == not implemented for sparse tensors yet
|
||||
self.assertEqual(expected.size(), result.size())
|
||||
self.assertEqual(dim, result.sparse_dim())
|
||||
|
||||
sp, _, _ = self._gen_sparse(2, 10, [3, 3, 3])
|
||||
self.assertRaises(RuntimeError, lambda: sp.to_sparse())
|
||||
@ -563,6 +577,12 @@ class TestSparse(TestCase):
|
||||
|
||||
# test type conversion (when x1.copy_(x2), x1.dtype should stay the same)
|
||||
x1 = x1.to(torch.float32)
|
||||
|
||||
x2 = x2.to(torch.float16)
|
||||
x1_dtype = x1.dtype
|
||||
x1.copy_(x2)
|
||||
self.assertEqual(x1_dtype, x1.dtype)
|
||||
|
||||
x2 = x2.to(torch.float64)
|
||||
x1_dtype = x1.dtype
|
||||
x1.copy_(x2)
|
||||
@ -630,6 +650,12 @@ class TestSparse(TestCase):
|
||||
x = torch.sparse.FloatTensor(2, 3, 4)
|
||||
test_tensor(x)
|
||||
|
||||
x = torch.sparse.HalfTensor(2, 3, 4)
|
||||
test_tensor(x)
|
||||
|
||||
x = torch.cuda.sparse.HalfTensor(2, 3, 4)
|
||||
test_tensor(x)
|
||||
|
||||
x = torch.sparse.FloatTensor(2, 3, 4, 0)
|
||||
test_tensor(x)
|
||||
|
||||
@ -1512,33 +1538,33 @@ class TestSparse(TestCase):
|
||||
for use_tensor_idx in [True, False]:
|
||||
for use_tensor_val in [True, False]:
|
||||
for use_cuda in ([False] if not torch.cuda.is_available() else [True, False]):
|
||||
# have to include size with cuda sparse tensors
|
||||
include_size = include_size or use_cuda
|
||||
dtype = torch.float64
|
||||
long_dtype = torch.int64
|
||||
device = torch.device('cpu') if not use_cuda else \
|
||||
torch.device(torch.cuda.device_count() - 1)
|
||||
indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2])
|
||||
if test_empty_tensor:
|
||||
values = self.value_empty(1, 0)
|
||||
else:
|
||||
if use_tensor_val:
|
||||
values = torch.tensor([1.], dtype=dtype)
|
||||
for dtype in [torch.float64, torch.float16]:
|
||||
# have to include size with cuda sparse tensors
|
||||
include_size = include_size or use_cuda
|
||||
long_dtype = torch.int64
|
||||
device = torch.device('cpu') if not use_cuda else \
|
||||
torch.device(torch.cuda.device_count() - 1)
|
||||
indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2])
|
||||
if test_empty_tensor:
|
||||
values = self.value_empty(1, 0).to(dtype)
|
||||
else:
|
||||
values = 1.
|
||||
if include_size:
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, size, dtype=dtype,
|
||||
device=device, requires_grad=True)
|
||||
else:
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=dtype,
|
||||
device=device, requires_grad=True)
|
||||
self.assertEqual(indices, sparse_tensor._indices())
|
||||
self.assertEqual(values, sparse_tensor._values())
|
||||
self.assertEqual(size if include_size else default_size, sparse_tensor.size())
|
||||
self.assertEqual(dtype, sparse_tensor.dtype)
|
||||
if use_cuda:
|
||||
self.assertEqual(device, sparse_tensor._values().device)
|
||||
self.assertEqual(True, sparse_tensor.requires_grad)
|
||||
if use_tensor_val:
|
||||
values = torch.tensor([1.], dtype=dtype)
|
||||
else:
|
||||
values = 1.
|
||||
if include_size:
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, size, dtype=dtype,
|
||||
device=device, requires_grad=True)
|
||||
else:
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=dtype,
|
||||
device=device, requires_grad=True)
|
||||
self.assertEqual(indices, sparse_tensor._indices())
|
||||
self.assertEqual(values, sparse_tensor._values())
|
||||
self.assertEqual(size if include_size else default_size, sparse_tensor.size())
|
||||
self.assertEqual(dtype, sparse_tensor.dtype)
|
||||
if use_cuda:
|
||||
self.assertEqual(device, sparse_tensor._values().device)
|
||||
self.assertEqual(True, sparse_tensor.requires_grad)
|
||||
|
||||
def test_factory_size_check(self):
|
||||
indices = self.index_tensor([[1, 2],
|
||||
@ -1653,6 +1679,8 @@ class TestSparse(TestCase):
|
||||
|
||||
@cpu_only
|
||||
def test_factory_type_inference(self):
|
||||
t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1.], dtype=torch.float16))
|
||||
self.assertEqual(torch.float16, t.dtype)
|
||||
t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1.], dtype=torch.float32))
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1.], dtype=torch.float64))
|
||||
@ -1660,6 +1688,8 @@ class TestSparse(TestCase):
|
||||
t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1]))
|
||||
self.assertEqual(torch.int64, t.dtype)
|
||||
|
||||
t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.HalfTensor(1, 0))
|
||||
self.assertEqual(torch.float16, t.dtype)
|
||||
t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.FloatTensor(1, 0))
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.DoubleTensor(1, 0))
|
||||
@ -1713,6 +1743,10 @@ class TestSparse(TestCase):
|
||||
values = torch.tensor([1.], dtype=torch.float32)
|
||||
test_tensor(indices, values, True, False)
|
||||
|
||||
indices = torch.tensor(([0], [2]), dtype=torch.int64)
|
||||
values = torch.tensor([1.], dtype=torch.float16)
|
||||
test_tensor(indices, values, True, False)
|
||||
|
||||
indices = torch.tensor(([0], [2]), dtype=torch.int64)
|
||||
values = torch.FloatTensor(1, 0)
|
||||
test_tensor(indices, values, True, True) # An empty tensor's data_ptr is always equal to 0
|
||||
@ -1766,14 +1800,14 @@ class TestSparse(TestCase):
|
||||
|
||||
@cpu_only # not really, but we only really want to run this once
|
||||
def test_dtypes(self):
|
||||
all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]
|
||||
all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes()]
|
||||
do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu'))
|
||||
if torch.cuda.is_available():
|
||||
do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0'))
|
||||
|
||||
@cpu_only # not really, but we only really want to run this once
|
||||
def test_empty_full(self):
|
||||
all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]
|
||||
all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes()]
|
||||
do_test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu'))
|
||||
if torch.cuda.device_count() > 0:
|
||||
do_test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, None)
|
||||
|
@ -85,8 +85,8 @@ std::vector<std::pair<Backend, ScalarType>> all_declared_types() {
|
||||
ScalarType::Int, ScalarType::Long, ScalarType::Short, ScalarType::Half, ScalarType::Bool};
|
||||
for (auto& backend : backends) {
|
||||
for (auto& scalar_type : scalar_types) {
|
||||
// there are no sparse half or bool types.
|
||||
if ((scalar_type == ScalarType::Half || scalar_type == ScalarType::Bool) && (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) {
|
||||
// there is no sparse bool type.
|
||||
if (scalar_type == ScalarType::Bool && (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) {
|
||||
continue;
|
||||
}
|
||||
ret.emplace_back(std::make_pair(backend, scalar_type));
|
||||
|
Reference in New Issue
Block a user