Rename torch._C._TensorBase to TensorBase (#109940)

I have gone ahead and implemented the renaming of the type `torch._C._TensorBase` to a non-private class name `TensorBase`.
The changes also include leaving `torch._C._TensorBase` as an alias to the new type: 70458768fb/torch/csrc/autograd/python_variable.cpp (L2196-L2197) both in the c++ code and in the corresponding `__init__.pyi.in` file:
70458768fb/torch/_C/__init__.pyi.in (L1522)

Fixes #109438

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109940
Approved by: https://github.com/ezyang
This commit is contained in:
Moritz Hennen
2023-09-25 19:10:19 +00:00
committed by PyTorch MergeBot
parent a565f1bee6
commit 09c598745c
12 changed files with 54 additions and 42 deletions

View File

@ -103,10 +103,10 @@ void initializeGlobals(Arena & A) {
torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr();
torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__");
torch_Tensor_expand = torch.attr("_C").attr("_TensorBase").attr("expand");
torch_Tensor_split = torch.attr("_C").attr("_TensorBase").attr("split");
torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand");
torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split");
torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_");
auto py_TensorBase = torch.attr("_C").attr("_TensorBase");
auto py_TensorBase = torch.attr("_C").attr("TensorBase");
auto TensorBase = (PyTypeObject*) py_TensorBase.ptr();
THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript;
THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript;
@ -3188,7 +3188,7 @@ PyObject* _patch_tensor_class(PyObject * self_,
PY_BEGIN
auto torch = mpy::import("torch");
auto py_TensorBase = torch.attr("_C").attr("_TensorBase");
auto py_TensorBase = torch.attr("_C").attr("TensorBase");
replaceMappingIfMatches(py_TensorBase);
Py_RETURN_NONE;

View File

@ -220,6 +220,12 @@ class TestPublicBindings(TestCase):
}
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
# torch.TensorBase is explicitly removed in torch/__init__.py, so included here (#109940)
explicitly_removed_torch_C_bindings = {
"TensorBase",
}
torch_C_bindings = torch_C_bindings - explicitly_removed_torch_C_bindings
# Check that the torch._C bindings are all in the allowlist. Since
# bindings can change based on how PyTorch was compiled (e.g. with/without
# CUDA), the two may not be an exact match but the bindings should be

View File

@ -695,7 +695,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])''')
def test_list_ret(self) -> None:
# test all sequence types are permissible returns
for list_type in (list, tuple):
class A(torch._C._TensorBase):
class A(torch._C.TensorBase):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@ -715,7 +715,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])''')
def test_invalid_ret(self) -> None:
# test invalid return gets reasonable error message
class A(torch._C._TensorBase):
class A(torch._C.TensorBase):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

View File

@ -8823,10 +8823,10 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
def test_tensor_base_init(self):
# Direct construction not OK
self.assertRaises(RuntimeError, lambda: torch._C._TensorBase())
self.assertRaises(RuntimeError, lambda: torch._C.TensorBase())
# But construction of subclass is OK
class T(torch._C._TensorBase):
class T(torch._C.TensorBase):
pass
T()
@ -8845,7 +8845,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
# OK to call super().__new__, see
# https://github.com/pytorch/pytorch/issues/57421
class TestTensor(torch._C._TensorBase):
class TestTensor(torch._C.TensorBase):
@staticmethod
def __new__(cls, x, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
@ -9045,7 +9045,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
def test_tensor_slot_dealloc(self):
class SlotTensor1(torch._C._TensorBase):
class SlotTensor1(torch._C.TensorBase):
__slots__ = ['slot1']
class SlotTensor2(SlotTensor1):
@ -9108,7 +9108,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
def test_tensor_finalizer_dealloc(self):
m = [False]
class FinalizerTensor(torch._C._TensorBase):
class FinalizerTensor(torch._C.TensorBase):
def __del__(self):
m[0] = True
@ -9246,7 +9246,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
m1 = [False]
m2 = [False]
class SlotTensor1(torch._C._TensorBase):
class SlotTensor1(torch._C.TensorBase):
__slots__ = ['slot1']
def __del__(self):

View File

@ -1560,7 +1560,7 @@ class _ImperativeEngine:
class _TensorMeta(type): ...
# Defined in torch/csrc/autograd/python_variable.cpp
class _TensorBase(metaclass=_TensorMeta):
class TensorBase(metaclass=_TensorMeta):
requires_grad: _bool
retains_grad: _bool
shape: Size
@ -1589,6 +1589,8 @@ class _TensorBase(metaclass=_TensorMeta):
itemsize: _int
${tensor_method_hints}
_TensorBase = TensorBase
# Defined in torch/csrc/multiprocessing/init.cpp
def _multiprocessing_init() -> None: ...

View File

@ -489,6 +489,9 @@ for name in dir(_C):
# TODO: fix their module from C++ side
if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
obj.__module__ = 'torch'
elif name == 'TensorBase':
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
delattr(sys.modules[__name__], name)
if not TYPE_CHECKING:
# issue 38137 and python issue 43367. Submodules of a C extension are

View File

@ -1110,7 +1110,7 @@ class BuiltinVariable(VariableTracker):
for i, b in enumerate(bases)
]
elif len(bases) == 1 and (
bases[0] is object or bases[0] is torch._C._TensorBase
bases[0] is object or bases[0] is torch._C.TensorBase
):
tuple_args = [SourcelessBuilder()(tx, bases[0])]
else:

View File

