Fix refcount handling for dtype, layout and memory format (#125271)

Finish fixing https://github.com/pytorch/pytorch/issues/124868
re-use our wrap() utils as much as possible and NewRef in other places.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125271
Approved by: https://github.com/colesbury
This commit is contained in:
albanD
2024-05-02 02:34:30 +00:00
committed by PyTorch MergeBot
parent 4731130ea8
commit b119e1bcc2
10 changed files with 28 additions and 24 deletions

View File

@ -60,14 +60,14 @@ static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObje
PyTuple_SET_ITEM(tuple.get(), 0, Py_None);
}
if (scalarType) {
PyTuple_SET_ITEM(tuple.get(), 1, torch::autograd::utils::wrap(torch::getTHPDtype(*scalarType)));
PyTuple_SET_ITEM(tuple.get(), 1, Py_NewRef(torch::getTHPDtype(*scalarType)));
} else {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(tuple.get(), 1, Py_None);
}
PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
if (opt_memory_format.has_value()) {
PyTuple_SET_ITEM(tuple.get(), 3, torch::utils::getTHPMemoryFormat(opt_memory_format.value()));
PyTuple_SET_ITEM(tuple.get(), 3, Py_NewRef(torch::utils::getTHPMemoryFormat(opt_memory_format.value())));
} else {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(tuple.get(), 3, Py_None);

View File

@ -31,6 +31,7 @@ std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
PyObject* obj);
bool isStorage(PyObject* obj);
// Both methods below return a borrowed reference!
TORCH_PYTHON_API THPDtype* getTHPDtype(at::ScalarType scalarType);
THPLayout* getTHPLayout(at::Layout layout);
} // namespace torch

View File

@ -536,7 +536,7 @@ static PyObject* get_autocast_dtype(
auto r = parser.parse(args, kwargs, parsed_args);
auto device_type = at::Device(r.string(0)).type();
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(device_type);
return Py_NewRef(torch::getTHPDtype(current_dtype));
return utils::wrap(current_dtype);
END_HANDLE_TH_ERRORS
}
@ -733,7 +733,7 @@ static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
TORCH_WARN_DEPRECATION(
"torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead.")
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCUDA);
return Py_NewRef(torch::getTHPDtype(current_dtype));
return utils::wrap(current_dtype);
END_HANDLE_TH_ERRORS
}
@ -742,7 +742,7 @@ static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
TORCH_WARN_DEPRECATION(
"torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead.")
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCPU);
return Py_NewRef(torch::getTHPDtype(current_dtype));
return utils::wrap(current_dtype);
END_HANDLE_TH_ERRORS
}
@ -751,7 +751,7 @@ static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
TORCH_WARN_DEPRECATION(
"torch.get_autocast_ipu_dtype() is deprecated. Please use torch.get_autocast_dtype('ipu') instead.")
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kIPU);
return Py_NewRef(torch::getTHPDtype(current_dtype));
return utils::wrap(current_dtype);
END_HANDLE_TH_ERRORS
}
@ -760,7 +760,7 @@ static PyObject* get_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
TORCH_WARN_DEPRECATION(
"torch.get_autocast_xla_dtype() is deprecated. Please use torch.get_autocast_dtype('xla') instead.")
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kXLA);
return Py_NewRef(torch::getTHPDtype(current_dtype));
return utils::wrap(current_dtype);
END_HANDLE_TH_ERRORS
}

View File

@ -1527,7 +1527,7 @@ static PyObject* THPVariable_dtype(THPVariable* self, void* unused) {
return handle_torch_function_getter(self, "dtype");
}
auto& self_ = THPVariable_Unpack(self);
return torch::autograd::utils::wrap(torch::getTHPDtype(self_.scalar_type()));
return torch::autograd::utils::wrap(self_.scalar_type());
END_HANDLE_TH_ERRORS
}
@ -1537,7 +1537,7 @@ static PyObject* THPVariable_layout(THPVariable* self, void* unused) {
return handle_torch_function_getter(self, "layout");
}
auto& self_ = THPVariable_Unpack(self);
return torch::autograd::utils::wrap(torch::getTHPLayout(self_.layout()));
return torch::autograd::utils::wrap(self_.layout());
END_HANDLE_TH_ERRORS
}

View File

@ -53,21 +53,19 @@ inline PyObject* wrap(void* value) {
}
inline PyObject* wrap(THPDtype* dtype) {
Py_INCREF(dtype);
return (PyObject*)dtype;
return Py_NewRef(dtype);
}
inline PyObject* wrap(at::ScalarType scalarType) {
return wrap(getTHPDtype(scalarType));
return Py_NewRef(getTHPDtype(scalarType));
}
inline PyObject* wrap(THPLayout* layout) {
Py_INCREF(layout);
return (PyObject*)layout;
return Py_NewRef(layout);
}
inline PyObject* wrap(at::Layout layout) {
return wrap(getTHPLayout(layout));
return Py_NewRef(getTHPLayout(layout));
}
inline PyObject* wrap(at::Tensor tensor) {

View File

@ -440,8 +440,7 @@ void initPythonBindings(PyObject* module) {
"dtype",
[](const TensorMetadata& metadata) {
return py::reinterpret_borrow<py::object>(
torch::autograd::utils::wrap(
torch::getTHPDtype(metadata.dtype_)));
torch::autograd::utils::wrap(metadata.dtype_));
})
.def_readonly("dim", &TensorMetadata::dim_)
.def_readonly("sizes", &TensorMetadata::sizes_)

View File

@ -242,8 +242,9 @@ static void set_type(
// This field is lazily initialized from backend and scalar_type
type_obj.backend = static_cast<int>(backend);
type_obj.scalar_type = static_cast<int>(scalarType);
type_obj.layout = torch::getTHPLayout(layout_from_backend(backend));
type_obj.dtype = torch::getTHPDtype(scalarType);
type_obj.layout =
(THPLayout*)Py_NewRef(torch::getTHPLayout(layout_from_backend(backend)));
type_obj.dtype = (THPDtype*)Py_NewRef(torch::getTHPDtype(scalarType));
type_obj.is_cuda =
(backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA);
type_obj.is_xpu =

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/jit_type_base.h>
@ -155,7 +156,7 @@ struct type_caster<at::MemoryFormat> {
at::MemoryFormat src,
return_value_policy /* policy */,
handle /* parent */) {
return handle(torch::utils::getTHPMemoryFormat(src));
return handle(Py_NewRef(torch::utils::getTHPMemoryFormat(src)));
}
};

View File

@ -18,10 +18,12 @@ std::array<PyObject*, static_cast<int>(at::MemoryFormat::NumOptions)>
} // anonymous namespace
PyObject* getTHPMemoryFormat(at::MemoryFormat memory_format) {
return py::reinterpret_borrow<py::object>(
memory_format_registry[static_cast<size_t>(memory_format)])
.release()
.ptr();
auto py_memory_format =
memory_format_registry[static_cast<int>(memory_format)];
if (!py_memory_format) {
throw std::invalid_argument("unsupported memory_format");
}
return py_memory_format;
}
void initializeMemoryFormats() {

View File

@ -7,6 +7,8 @@
namespace torch::utils {
void initializeMemoryFormats();
// This methods returns a borrowed reference!
TORCH_PYTHON_API PyObject* getTHPMemoryFormat(c10::MemoryFormat);
} // namespace torch::utils