Move Python argument related functions to cpp file (#84919)

No changes to contents, just moving things out of header.
I only moved the stuff I suspected I'd be editing; maybe more
things from this header could migrate out.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84919
Approved by: https://github.com/suo
This commit is contained in:
Edward Z. Yang
2022-09-12 22:02:31 -07:00
committed by PyTorch MergeBot
parent 99cfaf9eee
commit 7a9ab5c232
2 changed files with 292 additions and 272 deletions

View File

@ -455,5 +455,290 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
"toIValue() cannot handle converting to type: ", type->repr_str()));
}
py::object toPyObject(IValue ivalue) {
if (ivalue.isNone()) {
return py::none();
} else if (ivalue.isTensor()) {
auto tensor = std::move(ivalue).toTensor();
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
TORCH_INTERNAL_ASSERT(tensor.device().is_cpu());
auto py_tensor = py::cast(tensor);
if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) {
return py_tensor.attr("_wrapped_number");
}
auto scalar_type = tensor.scalar_type();
switch (scalar_type) {
case at::ScalarType::Bool:
return py::cast(*tensor.data_ptr<bool>());
case at::ScalarType::Long:
return py::cast(*tensor.data_ptr<int64_t>());
case at::ScalarType::Double:
return py::cast(*tensor.data_ptr<double>());
case at::ScalarType::ComplexDouble:
// TODO: https://github.com/pytorch/pytorch/issues/77134
return py::cast(static_cast<std::complex<double>>(
*tensor.data_ptr<c10::complex<double>>()));
default:
TORCH_CHECK(
false,
"Missing cases in 'toPyObject' wrapped number handling! Can't convert ",
scalar_type,
" to a Python object");
}
} else {
guardAgainstNamedTensor<at::Tensor>(tensor);
return py::cast(autograd::Variable(std::move(tensor)));
}
} else if (ivalue.isStorage()) {
return py::cast(ivalue.toStorage());
} else if (ivalue.isGenerator()) {
return py::cast(ivalue.toGenerator());
} else if (ivalue.isDouble()) {
return py::cast(std::move(ivalue).toDouble());
} else if (ivalue.isComplexDouble()) {
return py::cast(
static_cast<std::complex<double>>(std::move(ivalue).toComplexDouble()));
} else if (ivalue.isInt()) {
return py::cast(std::move(ivalue).toInt());
} else if (ivalue.isBool()) {
return py::cast(std::move(ivalue).toBool());
} else if (ivalue.isString()) {
return py::cast(std::move(ivalue).toStringRef());
} else if (ivalue.isList()) {
auto list = std::move(ivalue).toList();
py::list t{list.size()};
for (const auto i : c10::irange(list.size())) {
t[i] = toPyObject(IValue{list.get(i)});
}
return std::move(t);
} else if (ivalue.isTuple()) {
auto tuple = std::move(ivalue).toTuple();
const auto& elements = tuple->elements();
py::tuple t{elements.size()};
for (const auto i : c10::irange(elements.size())) {
t[i] = toPyObject(IValue{elements.at(i)});
}
// If we have a NamedTuple
if (tuple->type() && tuple->type()->schema() &&
tuple->type()->schema()->name() != "") {
auto unqualName = tuple->type()->name()->name();
const std::vector<Argument>& tuple_args =
tuple->type()->schema()->arguments();
std::vector<pybind11::object> defaults;
auto it = std::find_if(
tuple_args.begin(), tuple_args.end(), [](const Argument& arg) {
return arg.default_value().has_value();
});
std::transform(
it,
tuple_args.end(),
std::back_inserter(defaults),
[](const Argument& arg) { return toPyObject(*arg.default_value()); });
std::vector<std::string> fieldNames =
fmap(tuple_args, [](const Argument& arg) { return arg.name(); });
return py::module::import("torch._jit_internal")
.attr("_create_named_tuple")(
t, unqualName, fieldNames, py::make_tuple(defaults));
} else {
return std::move(t);
}
} else if (ivalue.isDevice()) {
return py::cast<py::object>(THPDevice_New(std::move(ivalue).toDevice()));
} else if (ivalue.isGenericDict()) {
auto dict = std::move(ivalue).toGenericDict();
py::dict py_dict;
for (auto& pair : dict) {
py_dict[toPyObject(IValue{pair.key()})] =
toPyObject(IValue{pair.value()});
}
return std::move(py_dict);
} else if (ivalue.isRRef()) {
#ifdef USE_RPC
auto RRefPtr =
c10::dynamic_intrusive_pointer_cast<torch::distributed::rpc::RRef>(
std::move(ivalue).toRRef());
return py::cast(torch::distributed::rpc::PyRRef(RRefPtr));
#else
AT_ERROR("RRef is only supported with the distributed package");
#endif
} else if (ivalue.isObject()) {
const auto obj = std::move(ivalue).toObject();
if (obj->type()->is_module()) {
return py::cast(Module(obj));
}
auto pyCu = get_python_cu();
if (obj->name().find("__torch__.torch.classes") == 0) {
return py::cast(Object(obj));
}
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
AT_ASSERT(classType);
auto pyClass = getScriptedClassOrError(obj->type());
auto pyObj = pyClass.attr("__new__")(pyClass);
const auto numAttrs = classType->numAttributes();
for (const auto slot : c10::irange(numAttrs)) {
const auto& attrName = classType->getAttributeName(slot);
IValue v = obj->getSlot(slot);
py::setattr(pyObj, attrName.c_str(), toPyObject(std::move(v)));
}
return pyObj;
} else if (ivalue.isPyObject()) {
// return borrowed reference to ensure it correctly incref the underlying
// PyObject
return py::reinterpret_borrow<py::object>(ivalue.toPyObject());
} else if (ivalue.isCapsule()) {
return py::cast(c10::Capsule(ivalue.toCapsule()));
} else if (ivalue.isFuture()) {
return py::cast(std::make_shared<PythonFutureWrapper>(ivalue.toFuture()));
} else if (ivalue.isEnum()) {
auto enum_holder = ivalue.toEnumHolder();
auto py_class = getScriptedClassOrError(enum_holder->type());
return py_class.attr(enum_holder->name().c_str());
} else if (ivalue.isRRef()) {
#ifdef USE_RPC
return py::cast(torch::distributed::rpc::PyRRef(
c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(
ivalue.toRRef())));
#else
TORCH_CHECK(false, "RRef is only supported with the distributed package");
#endif
} else if (ivalue.isSymInt()) {
auto si = ivalue.toSymInt();
return py::cast(si);
} else {
AT_ERROR(
"Missing cases in 'toPyObject'! Can't convert ",
ivalue.tagKind(),
" to a Python object");
}
}
std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
const std::vector<std::shared_ptr<Operator>>& operations,
py::args args,
const py::kwargs& kwargs) {
Stack stack;
if (operations.size() == 1) {
std::shared_ptr<Operator> op = operations.at(0);
// Create a stack full of the arguments and keyword arguments.
stack = createStackForSchema(
op->schema(), std::move(args), kwargs, c10::nullopt);
return std::make_pair(op, stack);
} else {
std::vector<schema_match_error> errors;
std::shared_ptr<Operator> found_op = nullptr;
for (const auto& op : operations) {
try {
stack = createStackForSchema(op->schema(), args, kwargs, c10::nullopt);
found_op = op;
break;
} catch (schema_match_error& error) {
errors.push_back(std::move(error));
}
}
if (!found_op) {
std::stringstream ss;
ss << "Overloaded torch operator invoked from Python failed to many any schema:\n";
for (const auto& err : errors) {
ss << err.what() << "\n\n";
}
throw std::runtime_error(ss.str());
}
return std::make_pair(found_op, stack);
}
}
py::object invokeOperatorFromPython(
const std::vector<std::shared_ptr<Operator>>& operations,
py::args args,
const py::kwargs& kwargs,
c10::optional<c10::DispatchKey> dk) {
auto opWithStack = getOpWithStack(operations, args, kwargs);
std::shared_ptr<Operator> found_op = std::get<0>(opWithStack);
Stack stack = std::get<1>(opWithStack);
{
pybind11::gil_scoped_release no_gil_guard;
if (dk) {
found_op->getOperationForDispatchKey (*dk)(stack);
} else {
found_op->getOperation()(stack);
}
}
return createPyObjectForStack(std::move(stack));
}
py::object _get_operation_for_overload_or_packet(
const std::vector<std::shared_ptr<Operator>>& operations,
Symbol symbol,
py::args args,
const py::kwargs& kwargs,
bool is_overload,
c10::optional<c10::DispatchKey> dk) {
std::vector<py::handle> overloaded_args;
size_t total_arg_num = args.size() + kwargs.size();
for (const auto i : c10::irange(args.size())) {
is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
args[i].ptr(),
&overloaded_args,
static_cast<int>(total_arg_num),
false /* throw_error */);
}
// NB: for kwargs, we cannot guarantee the order of appending
// is the same as the argument order in operator's schema.
// This is suboptimal, but should be fine. Later when we have
// better schema matching and argument parsing, we could
// match the operator in `operations` first, then the order will
// be guaranteed.
for (auto item : kwargs) {
is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
item.second.ptr(),
&overloaded_args,
total_arg_num,
false /* throw_error */);
}
if (overloaded_args.size() > 0 ||
at::impl::PythonTorchFunctionTLS::get_mode()) {
py::object ret;
std::string ns = symbol.ns().toUnqualString();
std::string method_name = symbol.toUnqualString();
auto self_func = py::module::import("torch")
.attr("ops")
.attr(ns.c_str())
.attr(method_name.c_str());
if (is_overload) {
auto overload_name = operations[0]->schema().overload_name();
if (overload_name == "") {
self_func = self_func.attr("default");
} else {
self_func = self_func.attr(overload_name.c_str());
}
}
std::string module_name("torch.ops");
module_name.append(ns);
return pybind11::reinterpret_steal<py::object>(
handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
method_name.c_str(),
self_func.ptr(),
module_name.c_str()));
}
return invokeOperatorFromPython(operations, args, kwargs, dk);
}
} // namespace jit
} // namespace torch