mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix refcount handling for dtype, layout and memory format (#125271)
Finish fixing https://github.com/pytorch/pytorch/issues/124868 re-use our wrap() utils as much as possible and NewRef in other places. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125271 Approved by: https://github.com/colesbury
This commit is contained in:
@ -60,14 +60,14 @@ static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObje
|
||||
PyTuple_SET_ITEM(tuple.get(), 0, Py_None);
|
||||
}
|
||||
if (scalarType) {
|
||||
PyTuple_SET_ITEM(tuple.get(), 1, torch::autograd::utils::wrap(torch::getTHPDtype(*scalarType)));
|
||||
PyTuple_SET_ITEM(tuple.get(), 1, Py_NewRef(torch::getTHPDtype(*scalarType)));
|
||||
} else {
|
||||
Py_INCREF(Py_None);
|
||||
PyTuple_SET_ITEM(tuple.get(), 1, Py_None);
|
||||
}
|
||||
PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
|
||||
if (opt_memory_format.has_value()) {
|
||||
PyTuple_SET_ITEM(tuple.get(), 3, torch::utils::getTHPMemoryFormat(opt_memory_format.value()));
|
||||
PyTuple_SET_ITEM(tuple.get(), 3, Py_NewRef(torch::utils::getTHPMemoryFormat(opt_memory_format.value())));
|
||||
} else {
|
||||
Py_INCREF(Py_None);
|
||||
PyTuple_SET_ITEM(tuple.get(), 3, Py_None);
|
||||
|
||||
@ -31,6 +31,7 @@ std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
|
||||
PyObject* obj);
|
||||
bool isStorage(PyObject* obj);
|
||||
|
||||
// Both methods below return a borrowed reference!
|
||||
TORCH_PYTHON_API THPDtype* getTHPDtype(at::ScalarType scalarType);
|
||||
THPLayout* getTHPLayout(at::Layout layout);
|
||||
} // namespace torch
|
||||
|
||||
@ -536,7 +536,7 @@ static PyObject* get_autocast_dtype(
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto device_type = at::Device(r.string(0)).type();
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(device_type);
|
||||
return Py_NewRef(torch::getTHPDtype(current_dtype));
|
||||
return utils::wrap(current_dtype);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -733,7 +733,7 @@ static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead.")
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCUDA);
|
||||
return Py_NewRef(torch::getTHPDtype(current_dtype));
|
||||
return utils::wrap(current_dtype);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -742,7 +742,7 @@ static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead.")
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCPU);
|
||||
return Py_NewRef(torch::getTHPDtype(current_dtype));
|
||||
return utils::wrap(current_dtype);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -751,7 +751,7 @@ static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.get_autocast_ipu_dtype() is deprecated. Please use torch.get_autocast_dtype('ipu') instead.")
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kIPU);
|
||||
return Py_NewRef(torch::getTHPDtype(current_dtype));
|
||||
return utils::wrap(current_dtype);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -760,7 +760,7 @@ static PyObject* get_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.get_autocast_xla_dtype() is deprecated. Please use torch.get_autocast_dtype('xla') instead.")
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kXLA);
|
||||
return Py_NewRef(torch::getTHPDtype(current_dtype));
|
||||
return utils::wrap(current_dtype);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
@ -1527,7 +1527,7 @@ static PyObject* THPVariable_dtype(THPVariable* self, void* unused) {
|
||||
return handle_torch_function_getter(self, "dtype");
|
||||
}
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
return torch::autograd::utils::wrap(torch::getTHPDtype(self_.scalar_type()));
|
||||
return torch::autograd::utils::wrap(self_.scalar_type());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -1537,7 +1537,7 @@ static PyObject* THPVariable_layout(THPVariable* self, void* unused) {
|
||||
return handle_torch_function_getter(self, "layout");
|
||||
}
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
return torch::autograd::utils::wrap(torch::getTHPLayout(self_.layout()));
|
||||
return torch::autograd::utils::wrap(self_.layout());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
@ -53,21 +53,19 @@ inline PyObject* wrap(void* value) {
|
||||
}
|
||||
|
||||
inline PyObject* wrap(THPDtype* dtype) {
|
||||
Py_INCREF(dtype);
|
||||
return (PyObject*)dtype;
|
||||
return Py_NewRef(dtype);
|
||||
}
|
||||
|
||||
inline PyObject* wrap(at::ScalarType scalarType) {
|
||||
return wrap(getTHPDtype(scalarType));
|
||||
return Py_NewRef(getTHPDtype(scalarType));
|
||||
}
|
||||
|
||||
inline PyObject* wrap(THPLayout* layout) {
|
||||
Py_INCREF(layout);
|
||||
return (PyObject*)layout;
|
||||
return Py_NewRef(layout);
|
||||
}
|
||||
|
||||
inline PyObject* wrap(at::Layout layout) {
|
||||
return wrap(getTHPLayout(layout));
|
||||
return Py_NewRef(getTHPLayout(layout));
|
||||
}
|
||||
|
||||
inline PyObject* wrap(at::Tensor tensor) {
|
||||
|
||||
@ -440,8 +440,7 @@ void initPythonBindings(PyObject* module) {
|
||||
"dtype",
|
||||
[](const TensorMetadata& metadata) {
|
||||
return py::reinterpret_borrow<py::object>(
|
||||
torch::autograd::utils::wrap(
|
||||
torch::getTHPDtype(metadata.dtype_)));
|
||||
torch::autograd::utils::wrap(metadata.dtype_));
|
||||
})
|
||||
.def_readonly("dim", &TensorMetadata::dim_)
|
||||
.def_readonly("sizes", &TensorMetadata::sizes_)
|
||||
|
||||
@ -242,8 +242,9 @@ static void set_type(
|
||||
// This field is lazily initialized from backend and scalar_type
|
||||
type_obj.backend = static_cast<int>(backend);
|
||||
type_obj.scalar_type = static_cast<int>(scalarType);
|
||||
type_obj.layout = torch::getTHPLayout(layout_from_backend(backend));
|
||||
type_obj.dtype = torch::getTHPDtype(scalarType);
|
||||
type_obj.layout =
|
||||
(THPLayout*)Py_NewRef(torch::getTHPLayout(layout_from_backend(backend)));
|
||||
type_obj.dtype = (THPDtype*)Py_NewRef(torch::getTHPDtype(scalarType));
|
||||
type_obj.is_cuda =
|
||||
(backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA);
|
||||
type_obj.is_xpu =
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/jit_type_base.h>
|
||||
@ -155,7 +156,7 @@ struct type_caster<at::MemoryFormat> {
|
||||
at::MemoryFormat src,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
return handle(torch::utils::getTHPMemoryFormat(src));
|
||||
return handle(Py_NewRef(torch::utils::getTHPMemoryFormat(src)));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -18,10 +18,12 @@ std::array<PyObject*, static_cast<int>(at::MemoryFormat::NumOptions)>
|
||||
} // anonymous namespace
|
||||
|
||||
PyObject* getTHPMemoryFormat(at::MemoryFormat memory_format) {
|
||||
return py::reinterpret_borrow<py::object>(
|
||||
memory_format_registry[static_cast<size_t>(memory_format)])
|
||||
.release()
|
||||
.ptr();
|
||||
auto py_memory_format =
|
||||
memory_format_registry[static_cast<int>(memory_format)];
|
||||
if (!py_memory_format) {
|
||||
throw std::invalid_argument("unsupported memory_format");
|
||||
}
|
||||
return py_memory_format;
|
||||
}
|
||||
|
||||
void initializeMemoryFormats() {
|
||||
|
||||
@ -7,6 +7,8 @@
|
||||
namespace torch::utils {
|
||||
|
||||
void initializeMemoryFormats();
|
||||
|
||||
// This methods returns a borrowed reference!
|
||||
TORCH_PYTHON_API PyObject* getTHPMemoryFormat(c10::MemoryFormat);
|
||||
|
||||
} // namespace torch::utils
|
||||
|
||||
Reference in New Issue
Block a user