mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rename torch._C._TensorBase
to TensorBase
(#109940)
I have gone ahead and implemented the renaming of the type `torch._C._TensorBase` to a non-private class name `TensorBase`. The changes also include leaving `torch._C._TensorBase` as an alias to the new type:70458768fb/torch/csrc/autograd/python_variable.cpp (L2196-L2197)
both in the c++ code and in the corresponding `__init__.pyi.in` file:70458768fb/torch/_C/__init__.pyi.in (L1522)
Fixes #109438 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109940 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a565f1bee6
commit
09c598745c
@ -103,10 +103,10 @@ void initializeGlobals(Arena & A) {
|
|||||||
torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr();
|
torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr();
|
||||||
torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__");
|
torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__");
|
||||||
|
|
||||||
torch_Tensor_expand = torch.attr("_C").attr("_TensorBase").attr("expand");
|
torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand");
|
||||||
torch_Tensor_split = torch.attr("_C").attr("_TensorBase").attr("split");
|
torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split");
|
||||||
torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_");
|
torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_");
|
||||||
auto py_TensorBase = torch.attr("_C").attr("_TensorBase");
|
auto py_TensorBase = torch.attr("_C").attr("TensorBase");
|
||||||
auto TensorBase = (PyTypeObject*) py_TensorBase.ptr();
|
auto TensorBase = (PyTypeObject*) py_TensorBase.ptr();
|
||||||
THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript;
|
THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript;
|
||||||
THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript;
|
THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript;
|
||||||
@ -3188,7 +3188,7 @@ PyObject* _patch_tensor_class(PyObject * self_,
|
|||||||
PY_BEGIN
|
PY_BEGIN
|
||||||
|
|
||||||
auto torch = mpy::import("torch");
|
auto torch = mpy::import("torch");
|
||||||
auto py_TensorBase = torch.attr("_C").attr("_TensorBase");
|
auto py_TensorBase = torch.attr("_C").attr("TensorBase");
|
||||||
replaceMappingIfMatches(py_TensorBase);
|
replaceMappingIfMatches(py_TensorBase);
|
||||||
|
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
|
@ -220,6 +220,12 @@ class TestPublicBindings(TestCase):
|
|||||||
}
|
}
|
||||||
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
|
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
|
||||||
|
|
||||||
|
# torch.TensorBase is explicitly removed in torch/__init__.py, so included here (#109940)
|
||||||
|
explicitly_removed_torch_C_bindings = {
|
||||||
|
"TensorBase",
|
||||||
|
}
|
||||||
|
torch_C_bindings = torch_C_bindings - explicitly_removed_torch_C_bindings
|
||||||
|
|
||||||
# Check that the torch._C bindings are all in the allowlist. Since
|
# Check that the torch._C bindings are all in the allowlist. Since
|
||||||
# bindings can change based on how PyTorch was compiled (e.g. with/without
|
# bindings can change based on how PyTorch was compiled (e.g. with/without
|
||||||
# CUDA), the two may not be an exact match but the bindings should be
|
# CUDA), the two may not be an exact match but the bindings should be
|
||||||
|
@ -695,7 +695,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])''')
|
|||||||
def test_list_ret(self) -> None:
|
def test_list_ret(self) -> None:
|
||||||
# test all sequence types are permissible returns
|
# test all sequence types are permissible returns
|
||||||
for list_type in (list, tuple):
|
for list_type in (list, tuple):
|
||||||
class A(torch._C._TensorBase):
|
class A(torch._C.TensorBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, elem):
|
def __new__(cls, elem):
|
||||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||||
@ -715,7 +715,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])''')
|
|||||||
|
|
||||||
def test_invalid_ret(self) -> None:
|
def test_invalid_ret(self) -> None:
|
||||||
# test invalid return gets reasonable error message
|
# test invalid return gets reasonable error message
|
||||||
class A(torch._C._TensorBase):
|
class A(torch._C.TensorBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, elem):
|
def __new__(cls, elem):
|
||||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||||
|
@ -8823,10 +8823,10 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
|
|
||||||
def test_tensor_base_init(self):
|
def test_tensor_base_init(self):
|
||||||
# Direct construction not OK
|
# Direct construction not OK
|
||||||
self.assertRaises(RuntimeError, lambda: torch._C._TensorBase())
|
self.assertRaises(RuntimeError, lambda: torch._C.TensorBase())
|
||||||
|
|
||||||
# But construction of subclass is OK
|
# But construction of subclass is OK
|
||||||
class T(torch._C._TensorBase):
|
class T(torch._C.TensorBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
T()
|
T()
|
||||||
@ -8845,7 +8845,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
|
|
||||||
# OK to call super().__new__, see
|
# OK to call super().__new__, see
|
||||||
# https://github.com/pytorch/pytorch/issues/57421
|
# https://github.com/pytorch/pytorch/issues/57421
|
||||||
class TestTensor(torch._C._TensorBase):
|
class TestTensor(torch._C.TensorBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, x, *args, **kwargs):
|
def __new__(cls, x, *args, **kwargs):
|
||||||
return super().__new__(cls, x, *args, **kwargs)
|
return super().__new__(cls, x, *args, **kwargs)
|
||||||
@ -9045,7 +9045,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
|
|
||||||
def test_tensor_slot_dealloc(self):
|
def test_tensor_slot_dealloc(self):
|
||||||
|
|
||||||
class SlotTensor1(torch._C._TensorBase):
|
class SlotTensor1(torch._C.TensorBase):
|
||||||
__slots__ = ['slot1']
|
__slots__ = ['slot1']
|
||||||
|
|
||||||
class SlotTensor2(SlotTensor1):
|
class SlotTensor2(SlotTensor1):
|
||||||
@ -9108,7 +9108,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
def test_tensor_finalizer_dealloc(self):
|
def test_tensor_finalizer_dealloc(self):
|
||||||
m = [False]
|
m = [False]
|
||||||
|
|
||||||
class FinalizerTensor(torch._C._TensorBase):
|
class FinalizerTensor(torch._C.TensorBase):
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
m[0] = True
|
m[0] = True
|
||||||
|
|
||||||
@ -9246,7 +9246,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
m1 = [False]
|
m1 = [False]
|
||||||
m2 = [False]
|
m2 = [False]
|
||||||
|
|
||||||
class SlotTensor1(torch._C._TensorBase):
|
class SlotTensor1(torch._C.TensorBase):
|
||||||
__slots__ = ['slot1']
|
__slots__ = ['slot1']
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
@ -1560,7 +1560,7 @@ class _ImperativeEngine:
|
|||||||
class _TensorMeta(type): ...
|
class _TensorMeta(type): ...
|
||||||
|
|
||||||
# Defined in torch/csrc/autograd/python_variable.cpp
|
# Defined in torch/csrc/autograd/python_variable.cpp
|
||||||
class _TensorBase(metaclass=_TensorMeta):
|
class TensorBase(metaclass=_TensorMeta):
|
||||||
requires_grad: _bool
|
requires_grad: _bool
|
||||||
retains_grad: _bool
|
retains_grad: _bool
|
||||||
shape: Size
|
shape: Size
|
||||||
@ -1589,6 +1589,8 @@ class _TensorBase(metaclass=_TensorMeta):
|
|||||||
itemsize: _int
|
itemsize: _int
|
||||||
${tensor_method_hints}
|
${tensor_method_hints}
|
||||||
|
|
||||||
|
_TensorBase = TensorBase
|
||||||
|
|
||||||
# Defined in torch/csrc/multiprocessing/init.cpp
|
# Defined in torch/csrc/multiprocessing/init.cpp
|
||||||
def _multiprocessing_init() -> None: ...
|
def _multiprocessing_init() -> None: ...
|
||||||
|
|
||||||
|
@ -489,6 +489,9 @@ for name in dir(_C):
|
|||||||
# TODO: fix their module from C++ side
|
# TODO: fix their module from C++ side
|
||||||
if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
|
if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
|
||||||
obj.__module__ = 'torch'
|
obj.__module__ = 'torch'
|
||||||
|
elif name == 'TensorBase':
|
||||||
|
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
|
||||||
|
delattr(sys.modules[__name__], name)
|
||||||
|
|
||||||
if not TYPE_CHECKING:
|
if not TYPE_CHECKING:
|
||||||
# issue 38137 and python issue 43367. Submodules of a C extension are
|
# issue 38137 and python issue 43367. Submodules of a C extension are
|
||||||
|
@ -1110,7 +1110,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
for i, b in enumerate(bases)
|
for i, b in enumerate(bases)
|
||||||
]
|
]
|
||||||
elif len(bases) == 1 and (
|
elif len(bases) == 1 and (
|
||||||
bases[0] is object or bases[0] is torch._C._TensorBase
|
bases[0] is object or bases[0] is torch._C.TensorBase
|
||||||
):
|
):
|
||||||
tuple_args = [SourcelessBuilder()(tx, bases[0])]
|
tuple_args = [SourcelessBuilder()(tx, bases[0])]
|
||||||
else:
|
else:
|
||||||
|
@ -51,11 +51,11 @@ tensor_dunder_fns = [
|
|||||||
torch.Tensor.__rpow__,
|
torch.Tensor.__rpow__,
|
||||||
torch.Tensor.__rsub__,
|
torch.Tensor.__rsub__,
|
||||||
torch.Tensor.__rdiv__,
|
torch.Tensor.__rdiv__,
|
||||||
torch._C._TensorBase.__radd__,
|
torch._C.TensorBase.__radd__,
|
||||||
torch._C._TensorBase.__rmul__,
|
torch._C.TensorBase.__rmul__,
|
||||||
torch._C._TensorBase.__ror__,
|
torch._C.TensorBase.__ror__,
|
||||||
torch._C._TensorBase.__rxor__,
|
torch._C.TensorBase.__rxor__,
|
||||||
torch._C._TensorBase.__rand__,
|
torch._C.TensorBase.__rand__,
|
||||||
]
|
]
|
||||||
|
|
||||||
torch_special_class_types = (torch._C.Generator,)
|
torch_special_class_types = (torch._C.Generator,)
|
||||||
@ -97,31 +97,31 @@ if torch.distributed.is_available():
|
|||||||
|
|
||||||
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
|
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
|
||||||
def remap_as_fn___radd__(*args):
|
def remap_as_fn___radd__(*args):
|
||||||
return torch._C._TensorBase.__radd__(*args)
|
return torch._C.TensorBase.__radd__(*args)
|
||||||
|
|
||||||
|
|
||||||
def remap_as_fn___rmul__(*args):
|
def remap_as_fn___rmul__(*args):
|
||||||
return torch._C._TensorBase.__rmul__(*args)
|
return torch._C.TensorBase.__rmul__(*args)
|
||||||
|
|
||||||
|
|
||||||
def remap_as_fn___ror__(*args):
|
def remap_as_fn___ror__(*args):
|
||||||
return torch._C._TensorBase.__ror__(*args)
|
return torch._C.TensorBase.__ror__(*args)
|
||||||
|
|
||||||
|
|
||||||
def remap_as_fn___rxor__(*args):
|
def remap_as_fn___rxor__(*args):
|
||||||
return torch._C._TensorBase.__rxor__(*args)
|
return torch._C.TensorBase.__rxor__(*args)
|
||||||
|
|
||||||
|
|
||||||
def remap_as_fn___rand__(*args):
|
def remap_as_fn___rand__(*args):
|
||||||
return torch._C._TensorBase.__rand__(*args)
|
return torch._C.TensorBase.__rand__(*args)
|
||||||
|
|
||||||
|
|
||||||
tensor_dunder_fns_remap = {
|
tensor_dunder_fns_remap = {
|
||||||
torch._C._TensorBase.__radd__: remap_as_fn___radd__,
|
torch._C.TensorBase.__radd__: remap_as_fn___radd__,
|
||||||
torch._C._TensorBase.__rmul__: remap_as_fn___rmul__,
|
torch._C.TensorBase.__rmul__: remap_as_fn___rmul__,
|
||||||
torch._C._TensorBase.__ror__: remap_as_fn___ror__,
|
torch._C.TensorBase.__ror__: remap_as_fn___ror__,
|
||||||
torch._C._TensorBase.__rxor__: remap_as_fn___rxor__,
|
torch._C.TensorBase.__rxor__: remap_as_fn___rxor__,
|
||||||
torch._C._TensorBase.__rand__: remap_as_fn___rand__,
|
torch._C.TensorBase.__rand__: remap_as_fn___rand__,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1858,7 +1858,7 @@ class FakeCopyMode(TorchFunctionMode):
|
|||||||
kwargs = kwargs if kwargs else {}
|
kwargs = kwargs if kwargs else {}
|
||||||
|
|
||||||
# clone will get called in Parameter deepcopy
|
# clone will get called in Parameter deepcopy
|
||||||
if func == torch._C._TensorBase.clone:
|
if func == torch._C.TensorBase.clone:
|
||||||
return func(
|
return func(
|
||||||
self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs
|
self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs
|
||||||
)
|
)
|
||||||
|
@ -78,7 +78,7 @@ def _rebuild_from_type_v2(func, new_type, args, state):
|
|||||||
# NB: If you add a new method to Tensor, you must update
|
# NB: If you add a new method to Tensor, you must update
|
||||||
# torch/__init__.py.in to add a type annotation for your method;
|
# torch/__init__.py.in to add a type annotation for your method;
|
||||||
# otherwise, it will not show up in autocomplete.
|
# otherwise, it will not show up in autocomplete.
|
||||||
class Tensor(torch._C._TensorBase):
|
class Tensor(torch._C.TensorBase):
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
if has_torch_function_unary(self):
|
if has_torch_function_unary(self):
|
||||||
return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
|
return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
|
||||||
@ -628,7 +628,7 @@ class Tensor(torch._C._TensorBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
detach = _C._add_docstr(
|
detach = _C._add_docstr(
|
||||||
_C._TensorBase.detach,
|
_C.TensorBase.detach,
|
||||||
r"""
|
r"""
|
||||||
Returns a new Tensor, detached from the current graph.
|
Returns a new Tensor, detached from the current graph.
|
||||||
|
|
||||||
@ -654,7 +654,7 @@ class Tensor(torch._C._TensorBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
detach_ = _C._add_docstr(
|
detach_ = _C._add_docstr(
|
||||||
_C._TensorBase.detach_,
|
_C.TensorBase.detach_,
|
||||||
r"""
|
r"""
|
||||||
Detaches the Tensor from the graph that created it, making it a leaf.
|
Detaches the Tensor from the graph that created it, making it a leaf.
|
||||||
Views cannot be detached in-place.
|
Views cannot be detached in-place.
|
||||||
@ -913,13 +913,13 @@ class Tensor(torch._C._TensorBase):
|
|||||||
return self.reciprocal() * other
|
return self.reciprocal() * other
|
||||||
|
|
||||||
__rtruediv__ = __rdiv__
|
__rtruediv__ = __rdiv__
|
||||||
__itruediv__ = _C._TensorBase.__idiv__
|
__itruediv__ = _C.TensorBase.__idiv__
|
||||||
|
|
||||||
__pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
|
__pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
|
||||||
_C._TensorBase.pow
|
_C.TensorBase.pow
|
||||||
)
|
)
|
||||||
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
|
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
|
||||||
_C._TensorBase.pow_
|
_C.TensorBase.pow_
|
||||||
)
|
)
|
||||||
|
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||||
@ -957,9 +957,9 @@ class Tensor(torch._C._TensorBase):
|
|||||||
def __rmatmul__(self, other):
|
def __rmatmul__(self, other):
|
||||||
return torch.matmul(other, self)
|
return torch.matmul(other, self)
|
||||||
|
|
||||||
__pos__ = _C._TensorBase.positive
|
__pos__ = _C.TensorBase.positive
|
||||||
__neg__ = _C._TensorBase.neg
|
__neg__ = _C.TensorBase.neg
|
||||||
__abs__ = _C._TensorBase.abs
|
__abs__ = _C.TensorBase.abs
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if has_torch_function_unary(self):
|
if has_torch_function_unary(self):
|
||||||
|
@ -6,7 +6,7 @@ from ._torch_docs import parse_kwargs, reproducibility_notes
|
|||||||
|
|
||||||
|
|
||||||
def add_docstr_all(method, docstr):
|
def add_docstr_all(method, docstr):
|
||||||
add_docstr(getattr(torch._C._TensorBase, method), docstr)
|
add_docstr(getattr(torch._C.TensorBase, method), docstr)
|
||||||
|
|
||||||
|
|
||||||
common_args = parse_kwargs(
|
common_args = parse_kwargs(
|
||||||
|
@ -1690,7 +1690,7 @@ PyTypeObject THPVariableMetaType = {
|
|||||||
PyTypeObject THPVariableType = {
|
PyTypeObject THPVariableType = {
|
||||||
PyVarObject_HEAD_INIT(
|
PyVarObject_HEAD_INIT(
|
||||||
&THPVariableMetaType,
|
&THPVariableMetaType,
|
||||||
0) "torch._C._TensorBase", /* tp_name */
|
0) "torch._C.TensorBase", /* tp_name */
|
||||||
sizeof(THPVariable), /* tp_basicsize */
|
sizeof(THPVariable), /* tp_basicsize */
|
||||||
0, /* tp_itemsize */
|
0, /* tp_itemsize */
|
||||||
// This is unspecified, because it is illegal to create a THPVariableType
|
// This is unspecified, because it is illegal to create a THPVariableType
|
||||||
@ -1743,7 +1743,7 @@ PyObject* THPVariable_pynew(
|
|||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
type != &THPVariableType,
|
type != &THPVariableType,
|
||||||
"Cannot directly construct _TensorBase; subclass it and then construct that");
|
"Cannot directly construct TensorBase; subclass it and then construct that");
|
||||||
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
|
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
|
||||||
auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
|
auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
|
||||||
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
|
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
|
||||||
@ -2193,6 +2193,7 @@ bool THPVariable_initModule(PyObject* module) {
|
|||||||
if (PyType_Ready(&THPVariableType) < 0)
|
if (PyType_Ready(&THPVariableType) < 0)
|
||||||
return false;
|
return false;
|
||||||
Py_INCREF(&THPVariableType);
|
Py_INCREF(&THPVariableType);
|
||||||
|
PyModule_AddObject(module, "TensorBase", (PyObject*)&THPVariableType);
|
||||||
PyModule_AddObject(module, "_TensorBase", (PyObject*)&THPVariableType);
|
PyModule_AddObject(module, "_TensorBase", (PyObject*)&THPVariableType);
|
||||||
torch::autograd::initTorchFunctions(module);
|
torch::autograd::initTorchFunctions(module);
|
||||||
torch::autograd::initTensorImplConversion(module);
|
torch::autograd::initTensorImplConversion(module);
|
||||||
|
Reference in New Issue
Block a user