mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix optional dtype/layout/memory_format pycall; fix memory format
Double-header bug fix: - As reported by jansel, dtypes are still showing up as integers when the schema is an optional dtype. This is simple enough to fix and I added a test for it. But while I was at it... - I noticed that the THPMemoryFormat_new idiom with "unused" name doesn't actually work, the repr of the returned memory format object is wrong and this shows up when we try to log the args/kwargs. So I fixed memory format to do it properly along with everything else. Fixes https://github.com/pytorch/pytorch/issues/77135 Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/77543 Approved by: https://github.com/albanD, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
14e59edd02
commit
b5bc954a71
@ -30,6 +30,7 @@
|
||||
#include <torch/csrc/utils/tensor_new.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/tensor_memoryformats.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
@ -101,16 +102,24 @@ std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(const c10::OperatorHa
|
||||
|
||||
auto schemaAwareToPyObject = [&](int64_t idx) -> py::object {
|
||||
const auto& arg = schema.arguments()[idx];
|
||||
if (arg.real_type()->kind() == c10::ScalarTypeType::Kind) {
|
||||
auto match = [&](c10::TypeKind kind) {
|
||||
const auto& t = arg.real_type();
|
||||
if (t->kind() == kind) return true;
|
||||
if (auto opt_t = t->cast<c10::OptionalType>()) {
|
||||
if (opt_t->getElementType()->kind() == kind) return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
if (arguments[idx].isNone()) {
|
||||
return py::none();
|
||||
} else if (match(c10::ScalarTypeType::Kind)) {
|
||||
auto* obj = getTHPDtype(static_cast<c10::ScalarType>(arguments[idx].toInt()));
|
||||
return py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject*>(obj));
|
||||
} else if (arg.real_type()->kind() == c10::LayoutType::Kind) {
|
||||
} else if (match(c10::LayoutType::Kind)) {
|
||||
auto* obj = getTHPLayout(static_cast<c10::Layout>(arguments[idx].toInt()));
|
||||
return py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject*>(obj));
|
||||
} else if (arg.real_type()->kind() == c10::MemoryFormatType::Kind) {
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/77135
|
||||
auto* obj = THPMemoryFormat_New(static_cast<c10::MemoryFormat>(arguments[idx].toInt()), "unused");
|
||||
return py::reinterpret_steal<py::object>(reinterpret_cast<PyObject*>(obj));
|
||||
} else if (match(c10::MemoryFormatType::Kind)) {
|
||||
return torch::utils::getTHPMemoryFormat(static_cast<c10::MemoryFormat>(arguments[idx].toInt()));
|
||||
} else {
|
||||
return torch::jit::toPyObject(arguments[idx]);
|
||||
}
|
||||
|
Reference in New Issue
Block a user