From 1605d4aeb80c15c48f74ca8a82485addf26c9e53 Mon Sep 17 00:00:00 2001 From: cyy Date: Sat, 26 Oct 2024 00:13:19 +0000 Subject: [PATCH] Fix object slice (#138880) To avoid casting Tensor to Tensorbase Pull Request resolved: https://github.com/pytorch/pytorch/pull/138880 Approved by: https://github.com/Skylion007 --- functorch/csrc/dim/dim.cpp | 2 +- functorch/csrc/dim/python_variable_simple.h | 2 +- torch/csrc/autograd/python_variable.cpp | 31 ++++++++++----------- torch/csrc/autograd/python_variable.h | 2 +- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 47fe87c23526..304839cbaeed 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -867,7 +867,7 @@ mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice self = Tensor::create(); diff --git a/functorch/csrc/dim/python_variable_simple.h b/functorch/csrc/dim/python_variable_simple.h index caae56610760..fbd5cfd82815 100644 --- a/functorch/csrc/dim/python_variable_simple.h +++ b/functorch/csrc/dim/python_variable_simple.h @@ -26,7 +26,7 @@ struct THPVariable { TORCH_PYTHON_API extern PyObject *THPVariableClass; TORCH_PYTHON_API extern PyObject *ParameterClass; -TORCH_PYTHON_API PyObject * THPVariable_Wrap(at::TensorBase var); +TORCH_PYTHON_API PyObject * THPVariable_Wrap(const at::TensorBase& var); inline bool THPVariable_Check(PyObject *obj) { diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index d9c4ca0dc065..8f113a6a7028 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -207,7 +207,7 @@ PyObject* ParameterClass = nullptr; static PyObject* THPVariable_NewWithVar( PyTypeObject* type, - Variable _var, + const at::TensorBase& _var, c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj = false); @@ -254,8 +254,7 @@ void activateGPUTrace() { c10::impl::GPUTrace::set_trace(getPyInterpreter()); } -// TODO: Make this take Variable by const reference -PyObject* THPVariable_Wrap(at::TensorBase var) { +PyObject* THPVariable_Wrap(const at::TensorBase& var) { if (!var.defined()) { Py_RETURN_NONE; } @@ -263,7 +262,7 @@ PyObject* THPVariable_Wrap(at::TensorBase var) { if (c10::impl::HermeticPyObjectTLS::get_state()) { return THPVariable_NewWithVar( (PyTypeObject*)THPVariableClass, - std::move(var), + var, c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); } @@ -282,7 +281,7 @@ PyObject* THPVariable_Wrap(at::TensorBase var) { // object if all C++ references go to zero var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false); reinterpret_cast(obj)->cdata = - MaybeOwned::owned(std::move(var)); + MaybeOwned::owned(Variable(var)); // NB: incref is not necessary, because we are "stealing" the previous // ownership from the Variable to return it here for the wrap return obj; @@ -308,16 +307,14 @@ PyObject* THPVariable_Wrap(at::TensorBase var) { } if (C10_LIKELY(var.device().type() != c10::kXLA)) { - return THPVariable_NewWithVar( - (PyTypeObject*)THPVariableClass, std::move(var), status); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); } if (auto clazz = getPythonTensorClass(var.device())) { - return THPVariable_NewWithVar((PyTypeObject*)clazz, std::move(var), status); + return THPVariable_NewWithVar((PyTypeObject*)clazz, var, status); } - return THPVariable_NewWithVar( - (PyTypeObject*)THPVariableClass, std::move(var), status); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); } bool isResurrectable(THPVariable* self) { @@ -619,7 +616,7 @@ static PyObject* view_func_impl( } } } - return THPVariable_Wrap(std::move(out)); + return THPVariable_Wrap(out); END_HANDLE_TH_ERRORS } @@ -655,7 +652,7 @@ static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) { TORCH_CHECK(view_info.has_view_fn(), "No _rev_view_func() found"); out = view_info.rev_view_fn()(new_view); } - return THPVariable_Wrap(std::move(out)); + return THPVariable_Wrap(out); END_HANDLE_TH_ERRORS } @@ -1898,7 +1895,7 @@ PyObject* THPVariable_pynew( // these to be passed on directly. return THPVariable_NewWithVar( type, - std::move(tensor), + tensor, c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED, /*allow_preexisting_pyobj=*/true); END_HANDLE_TH_ERRORS @@ -2012,7 +2009,7 @@ void THPVariable_subclass_dealloc(PyObject* self) { // It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED. static PyObject* THPVariable_NewWithVar( PyTypeObject* type, - Variable _var, + const at::TensorBase& _var, c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj) { // Make sure that the reinterpret into a THPVariable* will be valid @@ -2082,7 +2079,7 @@ static PyObject* THPVariable_NewWithVar( " which is not a subclass of the " "requested type"); // We may (in fact, we typically will) need to resurrect this - return THPVariable_Wrap(std::move(_var)); + return THPVariable_Wrap(_var); } PyObject* obj = type->tp_alloc(type, 0); @@ -2092,7 +2089,7 @@ static PyObject* THPVariable_NewWithVar( new (&v->cdata) MaybeOwned(); if (c10::impl::HermeticPyObjectTLS::get_state()) { // Do NOT initialize pyobj field on the tensor, you own the C++ - v->cdata = MaybeOwned::owned(std::move(_var)); + v->cdata = MaybeOwned::owned(Variable(_var)); TORCH_INTERNAL_ASSERT( !check_has_torch_dispatch(obj), "While HermeticPyObject was enabled, we attempted to create a tensor " @@ -2104,7 +2101,7 @@ static PyObject* THPVariable_NewWithVar( "Python op registration."); } else { // Normal codepath - v->cdata = MaybeOwned::owned(std::move(_var)); + v->cdata = MaybeOwned::owned(Variable(_var)); const auto& var = THPVariable_Unpack(v); var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( getPyInterpreter(), obj, status); diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index 51ade77f03ec..32cc5c930ca0 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -37,7 +37,7 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass; TORCH_PYTHON_API extern PyObject* ParameterClass; bool THPVariable_initModule(PyObject* module); -TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var); +TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var); inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { // Check that a python object is a `Tensor`, but not a `Tensor` subclass.