mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
89 lines
2.4 KiB
Python
89 lines
2.4 KiB
Python
import torch
|
|
from torch import _C
|
|
from ..tensor import _TensorBase
|
|
from torch.sparse import _SparseBase, _sparse_tensor_classes
|
|
from . import _lazy_init, device, _dummy_type
|
|
|
|
|
|
if not hasattr(torch._C, 'CudaSparseDoubleTensorBase'):
|
|
# Define dummy base classes
|
|
for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half']:
|
|
tensor_name = 'CudaSparse{0}TensorBase'.format(t)
|
|
|
|
torch._C.__dict__[tensor_name] = _dummy_type(tensor_name)
|
|
|
|
|
|
class _CudaSparseBase(object):
|
|
is_cuda = True
|
|
is_sparse = True
|
|
|
|
def type(self, *args, **kwargs):
|
|
with device(self.get_device()):
|
|
return super(_CudaSparseBase, self).type(*args, **kwargs)
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
_lazy_init()
|
|
# We need this method only for lazy init, so we can remove it
|
|
del _CudaSparseBase.__new__
|
|
return super(_CudaSparseBase, cls).__new__(cls, *args, **kwargs)
|
|
|
|
|
|
class DoubleTensor(_CudaSparseBase, torch._C.CudaSparseDoubleTensorBase, _SparseBase, _TensorBase):
|
|
|
|
def is_signed(self):
|
|
return True
|
|
|
|
|
|
class FloatTensor(_CudaSparseBase, torch._C.CudaSparseFloatTensorBase, _SparseBase, _TensorBase):
|
|
|
|
def is_signed(self):
|
|
return True
|
|
|
|
|
|
class LongTensor(_CudaSparseBase, torch._C.CudaSparseLongTensorBase, _SparseBase, _TensorBase):
|
|
|
|
def is_signed(self):
|
|
return True
|
|
|
|
|
|
class IntTensor(_CudaSparseBase, torch._C.CudaSparseIntTensorBase, _SparseBase, _TensorBase):
|
|
|
|
def is_signed(self):
|
|
return True
|
|
|
|
|
|
class ShortTensor(_CudaSparseBase, torch._C.CudaSparseShortTensorBase, _SparseBase, _TensorBase):
|
|
|
|
def is_signed(self):
|
|
return True
|
|
|
|
|
|
class CharTensor(_CudaSparseBase, torch._C.CudaSparseCharTensorBase, _SparseBase, _TensorBase):
|
|
|
|
def is_signed(self):
|
|
# TODO
|
|
return False
|
|
|
|
|
|
class ByteTensor(_CudaSparseBase, torch._C.CudaSparseByteTensorBase, _SparseBase, _TensorBase):
|
|
|
|
def is_signed(self):
|
|
return False
|
|
|
|
|
|
class HalfTensor(_CudaSparseBase, torch._C.CudaSparseHalfTensorBase, _SparseBase, _TensorBase):
|
|
|
|
def is_signed(self):
|
|
return True
|
|
|
|
|
|
_sparse_tensor_classes.add(DoubleTensor)
|
|
_sparse_tensor_classes.add(FloatTensor)
|
|
_sparse_tensor_classes.add(LongTensor)
|
|
_sparse_tensor_classes.add(IntTensor)
|
|
_sparse_tensor_classes.add(ShortTensor)
|
|
_sparse_tensor_classes.add(CharTensor)
|
|
_sparse_tensor_classes.add(ByteTensor)
|
|
_sparse_tensor_classes.add(HalfTensor)
|
|
torch._tensor_classes.update(_sparse_tensor_classes)
|