@ -51,11 +51,11 @@ tensor_dunder_fns = [
torch.Tensor.__rpow__,
torch.Tensor.__rsub__,
torch.Tensor.__rdiv__,
torch._C._TensorBase.__radd__,
torch._C._TensorBase.__rmul__,
torch._C._TensorBase.__ror__,
torch._C._TensorBase.__rxor__,
torch._C._TensorBase.__rand__,
torch._C.TensorBase.__radd__,
torch._C.TensorBase.__rmul__,
torch._C.TensorBase.__ror__,
torch._C.TensorBase.__rxor__,
torch._C.TensorBase.__rand__,
]
torch_special_class_types = (torch._C.Generator,)
@ -97,31 +97,31 @@ if torch.distributed.is_available():
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
def remap_as_fn___radd__(*args):
return torch._C._TensorBase.__radd__(*args)
return torch._C.TensorBase.__radd__(*args)
def remap_as_fn___rmul__(*args):
return torch._C._TensorBase.__rmul__(*args)
return torch._C.TensorBase.__rmul__(*args)
def remap_as_fn___ror__(*args):
return torch._C._TensorBase.__ror__(*args)
return torch._C.TensorBase.__ror__(*args)
def remap_as_fn___rxor__(*args):
return torch._C._TensorBase.__rxor__(*args)
return torch._C.TensorBase.__rxor__(*args)
def remap_as_fn___rand__(*args):
return torch._C._TensorBase.__rand__(*args)
return torch._C.TensorBase.__rand__(*args)
tensor_dunder_fns_remap = {
torch._C._TensorBase.__radd__: remap_as_fn___radd__,
torch._C._TensorBase.__rmul__: remap_as_fn___rmul__,
torch._C._TensorBase.__ror__: remap_as_fn___ror__,
torch._C._TensorBase.__rxor__: remap_as_fn___rxor__,
torch._C._TensorBase.__rand__: remap_as_fn___rand__,
torch._C.TensorBase.__radd__: remap_as_fn___radd__,
torch._C.TensorBase.__rmul__: remap_as_fn___rmul__,
torch._C.TensorBase.__ror__: remap_as_fn___ror__,
torch._C.TensorBase.__rxor__: remap_as_fn___rxor__,
torch._C.TensorBase.__rand__: remap_as_fn___rand__,
}

View File

@ -1858,7 +1858,7 @@ class FakeCopyMode(TorchFunctionMode):
kwargs = kwargs if kwargs else {}
# clone will get called in Parameter deepcopy
if func == torch._C._TensorBase.clone:
if func == torch._C.TensorBase.clone:
return func(
self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs
)

View File

@ -78,7 +78,7 @@ def _rebuild_from_type_v2(func, new_type, args, state):
# NB: If you add a new method to Tensor, you must update
# torch/__init__.py.in to add a type annotation for your method;
# otherwise, it will not show up in autocomplete.
class Tensor(torch._C._TensorBase):
class Tensor(torch._C.TensorBase):
def __deepcopy__(self, memo):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
@ -628,7 +628,7 @@ class Tensor(torch._C._TensorBase):
)
detach = _C._add_docstr(
_C._TensorBase.detach,
_C.TensorBase.detach,
r"""
Returns a new Tensor, detached from the current graph.
@ -654,7 +654,7 @@ class Tensor(torch._C._TensorBase):
)
detach_ = _C._add_docstr(
_C._TensorBase.detach_,
_C.TensorBase.detach_,
r"""
Detaches the Tensor from the graph that created it, making it a leaf.
Views cannot be detached in-place.
@ -913,13 +913,13 @@ class Tensor(torch._C._TensorBase):
return self.reciprocal() * other
__rtruediv__ = __rdiv__
__itruediv__ = _C._TensorBase.__idiv__
__itruediv__ = _C.TensorBase.__idiv__
__pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
_C._TensorBase.pow
_C.TensorBase.pow
)
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
_C._TensorBase.pow_
_C.TensorBase.pow_
)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
@ -957,9 +957,9 @@ class Tensor(torch._C._TensorBase):
def __rmatmul__(self, other):
return torch.matmul(other, self)
__pos__ = _C._TensorBase.positive
__neg__ = _C._TensorBase.neg
__abs__ = _C._TensorBase.abs
__pos__ = _C.TensorBase.positive
__neg__ = _C.TensorBase.neg
__abs__ = _C.TensorBase.abs
def __len__(self):
if has_torch_function_unary(self):

View File

@ -6,7 +6,7 @@ from ._torch_docs import parse_kwargs, reproducibility_notes
def add_docstr_all(method, docstr):
add_docstr(getattr(torch._C._TensorBase, method), docstr)
add_docstr(getattr(torch._C.TensorBase, method), docstr)
common_args = parse_kwargs(

View File

@ -1690,7 +1690,7 @@ PyTypeObject THPVariableMetaType = {
PyTypeObject THPVariableType = {
PyVarObject_HEAD_INIT(
&THPVariableMetaType,
0) "torch._C._TensorBase", /* tp_name */
0) "torch._C.TensorBase", /* tp_name */
sizeof(THPVariable), /* tp_basicsize */
0, /* tp_itemsize */
// This is unspecified, because it is illegal to create a THPVariableType
@ -1743,7 +1743,7 @@ PyObject* THPVariable_pynew(
HANDLE_TH_ERRORS
TORCH_CHECK(
type != &THPVariableType,
"Cannot directly construct _TensorBase; subclass it and then construct that");
"Cannot directly construct TensorBase; subclass it and then construct that");
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
@ -2193,6 +2193,7 @@ bool THPVariable_initModule(PyObject* module) {
if (PyType_Ready(&THPVariableType) < 0)
return false;
Py_INCREF(&THPVariableType);
PyModule_AddObject(module, "TensorBase", (PyObject*)&THPVariableType);
PyModule_AddObject(module, "_TensorBase", (PyObject*)&THPVariableType);
torch::autograd::initTorchFunctions(module);
torch::autograd::initTensorImplConversion(module);