Allow direct Tensor constructor to return preexisting PyObject (#92754)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92754
Approved by: https://github.com/albanD, https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang
2023-01-23 16:53:00 +00:00
committed by PyTorch MergeBot
parent e994e78397
commit 1237cf6b6c
3 changed files with 105 additions and 27 deletions

View File

@ -390,7 +390,8 @@ PyObject* ParameterClass = nullptr;
static PyObject* THPVariable_NewWithVar(
PyTypeObject* type,
Variable _var,
c10::impl::PyInterpreterStatus status);
c10::impl::PyInterpreterStatus status,
bool allow_preexisting_pyobj = false);
// clang-tidy gets confused by static const
static const char* VOLATILE_WARNING =
@ -1804,10 +1805,14 @@ PyObject* THPVariable_pynew(
auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
// given a raw pointer that will refcount bump
// NB: base_tensor_ctor can call into dispatched ATen functions (e.g.,
// alias(), lift_fresh()) which can return Tensor subclasses. We allow
// these to be passed on directly.
return THPVariable_NewWithVar(
type,
std::move(tensor),
c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED);
c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED,
/*allow_preexisting_pyobj=*/true);
END_HANDLE_TH_ERRORS
}
@ -1940,25 +1945,78 @@ void THPVariable_subclass_dealloc(PyObject* self) {
static PyObject* THPVariable_NewWithVar(
PyTypeObject* type,
Variable _var,
c10::impl::PyInterpreterStatus status) {
// This function overwrite the Tensor's pyobj field without extra checks
// Make sure it is not set otherwise we would leak memory
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
self_interpreter.get());
TORCH_CHECK(
!mb_obj.has_value() || !mb_obj.value(),
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
mb_obj.value()->ob_type->tp_name);
c10::impl::PyInterpreterStatus status,
bool allow_preexisting_pyobj) {
// Make sure that the reinterpret into a THPVariable* will be valid
TORCH_CHECK(
PyType_IsSubtype(type, &THPVariableType),
"Creating a Tensor subclass from a class ",
"that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
// This function overwrite the Tensor's pyobj field without extra checks
// Make sure it is not set otherwise we would leak memory
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
self_interpreter.get());
// Under some circumstances, we may attempt to create a new Python
// object for a variable that already has a Python object. The most common
// situation this can occur is if you have a TorchDispatchMode active that
// is returning a subclass from lift_fresh (which is invoked to
// appropriately "wrap" a constant tensor into whatever ambient modes are
// active.)
//
// In general, it is impossible to handle this case compositionally.
// Suppose you have a user call ATensor([1, 2, 3]) when a mode is active
// that is transforming all ops (including the internal lift_fresh call that
// transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output
// BTensor, where ATensor and BTensor are completely unrelated subclasses
// and there is no way to compose them. There is no way to satisfy the user
// request here: in particular, you can't just try to re-invoke the ATensor
// constructor on the returned BTensor, because (1) this could cause an
// infinite loop--we are already in ATensor.__new__ and (2) there isn't any
// guarantee that ATensor.__new__ supports a single element constructor
// anyway.
//
// However, a more common case is a user just called torch.Tensor([1, 2, 3]),
// and a fake tensor mode is active. Really, all you want is to get back
// a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3)
// would have returned a fake tensor (concretely, the way this happens
// is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it
// turns into a FakeTensor when we call lift_fresh on this real tensor).
// This case is compositional because FakeTensor is a subclass of Tensor, so
// it's valid for us to return it in place of a Tensor. So this is what we
// do.
if (mb_obj.has_value() && mb_obj.value()) {
TORCH_CHECK(
allow_preexisting_pyobj,
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
mb_obj.value()->ob_type->tp_name);
// Even if we allow pre-existing PyObject, we don't allow completely
// ignoring the requested type. Check that we fulfilled a subtype
// relation here. In the common case the requested type is Tensor and
// this always succeeds.
PyObject* obj = *mb_obj;
// Check if it's OK to just directly return the Python object without
// allocating a new variable. We just check that the existing Python
// object is a subclass of the requested type.
PyTypeObject* obj_type = Py_TYPE(obj);
TORCH_CHECK(
obj_type == type || PyType_IsSubtype(obj_type, type),
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
mb_obj.value()->ob_type->tp_name,
" which is not a subclass of the "
"requested type");
// We may (in fact, we typically will) need to resurrect this
return THPVariable_Wrap(std::move(_var));
}
PyObject* obj = type->tp_alloc(type, 0);
if (obj) {
auto v = (THPVariable*)obj;