mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Separate cuda-ness from dtype. (#6470)
* Separate cuda-ness from dtype. There are no longer torch.cuda.int64, etc; only torch.int64 that correspond to at::ScalarType. At the python arg parser level, the corresponding ATen type is selected from the combination of (ScalarType, Layout, Device). There is also currently unused code in here for support ScalarType in native_functions; this will be used for specifying aggregate types on reduction functions. * Fix test_autograd. * Add defaults to randint_like. * Track is_cuda in py tensor types. * Fix test_sparse. * Fix multiprocessing. * Fix rnn. * Fix test_nn. * Fix flake8.
This commit is contained in:
@ -310,6 +310,8 @@
|
|||||||
|
|
||||||
- func: empty_like(Tensor self, *, Type dtype) -> Tensor
|
- func: empty_like(Tensor self, *, Type dtype) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
|
python_default_init:
|
||||||
|
dtype: self.type()
|
||||||
|
|
||||||
- func: exp(Tensor self) -> Tensor
|
- func: exp(Tensor self) -> Tensor
|
||||||
|
|
||||||
@ -357,6 +359,8 @@
|
|||||||
|
|
||||||
- func: full_like(Tensor self, Scalar fill_value, *, Type dtype) -> Tensor
|
- func: full_like(Tensor self, Scalar fill_value, *, Type dtype) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
|
python_default_init:
|
||||||
|
dtype: self.type()
|
||||||
|
|
||||||
- func: hinge_embedding_loss(Tensor self, Tensor target, double margin=1.0, bool size_average=true, bool reduce=true) -> Tensor
|
- func: hinge_embedding_loss(Tensor self, Tensor target, double margin=1.0, bool size_average=true, bool reduce=true) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
@ -470,6 +474,8 @@
|
|||||||
|
|
||||||
- func: ones_like(Tensor self, *, Type dtype) -> Tensor
|
- func: ones_like(Tensor self, *, Type dtype) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
|
python_default_init:
|
||||||
|
dtype: self.type()
|
||||||
|
|
||||||
- func: pairwise_distance(Tensor x1, Tensor x2, double p=2, double eps=1e-6, bool keepdim=false) -> Tensor
|
- func: pairwise_distance(Tensor x1, Tensor x2, double p=2, double eps=1e-6, bool keepdim=false) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
@ -490,6 +496,9 @@
|
|||||||
|
|
||||||
- func: rand_like(Tensor self, *, Type dtype) -> Tensor
|
- func: rand_like(Tensor self, *, Type dtype) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
|
python_default_init:
|
||||||
|
dtype: self.type()
|
||||||
|
|
||||||
|
|
||||||
- func: randint(Type dtype, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor
|
- func: randint(Type dtype, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
@ -511,9 +520,13 @@
|
|||||||
|
|
||||||
- func: randint_like(Tensor self, int64_t high, *, Type dtype) -> Tensor
|
- func: randint_like(Tensor self, int64_t high, *, Type dtype) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
|
python_default_init:
|
||||||
|
dtype: self.type()
|
||||||
|
|
||||||
- func: randint_like(Tensor self, int64_t low, int64_t high, *, Type dtype) -> Tensor
|
- func: randint_like(Tensor self, int64_t low, int64_t high, *, Type dtype) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
|
python_default_init:
|
||||||
|
dtype: self.type()
|
||||||
|
|
||||||
- func: randn(Type dtype, IntList size, *, Generator* generator=nullptr) -> Tensor
|
- func: randn(Type dtype, IntList size, *, Generator* generator=nullptr) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
@ -526,6 +539,8 @@
|
|||||||
|
|
||||||
- func: randn_like(Tensor self, *, Type dtype) -> Tensor
|
- func: randn_like(Tensor self, *, Type dtype) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
|
python_default_init:
|
||||||
|
dtype: self.type()
|
||||||
|
|
||||||
- func: randperm(Type dtype, int64_t n, *, Generator* generator=nullptr) -> Tensor
|
- func: randperm(Type dtype, int64_t n, *, Generator* generator=nullptr) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
@ -732,6 +747,8 @@
|
|||||||
|
|
||||||
- func: zeros_like(Tensor self, *, Type dtype) -> Tensor
|
- func: zeros_like(Tensor self, *, Type dtype) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
|
python_default_init:
|
||||||
|
dtype: self.type()
|
||||||
|
|
||||||
- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
|
- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
|
||||||
dispatch:
|
dispatch:
|
||||||
|
|||||||
@ -851,19 +851,20 @@ class TestAutograd(TestCase):
|
|||||||
def test_requires_grad_factory(self):
|
def test_requires_grad_factory(self):
|
||||||
x = Variable(torch.randn(2, 3))
|
x = Variable(torch.randn(2, 3))
|
||||||
fns = [torch.ones_like, torch.testing.randn_like]
|
fns = [torch.ones_like, torch.testing.randn_like]
|
||||||
dtypes = [torch.float32, torch.float64, torch.cuda.float32, torch.cuda.float64]
|
dtypes = [torch.float32, torch.float64]
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
for requires_grad in [True, False]:
|
for requires_grad in [True, False]:
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
if not dtype.is_cuda:
|
for use_cuda in [True, False]:
|
||||||
output = fn(x, dtype=dtype, requires_grad=requires_grad)
|
if not use_cuda:
|
||||||
self.assertEqual(requires_grad, output.requires_grad)
|
output = fn(x, dtype=dtype, requires_grad=requires_grad)
|
||||||
self.assertIs(dtype, output.dtype)
|
self.assertEqual(requires_grad, output.requires_grad)
|
||||||
elif torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
self.assertIs(dtype, output.dtype)
|
||||||
output = fn(x, dtype=dtype, device=1, requires_grad=requires_grad)
|
elif torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
||||||
self.assertEqual(requires_grad, output.requires_grad)
|
output = fn(x, dtype=dtype, device=1, requires_grad=requires_grad)
|
||||||
self.assertIs(dtype, output.dtype)
|
self.assertEqual(requires_grad, output.requires_grad)
|
||||||
self.assertEqual(1, output.get_device())
|
self.assertIs(dtype, output.dtype)
|
||||||
|
self.assertEqual(1, output.get_device())
|
||||||
|
|
||||||
def test_grad_assignment(self):
|
def test_grad_assignment(self):
|
||||||
x = Variable(torch.randn(5, 5))
|
x = Variable(torch.randn(5, 5))
|
||||||
|
|||||||
@ -35,7 +35,7 @@ def is_floating(t):
|
|||||||
|
|
||||||
def is_half(t):
|
def is_half(t):
|
||||||
if isinstance(t, torch.Tensor):
|
if isinstance(t, torch.Tensor):
|
||||||
return t.dtype in [torch.float16, torch.cuda.float16]
|
return t.dtype == torch.float16
|
||||||
assert isinstance(t, type)
|
assert isinstance(t, type)
|
||||||
assert t != torch.autograd.Variable
|
assert t != torch.autograd.Variable
|
||||||
return t in [torch.HalfTensor, torch.cuda.HalfTensor]
|
return t in [torch.HalfTensor, torch.cuda.HalfTensor]
|
||||||
@ -1069,7 +1069,7 @@ class TestCuda(TestCase):
|
|||||||
TestTorch._test_cat_empty(self, use_cuda=True)
|
TestTorch._test_cat_empty(self, use_cuda=True)
|
||||||
|
|
||||||
def test_bernoulli(self):
|
def test_bernoulli(self):
|
||||||
x = torch.tensor([0, 1], dtype=torch.cuda.float32)
|
x = torch.tensor([0, 1], dtype=torch.float32, device='cuda')
|
||||||
self.assertEqual(x.bernoulli().tolist(), [0, 1])
|
self.assertEqual(x.bernoulli().tolist(), [0, 1])
|
||||||
|
|
||||||
def test_cat_bad_input_sizes(self):
|
def test_cat_bad_input_sizes(self):
|
||||||
@ -1432,7 +1432,7 @@ class TestCuda(TestCase):
|
|||||||
TestTorch._test_int_pow(self, lambda x: x.cuda())
|
TestTorch._test_int_pow(self, lambda x: x.cuda())
|
||||||
|
|
||||||
def test_remainder_overflow(self):
|
def test_remainder_overflow(self):
|
||||||
TestTorch._test_remainder_overflow(self, dtype=torch.cuda.int64)
|
TestTorch._test_remainder_overflow(self, dtype=torch.int64, device='cuda')
|
||||||
|
|
||||||
def test_var(self):
|
def test_var(self):
|
||||||
cpu_tensor = torch.randn(2, 3, 3)
|
cpu_tensor = torch.randn(2, 3, 3)
|
||||||
@ -1541,10 +1541,10 @@ class TestCuda(TestCase):
|
|||||||
self.assertEqual(a, b.cuda())
|
self.assertEqual(a, b.cuda())
|
||||||
|
|
||||||
def test_diagonal(self):
|
def test_diagonal(self):
|
||||||
TestTorch._test_diagonal(self, dtype=torch.cuda.float32)
|
TestTorch._test_diagonal(self, dtype=torch.float32, device='cuda')
|
||||||
|
|
||||||
def test_diagflat(self):
|
def test_diagflat(self):
|
||||||
TestTorch._test_diagflat(self, dtype=torch.cuda.float32)
|
TestTorch._test_diagflat(self, dtype=torch.float32, device='cuda')
|
||||||
|
|
||||||
@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
|
@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
|
||||||
def test_get_set_rng_state_all(self):
|
def test_get_set_rng_state_all(self):
|
||||||
|
|||||||
@ -371,21 +371,21 @@ class TestMultiprocessing(TestCase):
|
|||||||
self.assertEqual(list(tensor), [4, 4, 4, 4])
|
self.assertEqual(list(tensor), [4, 4, 4, 4])
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
def _test_empty_tensor_sharing(self, dtype):
|
def _test_empty_tensor_sharing(self, dtype, device):
|
||||||
q = mp.Queue()
|
q = mp.Queue()
|
||||||
empty = torch.tensor([], dtype=dtype)
|
empty = torch.tensor([], dtype=dtype, device=device)
|
||||||
q.put(empty)
|
q.put(empty)
|
||||||
out = q.get(timeout=1)
|
out = q.get(timeout=1)
|
||||||
self.assertEqual(out, empty)
|
self.assertEqual(out, empty)
|
||||||
|
|
||||||
def test_empty_tensor_sharing(self):
|
def test_empty_tensor_sharing(self):
|
||||||
self._test_empty_tensor_sharing(torch.float32)
|
self._test_empty_tensor_sharing(torch.float32, torch.device('cpu'))
|
||||||
self._test_empty_tensor_sharing(torch.int64)
|
self._test_empty_tensor_sharing(torch.int64, torch.device('cpu'))
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
|
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
|
||||||
def test_empty_tensor_sharing_cuda(self):
|
def test_empty_tensor_sharing_cuda(self):
|
||||||
self._test_empty_tensor_sharing(torch.cuda.float32)
|
self._test_empty_tensor_sharing(torch.float32, torch.device('cuda'))
|
||||||
self._test_empty_tensor_sharing(torch.cuda.int64)
|
self._test_empty_tensor_sharing(torch.int64, torch.device('cuda'))
|
||||||
|
|
||||||
def _test_autograd_sharing(self, var):
|
def _test_autograd_sharing(self, var):
|
||||||
ready = mp.Event()
|
ready = mp.Event()
|
||||||
|
|||||||
@ -2155,7 +2155,7 @@ class TestNN(NNTestCase):
|
|||||||
|
|
||||||
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||||
def test_broadcast_no_grad(self):
|
def test_broadcast_no_grad(self):
|
||||||
x = torch.randn(1, 2, dtype=torch.cuda.float32, requires_grad=True)
|
x = torch.randn(1, 2, dtype=torch.float32, requires_grad=True, device='cuda')
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
broadcasted = Broadcast.apply((0, 1), x)
|
broadcasted = Broadcast.apply((0, 1), x)
|
||||||
self.assertTrue(x.requires_grad)
|
self.assertTrue(x.requires_grad)
|
||||||
|
|||||||
@ -850,9 +850,9 @@ class TestSparse(TestCase):
|
|||||||
for use_cuda in ([False] if not torch.cuda.is_available() else [True, False]):
|
for use_cuda in ([False] if not torch.cuda.is_available() else [True, False]):
|
||||||
# have to include size with cuda sparse tensors
|
# have to include size with cuda sparse tensors
|
||||||
include_size = include_size or use_cuda
|
include_size = include_size or use_cuda
|
||||||
dtype = torch.cuda.float64 if use_cuda else torch.float64
|
dtype = torch.float64
|
||||||
long_dtype = torch.cuda.int64 if use_cuda else torch.int64
|
long_dtype = torch.int64
|
||||||
device = -1 if not use_cuda else torch.cuda.device_count() - 1
|
device = torch.device('cpu') if not use_cuda else torch.device(torch.cuda.device_count() - 1)
|
||||||
indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2])
|
indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2])
|
||||||
values = torch.tensor([1.], dtype=dtype) if use_tensor_val else 1.
|
values = torch.tensor([1.], dtype=dtype) if use_tensor_val else 1.
|
||||||
if include_size:
|
if include_size:
|
||||||
@ -866,7 +866,7 @@ class TestSparse(TestCase):
|
|||||||
self.assertEqual(size if include_size else default_size, sparse_tensor.size())
|
self.assertEqual(size if include_size else default_size, sparse_tensor.size())
|
||||||
self.assertEqual(dtype, sparse_tensor.dtype)
|
self.assertEqual(dtype, sparse_tensor.dtype)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.assertEqual(device, sparse_tensor._values().get_device())
|
self.assertEqual(device, sparse_tensor._values().device)
|
||||||
self.assertEqual(True, sparse_tensor.requires_grad)
|
self.assertEqual(True, sparse_tensor.requires_grad)
|
||||||
|
|
||||||
@cpu_only
|
@cpu_only
|
||||||
@ -910,17 +910,18 @@ class TestSparse(TestCase):
|
|||||||
|
|
||||||
@cpu_only # not really, but we only really want to run this once
|
@cpu_only # not really, but we only really want to run this once
|
||||||
def test_dtypes(self):
|
def test_dtypes(self):
|
||||||
all_dtypes = torch.testing.get_all_dtypes()
|
all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]
|
||||||
cpu_dtypes = [d for d in all_dtypes if not d.is_cuda and d != torch.float16]
|
TestTorch._test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu'))
|
||||||
cuda_dtypes = [d for d in all_dtypes if d.is_cuda and d != torch.cuda.float16]
|
if torch.cuda.is_available():
|
||||||
TestTorch._test_dtypes(self, cpu_dtypes, cuda_dtypes, torch.sparse_coo)
|
TestTorch._test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0'))
|
||||||
|
|
||||||
@cpu_only # not really, but we only really want to run this once
|
@cpu_only # not really, but we only really want to run this once
|
||||||
def test_empty_full(self):
|
def test_empty_full(self):
|
||||||
all_dtypes = torch.testing.get_all_dtypes()
|
all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]
|
||||||
cpu_dtypes = [d for d in all_dtypes if not d.is_cuda and d != torch.half]
|
TestTorch._test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu'))
|
||||||
cuda_dtypes = [d for d in all_dtypes if d.is_cuda and d != torch.cuda.half]
|
if torch.cuda.device_count() > 0:
|
||||||
TestTorch._test_empty_full(self, cpu_dtypes, cuda_dtypes, torch.sparse_coo)
|
TestTorch._test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, -1)
|
||||||
|
TestTorch._test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0'))
|
||||||
|
|
||||||
def test_is_sparse(self):
|
def test_is_sparse(self):
|
||||||
x = torch.randn(3, 3)
|
x = torch.randn(3, 3)
|
||||||
|
|||||||
@ -957,9 +957,9 @@ class TestTorch(TestCase):
|
|||||||
long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1))
|
long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_remainder_overflow(self, dtype=torch.int64):
|
def _test_remainder_overflow(self, dtype, device):
|
||||||
# Check Integer Overflows
|
# Check Integer Overflows
|
||||||
x = torch.tensor(23500, dtype=dtype)
|
x = torch.tensor(23500, dtype=dtype, device=device)
|
||||||
q = 392486996410368
|
q = 392486996410368
|
||||||
self.assertEqual(x % q, x)
|
self.assertEqual(x % q, x)
|
||||||
self.assertEqual(-x % q, q - x)
|
self.assertEqual(-x % q, q - x)
|
||||||
@ -967,7 +967,7 @@ class TestTorch(TestCase):
|
|||||||
self.assertEqual(-x % -q, -x)
|
self.assertEqual(-x % -q, -x)
|
||||||
|
|
||||||
def test_remainder_overflow(self):
|
def test_remainder_overflow(self):
|
||||||
self._test_remainder_overflow(self, dtype=torch.int64)
|
self._test_remainder_overflow(self, dtype=torch.int64, device='cpu')
|
||||||
|
|
||||||
def test_mm(self):
|
def test_mm(self):
|
||||||
# helper function
|
# helper function
|
||||||
@ -1429,28 +1429,19 @@ class TestTorch(TestCase):
|
|||||||
self.assertEqual(output, expected)
|
self.assertEqual(output, expected)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_dtypes(self, cpu_dtypes, cuda_dtypes, layout):
|
def _test_dtypes(self, dtypes, layout, device):
|
||||||
dtypes = cpu_dtypes + (cuda_dtypes if torch.cuda.is_available() else [])
|
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
# no ops on torch.float16 currently, cuda.float16 doesn't work on windows
|
|
||||||
if dtype != torch.float16:
|
if dtype != torch.float16:
|
||||||
if dtype.is_cuda and torch.cuda.device_count() > 1:
|
out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device)
|
||||||
out = torch.zeros((2, 3), device=1, dtype=dtype, layout=layout)
|
self.assertIs(dtype, out.dtype)
|
||||||
self.assertIs(dtype, out.dtype)
|
self.assertIs(layout, out.layout)
|
||||||
self.assertIs(layout, out.layout)
|
self.assertEqual(device, out.device)
|
||||||
self.assertEqual(1, out.get_device())
|
|
||||||
else:
|
|
||||||
out = torch.zeros((2, 3), dtype=dtype, layout=layout)
|
|
||||||
self.assertIs(dtype, out.dtype)
|
|
||||||
self.assertIs(layout, out.layout)
|
|
||||||
self.assertEqual(dtype in cuda_dtypes, dtype.is_cuda)
|
|
||||||
|
|
||||||
def test_dtypes(self):
|
def test_dtypes(self):
|
||||||
all_dtypes = torch.testing.get_all_dtypes()
|
all_dtypes = torch.testing.get_all_dtypes()
|
||||||
cpu_dtypes = [d for d in all_dtypes if not d.is_cuda]
|
self._test_dtypes(self, all_dtypes, torch.strided, torch.device('cpu'))
|
||||||
cuda_dtypes = [d for d in all_dtypes if d.is_cuda]
|
if torch.cuda.is_available():
|
||||||
self._test_dtypes(self, cpu_dtypes, cuda_dtypes, torch.strided)
|
self._test_dtypes(self, all_dtypes, torch.strided, torch.device('cuda:0'))
|
||||||
|
|
||||||
def test_device(self):
|
def test_device(self):
|
||||||
cpu = torch.device('cpu')
|
cpu = torch.device('cpu')
|
||||||
@ -1508,20 +1499,19 @@ class TestTorch(TestCase):
|
|||||||
assertEqual('cuda:0', lambda: torch.tensor(5).cuda('cuda:0'))
|
assertEqual('cuda:0', lambda: torch.tensor(5).cuda('cuda:0'))
|
||||||
self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu'))
|
self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu'))
|
||||||
self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu:0'))
|
self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu:0'))
|
||||||
assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.cuda.int64, device=0))
|
assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device=0))
|
||||||
assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.cuda.int64, device='cuda:0'))
|
assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:0'))
|
||||||
assertEqual('cuda:' + str(torch.cuda.current_device()),
|
assertEqual('cuda:' + str(torch.cuda.current_device()),
|
||||||
lambda: torch.tensor(5, dtype=torch.cuda.int64, device='cuda'))
|
lambda: torch.tensor(5, dtype=torch.int64, device='cuda'))
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
assertEqual('cuda:1', lambda: torch.tensor(5).cuda(1))
|
assertEqual('cuda:1', lambda: torch.tensor(5).cuda(1))
|
||||||
assertEqual('cuda:1', lambda: torch.tensor(5).cuda('cuda:1'))
|
assertEqual('cuda:1', lambda: torch.tensor(5).cuda('cuda:1'))
|
||||||
assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.cuda.int64, device=1))
|
assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device=1))
|
||||||
assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.cuda.int64, device='cuda:1'))
|
assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:1'))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_empty_full(self, cpu_dtypes, cuda_dtypes, layout):
|
def _test_empty_full(self, dtypes, layout, device):
|
||||||
dtypes = cpu_dtypes + (cuda_dtypes if torch.cuda.is_available() else [])
|
|
||||||
shape = torch.Size([2, 3])
|
shape = torch.Size([2, 3])
|
||||||
|
|
||||||
def check_value(tensor, dtype, layout, device, value, requires_grad):
|
def check_value(tensor, dtype, layout, device, value, requires_grad):
|
||||||
@ -1530,7 +1520,7 @@ class TestTorch(TestCase):
|
|||||||
self.assertIs(layout, tensor.layout)
|
self.assertIs(layout, tensor.layout)
|
||||||
self.assertEqual(tensor.requires_grad, requires_grad)
|
self.assertEqual(tensor.requires_grad, requires_grad)
|
||||||
if tensor.is_cuda and device != -1:
|
if tensor.is_cuda and device != -1:
|
||||||
self.assertEqual(device, tensor.get_device())
|
self.assertEqual(device, tensor.device)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
fill = tensor.new(shape).fill_(value)
|
fill = tensor.new(shape).fill_(value)
|
||||||
self.assertEqual(tensor, fill)
|
self.assertEqual(tensor, fill)
|
||||||
@ -1547,7 +1537,6 @@ class TestTorch(TestCase):
|
|||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
for rg in [True, False]:
|
for rg in [True, False]:
|
||||||
int64_dtype = get_int64_dtype(dtype)
|
int64_dtype = get_int64_dtype(dtype)
|
||||||
device = -1 if not (dtype.is_cuda and torch.cuda.device_count() > 1) else 1
|
|
||||||
v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
|
v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
|
||||||
check_value(v, dtype, layout, device, None, rg)
|
check_value(v, dtype, layout, device, None, rg)
|
||||||
out = v.new()
|
out = v.new()
|
||||||
@ -1576,10 +1565,10 @@ class TestTorch(TestCase):
|
|||||||
int64_dtype, layout, device, fv + 5, rg)
|
int64_dtype, layout, device, fv + 5, rg)
|
||||||
|
|
||||||
def test_empty_full(self):
|
def test_empty_full(self):
|
||||||
all_dtypes = torch.testing.get_all_dtypes()
|
self._test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cpu'))
|
||||||
cpu_dtypes = [d for d in all_dtypes if not d.is_cuda]
|
if torch.cuda.device_count() > 0:
|
||||||
cuda_dtypes = [d for d in all_dtypes if d.is_cuda]
|
self._test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, -1)
|
||||||
self._test_empty_full(self, cpu_dtypes, cuda_dtypes, torch.strided)
|
self._test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cuda:0'))
|
||||||
|
|
||||||
def test_dtype_out_match(self):
|
def test_dtype_out_match(self):
|
||||||
d = torch.autograd.Variable(torch.DoubleTensor(2, 3))
|
d = torch.autograd.Variable(torch.DoubleTensor(2, 3))
|
||||||
@ -1606,9 +1595,9 @@ class TestTorch(TestCase):
|
|||||||
self.assertIs(torch.FloatStorage, torch.Storage)
|
self.assertIs(torch.FloatStorage, torch.Storage)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.set_default_tensor_type(torch.cuda.float32)
|
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||||
self.assertIs(torch.cuda.float32, torch.get_default_dtype())
|
self.assertIs(torch.float32, torch.get_default_dtype())
|
||||||
self.assertIs(torch.cuda.float32, torch.cuda.FloatTensor.dtype)
|
self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype)
|
||||||
self.assertIs(torch.cuda.FloatStorage, torch.Storage)
|
self.assertIs(torch.cuda.FloatStorage, torch.Storage)
|
||||||
|
|
||||||
# don't support integral or sparse default types.
|
# don't support integral or sparse default types.
|
||||||
@ -1686,8 +1675,22 @@ class TestTorch(TestCase):
|
|||||||
saved_dtype = torch.get_default_dtype()
|
saved_dtype = torch.get_default_dtype()
|
||||||
torch.set_default_tensor_type(torch.float32)
|
torch.set_default_tensor_type(torch.float32)
|
||||||
self.assertIs(torch.float32, torch.tensor(0.).dtype)
|
self.assertIs(torch.float32, torch.tensor(0.).dtype)
|
||||||
torch.set_default_tensor_type(torch.cuda.float64)
|
self.assertEqual(torch.device('cpu'), torch.tensor(0.).device)
|
||||||
self.assertIs(torch.cuda.float64, torch.tensor(0.).dtype)
|
torch.set_default_tensor_type(torch.float64)
|
||||||
|
self.assertIs(torch.float64, torch.tensor(0.).dtype)
|
||||||
|
torch.set_default_tensor_type(saved_dtype)
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
|
||||||
|
def test_tensor_factory_cuda_type(self):
|
||||||
|
saved_dtype = torch.get_default_dtype()
|
||||||
|
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||||
|
x = torch.zeros((5, 5))
|
||||||
|
self.assertIs(torch.float32, x.dtype)
|
||||||
|
self.assertTrue(x.is_cuda)
|
||||||
|
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
|
||||||
|
x = torch.zeros((5, 5))
|
||||||
|
self.assertIs(torch.float64, x.dtype)
|
||||||
|
self.assertTrue(x.is_cuda)
|
||||||
torch.set_default_tensor_type(saved_dtype)
|
torch.set_default_tensor_type(saved_dtype)
|
||||||
|
|
||||||
def test_new_tensor(self):
|
def test_new_tensor(self):
|
||||||
@ -1721,23 +1724,23 @@ class TestTorch(TestCase):
|
|||||||
expected = expected.cuda(1)
|
expected = expected.cuda(1)
|
||||||
res1 = expected.new_tensor([1, 1])
|
res1 = expected.new_tensor([1, 1])
|
||||||
self.assertEqual(res1.get_device(), expected.get_device())
|
self.assertEqual(res1.get_device(), expected.get_device())
|
||||||
res1 = expected.new_tensor([1, 1], dtype=torch.cuda.int)
|
res1 = expected.new_tensor([1, 1], dtype=torch.int)
|
||||||
self.assertIs(torch.cuda.int, res1.dtype)
|
self.assertIs(torch.int, res1.dtype)
|
||||||
self.assertEqual(res1.get_device(), expected.get_device())
|
self.assertEqual(res1.get_device(), expected.get_device())
|
||||||
|
|
||||||
res2 = expected.new_tensor(expected)
|
res2 = expected.new_tensor(expected)
|
||||||
self.assertEqual(res2.get_device(), expected.get_device())
|
self.assertEqual(res2.get_device(), expected.get_device())
|
||||||
res2 = expected.new_tensor(expected, dtype=torch.cuda.int)
|
res2 = expected.new_tensor(expected, dtype=torch.int)
|
||||||
self.assertIs(torch.cuda.int, res1.dtype)
|
self.assertIs(torch.int, res1.dtype)
|
||||||
self.assertEqual(res2.get_device(), expected.get_device())
|
self.assertEqual(res2.get_device(), expected.get_device())
|
||||||
res2 = expected.new_tensor(expected, dtype=torch.cuda.int, device=0)
|
res2 = expected.new_tensor(expected, dtype=torch.int, device=0)
|
||||||
self.assertIs(torch.cuda.int, res1.dtype)
|
self.assertIs(torch.int, res1.dtype)
|
||||||
self.assertEqual(res2.get_device(), 0)
|
self.assertEqual(res2.get_device(), 0)
|
||||||
|
|
||||||
res1 = expected.new_tensor(1)
|
res1 = expected.new_tensor(1)
|
||||||
self.assertEqual(res1.get_device(), expected.get_device())
|
self.assertEqual(res1.get_device(), expected.get_device())
|
||||||
res1 = expected.new_tensor(1, dtype=torch.cuda.int)
|
res1 = expected.new_tensor(1, dtype=torch.int)
|
||||||
self.assertIs(torch.cuda.int, res1.dtype)
|
self.assertIs(torch.int, res1.dtype)
|
||||||
self.assertEqual(res1.get_device(), expected.get_device())
|
self.assertEqual(res1.get_device(), expected.get_device())
|
||||||
|
|
||||||
def test_diag(self):
|
def test_diag(self):
|
||||||
@ -1748,49 +1751,49 @@ class TestTorch(TestCase):
|
|||||||
self.assertEqual(res1, res2)
|
self.assertEqual(res1, res2)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_diagonal(self, dtype=torch.float32):
|
def _test_diagonal(self, dtype, device):
|
||||||
x = torch.randn((100, 100), dtype=dtype)
|
x = torch.randn((100, 100), dtype=dtype, device=device)
|
||||||
result = torch.diagonal(x)
|
result = torch.diagonal(x)
|
||||||
expected = torch.diag(x)
|
expected = torch.diag(x)
|
||||||
self.assertEqual(result, expected)
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
x = torch.randn((100, 100), dtype=dtype)
|
x = torch.randn((100, 100), dtype=dtype, device=device)
|
||||||
result = torch.diagonal(x, 17)
|
result = torch.diagonal(x, 17)
|
||||||
expected = torch.diag(x, 17)
|
expected = torch.diag(x, 17)
|
||||||
self.assertEqual(result, expected)
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
def test_diagonal(self):
|
def test_diagonal(self):
|
||||||
self._test_diagonal(self, dtype=torch.float32)
|
self._test_diagonal(self, dtype=torch.float32, device='cpu')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_diagflat(self, dtype=torch.float32):
|
def _test_diagflat(self, dtype, device):
|
||||||
# Basic sanity test
|
# Basic sanity test
|
||||||
x = torch.randn((100,), dtype=dtype)
|
x = torch.randn((100,), dtype=dtype, device=device)
|
||||||
result = torch.diagflat(x)
|
result = torch.diagflat(x)
|
||||||
expected = torch.diag(x)
|
expected = torch.diag(x)
|
||||||
self.assertEqual(result, expected)
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
# Test offset
|
# Test offset
|
||||||
x = torch.randn((100,), dtype=dtype)
|
x = torch.randn((100,), dtype=dtype, device=device)
|
||||||
result = torch.diagflat(x, 17)
|
result = torch.diagflat(x, 17)
|
||||||
expected = torch.diag(x, 17)
|
expected = torch.diag(x, 17)
|
||||||
self.assertEqual(result, expected)
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
# Test where input has more than one dimension
|
# Test where input has more than one dimension
|
||||||
x = torch.randn((2, 3, 4), dtype=dtype)
|
x = torch.randn((2, 3, 4), dtype=dtype, device=device)
|
||||||
result = torch.diagflat(x)
|
result = torch.diagflat(x)
|
||||||
expected = torch.diag(x.contiguous().view(-1))
|
expected = torch.diag(x.contiguous().view(-1))
|
||||||
self.assertEqual(result, expected)
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
# Noncontig input
|
# Noncontig input
|
||||||
x = torch.randn((2, 3, 4), dtype=dtype).transpose(2, 0)
|
x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0)
|
||||||
self.assertFalse(x.is_contiguous())
|
self.assertFalse(x.is_contiguous())
|
||||||
result = torch.diagflat(x)
|
result = torch.diagflat(x)
|
||||||
expected = torch.diag(x.contiguous().view(-1))
|
expected = torch.diag(x.contiguous().view(-1))
|
||||||
self.assertEqual(result, expected)
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
def test_diagflat(self):
|
def test_diagflat(self):
|
||||||
self._test_diagflat(self, dtype=torch.float32)
|
self._test_diagflat(self, dtype=torch.float32, device='cpu')
|
||||||
|
|
||||||
def test_eye(self):
|
def test_eye(self):
|
||||||
res1 = torch.eye(100, 100)
|
res1 = torch.eye(100, 100)
|
||||||
@ -2667,13 +2670,11 @@ class TestTorch(TestCase):
|
|||||||
def _test_cat_empty(self, use_cuda=False):
|
def _test_cat_empty(self, use_cuda=False):
|
||||||
# FIXME: this is legacy behavior and should be removed
|
# FIXME: this is legacy behavior and should be removed
|
||||||
# when we support empty tensors with arbitrary sizes
|
# when we support empty tensors with arbitrary sizes
|
||||||
if use_cuda:
|
dtype = torch.float32
|
||||||
dtype = torch.cuda.float32
|
device = 'cuda' if use_cuda else 'cpu'
|
||||||
else:
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
x = torch.randn((4, 3, 32, 32), dtype=dtype)
|
x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
|
||||||
empty = torch.randn((0,), dtype=dtype)
|
empty = torch.randn((0,), dtype=dtype, device=device)
|
||||||
|
|
||||||
res1 = torch.cat([x, empty], dim=1)
|
res1 = torch.cat([x, empty], dim=1)
|
||||||
res2 = torch.cat([empty, x], dim=1)
|
res2 = torch.cat([empty, x], dim=1)
|
||||||
|
|||||||
@ -73,7 +73,8 @@ if (r.isNone(${out_idx})) {
|
|||||||
${call_dispatch}
|
${call_dispatch}
|
||||||
} else {
|
} else {
|
||||||
if (!r.isNone(${type_idx})) {
|
if (!r.isNone(${type_idx})) {
|
||||||
check_out_type_matches(r.tensor(${out_idx}), r.dtype(${type_idx}), r.layout(${layout_idx}));
|
check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.layout(${layout_idx}),
|
||||||
|
r.device(${device_idx}), r.isNone(${device_idx}));
|
||||||
}
|
}
|
||||||
${call_dispatch_out}
|
${call_dispatch_out}
|
||||||
}
|
}
|
||||||
@ -207,9 +208,9 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
|||||||
'Tensor &': 'tensor',
|
'Tensor &': 'tensor',
|
||||||
'Generator *': 'generator',
|
'Generator *': 'generator',
|
||||||
'Storage &': 'storage',
|
'Storage &': 'storage',
|
||||||
'const Type &': 'dtype',
|
'const Type &': 'scalartype',
|
||||||
'const THPLayout &': 'layout',
|
'const THPLayout &': 'layout',
|
||||||
'const Device &': 'deviceInt64',
|
'const Device &': 'device',
|
||||||
'int64_t': 'toInt64',
|
'int64_t': 'toInt64',
|
||||||
'bool': 'toBool',
|
'bool': 'toBool',
|
||||||
'double': 'toDouble',
|
'double': 'toDouble',
|
||||||
@ -221,6 +222,10 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
|||||||
'int64_t': 'toInt64WithDefault',
|
'int64_t': 'toInt64WithDefault',
|
||||||
'bool': 'setDefaultBool',
|
'bool': 'setDefaultBool',
|
||||||
'double': 'setDefaultDouble',
|
'double': 'setDefaultDouble',
|
||||||
|
'const Type &': 'scalartypeWithDefault',
|
||||||
|
'const THPLayout &': 'layoutWithDefault',
|
||||||
|
'const Device &': 'deviceWithDefault',
|
||||||
|
'ScalarType': 'scalartypeWithDefault',
|
||||||
}
|
}
|
||||||
|
|
||||||
def first_tensor_arg(arguments):
|
def first_tensor_arg(arguments):
|
||||||
@ -286,6 +291,9 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
|||||||
'`{}` type is not supported in python_default_init'.format(typename)
|
'`{}` type is not supported in python_default_init'.format(typename)
|
||||||
unpack_with_default = unpack_with_default_methods.get(typename)
|
unpack_with_default = unpack_with_default_methods.get(typename)
|
||||||
default_expr = arg.get('python_default_init')
|
default_expr = arg.get('python_default_init')
|
||||||
|
# TODO: Type currently maps to ScalarType, figure out a cleaner solution
|
||||||
|
if typename == 'const Type &':
|
||||||
|
default_expr += '.scalarType()'
|
||||||
expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
|
expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
|
||||||
else:
|
else:
|
||||||
unpack = unpack_methods.get(typename, typename.lower())
|
unpack = unpack_methods.get(typename, typename.lower())
|
||||||
@ -335,7 +343,6 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
|||||||
actuals.append('results[{}]'.format(i))
|
actuals.append('results[{}]'.format(i))
|
||||||
|
|
||||||
layout = None
|
layout = None
|
||||||
parsed_type_dispatch = None
|
|
||||||
# type args go after the outputs to match the signature generation.
|
# type args go after the outputs to match the signature generation.
|
||||||
arg_idx = arg_idx if out_idx is None else out_idx + 1
|
arg_idx = arg_idx if out_idx is None else out_idx + 1
|
||||||
for arg in type_args:
|
for arg in type_args:
|
||||||
@ -357,23 +364,29 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
|||||||
for arg in python_binding_arguments:
|
for arg in python_binding_arguments:
|
||||||
if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
|
if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
|
||||||
pass # already handled by type_dispatched_args
|
pass # already handled by type_dispatched_args
|
||||||
elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
|
|
||||||
if len(outputs) == 0:
|
|
||||||
has_device_bind = True
|
|
||||||
append_actuals_formals(*parse_arg(arg, device_idx))
|
|
||||||
elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
|
|
||||||
requires_grad = parse_arg(arg, requires_grad_idx)[0]
|
|
||||||
elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
|
elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
|
||||||
# out(s) determines the type and layout if it is present, so only use this if there are no outputs.
|
# out(s) determines the type and layout if it is present, so only use this if there are no outputs.
|
||||||
if len(outputs) == 0:
|
if len(outputs) == 0:
|
||||||
layout = parse_arg(arg, layout_idx)[0]
|
layout = parse_arg(arg, layout_idx, arg.get('python_default_init'))[0]
|
||||||
|
elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
|
||||||
|
if len(outputs) == 0:
|
||||||
assert parsed_type_args
|
assert parsed_type_args
|
||||||
actuals.append("torch::getType({}, {})".format(parsed_type_args[0], layout))
|
assert layout
|
||||||
|
device_arg = parse_arg(arg, device_idx, True)
|
||||||
|
# add type, device formals and corresponding actuals.
|
||||||
|
# The type actual isthe ATen type mapped from (ScalarType, Layout, Device)
|
||||||
|
# The device actual is the corresponding AutoGPU index for the Device.
|
||||||
formal_args.append(parsed_type_args[1])
|
formal_args.append(parsed_type_args[1])
|
||||||
|
formal_args.append(device_arg[1])
|
||||||
|
actuals.append("torch::getType({}, {}, {}.type)".format(parsed_type_args[0], layout, device_arg[0]))
|
||||||
|
actuals.append('{}.deviceInt64()'.format(device_arg[0]))
|
||||||
|
has_device_bind = True
|
||||||
|
elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
|
||||||
|
requires_grad = parse_arg(arg, requires_grad_idx)[0]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(("found {} in python_binding_arguments but only "
|
raise RuntimeError(("found {} in python_binding_arguments but only "
|
||||||
"\"bool requires_grad\", \"Dtype dtype\", \"Layout layout\", \"Device device\" "
|
"\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
|
||||||
"are supported".format(arg)))
|
"\"Device device\" are supported".format(arg)))
|
||||||
|
|
||||||
env['unpack_args'] = []
|
env['unpack_args'] = []
|
||||||
env['formal_args'] = formal_args
|
env['formal_args'] = formal_args
|
||||||
@ -414,7 +427,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
|||||||
has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
|
has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
|
||||||
if has_dtype_bind:
|
if has_dtype_bind:
|
||||||
body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
|
body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
|
||||||
layout_idx=out_idx + 2).split('\n')
|
layout_idx=out_idx + 2, device_idx=out_idx + 3).split('\n')
|
||||||
else:
|
else:
|
||||||
body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
|
body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
|
||||||
else:
|
else:
|
||||||
@ -463,6 +476,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
|||||||
}
|
}
|
||||||
python_binding_arguments.append(dtype_arg)
|
python_binding_arguments.append(dtype_arg)
|
||||||
if is_factory_function or is_typed_like_function:
|
if is_factory_function or is_typed_like_function:
|
||||||
|
py_default_layout = '*torch::getLayout(self.type().backend())' if is_typed_like_function else None
|
||||||
layout_arg = {
|
layout_arg = {
|
||||||
'default': 'torch.strided',
|
'default': 'torch.strided',
|
||||||
'dynamic_type': 'Layout',
|
'dynamic_type': 'Layout',
|
||||||
@ -470,9 +484,10 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
|||||||
'name': 'layout',
|
'name': 'layout',
|
||||||
'type': 'const THPLayout &',
|
'type': 'const THPLayout &',
|
||||||
'simple_type': 'Layout',
|
'simple_type': 'Layout',
|
||||||
|
'python_default_init': py_default_layout,
|
||||||
}
|
}
|
||||||
python_binding_arguments.append(layout_arg)
|
python_binding_arguments.append(layout_arg)
|
||||||
if is_factory_or_like_function:
|
py_default_device = 'torch::utils::getDevice(self)' if is_typed_like_function else None
|
||||||
device_arg = {
|
device_arg = {
|
||||||
'default': 'None',
|
'default': 'None',
|
||||||
'default_init': 'None',
|
'default_init': 'None',
|
||||||
@ -480,9 +495,11 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
|||||||
'kwarg_only': True,
|
'kwarg_only': True,
|
||||||
'name': 'device',
|
'name': 'device',
|
||||||
'type': 'const Device &',
|
'type': 'const Device &',
|
||||||
'simple_type': 'Device'
|
'simple_type': 'Device',
|
||||||
|
'python_default_init': py_default_device
|
||||||
}
|
}
|
||||||
python_binding_arguments.append(device_arg)
|
python_binding_arguments.append(device_arg)
|
||||||
|
if is_factory_or_like_function:
|
||||||
requires_grad_arg = {
|
requires_grad_arg = {
|
||||||
'default': False,
|
'default': False,
|
||||||
'dynamic_type': 'bool',
|
'dynamic_type': 'bool',
|
||||||
@ -590,7 +607,7 @@ def get_python_signature(declaration, include_out):
|
|||||||
positional = True
|
positional = True
|
||||||
|
|
||||||
def get_py_formal_arg(arg):
|
def get_py_formal_arg(arg):
|
||||||
typename = arg['simple_type'] if arg['simple_type'] != 'Type' else 'Dtype'
|
typename = arg['simple_type'] if arg['simple_type'] != 'Type' else 'ScalarType'
|
||||||
if arg.get('is_nullable'):
|
if arg.get('is_nullable'):
|
||||||
typename = '{}?'.format(typename)
|
typename = '{}?'.format(typename)
|
||||||
if arg.get('size') is not None:
|
if arg.get('size') is not None:
|
||||||
|
|||||||
@ -318,7 +318,8 @@ def emit_body(declaration):
|
|||||||
def emit_record_trace(env):
|
def emit_record_trace(env):
|
||||||
# Operations involving Generator, Storage, Type are not traceable
|
# Operations involving Generator, Storage, Type are not traceable
|
||||||
# at the moment
|
# at the moment
|
||||||
if any(arg['simple_type'] in {'Generator', 'Storage', 'Type'} for arg in declaration['arguments']):
|
if any(arg['simple_type'] in {'Generator', 'Storage', 'ScalarType', 'Type'}
|
||||||
|
for arg in declaration['arguments']):
|
||||||
return ('', '')
|
return ('', '')
|
||||||
# We can't trace functions which don't have any Tensor or TensorList returns
|
# We can't trace functions which don't have any Tensor or TensorList returns
|
||||||
if 'Tensor' not in declaration['return_type']:
|
if 'Tensor' not in declaration['return_type']:
|
||||||
|
|||||||
@ -22,6 +22,7 @@ using at::Storage;
|
|||||||
using at::Tensor;
|
using at::Tensor;
|
||||||
using at::TensorList;
|
using at::TensorList;
|
||||||
using at::Type;
|
using at::Type;
|
||||||
|
using at::ScalarType;
|
||||||
|
|
||||||
struct VariableType final : public at::Type {
|
struct VariableType final : public at::Type {
|
||||||
VariableType(Context* context, at::Type* baseType);
|
VariableType(Context* context, at::Type* baseType);
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
#include "torch/csrc/utils/python_arg_parser.h"
|
#include "torch/csrc/utils/python_arg_parser.h"
|
||||||
#include "torch/csrc/utils/tensor_new.h"
|
#include "torch/csrc/utils/tensor_new.h"
|
||||||
#include "torch/csrc/utils/tensor_numpy.h"
|
#include "torch/csrc/utils/tensor_numpy.h"
|
||||||
|
#include "torch/csrc/utils/tensor_devices.h"
|
||||||
#include "torch/csrc/utils/tensor_layouts.h"
|
#include "torch/csrc/utils/tensor_layouts.h"
|
||||||
|
|
||||||
#include "python_torch_functions_dispatch.h"
|
#include "python_torch_functions_dispatch.h"
|
||||||
@ -33,8 +34,11 @@ static Tensor set_requires_grad(Tensor self, bool requires_grad) {
|
|||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void check_out_type_matches(Tensor result, const THPDtype &dtype, const THPLayout& layout) {
|
static void check_out_type_matches(Tensor result, ScalarType scalarType, const THPLayout& layout,
|
||||||
const auto& type = torch::getType(dtype, layout);
|
const Device& device, bool device_is_none) {
|
||||||
|
auto result_device_type = torch::getDeviceType(result.type());
|
||||||
|
auto device_type = device_is_none ? result_device_type : device.type;
|
||||||
|
const auto& type = torch::getType(scalarType, layout, device_type);
|
||||||
if (result.type() != type) {
|
if (result.type() != type) {
|
||||||
AT_ERROR(
|
AT_ERROR(
|
||||||
"type corresponding to %s does not match type of out parameter (%s)",
|
"type corresponding to %s does not match type of out parameter (%s)",
|
||||||
@ -90,19 +94,13 @@ static PyObject * THPVariable__promote_types(PyObject* self, PyObject* args, PyO
|
|||||||
{
|
{
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
static PythonArgParser parser({
|
static PythonArgParser parser({
|
||||||
"_promote_types(Dtype type1, Dtype type2)",
|
"_promote_types(ScalarType type1, ScalarType type2)",
|
||||||
});
|
});
|
||||||
ParsedArgs<2> parsed_args;
|
ParsedArgs<2> parsed_args;
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
auto& d1 = r.dtype(0);
|
ScalarType promoted = at::promoteTypes(r.scalartype(0), r.scalartype(1));
|
||||||
auto& d2 = r.dtype(1);
|
return torch::autograd::utils::wrap(torch::getDtype(promoted));
|
||||||
if (d1.is_cuda != d2.is_cuda) {
|
|
||||||
AT_ERROR("_promote_types only supports dtypes being both on cpu or cuda. Got %s and %s",
|
|
||||||
d1.is_cuda ? "true" : "false", d2.is_cuda ? "true" : "false");
|
|
||||||
}
|
|
||||||
ScalarType promoted = at::promoteTypes(d1.scalar_type, d2.scalar_type);
|
|
||||||
return torch::autograd::utils::wrap(torch::getDtype(promoted, d1.is_cuda));
|
|
||||||
}
|
}
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
|
|||||||
@ -565,7 +565,8 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
|
|||||||
} else {
|
} else {
|
||||||
throw TypeError("dtype must be a type, str, or dtype object");
|
throw TypeError("dtype must be a type, str, or dtype object");
|
||||||
}
|
}
|
||||||
auto& type = is_dtype ? torch::getType(r.dtype(0), *torch::getLayout(self_.type().backend())) :
|
auto self_device_type = torch::getDeviceType(self_.type());
|
||||||
|
auto& type = is_dtype ? torch::getType(r.scalartype(0), *torch::getLayout(self_.type().backend()), self_device_type) :
|
||||||
torch::utils::type_from_string(type_name);
|
torch::utils::type_from_string(type_name);
|
||||||
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, -1, r.toBool(1)));
|
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, -1, r.toBool(1)));
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
|
|||||||
@ -68,6 +68,7 @@ def is_jit_op(decl):
|
|||||||
not any(arg['simple_type'] == 'Generator' for arg in decl['arguments']) and
|
not any(arg['simple_type'] == 'Generator' for arg in decl['arguments']) and
|
||||||
not any(arg['simple_type'] == 'SparseTensor' for arg in decl['arguments']) and
|
not any(arg['simple_type'] == 'SparseTensor' for arg in decl['arguments']) and
|
||||||
not any(arg['simple_type'] == 'Storage' for arg in decl['arguments']) and
|
not any(arg['simple_type'] == 'Storage' for arg in decl['arguments']) and
|
||||||
|
not any(arg['simple_type'] == 'ScalarType' for arg in decl['arguments']) and
|
||||||
not any(arg['simple_type'] == 'Type' for arg in decl['arguments']) and
|
not any(arg['simple_type'] == 'Type' for arg in decl['arguments']) and
|
||||||
uses_tensors)
|
uses_tensors)
|
||||||
|
|
||||||
|
|||||||
@ -35,12 +35,12 @@ class Unserializable(object):
|
|||||||
self.inner = None
|
self.inner = None
|
||||||
|
|
||||||
|
|
||||||
def init_dropout_state(ty, dropout, train, dropout_seed, dropout_state):
|
def init_dropout_state(ty, device, dropout, train, dropout_seed, dropout_state):
|
||||||
dropout_desc_name = 'desc_' + str(torch.cuda.current_device())
|
dropout_desc_name = 'desc_' + str(torch.cuda.current_device())
|
||||||
dropout_p = dropout if train else 0
|
dropout_p = dropout if train else 0
|
||||||
if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None):
|
if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None):
|
||||||
dropout_state[dropout_desc_name] = Unserializable(
|
dropout_state[dropout_desc_name] = Unserializable(
|
||||||
torch._C._VariableFunctions._cudnn_init_dropout_state(dropout_p, train, dropout_seed, ty=ty)
|
torch._C._VariableFunctions._cudnn_init_dropout_state(dropout_p, train, dropout_seed, ty=ty, device=device)
|
||||||
if dropout_p != 0 else None
|
if dropout_p != 0 else None
|
||||||
)
|
)
|
||||||
dropout_ts = dropout_state[dropout_desc_name].get()
|
dropout_ts = dropout_state[dropout_desc_name].get()
|
||||||
|
|||||||
@ -35,7 +35,7 @@ static inline const char* deviceTypeString(torch::DeviceType device_type) {
|
|||||||
PyObject *THPDevice_repr(THPDevice *self)
|
PyObject *THPDevice_repr(THPDevice *self)
|
||||||
{
|
{
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "Device(device_type=\'" << deviceTypeString(self->device.type) << "\'";
|
oss << "device(device_type=\'" << deviceTypeString(self->device.type) << "\'";
|
||||||
if (!self->device.is_default) {
|
if (!self->device.is_default) {
|
||||||
oss << ", device_index=" << self->device.index;
|
oss << ", device_index=" << self->device.index;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,33 +8,18 @@
|
|||||||
#include "torch/csrc/utils/tensor_dtypes.h"
|
#include "torch/csrc/utils/tensor_dtypes.h"
|
||||||
#include "torch/csrc/utils/tensor_types.h"
|
#include "torch/csrc/utils/tensor_types.h"
|
||||||
|
|
||||||
PyObject * THPDtype_New(at::ScalarType scalar_type, bool is_cuda, const std::string& name)
|
PyObject * THPDtype_New(at::ScalarType scalar_type, const std::string& name)
|
||||||
{
|
{
|
||||||
auto type = (PyTypeObject*)&THPDtypeType;
|
auto type = (PyTypeObject*)&THPDtypeType;
|
||||||
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
||||||
if (!self) throw python_error();
|
if (!self) throw python_error();
|
||||||
auto self_ = reinterpret_cast<THPDtype*>(self.get());
|
auto self_ = reinterpret_cast<THPDtype*>(self.get());
|
||||||
self_->scalar_type = scalar_type;
|
self_->scalar_type = scalar_type;
|
||||||
self_->is_cuda = is_cuda;
|
|
||||||
std::strncpy (self_->name, name.c_str(), DTYPE_NAME_LEN);
|
std::strncpy (self_->name, name.c_str(), DTYPE_NAME_LEN);
|
||||||
self_->name[DTYPE_NAME_LEN] = '\0';
|
self_->name[DTYPE_NAME_LEN] = '\0';
|
||||||
return self.release();
|
return self.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject *THPDtype_repr(THPDtype *self)
|
|
||||||
{
|
|
||||||
return THPUtils_packString(self->name);
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject *THPDtype_is_cuda(THPDtype *self)
|
|
||||||
{
|
|
||||||
if (self->is_cuda) {
|
|
||||||
Py_RETURN_TRUE;
|
|
||||||
} else {
|
|
||||||
Py_RETURN_FALSE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject *THPDtype_is_floating_point(THPDtype *self)
|
PyObject *THPDtype_is_floating_point(THPDtype *self)
|
||||||
{
|
{
|
||||||
if (at::isFloatingType(self->scalar_type)) {
|
if (at::isFloatingType(self->scalar_type)) {
|
||||||
@ -47,11 +32,15 @@ PyObject *THPDtype_is_floating_point(THPDtype *self)
|
|||||||
typedef PyObject *(*getter)(PyObject *, void *);
|
typedef PyObject *(*getter)(PyObject *, void *);
|
||||||
|
|
||||||
static struct PyGetSetDef THPDtype_properties[] = {
|
static struct PyGetSetDef THPDtype_properties[] = {
|
||||||
{"is_cuda", (getter)THPDtype_is_cuda, nullptr, nullptr, nullptr},
|
|
||||||
{"is_floating_point", (getter)THPDtype_is_floating_point, nullptr, nullptr, nullptr},
|
{"is_floating_point", (getter)THPDtype_is_floating_point, nullptr, nullptr, nullptr},
|
||||||
{nullptr}
|
{nullptr}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
PyObject *THPDtype_repr(THPDtype *self)
|
||||||
|
{
|
||||||
|
return THPUtils_packString(self->name);
|
||||||
|
}
|
||||||
|
|
||||||
PyTypeObject THPDtypeType = {
|
PyTypeObject THPDtypeType = {
|
||||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||||
"torch.dtype", /* tp_name */
|
"torch.dtype", /* tp_name */
|
||||||
|
|||||||
@ -8,7 +8,6 @@ const int DTYPE_NAME_LEN = 64;
|
|||||||
struct THPDtype {
|
struct THPDtype {
|
||||||
PyObject_HEAD
|
PyObject_HEAD
|
||||||
at::ScalarType scalar_type;
|
at::ScalarType scalar_type;
|
||||||
bool is_cuda;
|
|
||||||
char name[DTYPE_NAME_LEN + 1];
|
char name[DTYPE_NAME_LEN + 1];
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -18,6 +17,6 @@ inline bool THPDtype_Check(PyObject *obj) {
|
|||||||
return Py_TYPE(obj) == &THPDtypeType;
|
return Py_TYPE(obj) == &THPDtypeType;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject * THPDtype_New(at::ScalarType scalar_type, bool is_cuda, const std::string& name);
|
PyObject * THPDtype_New(at::ScalarType scalar_type, const std::string& name);
|
||||||
|
|
||||||
void THPDtype_init(PyObject *module);
|
void THPDtype_init(PyObject *module);
|
||||||
|
|||||||
@ -33,8 +33,7 @@ static std::unordered_map<PyTypeObject*, at::Type*> py_storage_type_to_attype;
|
|||||||
|
|
||||||
static const int NumBoolOptions = 2;
|
static const int NumBoolOptions = 2;
|
||||||
static THPDtype* dtype_registry
|
static THPDtype* dtype_registry
|
||||||
[static_cast<int>(at::ScalarType::NumOptions)]
|
[static_cast<int>(at::ScalarType::NumOptions)] = {};
|
||||||
[NumBoolOptions] = {};
|
|
||||||
|
|
||||||
static THPLayout* layout_registry
|
static THPLayout* layout_registry
|
||||||
[static_cast<int>(at::Backend::NumOptions)] = {};
|
[static_cast<int>(at::Backend::NumOptions)] = {};
|
||||||
@ -72,8 +71,8 @@ void registerStoragePyTypeObject(PyTypeObject *pytype, const std::string& name,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType, bool is_cuda) {
|
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType) {
|
||||||
dtype_registry[static_cast<int>(scalarType)][is_cuda] = dtype;
|
dtype_registry[static_cast<int>(scalarType)] = dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerLayoutObject(THPLayout *layout, at::Backend backend) {
|
void registerLayoutObject(THPLayout *layout, at::Backend backend) {
|
||||||
@ -89,15 +88,16 @@ static PyTypeObject* getPyTypeObject(const at::Storage& storage)
|
|||||||
throw std::invalid_argument("unsupported Storage type");
|
throw std::invalid_argument("unsupported Storage type");
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Type& getType(const THPDtype &dtype, const THPLayout& layout) {
|
at::Type& getType(at::ScalarType scalarType, const THPLayout& layout, const DeviceType& deviceType) {
|
||||||
at::Backend backend = get_backend(dtype.is_cuda, !layout.is_strided);
|
at::Backend backend = get_backend(deviceType == DeviceType::CUDA, !layout.is_strided);
|
||||||
// use type_registry rather than context.getType() because getType throws exceptions.
|
// use type_registry rather than context.getType() because getType throws exceptions.
|
||||||
auto baseType = at::globalContext().type_registry[static_cast<int>(backend)]
|
auto baseType = at::globalContext().type_registry[static_cast<int>(backend)]
|
||||||
[static_cast<int>(dtype.scalar_type)].get();
|
[static_cast<int>(scalarType)].get();
|
||||||
if (!baseType) {
|
if (!baseType) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "Error attempting to use dtype " << dtype.name << " with layout " << layout.name << ".";
|
oss << "Error attempting to use dtype " << getDtype(scalarType)->name << " with layout " << layout.name
|
||||||
if (!torch::utils::cuda_enabled()) {
|
<< " and device type " << (deviceType == DeviceType::CPU ? "CPU" : "CUDA") << ".";
|
||||||
|
if (deviceType == DeviceType::CUDA && !torch::utils::cuda_enabled()) {
|
||||||
oss << " Torch not compiled with CUDA enabled." << std::endl;
|
oss << " Torch not compiled with CUDA enabled." << std::endl;
|
||||||
}
|
}
|
||||||
throw std::runtime_error(oss.str());
|
throw std::runtime_error(oss.str());
|
||||||
@ -105,10 +105,10 @@ at::Type& getType(const THPDtype &dtype, const THPLayout& layout) {
|
|||||||
return *torch::autograd::VariableType::getType(*baseType);
|
return *torch::autograd::VariableType::getType(*baseType);
|
||||||
}
|
}
|
||||||
|
|
||||||
THPDtype* getDtype(at::ScalarType scalarType, bool is_cuda) {
|
THPDtype* getDtype(at::ScalarType scalarType) {
|
||||||
auto dtype = dtype_registry[static_cast<int>(scalarType)][is_cuda];
|
auto dtype = dtype_registry[static_cast<int>(scalarType)];
|
||||||
if (!dtype) {
|
if (!dtype) {
|
||||||
throw std::invalid_argument("unsupported backend, scalarType");
|
throw std::invalid_argument("unsupported scalarType");
|
||||||
}
|
}
|
||||||
return dtype;
|
return dtype;
|
||||||
}
|
}
|
||||||
@ -121,6 +121,10 @@ THPLayout* getLayout(at::Backend backend) {
|
|||||||
return layout;
|
return layout;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DeviceType getDeviceType(const at::Type& type) {
|
||||||
|
return type.is_cuda() ? torch::DeviceType::CUDA : torch::DeviceType::CPU;
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* createPyObject(const at::Storage& storage)
|
PyObject* createPyObject(const at::Storage& storage)
|
||||||
{
|
{
|
||||||
auto type = getPyTypeObject(storage);
|
auto type = getPyTypeObject(storage);
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include "torch/csrc/Dtype.h"
|
#include "torch/csrc/Dtype.h"
|
||||||
#include "torch/csrc/Layout.h"
|
#include "torch/csrc/Layout.h"
|
||||||
|
#include "torch/csrc/utils/device.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
|
||||||
@ -16,15 +17,16 @@ void registerStoragePyTypeObject(
|
|||||||
PyTypeObject *pytype, const std::string& name,
|
PyTypeObject *pytype, const std::string& name,
|
||||||
bool is_cuda, bool is_sparse);
|
bool is_cuda, bool is_sparse);
|
||||||
|
|
||||||
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType, bool is_cuda);
|
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType);
|
||||||
void registerLayoutObject(THPLayout *layout, at::Backend backend);
|
void registerLayoutObject(THPLayout *layout, at::Backend backend);
|
||||||
|
|
||||||
PyObject* createPyObject(const at::Storage& storage);
|
PyObject* createPyObject(const at::Storage& storage);
|
||||||
std::unique_ptr<at::Storage> createStorage(PyObject* obj);
|
std::unique_ptr<at::Storage> createStorage(PyObject* obj);
|
||||||
bool isStorage(PyObject* obj);
|
bool isStorage(PyObject* obj);
|
||||||
|
|
||||||
THPDtype* getDtype(at::ScalarType scalarType, bool is_cuda);
|
THPDtype* getDtype(at::ScalarType scalarType);
|
||||||
THPLayout* getLayout(at::Backend backend);
|
THPLayout* getLayout(at::Backend backend);
|
||||||
at::Type& getType(const THPDtype &dtype, const THPLayout& layout);
|
at::Type& getType(at::ScalarType scalarType, const THPLayout& layout, const DeviceType& deviceType);
|
||||||
|
DeviceType getDeviceType(const at::Type& type);
|
||||||
|
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|||||||
@ -331,8 +331,7 @@ PyObject *THPModule_setFlushDenormal(PyObject *_unused, PyObject *arg) {
|
|||||||
PyObject *THPModule_getDefaultDtype(PyObject *_unused, PyObject *arg) {
|
PyObject *THPModule_getDefaultDtype(PyObject *_unused, PyObject *arg) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
auto& type = torch::tensor::get_default_tensor_type();
|
auto& type = torch::tensor::get_default_tensor_type();
|
||||||
bool is_cuda = type.backend() == at::kCUDA;
|
auto dtype = (PyObject*)torch::getDtype(type.scalarType());
|
||||||
auto dtype = (PyObject*)torch::getDtype(type.scalarType(), is_cuda);
|
|
||||||
Py_INCREF(dtype);
|
Py_INCREF(dtype);
|
||||||
return dtype;
|
return dtype;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
|
|||||||
@ -378,7 +378,7 @@ PyObject *THPVariable_dtype(THPVariable *self)
|
|||||||
{
|
{
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
auto& self_ = self->cdata;
|
auto& self_ = self->cdata;
|
||||||
return torch::autograd::utils::wrap(torch::getDtype(self_.type().scalarType(), self_.type().is_cuda()));
|
return torch::autograd::utils::wrap(torch::getDtype(self_.type().scalarType()));
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -30,6 +30,7 @@ struct PyTensorType {
|
|||||||
at::Type* aten_type;
|
at::Type* aten_type;
|
||||||
THPDtype* dtype;
|
THPDtype* dtype;
|
||||||
THPLayout* layout;
|
THPLayout* layout;
|
||||||
|
bool is_cuda;
|
||||||
char name[64];
|
char name[64];
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -51,7 +52,7 @@ static PyObject* Tensor_new(PyTypeObject *type, PyObject *args, PyObject *kwargs
|
|||||||
if (!tensor_type.aten_type) {
|
if (!tensor_type.aten_type) {
|
||||||
throw unavailable_type(tensor_type);
|
throw unavailable_type(tensor_type);
|
||||||
}
|
}
|
||||||
if (tensor_type.dtype->is_cuda) {
|
if (tensor_type.aten_type->is_cuda()) {
|
||||||
torch::utils::cuda_lazy_init();
|
torch::utils::cuda_lazy_init();
|
||||||
}
|
}
|
||||||
return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(*tensor_type.aten_type, args, kwargs));
|
return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(*tensor_type.aten_type, args, kwargs));
|
||||||
@ -79,7 +80,7 @@ PyObject *Tensor_layout(PyTensorType* self) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PyObject *Tensor_is_cuda(PyTensorType* self) {
|
PyObject *Tensor_is_cuda(PyTensorType* self) {
|
||||||
if (self->dtype->is_cuda) {
|
if (self->is_cuda) {
|
||||||
Py_RETURN_TRUE;
|
Py_RETURN_TRUE;
|
||||||
} else {
|
} else {
|
||||||
Py_RETURN_FALSE;
|
Py_RETURN_FALSE;
|
||||||
@ -178,7 +179,8 @@ static void set_type(PyTensorType& type_obj, Backend backend, ScalarType scalarT
|
|||||||
auto baseType = globalContext().type_registry[static_cast<int>(backend)][static_cast<int>(scalarType)].get();
|
auto baseType = globalContext().type_registry[static_cast<int>(backend)][static_cast<int>(scalarType)].get();
|
||||||
type_obj.aten_type = baseType ? torch::autograd::VariableType::getType(*baseType) : nullptr;
|
type_obj.aten_type = baseType ? torch::autograd::VariableType::getType(*baseType) : nullptr;
|
||||||
type_obj.layout = torch::getLayout(backend);
|
type_obj.layout = torch::getLayout(backend);
|
||||||
type_obj.dtype = torch::getDtype(scalarType, backend == kCUDA || backend == kSparseCUDA);
|
type_obj.dtype = torch::getDtype(scalarType);
|
||||||
|
type_obj.is_cuda = (backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void set_name(PyTensorType& type_obj, const std::string& name) {
|
static void set_name(PyTensorType& type_obj, const std::string& name) {
|
||||||
|
|||||||
@ -11,7 +11,9 @@ struct Device {
|
|||||||
int64_t index;
|
int64_t index;
|
||||||
bool is_default; // is default device for type.
|
bool is_default; // is default device for type.
|
||||||
Device(DeviceType type, int64_t index, bool is_default);
|
Device(DeviceType type, int64_t index, bool is_default);
|
||||||
|
|
||||||
bool operator==(const Device& rhs);
|
bool operator==(const Device& rhs);
|
||||||
|
inline int64_t deviceInt64() { return (this->is_default || this->type == DeviceType::CPU) ? -1 : this->index; }
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,7 +23,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
|||||||
{"bool", ParameterType::BOOL},
|
{"bool", ParameterType::BOOL},
|
||||||
{"Storage", ParameterType::STORAGE},
|
{"Storage", ParameterType::STORAGE},
|
||||||
{"PyObject*", ParameterType::PYOBJECT},
|
{"PyObject*", ParameterType::PYOBJECT},
|
||||||
{"Dtype", ParameterType::DTYPE},
|
{"ScalarType", ParameterType::SCALARTYPE},
|
||||||
{"Layout", ParameterType::LAYOUT},
|
{"Layout", ParameterType::LAYOUT},
|
||||||
{"Device", ParameterType::DEVICE},
|
{"Device", ParameterType::DEVICE},
|
||||||
{"String", ParameterType::STRING},
|
{"String", ParameterType::STRING},
|
||||||
@ -111,7 +111,7 @@ bool FunctionParameter::check(PyObject* obj) {
|
|||||||
case ParameterType::BOOL: return PyBool_Check(obj);
|
case ParameterType::BOOL: return PyBool_Check(obj);
|
||||||
case ParameterType::STORAGE: return isStorage(obj);
|
case ParameterType::STORAGE: return isStorage(obj);
|
||||||
case ParameterType::PYOBJECT: return true;
|
case ParameterType::PYOBJECT: return true;
|
||||||
case ParameterType::DTYPE: return THPDtype_Check(obj);
|
case ParameterType::SCALARTYPE: return THPDtype_Check(obj);
|
||||||
case ParameterType::LAYOUT: return THPLayout_Check(obj);
|
case ParameterType::LAYOUT: return THPLayout_Check(obj);
|
||||||
case ParameterType::DEVICE:
|
case ParameterType::DEVICE:
|
||||||
return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj);
|
return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj);
|
||||||
@ -132,7 +132,7 @@ std::string FunctionParameter::type_name() const {
|
|||||||
case ParameterType::BOOL: return "bool";
|
case ParameterType::BOOL: return "bool";
|
||||||
case ParameterType::STORAGE: return "torch.Storage";
|
case ParameterType::STORAGE: return "torch.Storage";
|
||||||
case ParameterType::PYOBJECT: return "object";
|
case ParameterType::PYOBJECT: return "object";
|
||||||
case ParameterType::DTYPE: return "torch.dtype";
|
case ParameterType::SCALARTYPE: return "torch.dtype";
|
||||||
case ParameterType::LAYOUT: return "torch.layout";
|
case ParameterType::LAYOUT: return "torch.layout";
|
||||||
case ParameterType::DEVICE: return "torch.device";
|
case ParameterType::DEVICE: return "torch.device";
|
||||||
case ParameterType::STRING: return "str";
|
case ParameterType::STRING: return "str";
|
||||||
@ -166,21 +166,23 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|||||||
if (str != "None") {
|
if (str != "None") {
|
||||||
default_intlist.assign(size, std::stoi(str));
|
default_intlist.assign(size, std::stoi(str));
|
||||||
}
|
}
|
||||||
} else if (type_ == ParameterType::DTYPE) {
|
} else if (type_ == ParameterType::SCALARTYPE) {
|
||||||
if (str == "None") {
|
if (str == "None") {
|
||||||
default_dtype = nullptr;
|
default_scalartype = at::ScalarType::Undefined;
|
||||||
} else if (str == "torch.int64") {
|
} else if (str == "torch.int64") {
|
||||||
default_dtype = torch::getDtype(kLong, false);
|
default_scalartype = at::ScalarType::Long;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("invalid default value for dtype: " + str);
|
throw std::runtime_error("invalid default value for ScalarType: " + str);
|
||||||
}
|
}
|
||||||
} else if (type_ == ParameterType::LAYOUT) {
|
} else if (type_ == ParameterType::LAYOUT) {
|
||||||
if (str == "torch.strided") {
|
if (str == "None") {
|
||||||
|
default_layout = nullptr;
|
||||||
|
} else if (str == "torch.strided") {
|
||||||
default_layout = torch::getLayout(at::Backend::CPU);
|
default_layout = torch::getLayout(at::Backend::CPU);
|
||||||
} else if (str == "torch.sparse_coo") {
|
} else if (str == "torch.sparse_coo") {
|
||||||
default_layout = torch::getLayout(at::Backend::SparseCPU);
|
default_layout = torch::getLayout(at::Backend::SparseCPU);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("invalid default value for dtype: " + str);
|
throw std::runtime_error("invalid default value for layout: " + str);
|
||||||
}
|
}
|
||||||
} else if (type_ == ParameterType::DEVICE) {
|
} else if (type_ == ParameterType::DEVICE) {
|
||||||
if (str != "None") {
|
if (str != "None") {
|
||||||
|
|||||||
@ -44,7 +44,7 @@ namespace torch {
|
|||||||
|
|
||||||
enum class ParameterType {
|
enum class ParameterType {
|
||||||
TENSOR, SCALAR, INT64, DOUBLE, TENSOR_LIST, INT_LIST, GENERATOR,
|
TENSOR, SCALAR, INT64, DOUBLE, TENSOR_LIST, INT_LIST, GENERATOR,
|
||||||
BOOL, STORAGE, PYOBJECT, DTYPE, LAYOUT, DEVICE, STRING
|
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, DEVICE, STRING
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FunctionParameter;
|
struct FunctionParameter;
|
||||||
@ -93,10 +93,12 @@ struct PythonArgs {
|
|||||||
inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
||||||
inline at::Generator* generator(int i);
|
inline at::Generator* generator(int i);
|
||||||
inline std::unique_ptr<at::Storage> storage(int i);
|
inline std::unique_ptr<at::Storage> storage(int i);
|
||||||
inline const THPDtype& dtype(int i);
|
inline at::ScalarType scalartype(int i);
|
||||||
inline const THPDtype& dtypeWithDefault(int i, const THPDtype& default_dtype);
|
inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
||||||
inline const THPLayout& layout(int i);
|
inline const THPLayout& layout(int i);
|
||||||
|
inline const THPLayout& layoutWithDefault(int i, const THPLayout& default_layout);
|
||||||
inline Device device(int i);
|
inline Device device(int i);
|
||||||
|
inline Device deviceWithDefault(int i, const Device& default_device);
|
||||||
inline int64_t deviceInt64(int i);
|
inline int64_t deviceInt64(int i);
|
||||||
inline std::string string(int i);
|
inline std::string string(int i);
|
||||||
inline PyObject* pyobject(int i);
|
inline PyObject* pyobject(int i);
|
||||||
@ -146,7 +148,7 @@ struct FunctionParameter {
|
|||||||
bool default_bool;
|
bool default_bool;
|
||||||
int64_t default_int;
|
int64_t default_int;
|
||||||
double default_double;
|
double default_double;
|
||||||
THPDtype* default_dtype;
|
at::ScalarType default_scalartype;
|
||||||
THPLayout* default_layout;
|
THPLayout* default_layout;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@ -256,21 +258,18 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(int i, std::vector<in
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const THPDtype& PythonArgs::dtypeWithDefault(int i, const THPDtype& default_dtype) {
|
inline at::ScalarType PythonArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
|
||||||
if (!args[i]) return default_dtype;
|
if (!args[i]) return default_scalartype;
|
||||||
return dtype(i);
|
return scalartype(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const THPDtype& PythonArgs::dtype(int i) {
|
inline at::ScalarType PythonArgs::scalartype(int i) {
|
||||||
if (!args[i]) {
|
if (!args[i]) {
|
||||||
auto dtype = signature.params[i].default_dtype;
|
auto scalartype = signature.params[i].default_scalartype;
|
||||||
if (!dtype) {
|
return (scalartype == at::ScalarType::Undefined) ?
|
||||||
const auto& type = torch::tensor::get_default_tensor_type();
|
torch::tensor::get_default_tensor_type().scalarType() : scalartype;
|
||||||
dtype = torch::getDtype(type.scalarType(), type.is_cuda());
|
|
||||||
}
|
|
||||||
return *dtype;
|
|
||||||
}
|
}
|
||||||
return *reinterpret_cast<THPDtype*>(args[i]);
|
return reinterpret_cast<THPDtype*>(args[i])->scalar_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const THPLayout& PythonArgs::layout(int i) {
|
inline const THPLayout& PythonArgs::layout(int i) {
|
||||||
@ -278,13 +277,22 @@ inline const THPLayout& PythonArgs::layout(int i) {
|
|||||||
return *reinterpret_cast<THPLayout*>(args[i]);
|
return *reinterpret_cast<THPLayout*>(args[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline const THPLayout& PythonArgs::layoutWithDefault(int i, const THPLayout& default_layout) {
|
||||||
|
if (!args[i]) return default_layout;
|
||||||
|
return layout(i);
|
||||||
|
}
|
||||||
|
|
||||||
static std::string cuda_str = "cuda";
|
static std::string cuda_str = "cuda";
|
||||||
static std::string cpu_str = "cpu";
|
static std::string cpu_str = "cpu";
|
||||||
static std::string cuda_prefix = "cuda:";
|
static std::string cuda_prefix = "cuda:";
|
||||||
static std::string cpu_prefix = "cpu:";
|
static std::string cpu_prefix = "cpu:";
|
||||||
|
|
||||||
inline Device PythonArgs::device(int i) {
|
inline Device PythonArgs::device(int i) {
|
||||||
if (!args[i]) return Device(DeviceType::CPU, -1, true); // TODO: use CUDA if default type is a cuda type.
|
if (!args[i]) {
|
||||||
|
const auto& default_tensor_type = torch::tensor::get_default_tensor_type();
|
||||||
|
const auto device_type = torch::getDeviceType(default_tensor_type);
|
||||||
|
return Device(device_type, -1, true);
|
||||||
|
}
|
||||||
if (THPDevice_Check(args[i])) {
|
if (THPDevice_Check(args[i])) {
|
||||||
auto device = reinterpret_cast<THPDevice*>(args[i]);
|
auto device = reinterpret_cast<THPDevice*>(args[i]);
|
||||||
return device->device;
|
return device->device;
|
||||||
@ -308,9 +316,14 @@ inline Device PythonArgs::device(int i) {
|
|||||||
throw torch::TypeError("only \"cuda\" and \"cpu\" are valid device types, got %s", device_str.c_str());
|
throw torch::TypeError("only \"cuda\" and \"cpu\" are valid device types, got %s", device_str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline Device PythonArgs::deviceWithDefault(int i, const Device& default_device) {
|
||||||
|
if (!args[i]) return default_device;
|
||||||
|
return device(i);
|
||||||
|
}
|
||||||
|
|
||||||
inline int64_t PythonArgs::deviceInt64(int i) {
|
inline int64_t PythonArgs::deviceInt64(int i) {
|
||||||
auto dev = device(i);
|
auto dev = device(i);
|
||||||
return (dev.is_default || dev.type == DeviceType::CPU) ? -1 : dev.index;
|
return dev.deviceInt64();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string PythonArgs::string(int i) {
|
inline std::string PythonArgs::string(int i) {
|
||||||
|
|||||||
13
torch/csrc/utils/tensor_devices.h
Normal file
13
torch/csrc/utils/tensor_devices.h
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include "DynamicTypes.h"
|
||||||
|
#include "device.h"
|
||||||
|
|
||||||
|
namespace torch { namespace utils {
|
||||||
|
|
||||||
|
Device getDevice(const at::Tensor tensor) {
|
||||||
|
return torch::Device(torch::getDeviceType(tensor.type()), tensor.type().is_cuda() ? tensor.get_device(): 0, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
}} // namespace torch::utils
|
||||||
@ -38,45 +38,26 @@ static std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarTy
|
|||||||
void initializeDtypes() {
|
void initializeDtypes() {
|
||||||
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
|
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
|
||||||
if (!torch_module) python_error();
|
if (!torch_module) python_error();
|
||||||
auto cuda_module = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
|
|
||||||
if (!cuda_module) python_error();
|
#define DEFINE_SCALAR_TYPE(_1,n,_2) at::ScalarType::n,
|
||||||
for (auto type_pair : torch::utils::all_declared_types()) {
|
|
||||||
at::Backend backend;
|
at::ScalarType all_scalar_types[] = {
|
||||||
at::ScalarType scalarType;
|
AT_FORALL_SCALAR_TYPES(DEFINE_SCALAR_TYPE)
|
||||||
std::tie(backend, scalarType) = type_pair;
|
};
|
||||||
|
|
||||||
|
for (at::ScalarType scalarType: all_scalar_types) {
|
||||||
std::string primary_name, legacy_name;
|
std::string primary_name, legacy_name;
|
||||||
std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
|
std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
|
||||||
PyObject *module = nullptr;
|
std::string name = std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name;
|
||||||
bool is_cuda;
|
PyObject *dtype = THPDtype_New(scalarType, name);
|
||||||
switch (backend) {
|
torch::registerDtypeObject((THPDtype*)dtype, scalarType);
|
||||||
case at::kCPU: {
|
|
||||||
module = torch_module.get();
|
|
||||||
is_cuda = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case at::kCUDA: {
|
|
||||||
module = cuda_module.get();
|
|
||||||
is_cuda = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case at::kSparseCPU: {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
case at::kSparseCUDA: {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
default: throw std::runtime_error("Unimplemented backend");
|
|
||||||
}
|
|
||||||
std::string name = std::string(PyModule_GetName(module)) + '.' + primary_name;
|
|
||||||
PyObject *dtype = THPDtype_New(scalarType, is_cuda, name);
|
|
||||||
torch::registerDtypeObject((THPDtype*)dtype, scalarType, is_cuda);
|
|
||||||
Py_INCREF(dtype);
|
Py_INCREF(dtype);
|
||||||
if (PyModule_AddObject(module, primary_name.c_str(), dtype) != 0) {
|
if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) != 0) {
|
||||||
throw python_error();
|
throw python_error();
|
||||||
}
|
}
|
||||||
if (legacy_name != "") {
|
if (legacy_name != "") {
|
||||||
Py_INCREF(dtype);
|
Py_INCREF(dtype);
|
||||||
if (PyModule_AddObject(module, legacy_name.c_str(), dtype) != 0) {
|
if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) != 0) {
|
||||||
throw python_error();
|
throw python_error();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -371,9 +371,11 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
|
|||||||
throw std::runtime_error("new(): invalid arguments");
|
throw std::runtime_error("new(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
static const Type& typeWithDefault(PythonArgs& r, int64_t idx, const Type& type) {
|
static const Type& typeWithDefault(PythonArgs& r, int64_t dtype_idx, int64_t device_idx, const Type& type) {
|
||||||
auto dtype = r.dtypeWithDefault(idx, *torch::getDtype(type.scalarType(), type.is_cuda()));
|
auto scalartype = r.scalartypeWithDefault(dtype_idx, type.scalarType());
|
||||||
return torch::getType(dtype, *torch::getLayout(type.backend()));
|
auto types_device_type = torch::getDeviceType(type);
|
||||||
|
auto device_type = r.isNone(device_idx) ? types_device_type : r.device(device_idx).type;
|
||||||
|
return torch::getType(scalartype, *torch::getLayout(type.backend()), device_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Tensor set_requires_grad(Tensor self, bool requires_grad) {
|
static Tensor set_requires_grad(Tensor self, bool requires_grad) {
|
||||||
@ -386,15 +388,15 @@ Tensor sparse_coo_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs
|
|||||||
const auto& default_sparse_type = type.toBackend(sparse_backend);
|
const auto& default_sparse_type = type.toBackend(sparse_backend);
|
||||||
|
|
||||||
static PythonArgParser parser({
|
static PythonArgParser parser({
|
||||||
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
|
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||||
"sparse_coo_tensor(PyObject* indices, PyObject* values, IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
|
"sparse_coo_tensor(PyObject* indices, PyObject* values, IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||||
});
|
});
|
||||||
|
|
||||||
ParsedArgs<6> parsed_args;
|
ParsedArgs<6> parsed_args;
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
bool type_inference = r.isNone(2);
|
bool type_inference = r.isNone(2);
|
||||||
const auto& sparse_type = typeWithDefault(r, 2, default_sparse_type);
|
const auto& sparse_type = typeWithDefault(r, 2, 3, default_sparse_type);
|
||||||
const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
|
const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
|
||||||
const auto& index_type = dense_type.toScalarType(kLong);
|
const auto& index_type = dense_type.toScalarType(kLong);
|
||||||
AutoGPU autogpu(r.deviceInt64(3));
|
AutoGPU autogpu(r.deviceInt64(3));
|
||||||
@ -405,7 +407,7 @@ Tensor sparse_coo_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs
|
|||||||
return set_requires_grad(sparse_type_to_use.sparse_coo_tensor(indices, values), r.toBool(4));
|
return set_requires_grad(sparse_type_to_use.sparse_coo_tensor(indices, values), r.toBool(4));
|
||||||
} else if (r.idx == 1) {
|
} else if (r.idx == 1) {
|
||||||
bool type_inference = r.isNone(3);
|
bool type_inference = r.isNone(3);
|
||||||
const auto& sparse_type = typeWithDefault(r, 3, default_sparse_type);
|
const auto& sparse_type = typeWithDefault(r, 3, 4, default_sparse_type);
|
||||||
const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
|
const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
|
||||||
const auto& index_type = dense_type.toScalarType(kLong);
|
const auto& index_type = dense_type.toScalarType(kLong);
|
||||||
AutoGPU autogpu(r.deviceInt64(4));
|
AutoGPU autogpu(r.deviceInt64(4));
|
||||||
@ -420,7 +422,7 @@ Tensor sparse_coo_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs
|
|||||||
|
|
||||||
Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
|
Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
|
||||||
static PythonArgParser parser({
|
static PythonArgParser parser({
|
||||||
"tensor(PyObject* data, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
|
"tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||||
});
|
});
|
||||||
|
|
||||||
ParsedArgs<4> parsed_args;
|
ParsedArgs<4> parsed_args;
|
||||||
@ -428,7 +430,7 @@ Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
|
|||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
bool type_inference = r.isNone(1);
|
bool type_inference = r.isNone(1);
|
||||||
return set_requires_grad(internal_new_from_data(
|
return set_requires_grad(internal_new_from_data(
|
||||||
typeWithDefault(r, 1, type), r.deviceInt64(2), r.pyobject(0), true, true, type_inference), r.toBool(3));
|
typeWithDefault(r, 1, 2, type), r.deviceInt64(2), r.pyobject(0), true, true, type_inference), r.toBool(3));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("tensor(): invalid arguments");
|
throw std::runtime_error("tensor(): invalid arguments");
|
||||||
}
|
}
|
||||||
@ -436,27 +438,27 @@ Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
|
|||||||
|
|
||||||
Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
|
Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
|
||||||
static PythonArgParser parser({
|
static PythonArgParser parser({
|
||||||
"new_tensor(PyObject* data, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
|
"new_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||||
});
|
});
|
||||||
|
|
||||||
ParsedArgs<4> parsed_args;
|
ParsedArgs<4> parsed_args;
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
return set_requires_grad(new_from_data_copy(
|
return set_requires_grad(new_from_data_copy(
|
||||||
typeWithDefault(r, 1, type), r.deviceInt64(2), r.pyobject(0)), r.toBool(3));
|
typeWithDefault(r, 1, 2, type), r.deviceInt64(2), r.pyobject(0)), r.toBool(3));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_tensor(): invalid arguments");
|
throw std::runtime_error("new_tensor(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor new_empty(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
Tensor new_empty(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
||||||
static PythonArgParser parser({
|
static PythonArgParser parser({
|
||||||
"new_empty(IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
|
"new_empty(IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||||
});
|
});
|
||||||
|
|
||||||
ParsedArgs<4> parsed_args;
|
ParsedArgs<4> parsed_args;
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
const auto& actual_type = typeWithDefault(r, 1, type);
|
const auto& actual_type = typeWithDefault(r, 1, 2, type);
|
||||||
return set_requires_grad(new_with_sizes(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
|
return set_requires_grad(new_with_sizes(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_empty(): invalid arguments");
|
throw std::runtime_error("new_empty(): invalid arguments");
|
||||||
@ -464,13 +466,13 @@ Tensor new_empty(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
|||||||
|
|
||||||
Tensor new_full(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
Tensor new_full(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
||||||
static PythonArgParser parser({
|
static PythonArgParser parser({
|
||||||
"new_full(IntList size, Scalar fill_value, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
|
"new_full(IntList size, Scalar fill_value, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||||
});
|
});
|
||||||
|
|
||||||
ParsedArgs<5> parsed_args;
|
ParsedArgs<5> parsed_args;
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
const auto& actual_type = typeWithDefault(r, 2, type);
|
const auto& actual_type = typeWithDefault(r, 2, 3, type);
|
||||||
return set_requires_grad(dispatch_full(actual_type, r.scalar(1), r.deviceInt64(3), r.intlist(0)), r.toBool(4));
|
return set_requires_grad(dispatch_full(actual_type, r.scalar(1), r.deviceInt64(3), r.intlist(0)), r.toBool(4));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_full(): invalid arguments");
|
throw std::runtime_error("new_full(): invalid arguments");
|
||||||
@ -478,13 +480,13 @@ Tensor new_full(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
|||||||
|
|
||||||
Tensor new_ones(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
Tensor new_ones(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
||||||
static PythonArgParser parser({
|
static PythonArgParser parser({
|
||||||
"new_ones(IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
|
"new_ones(IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||||
});
|
});
|
||||||
|
|
||||||
ParsedArgs<4> parsed_args;
|
ParsedArgs<4> parsed_args;
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
const auto& actual_type = typeWithDefault(r, 1, type);
|
const auto& actual_type = typeWithDefault(r, 1, 2, type);
|
||||||
return set_requires_grad(dispatch_ones(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
|
return set_requires_grad(dispatch_ones(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_ones(): invalid arguments");
|
throw std::runtime_error("new_ones(): invalid arguments");
|
||||||
@ -492,13 +494,13 @@ Tensor new_ones(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
|||||||
|
|
||||||
Tensor new_zeros(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
Tensor new_zeros(const at::Type& type, PyObject* args, PyObject* kwargs) {
|
||||||
static PythonArgParser parser({
|
static PythonArgParser parser({
|
||||||
"new_zeros(IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
|
"new_zeros(IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||||
});
|
});
|
||||||
|
|
||||||
ParsedArgs<4> parsed_args;
|
ParsedArgs<4> parsed_args;
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
const auto& actual_type = typeWithDefault(r, 1, type);
|
const auto& actual_type = typeWithDefault(r, 1, 2, type);
|
||||||
return set_requires_grad(dispatch_zeros(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
|
return set_requires_grad(dispatch_zeros(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_zeros(): invalid arguments");
|
throw std::runtime_error("new_zeros(): invalid arguments");
|
||||||
|
|||||||
@ -270,7 +270,8 @@ def CudnnRNN(mode, input_size, hidden_size, num_layers=1,
|
|||||||
cx = None
|
cx = None
|
||||||
|
|
||||||
handle = cudnn.get_handle()
|
handle = cudnn.get_handle()
|
||||||
dropout_ts = cudnn.rnn.init_dropout_state(torch.cuda.uint8, dropout, train, dropout_seed, dropout_state)
|
dropout_ts = cudnn.rnn.init_dropout_state(torch.uint8, torch.device('cuda'), dropout,
|
||||||
|
train, dropout_seed, dropout_state)
|
||||||
|
|
||||||
weight_arr = list(itertools.chain.from_iterable(weight))
|
weight_arr = list(itertools.chain.from_iterable(weight))
|
||||||
weight_stride0 = len(weight[0])
|
weight_stride0 = len(weight[0])
|
||||||
|
|||||||
@ -85,11 +85,8 @@ def make_non_contiguous(tensor):
|
|||||||
|
|
||||||
|
|
||||||
def get_all_dtypes():
|
def get_all_dtypes():
|
||||||
cpu_dtypes = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
|
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
|
||||||
torch.float16, torch.float32, torch.float64]
|
torch.float16, torch.float32, torch.float64]
|
||||||
cuda_dtypes = [torch.cuda.uint8, torch.cuda.int8, torch.cuda.int16, torch.cuda.int32, torch.cuda.int64,
|
|
||||||
torch.cuda.float16, torch.cuda.float32, torch.cuda.float64]
|
|
||||||
return cpu_dtypes + cuda_dtypes
|
|
||||||
|
|
||||||
|
|
||||||
# 'dtype': (rtol, atol)
|
# 'dtype': (rtol, atol)
|
||||||
|
|||||||
Reference in New Issue
Block a user