Fix object slice (#138880)

To avoid casting Tensor to Tensorbase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138880
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2024-10-26 00:13:19 +00:00
committed by PyTorch MergeBot
parent 939fc4e335
commit 1605d4aeb8
4 changed files with 17 additions and 20 deletions

View File

@ -867,7 +867,7 @@ mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry
}
AT_ASSERT(last == 0 || last == -1);
if (!seen_dims) {
return mpy::object::steal(THPVariable_Wrap(std::move(tensor)));
return mpy::object::steal(THPVariable_Wrap(tensor));
}
mpy::obj<Tensor> self = Tensor::create();

View File

@ -26,7 +26,7 @@ struct THPVariable {
TORCH_PYTHON_API extern PyObject *THPVariableClass;
TORCH_PYTHON_API extern PyObject *ParameterClass;
TORCH_PYTHON_API PyObject * THPVariable_Wrap(at::TensorBase var);
TORCH_PYTHON_API PyObject * THPVariable_Wrap(const at::TensorBase& var);
inline bool THPVariable_Check(PyObject *obj)
{

View File

@ -207,7 +207,7 @@ PyObject* ParameterClass = nullptr;
static PyObject* THPVariable_NewWithVar(
PyTypeObject* type,
Variable _var,
const at::TensorBase& _var,
c10::impl::PyInterpreterStatus status,
bool allow_preexisting_pyobj = false);
@ -254,8 +254,7 @@ void activateGPUTrace() {
c10::impl::GPUTrace::set_trace(getPyInterpreter());
}
// TODO: Make this take Variable by const reference
PyObject* THPVariable_Wrap(at::TensorBase var) {
PyObject* THPVariable_Wrap(const at::TensorBase& var) {
if (!var.defined()) {
Py_RETURN_NONE;
}
@ -263,7 +262,7 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
if (c10::impl::HermeticPyObjectTLS::get_state()) {
return THPVariable_NewWithVar(
(PyTypeObject*)THPVariableClass,
std::move(var),
var,
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
}
@ -282,7 +281,7 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
// object if all C++ references go to zero
var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false);
reinterpret_cast<THPVariable*>(obj)->cdata =
MaybeOwned<Variable>::owned(std::move(var));
MaybeOwned<Variable>::owned(Variable(var));
// NB: incref is not necessary, because we are "stealing" the previous
// ownership from the Variable to return it here for the wrap
return obj;
@ -308,16 +307,14 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
}
if (C10_LIKELY(var.device().type() != c10::kXLA)) {
return THPVariable_NewWithVar(
(PyTypeObject*)THPVariableClass, std::move(var), status);
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status);
}
if (auto clazz = getPythonTensorClass(var.device())) {
return THPVariable_NewWithVar((PyTypeObject*)clazz, std::move(var), status);
return THPVariable_NewWithVar((PyTypeObject*)clazz, var, status);
}
return THPVariable_NewWithVar(
(PyTypeObject*)THPVariableClass, std::move(var), status);
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status);
}
bool isResurrectable(THPVariable* self) {
@ -619,7 +616,7 @@ static PyObject* view_func_impl(
}
}
}
return THPVariable_Wrap(std::move(out));
return THPVariable_Wrap(out);
END_HANDLE_TH_ERRORS
}
@ -655,7 +652,7 @@ static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) {
TORCH_CHECK(view_info.has_view_fn(), "No _rev_view_func() found");
out = view_info.rev_view_fn()(new_view);
}
return THPVariable_Wrap(std::move(out));
return THPVariable_Wrap(out);
END_HANDLE_TH_ERRORS
}
@ -1898,7 +1895,7 @@ PyObject* THPVariable_pynew(
// these to be passed on directly.
return THPVariable_NewWithVar(
type,
std::move(tensor),
tensor,
c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED,
/*allow_preexisting_pyobj=*/true);
END_HANDLE_TH_ERRORS
@ -2012,7 +2009,7 @@ void THPVariable_subclass_dealloc(PyObject* self) {
// It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED.
static PyObject* THPVariable_NewWithVar(
PyTypeObject* type,
Variable _var,
const at::TensorBase& _var,
c10::impl::PyInterpreterStatus status,
bool allow_preexisting_pyobj) {
// Make sure that the reinterpret into a THPVariable* will be valid
@ -2082,7 +2079,7 @@ static PyObject* THPVariable_NewWithVar(
" which is not a subclass of the "
"requested type");
// We may (in fact, we typically will) need to resurrect this
return THPVariable_Wrap(std::move(_var));
return THPVariable_Wrap(_var);
}
PyObject* obj = type->tp_alloc(type, 0);
@ -2092,7 +2089,7 @@ static PyObject* THPVariable_NewWithVar(
new (&v->cdata) MaybeOwned<Variable>();
if (c10::impl::HermeticPyObjectTLS::get_state()) {
// Do NOT initialize pyobj field on the tensor, you own the C++
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
TORCH_INTERNAL_ASSERT(
!check_has_torch_dispatch(obj),
"While HermeticPyObject was enabled, we attempted to create a tensor "
@ -2104,7 +2101,7 @@ static PyObject* THPVariable_NewWithVar(
"Python op registration.");
} else {
// Normal codepath
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
const auto& var = THPVariable_Unpack(v);
var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
getPyInterpreter(), obj, status);

View File

@ -37,7 +37,7 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass;
TORCH_PYTHON_API extern PyObject* ParameterClass;
bool THPVariable_initModule(PyObject* module);
TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var);
TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var);
inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
// Check that a python object is a `Tensor`, but not a `Tensor` subclass.