#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef USE_DISTRIBUTED #include #include #endif #include #include #ifdef USE_C10D_NCCL #include #include #endif #include #include #include #include #include #include #include #include // The visibility attribute is to avoid a warning about storing a field in the // struct that has a different visibility (from pybind) than the struct. #ifdef _WIN32 #define VISIBILITY_HIDDEN #else #define VISIBILITY_HIDDEN __attribute__((visibility("hidden"))) #endif namespace torch { namespace jit { void clear_registered_instances(void* ptr); TORCH_API IValue toIValue( py::handle obj, const TypePtr& type, c10::optional N = c10::nullopt); py::object toPyObject(IValue ivalue); // Wrap Python function to guard deref // NB: Need VISIBILITY_HIDDEN for silencing compiler error, // 'torch::jit::PythonFunctionGuard' declared with greater visibility than the // type of its field 'torch::jit::PythonFunctionGuard::func_' struct VISIBILITY_HIDDEN PythonFunctionGuard { explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {} ~PythonFunctionGuard() { pybind11::gil_scoped_acquire ag; func_.dec_ref(); // explicitly setting PyObject* to nullptr to prevent py::object's dtor to // decref on the PyObject again. // See Note [Destructing py::object] in python_ivalue.h func_.ptr() = nullptr; } py::function func_; }; // The PythonFutureWrapper for ivalue::Future // // NB: VISIBILITY_HIDDEN is for silencing compiling error, // "error: 'torch::jit::PythonFutureWrapper' declared with greater visibility // than the type of its field 'torch::jit::PythonFutureWrapper::unwrap_func' // [-Werror=attributes]" // // NB: inherit from enable_shared_from_this because then(py::function) needs to // get a shared_ptr from this pointer. struct VISIBILITY_HIDDEN PythonFutureWrapper : std::enable_shared_from_this { using UnwrapFunc = std::function; explicit PythonFutureWrapper( c10::intrusive_ptr fut, c10::optional unwrap_func = c10::nullopt) : fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {} explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete; PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete; bool done() { return fut->completed(); } py::object value() { // acquiring GIL as toPyObject creates new py::object // without grabbing the GIL. py::gil_scoped_acquire acquire; py::object py_obj = toPyObject(fut->value()); // unwrap_func is a general compositional function that takes in a // py::object and executes some python function. It is currently mostly used // to throw python exceptions. if (unwrap_func) { (*unwrap_func)(py_obj); } return py_obj; } py::object wait() { fut->wait(); if (jit::tracer::isTracing()) { auto graph = jit::tracer::getTracingState()->graph; Value* fut_val = jit::tracer::getValueTrace(fut); auto output = graph->insert(aten::wait, {fut_val}); jit::tracer::setValueTrace(fut->value(), output); } return value(); } // The py::function cb arg must take a std::shared_ptr // (i.e., torch._C.Future) as the only argument. If the type mismatches, an // error will be thrown when waiting for the value of this returned Future. std::shared_ptr then(py::function cb) { // We need this an additional layer of wrapper here to guard the // destruction of the py::function object. Because, the // Future owns a reference to the py::function in its callback // vector, but Future does not acquire GIL on destruction. auto pf = std::make_shared(std::move(cb)); return std::make_shared(fut->then( // Capture a copy of the ivalue::Future instead of the `this` pointer // because the PythonFutureWrapper object could have been deleted // when the callbacks are fired. For example, RPC only captures the // ivalue::Future instead of PythonFutureWrapper in JitFuture's // callback functions. Hence, if user code does not hold a reference to // this PythonFutureWrapper object, there is no guarantee that the // PythonFutureWrapper is still valid when running the callback. [pyFut(this->getPtr()), pf(std::move(pf))](c10::ivalue::Future& /* unused */) -> IValue { try { pybind11::gil_scoped_acquire ag; return toIValue(pf->func_(pyFut), PyObjectType::get()); } catch (py::error_already_set& e) { auto err = std::runtime_error(c10::str( "Got the following error when running the callback: ", e.what())); { pybind11::gil_scoped_acquire ag; // Release ownership on py::objects and also restore Python // Error Indicator. e.restore(); // Clear the Python Error Indicator as we has recorded the // exception in the response message. PyErr_Clear(); } throw err; } }, PyObjectType::get())); } void add_done_callback(py::function cb) { auto pf = std::make_shared(std::move(cb)); // NOLINTNEXTLINE(modernize-avoid-bind) fut->addCallback(std::bind( [pyFut(this->getPtr())](std::shared_ptr pf) { try { pybind11::gil_scoped_acquire ag; pf->func_(pyFut); } catch (py::error_already_set& e) { { pybind11::gil_scoped_acquire ag; // Release ownership on py::objects and also restore Python // Error Indicator. e.restore(); // Clear the Python Error Indicator as we has recorded the // exception in the response message. PyErr_Clear(); } // Log and ignore exceptions raised through the callback LOG(ERROR) << "Got the following error when running the callback: " << e.what(); } catch (const std::exception& e) { // Log and ignore exceptions raised through the callback LOG(ERROR) << "Got the following error when running the callback: " << e.what(); } }, std::move(pf))); } void markCompleted(const py::object& pyValue) { DCHECK(PyGILState_Check()); IValue value = toIValue(pyValue, PyObjectType::get()); py::gil_scoped_release release; fut->markCompleted(std::move(value)); } c10::intrusive_ptr fut; // unwrap_func works like a callback for the value returned by // PythonFutureWrapper::wait(). c10::optional unwrap_func; private: std::shared_ptr getPtr() { return shared_from_this(); } }; // error reporting: when reporting user-caused errors, these functions should // not use AT_ERROR macros, since these macros add stack trace information // that is confusing to display to the end user since it always reports // locations in libtorch code rather than user code. inline std::shared_ptr get_python_cu() { return py::module::import("torch.jit._state") .attr("_python_cu") .cast>(); } struct TypedIValue : public std::pair { using pair::pair; IValue& ivalue() { return this->first; } TypePtr& type() { return this->second; } }; inline TypedIValue toDictKeyIValue(py::handle key) { if (py::isinstance(key)) { return TypedIValue( ConstantString::create(py::cast(key)), StringType::get()); } else if (py::isinstance(key)) { return TypedIValue(py::cast(key), IntType::get()); } else if (py::isinstance(key)) { return TypedIValue(py::cast(key), FloatType::get()); } else { AT_ERROR("Dictionary inputs may only have string, int, or float keys"); } } inline c10::optional unifyOrInitializeType( const TypePtr& accum, const TypePtr& unify) { if (!accum) { return unify; } return unifyTypes(accum, unify); } using InferredType = c10::InferredType; InferredType tryToInferContainerType(py::handle input); // Try to infer the type of a Python object // The type cannot be inferred if: // input is an empty container (list, dict) // input is an list with element types that cannot be unified // input is an dict with key or value types that cannot be unified inline InferredType tryToInferType(py::handle input) { // Try tensor types if (THPVariable_Check(input.ptr())) { return InferredType(TensorType::get()); } if (input.is(py::none())) { return InferredType(NoneType::get()); } if (py::isinstance(input)) { auto fn = py::cast(input).function_; return InferredType(FunctionType::create(fn)); } // Try basic types first if (py::isinstance(input)) { return InferredType(BoolType::get()); // NOLINTNEXTLINE(bugprone-branch-clone) } else if (py::isinstance(input)) { return InferredType(IntType::get()); } else if (py::isinstance(input)) { return InferredType(FloatType::get()); } else if (PyComplex_CheckExact(input.ptr())) { return InferredType(ComplexType::get()); } else if (py::isinstance(input)) { return InferredType(StringType::get()); } else if (THPLayout_Check(input.ptr())) { return InferredType(IntType::get()); } else if (THPDevice_Check(input.ptr())) { return InferredType(DeviceObjType::get()); } else if (THPStream_Check(input.ptr())) { return InferredType(StreamObjType::get()); } else if (THPDtype_Check(input.ptr())) { return InferredType(IntType::get()); } else if (THPQScheme_Check(input.ptr())) { return InferredType(IntType::get()); } else if (THPLayout_Check(input.ptr())) { return InferredType(IntType::get()); } auto enum_type = py::module::import("enum").attr("Enum"); py::bool_ isEnumValue = py::isinstance(input, enum_type); if (py::cast(isEnumValue)) { auto enum_class = input.attr("__class__"); auto enum_type = py::cast( py::module::import("torch.jit.annotations") .attr("try_ann_to_type")(enum_class, SourceRange())); return InferredType(enum_type); } py::bool_ isClass = py::module::import("inspect").attr("isclass")(input.get_type()); if (py::cast(isClass)) { // Assume that the class is compiled already or will compile. Invalidate // this later if needed. bool class_compiled = true; // Check if the type is already compiled. py::object existing_ty = py::module::import("torch.jit._state") .attr("_get_script_class")(input.get_type()); if (existing_ty.is_none()) { // If not, try to compile it. py::bool_ can_compile = py::module::import("torch._jit_internal") .attr("can_compile_class")(input.get_type()); if (py::cast(can_compile)) { // Try to compile the class. This is wrapped in a try-catch because // compilation of class types can raise an Exception and in that case, // we want to defer to other attempts at type inference below rather // than fail compilation altogether. try { py::module::import("torch.jit._script") .attr("_recursive_compile_class")( input.get_type(), SourceRange()); } catch (...) { // Invalidate the assumption that the class compiled so that we don't // look up and return its JIT type as the type for the input. class_compiled = false; } } } // If the class compiled successfully, look up the existing JIT type by // qualified name and return it. if (class_compiled) { auto script_class = py::module::import("torch.jit._state") .attr("_get_script_class")(input.get_type()); if (!script_class.is_none()) { auto class_type = py::cast(script_class); if (class_type && !class_type->is_module()) { return InferredType(class_type); } } } } if (py::isinstance(input)) { auto object = py::cast(input); return InferredType(object.type()); #ifdef USE_RPC } else if (py::isinstance(input)) { auto rref_ivalue = input.cast().toIValue(); return InferredType(rref_ivalue.type()); #endif } if (as_module(py::cast(input))) { return InferredType("Cannot infer type of ScriptModule"); } auto module_type = py::module::import("torch.nn").attr("Module"); py::bool_ is_module = py::isinstance(input, module_type); if (py::cast(is_module)) { return InferredType("Cannot infer concrete type of torch.nn.Module"); } // Try container types return tryToInferContainerType(input); } inline InferredType tryToInferContainerType(py::handle input) { if (six::isTuple(input)) { py::tuple tuple = py::cast(input); std::vector element_types; element_types.reserve(tuple.size()); for (py::handle elem : tuple) { auto type_match = tryToInferType(elem); if (type_match.success()) { element_types.push_back(type_match.type()); } else { // Forward error message along return type_match.reason(); } } return InferredType(TupleType::create(element_types)); } else if (PyDict_Check(input.ptr())) { // Check to make sure we can generate useful input/output types auto dict = py::cast(input); size_t len = py::len(dict); if (!len) { return InferredType("Dictionary inputs must have entries"); } TypePtr key_type = nullptr; TypePtr value_type = nullptr; for (auto entry : dict) { // Try to infer the key type and unify it with the existing one auto entry_key_type_match = tryToInferType(entry.first); if (!entry_key_type_match.success()) { return entry_key_type_match.reason(); } auto unified_key = unifyOrInitializeType(key_type, entry_key_type_match.type()); if (!unified_key) { return InferredType(c10::str( "Dictionary inputs to traced functions must have consistent type. Found ", key_type->repr_str(), " and ", (entry_key_type_match.type())->repr_str())); } // Try to infer the value type and unify it with the existing one auto entry_value_type_match = tryToInferType(entry.second); if (!entry_value_type_match.success()) { return entry_value_type_match.reason(); } auto unified_value = unifyOrInitializeType(value_type, entry_value_type_match.type()); if (!unified_value) { return InferredType(c10::str( "Dictionary inputs to traced functions must have consistent type. Found ", value_type->repr_str(), " and ", (entry_value_type_match.type())->repr_str())); } key_type = *unified_key; value_type = *unified_value; } return InferredType(DictType::create(key_type, value_type)); } else if (PyList_Check(input.ptr())) { auto list = py::cast(input); size_t len = py::len(list); if (!len) { return InferredType("List trace inputs must have elements"); } TypePtr element_type = nullptr; for (auto elem : list) { auto element_type_match = tryToInferType(elem); if (!element_type_match.success()) { return InferredType(c10::str( "Could not infer type of list element: ", element_type_match.reason())); } auto unified_type = unifyOrInitializeType(element_type, element_type_match.type()); if (!unified_type) { return InferredType(c10::str( "List inputs to traced functions must have consistent element type. Found ", element_type->repr_str(), " and ", (element_type_match.type())->repr_str())); } element_type = *unified_type; } return InferredType(ListType::create(element_type)); } else { // TODO: this message is not correct anymore, since this InferredType is // used from a bunch of circumstances unrelated to tracing. We can re-use // this instead of the attribute_failure stuff in concreteType return InferredType(c10::str( "Only tensors and (possibly nested) tuples of tensors, lists, or dicts", "are supported ", "as inputs or outputs of traced functions", ", but instead got value of type ", py::str(input.get_type().attr("__name__")), ".")); } } inline bool isTraceableType(const TypePtr& type) { if (type->isSubtypeOf(*TensorType::get())) { return true; } if (auto list_type = type->cast()) { return isTraceableType(list_type->getElementType()); } if (auto tuple_type = type->cast()) { return std::all_of( tuple_type->elements().begin(), tuple_type->elements().end(), [](const TypePtr& element_type) { return isTraceableType(element_type); }); } if (auto dict_type = type->cast()) { return isTraceableType(dict_type->getValueType()); } return false; } inline IValue toTypeInferredIValue(py::handle input) { auto match = tryToInferType(input); if (!match.success()) { AT_ERROR( "Tracer cannot infer type of ", py::str(input), "\n:", match.reason()); } return toIValue(input, match.type()); } inline Stack toTraceableStack(const py::tuple& inputs) { auto info = toTypeInferredIValue(inputs); TORCH_CHECK( isTraceableType(info.type()), "Type '", info.type()->repr_str(), "' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and" " Tuples of Tensors can be traced"); return info.toTupleRef().elements().vec(); } inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) { auto elems = c10::impl::GenericList(elem_type); for (auto elem : obj) { elems.push_back(toIValue(elem, elem_type)); } return IValue(std::move(elems)); } inline IValue createGenericDict( const py::dict& obj, const TypePtr& key_type, const TypePtr& value_type) { c10::impl::GenericDict elems(key_type, value_type); elems.reserve(py::len(obj)); for (auto& entry : obj) { elems.insert( toIValue(entry.first, key_type), toIValue(entry.second, value_type)); } return IValue(std::move(elems)); } template inline void guardAgainstNamedTensor(const T& var) { TORCH_CHECK( !var.has_names(), "NYI: Named tensors are currently unsupported in TorchScript. As a " "workaround please drop names via `tensor = tensor.rename(None)`."); } // Defined in pybind_utils.cpp to break a circular dependency with // python_ivalue.h IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N); // Small wrapper around getting the type name string from Python to make // types easier to interpret, e.g. give the structural type for a NamedTuple inline std::string friendlyTypeName(py::handle obj) { if (py::isinstance(obj) && py::hasattr(obj, "_fields")) { auto field_names = py::cast>(py::getattr(obj, "_fields")); std::stringstream ss; ss << py::str(obj.get_type().attr("__name__")); ss << " (aka NamedTuple("; bool first = true; for (auto& field_name : field_names) { if (!first) { ss << ", "; } ss << field_name; first = false; } ss << "))"; return ss.str(); } else { return py::str(obj.get_type().attr("__name__")); } } // Thrown when trying to create a schema for a list of python // arguments that cannot be converted. // Can be caught by the caller to attempt to use other schema // when there is an overloaded operator. struct schema_match_error : public std::runtime_error { using std::runtime_error::runtime_error; }; inline IValue argumentToIValue( const FunctionSchema& schema, size_t argumentPosition, py::handle object) { const auto& argument = schema.arguments().at(argumentPosition); try { return toIValue(object, argument.type(), argument.N()); } catch (const py::cast_error& error) { throw schema_match_error(c10::str( schema.formatTypeMismatchMsg( argument, friendlyTypeName(object), argumentPosition, py::repr(object)), "\nCast error details: ", error.what())); } catch (const py::error_already_set& error) { throw schema_match_error(c10::str( schema.formatTypeMismatchMsg( argument, friendlyTypeName(object), argumentPosition, py::repr(object)), "\n Python error details: ", error.what())); } } inline IValue returnToIValue(const TypePtr& type, py::handle object) { try { return toIValue(object, type); } catch (const py::cast_error& error) { throw std::runtime_error(c10::str( " expected value of type ", type->str(), " for return value but instead got value of type ", py::str(object.get_type().attr("__name__")), ".", "\nValue: ", py::repr(object), "\nCast error details: ", error.what())); } } inline py::object getScriptedClassOrError(const c10::NamedTypePtr& classType) { auto py_class = py::module::import("torch.jit._state") .attr("_get_python_class")(classType->name()->qualifiedName()); if (py_class.is_none()) { std::stringstream err; err << "Unknown reference to ScriptClass "; err << classType->name()->qualifiedName(); err << ". (Did you forget to import it?)"; throw std::runtime_error(err.str()); } return py_class; } inline py::object toPyObject(IValue ivalue) { if (ivalue.isNone()) { return py::none(); } else if (ivalue.isTensor()) { auto tensor = std::move(ivalue).toTensor(); if (tensor.is_sparse()) { TORCH_WARN_ONCE( "Using sparse tensors in TorchScript is experimental. Many optimization " "pathways have not been thoroughly tested with sparse tensors. Please " "include the fact that the network is running sparse tensors in any bug " "reports submitted."); } guardAgainstNamedTensor(tensor); return py::cast(autograd::Variable(std::move(tensor))); } else if (ivalue.isStorage()) { return py::cast(ivalue.toStorage()); } else if (ivalue.isDouble()) { return py::cast(std::move(ivalue).toDouble()); } else if (ivalue.isComplexDouble()) { return py::cast( static_cast>(std::move(ivalue).toComplexDouble())); } else if (ivalue.isInt()) { return py::cast(std::move(ivalue).toInt()); } else if (ivalue.isBool()) { return py::cast(std::move(ivalue).toBool()); } else if (ivalue.isString()) { return py::cast(std::move(ivalue).toStringRef()); } else if (ivalue.isList()) { auto list = std::move(ivalue).toList(); py::list t{list.size()}; for (const auto i : c10::irange(list.size())) { t[i] = toPyObject(IValue{list.get(i)}); } return std::move(t); } else if (ivalue.isTuple()) { auto tuple = std::move(ivalue).toTuple(); const auto& elements = tuple->elements(); py::tuple t{elements.size()}; for (const auto i : c10::irange(elements.size())) { t[i] = toPyObject(IValue{elements.at(i)}); } // If we have a NamedTuple if (tuple->type() && tuple->type()->schema() && tuple->type()->schema()->name() != "") { auto unqualName = tuple->type()->name()->name(); const std::vector& tuple_args = tuple->type()->schema()->arguments(); std::vector defaults; auto it = std::find_if( tuple_args.begin(), tuple_args.end(), [](const Argument& arg) { return arg.default_value().has_value(); }); std::transform( it, tuple_args.end(), std::back_inserter(defaults), [](const Argument& arg) { return toPyObject(*arg.default_value()); }); std::vector fieldNames = fmap(tuple_args, [](const Argument& arg) { return arg.name(); }); return py::module::import("torch._jit_internal") .attr("_create_named_tuple")( t, unqualName, fieldNames, py::make_tuple(defaults)); } else { return std::move(t); } } else if (ivalue.isDevice()) { return py::cast(THPDevice_New(std::move(ivalue).toDevice())); } else if (ivalue.isGenericDict()) { auto dict = std::move(ivalue).toGenericDict(); py::dict py_dict; for (auto& pair : dict) { py_dict[toPyObject(IValue{pair.key()})] = toPyObject(IValue{pair.value()}); } return std::move(py_dict); } else if (ivalue.isRRef()) { #ifdef USE_RPC auto RRefPtr = c10::dynamic_intrusive_pointer_cast( std::move(ivalue).toRRef()); return py::cast(torch::distributed::rpc::PyRRef(RRefPtr)); #else AT_ERROR("RRef is only supported with the distributed package"); #endif } else if (ivalue.isObject()) { const auto obj = std::move(ivalue).toObject(); if (obj->type()->is_module()) { return py::cast(Module(obj)); } auto pyCu = get_python_cu(); if (obj->name().find("__torch__.torch.classes") == 0) { return py::cast(Object(obj)); } const auto classType = pyCu->get_class(c10::QualifiedName(obj->name())); AT_ASSERT(classType); auto pyClass = getScriptedClassOrError(obj->type()); auto pyObj = pyClass.attr("__new__")(pyClass); const auto numAttrs = classType->numAttributes(); for (const auto slot : c10::irange(numAttrs)) { const auto& attrName = classType->getAttributeName(slot); IValue v = obj->getSlot(slot); py::setattr(pyObj, attrName.c_str(), toPyObject(std::move(v))); } return pyObj; } else if (ivalue.isPyObject()) { // return borrowed reference to ensure it correctly incref the underlying // PyObject return py::reinterpret_borrow(ivalue.toPyObject()); } else if (ivalue.isCapsule()) { return py::cast(c10::Capsule(ivalue.toCapsule())); } else if (ivalue.isFuture()) { return py::cast(std::make_shared(ivalue.toFuture())); } else if (ivalue.isEnum()) { auto enum_holder = ivalue.toEnumHolder(); auto py_class = getScriptedClassOrError(enum_holder->type()); return py_class.attr(enum_holder->name().c_str()); } else if (ivalue.isRRef()) { #ifdef USE_RPC return py::cast(torch::distributed::rpc::PyRRef( c10::static_intrusive_pointer_cast( ivalue.toRRef()))); #else TORCH_CHECK(false, "RRef is only supported with the distributed package"); #endif } else { AT_ERROR( "Missing cases in 'toPyObject'! Can't convert ", ivalue.tagKind(), " to a Python object"); } } struct VISIBILITY_HIDDEN tuple_slice { /*implicit*/ tuple_slice(py::tuple tup_) : tup(std::move(tup_)), b(0), e(tup.size()) {} tuple_slice(py::tuple tup_, int64_t b_) : tup(std::move(tup_)), b(b_), e(tup.size()) {} tuple_slice(py::tuple tup_, int64_t b_, int64_t e_) : tup(std::move(tup_)), b(b_), e(e_) {} py::detail::tuple_iterator begin() const { return {tup, static_cast(b)}; } py::detail::tuple_iterator end() const { return {tup, static_cast(e)}; } size_t size() const { return e - b; } py::detail::tuple_accessor operator[](size_t index) const { return {tup, static_cast(b + index)}; } private: py::tuple tup; int64_t b; int64_t e; }; inline Stack createStackForSchema( const FunctionSchema& schema, const tuple_slice& args, const py::kwargs& kwargs, c10::optional self) { size_t all_arguments = (self ? 1 : 0) + args.size() + kwargs.size(); if (all_arguments > schema.arguments().size()) { throw schema_match_error(c10::str( schema.name(), "() expected at most ", schema.arguments().size(), " argument(s) but received ", all_arguments, " argument(s). Declaration: ", schema)); } Stack stack; stack.reserve(schema.arguments().size()); int64_t arg_idx = 0; if (self) { push(stack, std::move(*self)); arg_idx++; } // First push all positional args. for (const auto& arg : args) { // ...but refuse to do it if the schema says that this was supposed // to be keyword only if (schema.arguments()[arg_idx].kwarg_only()) { throw schema_match_error(c10::str( schema.name(), "() takes ", arg_idx, " positional argument(s) but ", self ? 1 + args.size() : args.size(), " was/were given. Declaration: ", schema)); } // Use the type information from the schema to convert the PyObject. push(stack, argumentToIValue(schema, stack.size(), arg)); arg_idx++; } // Now for every remaining non-positional argument in the schema, look for it // in the kwargs dict and push it if found, or use its default value if it // has one. size_t consumed_kwargs = 0; for (size_t i = stack.size(); i < schema.arguments().size(); ++i) { const auto& arg = schema.arguments()[i]; if (kwargs.contains(arg.name().c_str())) { push(stack, argumentToIValue(schema, i, kwargs[arg.name().c_str()])); consumed_kwargs += 1; } else if (arg.default_value()) { push(stack, *arg.default_value()); } else { throw schema_match_error(c10::str( schema.name(), "() is missing value for argument '", arg.name(), "'. Declaration: ", schema)); } } if (consumed_kwargs != kwargs.size()) { std::vector names; for (const auto& kwarg : kwargs) { names.emplace_back(py::cast(kwarg.first)); } throw schema_match_error(schema.findErrorInKwargs(names)); } return stack; } inline py::object createPyObjectForStack(Stack&& stack) { if (stack.empty()) { return py::none(); } // Return a simple value and not a single-element tuple if there is only one // return value. if (stack.size() == 1) { return toPyObject(std::move(stack[0])); } // If there is more than one return value, pop them into a py::tuple. py::tuple return_values(stack.size()); for (const auto ret : c10::irange(return_values.size())) { return_values[ret] = toPyObject(std::move(stack[ret])); } return std::move(return_values); } // TODO: Remove once we clean up the GraphExecutor usage. inline Stack evilDeprecatedBadCreateStackDoNotUse( const py::tuple& tuple, at::ArrayRef inputs, size_t reserve_extra_space = 0) { if (tuple.size() != inputs.size()) { AT_ERROR( "expected " + std::to_string(inputs.size()) + " inputs, but got " + std::to_string(tuple.size())); } Stack result; result.reserve(tuple.size() + reserve_extra_space); for (const auto i : c10::irange(inputs.size())) { result.push_back(toIValue(std::move(tuple[i]), inputs[i]->type())); } return result; } // Run `callee`, potentially inserting a CallFunction/CallMethod node into the // tracing graph. inline py::object runAndInsertCall( Function& callee, const tuple_slice& args, const py::kwargs& kwargs, c10::optional self, // Lambda that tells this function how to insert `callee` into the graph if // we're tracing. const std::function& callInserter) { auto stack = createStackForSchema(callee.getSchema(), args, kwargs, std::move(self)); const auto& tracing_state = tracer::getTracingState(); if (!tracing_state) { pybind11::gil_scoped_release no_gil_guard; // If we're not tracing, just run the callee as normal. callee.run(stack); } else { // If we are tracing, insert the appropriate CallFunction or CallMethod node // and then run the callee with tracing disabled. // Get the graph `Value`s that represent the input IValues auto inputs = last(stack, callee.num_inputs()); auto input_values = fmap(inputs, [](const IValue& v) { return tracer::getValueTrace(v); }); TORCH_INTERNAL_ASSERT(callee.getSchema().returns().size() == 1) auto return_type = callee.getSchema().returns().at(0).type(); auto graph = tracing_state->graph; std::vector named_values; named_values.reserve(input_values.size()); for (Value* v : input_values) { named_values.emplace_back(v); } // Add a call node. MatchedSchema match = matchSchema( callee.getSchema(), tracer::getPythonInterpreterSourceRange(), *graph, named_values, {}); auto output_value = callInserter(*graph, match); // Actually run the callee. Pause the tracer so that we don't double-add the // callee nodes. { pybind11::gil_scoped_release no_gil_guard; ResourceGuard guard(tracer::pauseTracing()); callee.run(stack); } // Associate the output IValues with the output `Value`s in the graph tracer::setValueTrace(stack.back(), output_value); } TORCH_CHECK( stack.size() > 0, "Expected values in the stack after execution but found none"); return toPyObject(std::move(stack.back())); } inline c10::optional maybeTorchFunctionDispatch( const py::object& callee, const tuple_slice& args_no_self, const py::kwargs& kwargs, const c10::QualifiedName qualname) { std::vector args_vec; for (const auto& arg : args_no_self) { args_vec.push_back(arg); } py::tuple args = py::cast(args_vec); // Handle __torch_function__ dispatch std::vector overloaded_args; size_t total_arg_num = args.size() + kwargs.size(); for (const auto& arg : args) { is_tensor_and_append_overloaded(arg.ptr(), &overloaded_args); is_tensor_list_and_append_overloaded( arg.ptr(), &overloaded_args, static_cast(total_arg_num), false /* throw_error */); } // NB: for kwargs, we cannot guarantee the order of appending // is the same as the argument order in operator's schema. // This is suboptimal, but should be fine. Later when we have // better schema matching and argument parsing, we could // match the operator in `operations` first, then the order will // be guaranteed. for (auto item : kwargs) { is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args); is_tensor_list_and_append_overloaded( item.second.ptr(), &overloaded_args, total_arg_num, false /* throw_error */); } if (overloaded_args.size() > 0) { return pybind11::reinterpret_steal( handle_torch_function_no_python_arg_parser( /*overloaded_args=*/overloaded_args, /*args=*/args.ptr(), /*kwargs=*/kwargs.ptr(), /*func_name=*/qualname.name().c_str(), /*torch_api_function=*/callee.ptr(), /*module_name=*/qualname.prefix().c_str())); } return c10::nullopt; } inline py::object invokeScriptFunctionFromPython( Function& callee, const tuple_slice& args, const py::kwargs& kwargs) { // TODO: we could add __torch_function__ dispatch here but I don't know // the implications of doing so return runAndInsertCall( callee, args, kwargs, /*self=*/c10::nullopt, [&](Graph& graph, const MatchedSchema& match) { return graph.insertFunctionCall(&callee, match); }); } inline py::object invokeScriptMethodFromPython( Method& callee, const tuple_slice& args, const py::kwargs& kwargs) { auto self = callee.owner()._ivalue(); if (auto torch_fn_result = maybeTorchFunctionDispatch( py::cast(callee), args, kwargs, callee.name())) { return *torch_fn_result; } return runAndInsertCall( callee.function(), args, kwargs, self, [&](Graph& graph, const MatchedSchema& match) { return graph.insertMethodCall(callee.name(), match); }); } inline std::pair, Stack> getOpWithStack( const std::vector>& operations, py::args args, const py::kwargs& kwargs) { Stack stack; if (operations.size() == 1) { std::shared_ptr op = operations.at(0); // Create a stack full of the arguments and keyword arguments. stack = createStackForSchema( op->schema(), std::move(args), kwargs, c10::nullopt); return std::make_pair(op, stack); } else { std::vector errors; std::shared_ptr found_op = nullptr; for (const auto& op : operations) { try { stack = createStackForSchema(op->schema(), args, kwargs, c10::nullopt); found_op = op; break; } catch (schema_match_error& error) { errors.push_back(std::move(error)); } } if (!found_op) { std::stringstream ss; ss << "Overloaded torch operator invoked from Python failed to many any schema:\n"; for (const auto& err : errors) { ss << err.what() << "\n\n"; } throw std::runtime_error(ss.str()); } return std::make_pair(found_op, stack); } } inline py::object invokeOperatorFromPython( const std::vector>& operations, py::args args, const py::kwargs& kwargs) { auto opWithStack = getOpWithStack(operations, args, kwargs); std::shared_ptr found_op = std::get<0>(opWithStack); Stack stack = std::get<1>(opWithStack); { pybind11::gil_scoped_release no_gil_guard; found_op->getOperation()(stack); } return createPyObjectForStack(std::move(stack)); } } // namespace jit } // namespace torch