mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 15:44:58 +08:00
Move Python argument related functions to cpp file (#84919)
No changes to contents, just moving things out of header. I only moved the stuff I suspected I'd be editing; maybe more things from this header could migrate out. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84919 Approved by: https://github.com/suo
This commit is contained in:
committed by
PyTorch MergeBot
parent
99cfaf9eee
commit
7a9ab5c232
@ -455,5 +455,290 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
"toIValue() cannot handle converting to type: ", type->repr_str()));
|
||||
}
|
||||
|
||||
py::object toPyObject(IValue ivalue) {
|
||||
if (ivalue.isNone()) {
|
||||
return py::none();
|
||||
} else if (ivalue.isTensor()) {
|
||||
auto tensor = std::move(ivalue).toTensor();
|
||||
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
TORCH_INTERNAL_ASSERT(tensor.device().is_cpu());
|
||||
auto py_tensor = py::cast(tensor);
|
||||
if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) {
|
||||
return py_tensor.attr("_wrapped_number");
|
||||
}
|
||||
auto scalar_type = tensor.scalar_type();
|
||||
switch (scalar_type) {
|
||||
case at::ScalarType::Bool:
|
||||
return py::cast(*tensor.data_ptr<bool>());
|
||||
case at::ScalarType::Long:
|
||||
return py::cast(*tensor.data_ptr<int64_t>());
|
||||
case at::ScalarType::Double:
|
||||
return py::cast(*tensor.data_ptr<double>());
|
||||
case at::ScalarType::ComplexDouble:
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/77134
|
||||
return py::cast(static_cast<std::complex<double>>(
|
||||
*tensor.data_ptr<c10::complex<double>>()));
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Missing cases in 'toPyObject' wrapped number handling! Can't convert ",
|
||||
scalar_type,
|
||||
" to a Python object");
|
||||
}
|
||||
} else {
|
||||
guardAgainstNamedTensor<at::Tensor>(tensor);
|
||||
return py::cast(autograd::Variable(std::move(tensor)));
|
||||
}
|
||||
} else if (ivalue.isStorage()) {
|
||||
return py::cast(ivalue.toStorage());
|
||||
} else if (ivalue.isGenerator()) {
|
||||
return py::cast(ivalue.toGenerator());
|
||||
} else if (ivalue.isDouble()) {
|
||||
return py::cast(std::move(ivalue).toDouble());
|
||||
} else if (ivalue.isComplexDouble()) {
|
||||
return py::cast(
|
||||
static_cast<std::complex<double>>(std::move(ivalue).toComplexDouble()));
|
||||
} else if (ivalue.isInt()) {
|
||||
return py::cast(std::move(ivalue).toInt());
|
||||
} else if (ivalue.isBool()) {
|
||||
return py::cast(std::move(ivalue).toBool());
|
||||
} else if (ivalue.isString()) {
|
||||
return py::cast(std::move(ivalue).toStringRef());
|
||||
} else if (ivalue.isList()) {
|
||||
auto list = std::move(ivalue).toList();
|
||||
py::list t{list.size()};
|
||||
for (const auto i : c10::irange(list.size())) {
|
||||
t[i] = toPyObject(IValue{list.get(i)});
|
||||
}
|
||||
return std::move(t);
|
||||
} else if (ivalue.isTuple()) {
|
||||
auto tuple = std::move(ivalue).toTuple();
|
||||
const auto& elements = tuple->elements();
|
||||
|
||||
py::tuple t{elements.size()};
|
||||
for (const auto i : c10::irange(elements.size())) {
|
||||
t[i] = toPyObject(IValue{elements.at(i)});
|
||||
}
|
||||
|
||||
// If we have a NamedTuple
|
||||
if (tuple->type() && tuple->type()->schema() &&
|
||||
tuple->type()->schema()->name() != "") {
|
||||
auto unqualName = tuple->type()->name()->name();
|
||||
|
||||
const std::vector<Argument>& tuple_args =
|
||||
tuple->type()->schema()->arguments();
|
||||
|
||||
std::vector<pybind11::object> defaults;
|
||||
auto it = std::find_if(
|
||||
tuple_args.begin(), tuple_args.end(), [](const Argument& arg) {
|
||||
return arg.default_value().has_value();
|
||||
});
|
||||
std::transform(
|
||||
it,
|
||||
tuple_args.end(),
|
||||
std::back_inserter(defaults),
|
||||
[](const Argument& arg) { return toPyObject(*arg.default_value()); });
|
||||
|
||||
std::vector<std::string> fieldNames =
|
||||
fmap(tuple_args, [](const Argument& arg) { return arg.name(); });
|
||||
|
||||
return py::module::import("torch._jit_internal")
|
||||
.attr("_create_named_tuple")(
|
||||
t, unqualName, fieldNames, py::make_tuple(defaults));
|
||||
} else {
|
||||
return std::move(t);
|
||||
}
|
||||
} else if (ivalue.isDevice()) {
|
||||
return py::cast<py::object>(THPDevice_New(std::move(ivalue).toDevice()));
|
||||
} else if (ivalue.isGenericDict()) {
|
||||
auto dict = std::move(ivalue).toGenericDict();
|
||||
py::dict py_dict;
|
||||
for (auto& pair : dict) {
|
||||
py_dict[toPyObject(IValue{pair.key()})] =
|
||||
toPyObject(IValue{pair.value()});
|
||||
}
|
||||
return std::move(py_dict);
|
||||
} else if (ivalue.isRRef()) {
|
||||
#ifdef USE_RPC
|
||||
auto RRefPtr =
|
||||
c10::dynamic_intrusive_pointer_cast<torch::distributed::rpc::RRef>(
|
||||
std::move(ivalue).toRRef());
|
||||
return py::cast(torch::distributed::rpc::PyRRef(RRefPtr));
|
||||
#else
|
||||
AT_ERROR("RRef is only supported with the distributed package");
|
||||
#endif
|
||||
} else if (ivalue.isObject()) {
|
||||
const auto obj = std::move(ivalue).toObject();
|
||||
if (obj->type()->is_module()) {
|
||||
return py::cast(Module(obj));
|
||||
}
|
||||
|
||||
auto pyCu = get_python_cu();
|
||||
if (obj->name().find("__torch__.torch.classes") == 0) {
|
||||
return py::cast(Object(obj));
|
||||
}
|
||||
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
|
||||
AT_ASSERT(classType);
|
||||
auto pyClass = getScriptedClassOrError(obj->type());
|
||||
auto pyObj = pyClass.attr("__new__")(pyClass);
|
||||
|
||||
const auto numAttrs = classType->numAttributes();
|
||||
|
||||
for (const auto slot : c10::irange(numAttrs)) {
|
||||
const auto& attrName = classType->getAttributeName(slot);
|
||||
IValue v = obj->getSlot(slot);
|
||||
py::setattr(pyObj, attrName.c_str(), toPyObject(std::move(v)));
|
||||
}
|
||||
return pyObj;
|
||||
} else if (ivalue.isPyObject()) {
|
||||
// return borrowed reference to ensure it correctly incref the underlying
|
||||
// PyObject
|
||||
return py::reinterpret_borrow<py::object>(ivalue.toPyObject());
|
||||
} else if (ivalue.isCapsule()) {
|
||||
return py::cast(c10::Capsule(ivalue.toCapsule()));
|
||||
} else if (ivalue.isFuture()) {
|
||||
return py::cast(std::make_shared<PythonFutureWrapper>(ivalue.toFuture()));
|
||||
} else if (ivalue.isEnum()) {
|
||||
auto enum_holder = ivalue.toEnumHolder();
|
||||
auto py_class = getScriptedClassOrError(enum_holder->type());
|
||||
return py_class.attr(enum_holder->name().c_str());
|
||||
} else if (ivalue.isRRef()) {
|
||||
#ifdef USE_RPC
|
||||
return py::cast(torch::distributed::rpc::PyRRef(
|
||||
c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(
|
||||
ivalue.toRRef())));
|
||||
#else
|
||||
TORCH_CHECK(false, "RRef is only supported with the distributed package");
|
||||
#endif
|
||||
} else if (ivalue.isSymInt()) {
|
||||
auto si = ivalue.toSymInt();
|
||||
return py::cast(si);
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"Missing cases in 'toPyObject'! Can't convert ",
|
||||
ivalue.tagKind(),
|
||||
" to a Python object");
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
|
||||
const std::vector<std::shared_ptr<Operator>>& operations,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs) {
|
||||
Stack stack;
|
||||
if (operations.size() == 1) {
|
||||
std::shared_ptr<Operator> op = operations.at(0);
|
||||
// Create a stack full of the arguments and keyword arguments.
|
||||
stack = createStackForSchema(
|
||||
op->schema(), std::move(args), kwargs, c10::nullopt);
|
||||
|
||||
return std::make_pair(op, stack);
|
||||
} else {
|
||||
std::vector<schema_match_error> errors;
|
||||
std::shared_ptr<Operator> found_op = nullptr;
|
||||
for (const auto& op : operations) {
|
||||
try {
|
||||
stack = createStackForSchema(op->schema(), args, kwargs, c10::nullopt);
|
||||
found_op = op;
|
||||
break;
|
||||
} catch (schema_match_error& error) {
|
||||
errors.push_back(std::move(error));
|
||||
}
|
||||
}
|
||||
if (!found_op) {
|
||||
std::stringstream ss;
|
||||
ss << "Overloaded torch operator invoked from Python failed to many any schema:\n";
|
||||
for (const auto& err : errors) {
|
||||
ss << err.what() << "\n\n";
|
||||
}
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
return std::make_pair(found_op, stack);
|
||||
}
|
||||
}
|
||||
|
||||
py::object invokeOperatorFromPython(
|
||||
const std::vector<std::shared_ptr<Operator>>& operations,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
c10::optional<c10::DispatchKey> dk) {
|
||||
auto opWithStack = getOpWithStack(operations, args, kwargs);
|
||||
std::shared_ptr<Operator> found_op = std::get<0>(opWithStack);
|
||||
Stack stack = std::get<1>(opWithStack);
|
||||
{
|
||||
pybind11::gil_scoped_release no_gil_guard;
|
||||
if (dk) {
|
||||
found_op->getOperationForDispatchKey (*dk)(stack);
|
||||
} else {
|
||||
found_op->getOperation()(stack);
|
||||
}
|
||||
}
|
||||
|
||||
return createPyObjectForStack(std::move(stack));
|
||||
}
|
||||
|
||||
py::object _get_operation_for_overload_or_packet(
|
||||
const std::vector<std::shared_ptr<Operator>>& operations,
|
||||
Symbol symbol,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
bool is_overload,
|
||||
c10::optional<c10::DispatchKey> dk) {
|
||||
std::vector<py::handle> overloaded_args;
|
||||
size_t total_arg_num = args.size() + kwargs.size();
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
|
||||
is_tensor_list_and_append_overloaded(
|
||||
args[i].ptr(),
|
||||
&overloaded_args,
|
||||
static_cast<int>(total_arg_num),
|
||||
false /* throw_error */);
|
||||
}
|
||||
// NB: for kwargs, we cannot guarantee the order of appending
|
||||
// is the same as the argument order in operator's schema.
|
||||
// This is suboptimal, but should be fine. Later when we have
|
||||
// better schema matching and argument parsing, we could
|
||||
// match the operator in `operations` first, then the order will
|
||||
// be guaranteed.
|
||||
for (auto item : kwargs) {
|
||||
is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
|
||||
is_tensor_list_and_append_overloaded(
|
||||
item.second.ptr(),
|
||||
&overloaded_args,
|
||||
total_arg_num,
|
||||
false /* throw_error */);
|
||||
}
|
||||
if (overloaded_args.size() > 0 ||
|
||||
at::impl::PythonTorchFunctionTLS::get_mode()) {
|
||||
py::object ret;
|
||||
std::string ns = symbol.ns().toUnqualString();
|
||||
std::string method_name = symbol.toUnqualString();
|
||||
auto self_func = py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr(ns.c_str())
|
||||
.attr(method_name.c_str());
|
||||
if (is_overload) {
|
||||
auto overload_name = operations[0]->schema().overload_name();
|
||||
if (overload_name == "") {
|
||||
self_func = self_func.attr("default");
|
||||
} else {
|
||||
self_func = self_func.attr(overload_name.c_str());
|
||||
}
|
||||
}
|
||||
std::string module_name("torch.ops");
|
||||
module_name.append(ns);
|
||||
return pybind11::reinterpret_steal<py::object>(
|
||||
handle_torch_function_no_python_arg_parser(
|
||||
overloaded_args,
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
method_name.c_str(),
|
||||
self_func.ptr(),
|
||||
module_name.c_str()));
|
||||
}
|
||||
return invokeOperatorFromPython(operations, args, kwargs, dk);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
Reference in New Issue
Block a user