mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix object slice (#138880)
To avoid casting Tensor to Tensorbase Pull Request resolved: https://github.com/pytorch/pytorch/pull/138880 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -867,7 +867,7 @@ mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry
|
||||
}
|
||||
AT_ASSERT(last == 0 || last == -1);
|
||||
if (!seen_dims) {
|
||||
return mpy::object::steal(THPVariable_Wrap(std::move(tensor)));
|
||||
return mpy::object::steal(THPVariable_Wrap(tensor));
|
||||
}
|
||||
|
||||
mpy::obj<Tensor> self = Tensor::create();
|
||||
|
@ -26,7 +26,7 @@ struct THPVariable {
|
||||
TORCH_PYTHON_API extern PyObject *THPVariableClass;
|
||||
TORCH_PYTHON_API extern PyObject *ParameterClass;
|
||||
|
||||
TORCH_PYTHON_API PyObject * THPVariable_Wrap(at::TensorBase var);
|
||||
TORCH_PYTHON_API PyObject * THPVariable_Wrap(const at::TensorBase& var);
|
||||
|
||||
inline bool THPVariable_Check(PyObject *obj)
|
||||
{
|
||||
|
@ -207,7 +207,7 @@ PyObject* ParameterClass = nullptr;
|
||||
|
||||
static PyObject* THPVariable_NewWithVar(
|
||||
PyTypeObject* type,
|
||||
Variable _var,
|
||||
const at::TensorBase& _var,
|
||||
c10::impl::PyInterpreterStatus status,
|
||||
bool allow_preexisting_pyobj = false);
|
||||
|
||||
@ -254,8 +254,7 @@ void activateGPUTrace() {
|
||||
c10::impl::GPUTrace::set_trace(getPyInterpreter());
|
||||
}
|
||||
|
||||
// TODO: Make this take Variable by const reference
|
||||
PyObject* THPVariable_Wrap(at::TensorBase var) {
|
||||
PyObject* THPVariable_Wrap(const at::TensorBase& var) {
|
||||
if (!var.defined()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
@ -263,7 +262,7 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)THPVariableClass,
|
||||
std::move(var),
|
||||
var,
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
}
|
||||
|
||||
@ -282,7 +281,7 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
|
||||
// object if all C++ references go to zero
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false);
|
||||
reinterpret_cast<THPVariable*>(obj)->cdata =
|
||||
MaybeOwned<Variable>::owned(std::move(var));
|
||||
MaybeOwned<Variable>::owned(Variable(var));
|
||||
// NB: incref is not necessary, because we are "stealing" the previous
|
||||
// ownership from the Variable to return it here for the wrap
|
||||
return obj;
|
||||
@ -308,16 +307,14 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
|
||||
}
|
||||
|
||||
if (C10_LIKELY(var.device().type() != c10::kXLA)) {
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)THPVariableClass, std::move(var), status);
|
||||
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status);
|
||||
}
|
||||
|
||||
if (auto clazz = getPythonTensorClass(var.device())) {
|
||||
return THPVariable_NewWithVar((PyTypeObject*)clazz, std::move(var), status);
|
||||
return THPVariable_NewWithVar((PyTypeObject*)clazz, var, status);
|
||||
}
|
||||
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)THPVariableClass, std::move(var), status);
|
||||
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status);
|
||||
}
|
||||
|
||||
bool isResurrectable(THPVariable* self) {
|
||||
@ -619,7 +616,7 @@ static PyObject* view_func_impl(
|
||||
}
|
||||
}
|
||||
}
|
||||
return THPVariable_Wrap(std::move(out));
|
||||
return THPVariable_Wrap(out);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -655,7 +652,7 @@ static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) {
|
||||
TORCH_CHECK(view_info.has_view_fn(), "No _rev_view_func() found");
|
||||
out = view_info.rev_view_fn()(new_view);
|
||||
}
|
||||
return THPVariable_Wrap(std::move(out));
|
||||
return THPVariable_Wrap(out);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -1898,7 +1895,7 @@ PyObject* THPVariable_pynew(
|
||||
// these to be passed on directly.
|
||||
return THPVariable_NewWithVar(
|
||||
type,
|
||||
std::move(tensor),
|
||||
tensor,
|
||||
c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED,
|
||||
/*allow_preexisting_pyobj=*/true);
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -2012,7 +2009,7 @@ void THPVariable_subclass_dealloc(PyObject* self) {
|
||||
// It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED.
|
||||
static PyObject* THPVariable_NewWithVar(
|
||||
PyTypeObject* type,
|
||||
Variable _var,
|
||||
const at::TensorBase& _var,
|
||||
c10::impl::PyInterpreterStatus status,
|
||||
bool allow_preexisting_pyobj) {
|
||||
// Make sure that the reinterpret into a THPVariable* will be valid
|
||||
@ -2082,7 +2079,7 @@ static PyObject* THPVariable_NewWithVar(
|
||||
" which is not a subclass of the "
|
||||
"requested type");
|
||||
// We may (in fact, we typically will) need to resurrect this
|
||||
return THPVariable_Wrap(std::move(_var));
|
||||
return THPVariable_Wrap(_var);
|
||||
}
|
||||
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
@ -2092,7 +2089,7 @@ static PyObject* THPVariable_NewWithVar(
|
||||
new (&v->cdata) MaybeOwned<Variable>();
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
// Do NOT initialize pyobj field on the tensor, you own the C++
|
||||
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
|
||||
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!check_has_torch_dispatch(obj),
|
||||
"While HermeticPyObject was enabled, we attempted to create a tensor "
|
||||
@ -2104,7 +2101,7 @@ static PyObject* THPVariable_NewWithVar(
|
||||
"Python op registration.");
|
||||
} else {
|
||||
// Normal codepath
|
||||
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
|
||||
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
|
||||
const auto& var = THPVariable_Unpack(v);
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
|
||||
getPyInterpreter(), obj, status);
|
||||
|
@ -37,7 +37,7 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass;
|
||||
TORCH_PYTHON_API extern PyObject* ParameterClass;
|
||||
|
||||
bool THPVariable_initModule(PyObject* module);
|
||||
TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var);
|
||||
TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var);
|
||||
|
||||
inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
|
||||
// Check that a python object is a `Tensor`, but not a `Tensor` subclass.
|
||||
|
Reference in New Issue
Block a user