mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Allow direct Tensor constructor to return preexisting PyObject (#92754)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/92754 Approved by: https://github.com/albanD, https://github.com/voznesenskym
This commit is contained in:
committed by
PyTorch MergeBot
parent
e994e78397
commit
1237cf6b6c
@ -390,7 +390,8 @@ PyObject* ParameterClass = nullptr;
|
||||
static PyObject* THPVariable_NewWithVar(
|
||||
PyTypeObject* type,
|
||||
Variable _var,
|
||||
c10::impl::PyInterpreterStatus status);
|
||||
c10::impl::PyInterpreterStatus status,
|
||||
bool allow_preexisting_pyobj = false);
|
||||
|
||||
// clang-tidy gets confused by static const
|
||||
static const char* VOLATILE_WARNING =
|
||||
@ -1804,10 +1805,14 @@ PyObject* THPVariable_pynew(
|
||||
auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
|
||||
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
|
||||
// given a raw pointer that will refcount bump
|
||||
// NB: base_tensor_ctor can call into dispatched ATen functions (e.g.,
|
||||
// alias(), lift_fresh()) which can return Tensor subclasses. We allow
|
||||
// these to be passed on directly.
|
||||
return THPVariable_NewWithVar(
|
||||
type,
|
||||
std::move(tensor),
|
||||
c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED);
|
||||
c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED,
|
||||
/*allow_preexisting_pyobj=*/true);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -1940,25 +1945,78 @@ void THPVariable_subclass_dealloc(PyObject* self) {
|
||||
static PyObject* THPVariable_NewWithVar(
|
||||
PyTypeObject* type,
|
||||
Variable _var,
|
||||
c10::impl::PyInterpreterStatus status) {
|
||||
// This function overwrite the Tensor's pyobj field without extra checks
|
||||
// Make sure it is not set otherwise we would leak memory
|
||||
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
self_interpreter.get());
|
||||
TORCH_CHECK(
|
||||
!mb_obj.has_value() || !mb_obj.value(),
|
||||
"Creating a new Tensor subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Tensor object is already associated to a python object ",
|
||||
"of type ",
|
||||
mb_obj.value()->ob_type->tp_name);
|
||||
|
||||
c10::impl::PyInterpreterStatus status,
|
||||
bool allow_preexisting_pyobj) {
|
||||
// Make sure that the reinterpret into a THPVariable* will be valid
|
||||
TORCH_CHECK(
|
||||
PyType_IsSubtype(type, &THPVariableType),
|
||||
"Creating a Tensor subclass from a class ",
|
||||
"that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
|
||||
|
||||
// This function overwrite the Tensor's pyobj field without extra checks
|
||||
// Make sure it is not set otherwise we would leak memory
|
||||
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
self_interpreter.get());
|
||||
|
||||
// Under some circumstances, we may attempt to create a new Python
|
||||
// object for a variable that already has a Python object. The most common
|
||||
// situation this can occur is if you have a TorchDispatchMode active that
|
||||
// is returning a subclass from lift_fresh (which is invoked to
|
||||
// appropriately "wrap" a constant tensor into whatever ambient modes are
|
||||
// active.)
|
||||
//
|
||||
// In general, it is impossible to handle this case compositionally.
|
||||
// Suppose you have a user call ATensor([1, 2, 3]) when a mode is active
|
||||
// that is transforming all ops (including the internal lift_fresh call that
|
||||
// transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output
|
||||
// BTensor, where ATensor and BTensor are completely unrelated subclasses
|
||||
// and there is no way to compose them. There is no way to satisfy the user
|
||||
// request here: in particular, you can't just try to re-invoke the ATensor
|
||||
// constructor on the returned BTensor, because (1) this could cause an
|
||||
// infinite loop--we are already in ATensor.__new__ and (2) there isn't any
|
||||
// guarantee that ATensor.__new__ supports a single element constructor
|
||||
// anyway.
|
||||
//
|
||||
// However, a more common case is a user just called torch.Tensor([1, 2, 3]),
|
||||
// and a fake tensor mode is active. Really, all you want is to get back
|
||||
// a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3)
|
||||
// would have returned a fake tensor (concretely, the way this happens
|
||||
// is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it
|
||||
// turns into a FakeTensor when we call lift_fresh on this real tensor).
|
||||
// This case is compositional because FakeTensor is a subclass of Tensor, so
|
||||
// it's valid for us to return it in place of a Tensor. So this is what we
|
||||
// do.
|
||||
|
||||
if (mb_obj.has_value() && mb_obj.value()) {
|
||||
TORCH_CHECK(
|
||||
allow_preexisting_pyobj,
|
||||
"Creating a new Tensor subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Tensor object is already associated to a python object ",
|
||||
"of type ",
|
||||
mb_obj.value()->ob_type->tp_name);
|
||||
// Even if we allow pre-existing PyObject, we don't allow completely
|
||||
// ignoring the requested type. Check that we fulfilled a subtype
|
||||
// relation here. In the common case the requested type is Tensor and
|
||||
// this always succeeds.
|
||||
PyObject* obj = *mb_obj;
|
||||
// Check if it's OK to just directly return the Python object without
|
||||
// allocating a new variable. We just check that the existing Python
|
||||
// object is a subclass of the requested type.
|
||||
PyTypeObject* obj_type = Py_TYPE(obj);
|
||||
TORCH_CHECK(
|
||||
obj_type == type || PyType_IsSubtype(obj_type, type),
|
||||
"Creating a new Tensor subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Tensor object is already associated to a python object ",
|
||||
"of type ",
|
||||
mb_obj.value()->ob_type->tp_name,
|
||||
" which is not a subclass of the "
|
||||
"requested type");
|
||||
// We may (in fact, we typically will) need to resurrect this
|
||||
return THPVariable_Wrap(std::move(_var));
|
||||
}
|
||||
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
if (obj) {
|
||||
auto v = (THPVariable*)obj;
|
||||
|
Reference in New Issue
Block a user