Fix optional dtype/layout/memory_format pycall; fix memory format

Double-header bug fix:

- As reported by jansel, dtypes are still showing up as integers
  when the schema is an optional dtype.  This is simple enough to
  fix and I added a test for it.  But while I was at it...

- I noticed that the THPMemoryFormat_new idiom with "unused" name
  doesn't actually work, the repr of the returned memory format
  object is wrong and this shows up when we try to log the args/kwargs.
  So I fixed memory format to do it properly along with everything
  else.

Fixes https://github.com/pytorch/pytorch/issues/77135

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77543

Approved by: https://github.com/albanD, https://github.com/jansel
This commit is contained in:
Edward Z. Yang
2022-05-16 09:38:52 -07:00
committed by PyTorch MergeBot
parent 14e59edd02
commit b5bc954a71
9 changed files with 78 additions and 25 deletions

View File

@ -30,6 +30,7 @@
#include <torch/csrc/utils/tensor_new.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/tensor_memoryformats.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/library.h>
@ -101,16 +102,24 @@ std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(const c10::OperatorHa
auto schemaAwareToPyObject = [&](int64_t idx) -> py::object {
const auto& arg = schema.arguments()[idx];
if (arg.real_type()->kind() == c10::ScalarTypeType::Kind) {
auto match = [&](c10::TypeKind kind) {
const auto& t = arg.real_type();
if (t->kind() == kind) return true;
if (auto opt_t = t->cast<c10::OptionalType>()) {
if (opt_t->getElementType()->kind() == kind) return true;
}
return false;
};
if (arguments[idx].isNone()) {
return py::none();
} else if (match(c10::ScalarTypeType::Kind)) {
auto* obj = getTHPDtype(static_cast<c10::ScalarType>(arguments[idx].toInt()));
return py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject*>(obj));
} else if (arg.real_type()->kind() == c10::LayoutType::Kind) {
} else if (match(c10::LayoutType::Kind)) {
auto* obj = getTHPLayout(static_cast<c10::Layout>(arguments[idx].toInt()));
return py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject*>(obj));
} else if (arg.real_type()->kind() == c10::MemoryFormatType::Kind) {
// TODO: https://github.com/pytorch/pytorch/issues/77135
auto* obj = THPMemoryFormat_New(static_cast<c10::MemoryFormat>(arguments[idx].toInt()), "unused");
return py::reinterpret_steal<py::object>(reinterpret_cast<PyObject*>(obj));
} else if (match(c10::MemoryFormatType::Kind)) {
return torch::utils::getTHPMemoryFormat(static_cast<c10::MemoryFormat>(arguments[idx].toInt()));
} else {
return torch::jit::toPyObject(arguments[idx]);
}