#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { static std::unordered_map type_map = { {"Tensor", ParameterType::TENSOR}, {"Scalar", ParameterType::SCALAR}, {"int64_t", ParameterType::INT64}, {"SymInt", ParameterType::SYM_INT}, {"double", ParameterType::DOUBLE}, {"complex", ParameterType::COMPLEX}, {"TensorList", ParameterType::TENSOR_LIST}, {"c10::List>", ParameterType::TENSOR_LIST}, {"IntArrayRef", ParameterType::INT_LIST}, {"SymIntArrayRef", ParameterType::SYM_INT_LIST}, {"ArrayRef", ParameterType::FLOAT_LIST}, {"Generator", ParameterType::GENERATOR}, {"bool", ParameterType::BOOL}, {"Storage", ParameterType::STORAGE}, {"PyObject*", ParameterType::PYOBJECT}, {"ScalarType", ParameterType::SCALARTYPE}, {"Layout", ParameterType::LAYOUT}, {"MemoryFormat", ParameterType::MEMORY_FORMAT}, {"QScheme", ParameterType::QSCHEME}, {"Device", ParameterType::DEVICE}, {"Stream", ParameterType::STREAM}, {"std::string", ParameterType::STRING}, {"c10::string_view", ParameterType::STRING}, {"Dimname", ParameterType::DIMNAME}, {"DimnameList", ParameterType::DIMNAME_LIST}, {"ScalarList", ParameterType::SCALAR_LIST}, }; // Default arg name translations for compatibility with NumPy. // // Example: // ```python // t = torch.randn(10,10) // torch.sum(a=t, axis=0, keepdim=True) // ``` // // A vector is necessary, because we might need to try multiple values. // In particular, NumPy sometimes uses "x" and sometimes "a" for the main input // tensor. Rather than annotate each function separately with whether it should // take "x" or "a", just try both. // // TODO: Allow individual functions to specify non-default translations: // For example, `torch.pow` should translate "exponent" to "x2". static const std::unordered_map> numpy_compatibility_arg_names = { {"dim", {"axis"}}, {"keepdim", {"keepdims"}}, {"input", {"x", "a", "x1"}}, {"other", {"x2"}}, }; // TODO: remove this. This is a temporary list of functions that allow Python // numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar // overloads and binding to the Tensor overload with a number of a different // type will trigger a type error. // // If you modify this, you will need to adjust the blocklist in // tools/pyi/gen_pyi.py (and add hardcoded signatures for these // functions.) bool should_allow_numbers_as_tensors(const std::string& name) { static std::unordered_set allowed = { "add", "add_", "add_out", "div", "div_", "div_out", "divide", "divide_", "divide_out", // alias of div "mul", "mul_", "mul_out", "multiply", "multiply_", "multiply_out", // alias of mul "sub", "sub_", "sub_out", "subtract", "subtract_", "subtract_out", // alias of sub "true_divide", "true_divide_", "true_divide_out", "to", "_to_copy", "copy_", "floor_divide", "floor_divide_", "floor_divide_out"}; return allowed.find(name) != allowed.end(); } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) : optional(false), allow_none(false), keyword_only(keyword_only), size(0), default_scalar(0) { auto space = fmt.find(' '); if (space == std::string::npos) { throw std::runtime_error("FunctionParameter(): missing type: " + fmt); } auto type_str = fmt.substr(0, space); auto question = type_str.find('?'); if (question != std::string::npos) { allow_none = true; type_str = type_str.substr(0, question); } // Parse and remove brackets from type_str auto bracket = type_str.find('['); if (bracket != std::string::npos) { auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2); size = atoi(size_str.c_str()); type_str = type_str.substr(0, bracket); } auto name_str = fmt.substr(space + 1); auto it = type_map.find(type_str); if (it == type_map.end()) { throw std::runtime_error( "FunctionParameter(): invalid type string: " + type_str); } type_ = it->second; auto eq = name_str.find('='); if (eq != std::string::npos) { name = name_str.substr(0, eq); optional = true; set_default_str(name_str.substr(eq + 1)); } else { name = name_str; } python_name = THPUtils_internString(name); auto np_compat_it = numpy_compatibility_arg_names.find(name); if (np_compat_it != numpy_compatibility_arg_names.end()) { for (const auto& str : np_compat_it->second) { numpy_python_names.push_back(THPUtils_internString(str)); } } } auto handle_torch_function_getter( THPVariable* self, const std::string& property_name) -> PyObject* { py::object torch_api = PyObject_FastGetAttrString( THPVariableClass, (char*)property_name.c_str()); std::string module_name = "torch.Tensor." + property_name; return handle_torch_function( (PyObject*)self, "__get__", nullptr, nullptr, torch_api.ptr(), module_name); } auto handle_torch_function_setter( THPVariable* self, const std::string& property_name, PyObject* value) -> int { py::object torch_api = PyObject_FastGetAttrString( THPVariableClass, (char*)property_name.c_str()); std::string module_name = "torch.Tensor." + property_name; if (value != nullptr) { py::tuple args_ = py::make_tuple(py::handle(value)); handle_torch_function( (PyObject*)self, "__set__", args_.ptr(), nullptr, torch_api.ptr(), module_name); } else { handle_torch_function( (PyObject*)self, "__delete__", nullptr, nullptr, torch_api.ptr(), module_name); } return 0; } // Combines self and args into one tuple. auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple { if (args == nullptr) { return py::make_tuple(py::handle(self)); } else if (self == nullptr) { return py::reinterpret_borrow(args); } auto py_args = py::reinterpret_borrow(args); size_t n = py_args.size(); auto args_ = py::tuple(n + 1); args_[0] = py::handle(self); for (const auto i : c10::irange(n)) { args_[i + 1] = py_args[i]; } return args_; } // TODO: I'm not sure if I should call this __torch_function__ or // torch_function. The former makes it easier to take an existing // Tensor-like __torch_function__ object and turn it into a mode; // but in general modes don't have to be Tensor-like (and we will // improperly accept mode objects as arguments when they shouldn't // be passed around in this way). const char* torch_function_mode_name = "__torch_function__"; auto handle_torch_function( PyObject* self, const std::string& func_name, PyObject* args, PyObject* kwargs, PyObject* torch_api, const std::string& module_name) -> PyObject* { py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str()); TORCH_INTERNAL_ASSERT( torch_api_function.ptr() != nullptr, "torch API function must exist"); py::tuple args_ = combine_self_args(self, args); return handle_torch_function_no_python_arg_parser( {py::handle(self)}, args_.ptr(), kwargs, func_name.c_str(), torch_api_function.ptr(), module_name.c_str(), TorchFunctionName::TorchFunction); } // Note: [Overloaded args] // An overloaded arg may be one of the following: // - an instance of an object that has a __torch_function__ method // - an instance of an object that has a __torch_dispatch__ classmethod // - a class type that has a __torch_dispatch__ classmethod // // This function returns the type of the arg (if the arg is an instance), // otherwise, it returns the arg. static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) { if (PyType_Check(obj_or_type)) { return obj_or_type; } return (PyObject*)Py_TYPE(obj_or_type); } // See Note: [Overloaded args] for what they hold auto handle_torch_function_no_python_arg_parser( at::ArrayRef overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, TorchFunctionName torch_function_name) -> PyObject* { const char* torch_function_name_str = nullptr; switch (torch_function_name) { case TorchFunctionName::TorchFunction: torch_function_name_str = "__torch_function__"; break; case TorchFunctionName::TorchDispatch: torch_function_name_str = "__torch_dispatch__"; break; default: TORCH_INTERNAL_ASSERT(0, static_cast(torch_function_name)); } // overloaded_args already all have unique types // nb: modes don't go in the overloaded types list, as they are not // necessarily types std::vector overloaded_types; overloaded_types.reserve(overloaded_args.size()); for (auto& arg : overloaded_args) { overloaded_types.push_back(py::reinterpret_borrow( get_type_of_overloaded_arg(arg.ptr()))); } py::tuple py_types = py::cast(overloaded_types); py::object ret; PyObject* mode_obj = nullptr; const bool is_torch_function = torch_function_name == TorchFunctionName::TorchFunction; const auto is_mode_active = [&]() { return is_torch_function ? at::impl::torch_function_mode_enabled() : c10::impl::dispatch_mode_enabled(); }; if (is_mode_active()) { // Disable mode on the inside; this makes for a more user-friendly // experience if you try to, e.g., print your tensors. at::optional tf_g; at::optional td_g; if (is_torch_function) { tf_g.emplace(); mode_obj = tf_g->get_cur_mode()->ptr(getPyInterpreter()); } else { td_g.emplace(); mode_obj = td_g->get_cur_mode()->ptr(getPyInterpreter()); } py::object torch_function = PyObject_FastGetAttrString(mode_obj, torch_function_name_str); if (!torch_function) { TORCH_INTERNAL_ASSERT(0); } TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr); TORCH_INTERNAL_ASSERT(args != nullptr); TORCH_CHECK( PyObject_FastGetAttrString(torch_function.ptr(), "__self__") .is(py::reinterpret_borrow(mode_obj)), "Defining your mode's `", torch_function_name_str, "` as a classmethod is not supported, please make it a plain method"); // Blegh. This accidentally works in PyObject_CallFunctionObjArgs below // because the nullptr terminates the argument list ick ick ick. if (kwargs == nullptr) { ret = py::reinterpret_steal(PyObject_CallMethod( mode_obj, torch_function_name_str, "OOO", torch_api_function, py_types.ptr(), args)); } else { ret = py::reinterpret_steal(PyObject_CallMethod( mode_obj, torch_function_name_str, "OOOO", torch_api_function, py_types.ptr(), args, kwargs)); } if (ret.ptr() == nullptr) { throw python_error(); } } if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) { for (auto& arg : overloaded_args) { // NOLINTNEXTLINE(clang-diagnostic-writable-strings) py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), torch_function_name_str); if (!torch_function) { TORCH_INTERNAL_ASSERT(0); } // See https://github.com/pytorch/pytorch/issues/63767 if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__") .is(arg) && torch_function.ptr() != torch::disabled_torch_function_impl()) { TORCH_WARN( "Defining your `", torch_function_name_str, "` as a plain method is deprecated ", "and will be an error in future, please define it as a classmethod."); } ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( torch_function.ptr(), torch_api_function, py_types.ptr(), args, kwargs, NULL)); if (ret.ptr() != Py_NotImplemented) { // Return the reference to the result. This also covers the case where // ret is NULL and __torch_function__/__torch_dispatch raised an // exception, which we throw below break; } } } if (ret.ptr() == nullptr) { // if an exception occurred in a user's implementation of // __torch_function__, throw it throw python_error(); } else if (ret.ptr() == Py_NotImplemented) { // all __torch_function__ implementations in overloaded_args // returned NotImplemented, so we raise a TypeError. std::stringstream ss; ss << "no implementation found for '"; if (module_name && func_name) { ss << module_name << "." << func_name; } else { py::handle fn = torch_api_function; ss << py::str(fn.attr("__module__")) << "." << py::str(fn.attr("__name__")); } ss << "' on types that implement " << torch_function_name_str << ": ["; for (auto& arg : overloaded_args) { ss << py::repr(get_type_of_overloaded_arg(arg.ptr())); if (!arg.is(overloaded_args.back())) { ss << ", "; } } ss << "]"; const std::string& tmp = ss.str(); PyErr_SetString(PyExc_TypeError, tmp.c_str()); throw python_error(); } return ret.release().ptr(); } auto handle_torch_function( PythonArgs& r, PyObject* self, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name, const char* func_name_override) -> PyObject* { py::object torch_api_function = PyObject_FastGetAttrString( torch_api, (char*)(func_name_override ? func_name_override : r.get_func_name().c_str())); TORCH_INTERNAL_ASSERT( torch_api_function.ptr() != nullptr, "torch API function must exist"); py::object ret; py::tuple args_ = combine_self_args(self, args); // overloaded_args already all have unique types std::vector overloaded_types; overloaded_types.reserve(r.signature.overloaded_args.size()); for (auto& arg : r.signature.overloaded_args) { overloaded_types.push_back( py::reinterpret_borrow((PyObject*)Py_TYPE(arg.ptr()))); } py::tuple py_types = py::cast(overloaded_types); return handle_torch_function_no_python_arg_parser( r.signature.overloaded_args, args_.ptr(), kwargs, r.get_func_name().c_str(), torch_api_function.ptr(), module_name); } auto handle_torch_function( PythonArgs& r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name, const char* func_name_override) -> PyObject* { return handle_torch_function( r, nullptr, args, kwargs, torch_api, module_name, func_name_override); } auto handle_torch_function_indexing( PyObject* self, PyObject* index, PyObject* val) -> PyObject* { const char* func_name = (val == nullptr) ? "__getitem__" : "__setitem__"; py::object index_tup; if (PyTuple_Check(index)) { index_tup = py::reinterpret_borrow(index); } else { index_tup = py::make_tuple(py::handle(index)); } std::vector overridable_args; is_tensor_and_append_overloaded(self, &overridable_args); auto size = PyTuple_GET_SIZE(index_tup.ptr()); for (auto i : c10::irange(size)) { auto* obj = PyTuple_GetItem(index_tup.ptr(), i); is_tensor_and_append_overloaded(obj, &overridable_args); } if (val != nullptr) { is_tensor_and_append_overloaded(val, &overridable_args); } py::object func = PyObject_FastGetAttrString(THPVariableClass, (char*)func_name); py::object args = (val == nullptr) ? py::make_tuple(py::handle(self), py::handle(index)) : py::make_tuple(py::handle(self), py::handle(index), py::handle(val)); return handle_torch_function_no_python_arg_parser( overridable_args, args.ptr(), nullptr, func_name, func.ptr(), "torch.Tensor"); } /* * obj has a __torch_function__ implementation and may either be a * subclass of Tensor or a Tensor-like duck type. We may need to * append this object to the overloaded_args vector, which tracks all * of the arguments with distinct __torch_function__ implementations * we've seen so far. * * If this is the first argument we've seen with __torch_function__ * defined, we unconditionally add obj to the overloaded_args vector. * * If we've already seen arguments with __torch_function__ defined, * then we first need to check if obj is the same type as any of the * entries in overloaded_args. If so, we can ignore obj since we * already have an entry in overloaded_args with the same * __torch_function__ implementation. * * If it's a different type, we then need to check if it's a subclass * of one of the types we've already seen. If so, we need to insert an * entry in overloaded_args for this type with higher precedence than * the superclass. * * See torch._overrides._get_overloaded_types_and_args for the equivalent * function in the Python __torch_function__ implementation. * * The precedence-determining algorithm implemented in this function is * described in NEP-0018: * https://numpy.org/neps/nep-0018-array-function-protocol.html * * 'overloaded_args' is a raw pointer to a vector of pybind11 handles * that have distinct __torch_function__ implementations, in order of calling * precedence. * * 'obj' is an object to check for a __torch_function__ implementation * * If changing this file in a way that can affect the __torch_function__ * overhead, please report the benchmarks in 'benchmarks/overrides_benchmark'. * See the instructions in the 'README.md' in that directory. * */ static void append_overloaded_arg( std::vector* overloaded_args, PyObject* obj, bool obj_is_type) { bool class_not_seen_yet = true; PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj); for (auto& arg : *overloaded_args) { if (obj_type == get_type_of_overloaded_arg(arg.ptr())) { // obj is the same type as another parameter we've seen in a prior // iteration of the loop over parameters so we already have an entry // with the proper __torch_function__ implementation to call, so skip // this parameter class_not_seen_yet = false; break; } } if (class_not_seen_yet) { auto arg_index = overloaded_args->size(); for (const auto j : c10::irange(arg_index)) { if (PyObject_IsSubclass( obj_type, (PyObject*)(get_type_of_overloaded_arg( (*overloaded_args)[j].ptr())))) { // obj is a subclass of another object we've seen already so its // __torch_function__ should be called first, therefore we // insert it into overloaded_args before the superclass arg_index = j; break; } } // add object to overloaded_args. If it's a subclass of another class // we've already seen it will be inserted before the superclass, // otherwise it will be inserted at the end of the array overloaded_args->insert( overloaded_args->begin() + static_cast(arg_index), obj); } } void append_overloaded_tensor( std::vector* overloaded_args, PyObject* obj) { append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ false); } void append_overloaded_type( std::vector* overloaded_args, PyObject* obj) { append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ true); } bool is_tensor_and_append_overloaded( PyObject* obj, std::vector* overloaded_args) { if (THPVariable_CheckExact(obj)) { // torch.Tensor instances (not subclasses, except for Parameter) return true; } if (check_has_torch_function(obj, /*ignore_mode*/ true)) { // tensor subclasses and unrelated objects with __torch_function__ append_overloaded_tensor(overloaded_args, obj); return true; } else if (THPVariable_Check(obj)) { // tensor subclasses without __torch_function__ return true; } return false; } bool is_scalar_list(PyObject* obj) { auto tuple = six::isTuple(obj); if (!(tuple || PyList_Check(obj))) { return false; } // NOLINTNEXTLINE(bugprone-branch-clone) const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); for (const auto idx : c10::irange(size)) { PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); if (!THPUtils_checkScalar(iobj)) { return false; } } return true; } bool is_tensor_list_and_append_overloaded( PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error) { auto tuple = six::isTuple(obj); if (!(tuple || PyList_Check(obj))) { return false; } // NOLINTNEXTLINE(bugprone-branch-clone) const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); for (long idx = 0; idx < size; idx++) { PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) { if (throw_error) { throw TypeError( "expected Tensor as element %d in argument %d, but got %s", static_cast(idx), argnum, Py_TYPE(iobj)->tp_name); } return false; } } return true; } bool is_float_or_complex_list(PyObject* obj) { auto tuple = six::isTuple(obj); if (!(tuple || PyList_Check(obj))) { return false; } // NOLINTNEXTLINE(bugprone-branch-clone) const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); if (size > 0) { PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0); if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) { return false; } } return true; } static bool is_int_list( PyObject* obj, int broadcast_size, int64_t* failed_idx = nullptr) { if (PyTuple_Check(obj) || PyList_Check(obj)) { auto len = PySequence_Size(obj); if (len == 0) { return true; } auto item = py::reinterpret_steal(PySequence_GetItem(obj, 0)); bool int_first = false; if (THPUtils_checkIndex(item.ptr())) { // we still have to check that the rest of items are NOT symint nodes int_first = true; } // Make sure none of the later arguments are SymInt // NB: do NOT check that the later arguments are ints, as this is // BC-breaking for FX for (int i = 1; i < len; i++) { if (torch::is_symint( py::reinterpret_steal(PySequence_GetItem(obj, i)))) { if (failed_idx != nullptr) { *failed_idx = i; } return false; } } if (int_first) { return true; } // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints // in an intlist argument. Even float or complex scalar tensors. bool r = (jit::tracer::isTracing() && THPVariable_Check(item.ptr()) && THPVariable_Unpack(item.ptr()).sizes().empty()); if (!r && failed_idx != nullptr) { *failed_idx = 0; } return r; } // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single // int return broadcast_size > 0 && THPUtils_checkLong(obj); } static bool is_int_or_symint(PyObject* obj) { // THPUtils_checkIndex may call __index__ or __int__ // which may have side effects if obj is a symint node // so we do `is_symint` check first // TODO: maybe we should be using checkLong here? return torch::is_symint(py::handle(obj)) || THPUtils_checkIndex(obj); } static bool is_int_or_symint_list( PyObject* obj, int broadcast_size, int64_t* failed_idx = nullptr) { if (PyTuple_Check(obj) || PyList_Check(obj)) { if (PySequence_Size(obj) == 0) { return true; } auto item = py::reinterpret_steal(PySequence_GetItem(obj, 0)); if (is_int_or_symint(item.ptr())) { return true; } // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints // in an intlist argument. Even float or complex scalar tensors. bool r = (jit::tracer::isTracing() && THPVariable_Check(item.ptr()) && THPVariable_Unpack(item.ptr()).sizes().empty()); if (!r && failed_idx != nullptr) { *failed_idx = 0; } return r; } // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single // int return broadcast_size > 0 && THPUtils_checkLong(obj); } // argnum is needed for raising the TypeError, it's used in the error message. auto FunctionParameter::check( PyObject* obj, std::vector& overloaded_args, int argnum, int64_t* failed_idx) -> bool { switch (type_) { case ParameterType::TENSOR: { if (is_tensor_and_append_overloaded(obj, &overloaded_args)) { return true; } if (allow_numbers_as_tensors) { return THPUtils_checkScalar(obj); } return false; } case ParameterType::SCALAR: if (THPUtils_checkScalar(obj)) { return true; } // fallthrough case ParameterType::COMPLEX: if (PyComplex_Check(obj)) { return true; } // fallthrough case ParameterType::DOUBLE: { if (THPUtils_checkDouble(obj)) { return true; } if (THPVariable_Check(obj)) { const auto& var = THPVariable_Unpack(obj); return !var.requires_grad() && var.dim() == 0; } return false; } case ParameterType::INT64: { if (THPUtils_checkLong(obj)) { return true; } if (THPVariable_Check(obj)) { const auto& var = THPVariable_Unpack(obj); return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0; } return false; } case ParameterType::DIMNAME: return THPUtils_checkDimname(obj); case ParameterType::DIMNAME_LIST: { if (THPUtils_checkDimnameList(obj)) { return true; } // if a size is specified (e.g. DimnameList[1]) we also allow passing a // single Dimname return size == 1 && THPUtils_checkDimname(obj); } case ParameterType::TENSOR_LIST: { return is_tensor_list_and_append_overloaded( obj, &overloaded_args, argnum, true /* throw_error */); } case ParameterType::INT_LIST: return is_int_list(obj, size, failed_idx); case ParameterType::FLOAT_LIST: return is_float_or_complex_list(obj); case ParameterType::GENERATOR: return THPGenerator_Check(obj); case ParameterType::BOOL: return PyBool_Check(obj); case ParameterType::STORAGE: return isStorage(obj); case ParameterType::PYOBJECT: return true; case ParameterType::SCALARTYPE: return THPDtype_Check(obj) || THPPythonScalarType_Check(obj); case ParameterType::LAYOUT: return THPLayout_Check(obj); case ParameterType::MEMORY_FORMAT: return THPMemoryFormat_Check(obj); case ParameterType::QSCHEME: return THPQScheme_Check(obj); case ParameterType::DEVICE: return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj); case ParameterType::STREAM: return THPStream_Check(obj); case ParameterType::STRING: return THPUtils_checkString(obj); case ParameterType::SCALAR_LIST: return is_scalar_list(obj); case ParameterType::SYM_INT: return is_int_or_symint(obj); case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size, failed_idx); default: throw std::runtime_error("unknown parameter type"); } } // WARNING: these strings are parsed invalid_arguments.cpp std::string FunctionParameter::type_name() const { switch (type_) { case ParameterType::TENSOR: return "Tensor"; case ParameterType::SCALAR: return "Number"; case ParameterType::INT64: // NB: SymInt is intentionally not mentioned here, as conventional user // use will only know about ints case ParameterType::SYM_INT: return "int"; case ParameterType::DOUBLE: return "float"; case ParameterType::COMPLEX: return "complex"; case ParameterType::TENSOR_LIST: return "tuple of Tensors"; case ParameterType::INT_LIST: return "tuple of ints"; case ParameterType::FLOAT_LIST: return "tuple of floats"; case ParameterType::GENERATOR: return "torch.Generator"; case ParameterType::BOOL: return "bool"; case ParameterType::STORAGE: return "torch.Storage"; case ParameterType::PYOBJECT: return "object"; case ParameterType::SCALARTYPE: return "torch.dtype"; case ParameterType::LAYOUT: return "torch.layout"; case ParameterType::MEMORY_FORMAT: return "torch.memory_format"; case ParameterType::QSCHEME: return "torch.qscheme"; case ParameterType::DEVICE: return "torch.device"; case ParameterType::STRING: return "str"; case ParameterType::DIMNAME: return "name"; case ParameterType::DIMNAME_LIST: return "tuple of names"; case ParameterType::SCALAR_LIST: return "tuple of Scalars"; case ParameterType::SYM_INT_LIST: return "tuple of ints"; default: throw std::runtime_error("unknown parameter type"); } } static inline c10::optional parse_as_integer(const std::string& s) { if (s.empty()) return c10::nullopt; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) char* str_end; long ans = strtol(s.c_str(), &str_end, 0); // *str_end == 0 if the entire string was parsed as an integer. return (*str_end == 0) ? c10::optional(ans) : c10::nullopt; } /* Parse default value of IntArrayRef declared at native_functions.yaml There are two kinds of default values: 1. IntArrayRef[2] x=1 (where size=2, value={1,1} 2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args) */ static inline std::vector parse_intlist_args( const std::string& s, int64_t size) { size_t n = s.size(); if (s.empty()) return std::vector(); // case 1. s is an int (e.g., s=2) if (s[0] != '{') { TORCH_CHECK(size > 0, "Incorrect size of IntArrayRef: ", size); return std::vector(size, std::stol(s)); } // case 2. s is a list of dims (e.g., s={1,2}) // since already checked left brace '{' above, here only checks right brace // '}' TORCH_CHECK( s[n - 1] == '}', "Default value of IntArrayRef is missing right brace '}', found ", s[n - 1]); auto args = std::vector(); std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}' std::string tok; while (std::getline(ss, tok, ',')) { args.emplace_back(std::stol(tok)); } return args; } // Parse a string literal to remove quotes and escape sequences static std::string parse_string_literal(c10::string_view str) { TORCH_CHECK(str.length() >= 2, "String defaults must be quoted"); if (str.front() == '"') { TORCH_CHECK( str.back() == '"', "Mismatched quotes in string default: ", str); } else { TORCH_CHECK( str.front() == '\'' && str.back() == '\'', "Invalid quotes in string default: ", str) } std::string parsed; parsed.reserve(str.size()); for (size_t i = 1; i < str.size() - 1;) { if (str[i] != '\\') { parsed.push_back(str[i]); ++i; continue; } // Handle escape sequences TORCH_CHECK( i < str.size() - 2, "String ends with escaped final quote: ", str) char c = str[i + 1]; switch (c) { case '\\': case '\'': case '\"': break; case 'a': c = '\a'; break; case 'b': c = '\b'; break; case 'f': c = '\f'; break; case 'n': c = '\n'; break; case 'v': c = '\v'; break; case 't': c = '\t'; break; default: TORCH_CHECK( false, "Unsupported escape sequence in string default: \\", str[i + 1]); } parsed.push_back(c); i += 2; } return parsed; } void FunctionParameter::set_default_str(const std::string& str) { if (str == "None") { allow_none = true; } if (type_ == ParameterType::TENSOR) { if (str != "None") { throw std::runtime_error( "default value for Tensor must be none, got: " + str); } } else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) { default_int = atol(str.c_str()); } else if (type_ == ParameterType::BOOL) { default_bool = (str == "True" || str == "true"); } else if (type_ == ParameterType::DOUBLE) { default_double = atof(str.c_str()); } else if (type_ == ParameterType::COMPLEX) { default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"? default_complex[1] = 0; } else if (type_ == ParameterType::SCALAR) { if (str != "None") { // we sometimes rely on integer-vs-float values, e.g. with arange. const auto as_integer = parse_as_integer(str); default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) : at::Scalar(atof(str.c_str())); } } else if ( type_ == ParameterType::INT_LIST || type_ == ParameterType::SYM_INT_LIST) { if (str != "None") { default_intlist = parse_intlist_args(str, size); } } else if (type_ == ParameterType::FLOAT_LIST) { if (str != "None") { throw std::runtime_error("Defaults not supported for float[]"); } } else if (type_ == ParameterType::SCALARTYPE) { if (str == "None") { default_scalartype = at::ScalarType::Undefined; } else if (str == "torch.int64") { default_scalartype = at::ScalarType::Long; } else { throw std::runtime_error("invalid default value for ScalarType: " + str); } } else if (type_ == ParameterType::LAYOUT) { if (str == "None") { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none); } else if (str == "torch.strided") { default_layout = at::Layout::Strided; } else if (str == "torch.sparse_coo") { default_layout = at::Layout::Sparse; } else { throw std::runtime_error("invalid default value for layout: " + str); } } else if (type_ == ParameterType::DEVICE) { if (str != "None") { throw std::runtime_error("invalid device: " + str); } } else if (type_ == ParameterType::STREAM) { if (str != "None") { throw std::runtime_error("invalid stream: " + str); } } else if (type_ == ParameterType::STRING) { if (str != "None") { default_string = parse_string_literal(str); } } // These types weren't handled here before. Adding a default error // led to a lot of test failures so adding this skip for now. // We should correctly handle these though because it might be causing // silent failures. else if (type_ == ParameterType::TENSOR_LIST) { // NOLINT // throw std::runtime_error("Invalid Tensor List"); } else if (type_ == ParameterType::GENERATOR) { // NOLINT // throw std::runtime_error("ParameterType::GENERATOR"); } else if (type_ == ParameterType::PYOBJECT) { // NOLINT // throw std::runtime_error("ParameterType::PYOBJECT"); } else if (type_ == ParameterType::MEMORY_FORMAT) { // NOLINT // throw std::runtime_error("ParameterType::MEMORY_FORMAT"); } else if (type_ == ParameterType::DIMNAME) { // NOLINT // throw std::runtime_error("ParameterType::DIMNAME"); } else if (type_ == ParameterType::DIMNAME_LIST) { // NOLINT // throw std::runtime_error("ParameterType::DIMNAME_LIST"); } else if (type_ == ParameterType::SCALAR_LIST) { // NOLINT // throw std::runtime_error("ParameterType::SCALAR_LIST"); } else if (type_ == ParameterType::STORAGE) { // NOLINT // throw std::runtime_error("ParameterType::STORAGE"); } else if (type_ == ParameterType::QSCHEME) { // NOLINT // throw std::runtime_error("ParameterType::QSCHEME"); } else { throw std::runtime_error("unknown parameter type"); } } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FunctionSignature::FunctionSignature(const std::string& fmt, int index) : min_args(0), max_args(0), max_pos_args(0), index(index), hidden(false), deprecated(false) { auto open_paren = fmt.find('('); if (open_paren == std::string::npos) { throw std::runtime_error("missing opening parenthesis: " + fmt); } name = fmt.substr(0, open_paren); bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name); auto last_offset = open_paren + 1; bool keyword_only = false; bool done = false; while (!done) { auto offset = fmt.find(", ", last_offset); auto next_offset = offset + 2; if (offset == std::string::npos) { offset = fmt.find(')', last_offset); done = true; next_offset = offset + 1; // this 'if' happens for an empty parameter list, i.e. fn(). if (offset == last_offset) { last_offset = next_offset; break; } } if (offset == std::string::npos) { throw std::runtime_error("missing closing parenthesis: " + fmt); } if (offset == last_offset) { throw std::runtime_error("malformed signature: " + fmt); } auto param_str = fmt.substr(last_offset, offset - last_offset); last_offset = next_offset; if (param_str == "*") { keyword_only = true; } else { params.emplace_back(param_str, keyword_only); params.back().allow_numbers_as_tensors = allow_numbers_as_tensors; } } if (fmt.substr(last_offset) == "|deprecated") { hidden = true; // TODO: raise warning when parsing deprecated signatures deprecated = true; } else if (fmt.substr(last_offset) == "|hidden") { hidden = true; } max_args = params.size(); // count the number of non-optional args for (auto& param : params) { if (!param.optional) { min_args++; } if (!param.keyword_only) { max_pos_args++; } } } std::string FunctionSignature::toString() const { // TODO: consider printing more proper schema strings with defaults, // optionals, etc. std::ostringstream ss; bool keyword_already = false; ss << "("; int i = 0; for (auto& param : params) { if (i != 0) { ss << ", "; } if (param.keyword_only && !keyword_already) { ss << "*, "; keyword_already = true; } ss << param.type_name() << " " << param.name; i++; } ss << ")"; return ss.str(); } [[noreturn]] static void extra_args( const FunctionSignature& signature, Py_ssize_t nargs) { const auto max_pos_args = signature.max_pos_args; const auto min_args = signature.min_args; const long nargs_ = nargs; if (min_args != max_pos_args) { throw TypeError( "%s() takes from %zu to %zu positional arguments but %ld were given", signature.name.c_str(), min_args, max_pos_args, nargs_); } throw TypeError( "%s() takes %zu positional argument%s but %ld %s given", signature.name.c_str(), max_pos_args, max_pos_args == 1 ? "" : "s", nargs_, nargs == 1 ? "was" : "were"); } [[noreturn]] static void missing_args( const FunctionSignature& signature, int idx) { int num_missing = 0; std::stringstream ss; auto& params = signature.params; for (auto it = params.begin() + idx; it != params.end(); ++it) { if (!it->optional) { if (num_missing > 0) { ss << ", "; } ss << '"' << it->name << '"'; num_missing++; } } throw TypeError( "%s() missing %d required positional argument%s: %s", signature.name.c_str(), num_missing, num_missing == 1 ? "s" : "", ss.str().c_str()); } static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) { Py_ssize_t i = 0; for (auto& param : signature.params) { int cmp = PyObject_RichCompareBool(name, param.python_name, Py_EQ); if (cmp < 0) { throw python_error(); } else if (cmp) { return i; } i++; } return -1; } [[noreturn]] static void extra_kwargs( FunctionSignature& signature, PyObject* kwargs, Py_ssize_t num_pos_args) { PyObject* key = nullptr; PyObject* value = nullptr; Py_ssize_t pos = 0; while (PyDict_Next(kwargs, &pos, &key, &value)) { if (!THPUtils_checkString(key)) { throw TypeError("keywords must be strings"); } auto param_idx = find_param(signature, key); if (param_idx < 0) { throw TypeError( "%s() got an unexpected keyword argument '%s'", signature.name.c_str(), THPUtils_unpackString(key).c_str()); } if (param_idx < num_pos_args) { throw TypeError( "%s() got multiple values for argument '%s'", signature.name.c_str(), THPUtils_unpackString(key).c_str()); } } // this should never be hit throw TypeError("invalid keyword arguments"); } bool FunctionSignature::parse( PyObject* self, PyObject* args, PyObject* kwargs, PyObject* dst[], // NOLINT bool raise_exception) { Py_ssize_t nargs = args ? PyTuple_GET_SIZE(args) : 0; auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0; size_t arg_pos = 0; bool allow_varargs_intlist = false; // if there is a single positional IntArrayRef argument, i.e. expand(..), // view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as // expand((5,3)) int int_list_overload = false; if (max_pos_args == 1 && (params[0].type_ == ParameterType::INT_LIST || params[0].type_ == ParameterType::SYM_INT_LIST)) { allow_varargs_intlist = true; if (params[0].type_ == ParameterType::INT_LIST) { int_list_overload = true; } } if (static_cast(nargs) > max_pos_args && !allow_varargs_intlist) { if (raise_exception) { // foo() takes takes 2 positional arguments but 3 were given extra_args(*this, nargs); } return false; } if (!overloaded_args.empty()) { overloaded_args.clear(); } int i = 0; if (self != nullptr && check_has_torch_function(self, /*ignore_mode*/ true)) { append_overloaded_tensor(&this->overloaded_args, self); } for (auto& param : params) { PyObject* obj = nullptr; bool is_kwd = false; if (arg_pos < static_cast(nargs)) { // extra positional args given after single positional IntArrayRef arg if (param.keyword_only) { if (raise_exception) { extra_args(*this, nargs); } return false; } obj = PyTuple_GET_ITEM(args, arg_pos); } else if (kwargs) { obj = PyDict_GetItem(kwargs, param.python_name); for (PyObject* numpy_name : param.numpy_python_names) { if (obj) { break; } obj = PyDict_GetItem(kwargs, numpy_name); } is_kwd = true; } int64_t failed_idx = -1; bool varargs_eligible = allow_varargs_intlist && arg_pos == 0 && !is_kwd; if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) { dst[i++] = nullptr; } else if (!obj) { if (raise_exception) { // foo() missing 1 required positional argument: "b" missing_args(*this, i); } return false; } else if (param.check(obj, this->overloaded_args, i, &failed_idx)) { dst[i++] = obj; // XXX: the Variable check is necessary because sizes become tensors when // tracer is enabled. This behavior easily leads to ambiguities, and we // should avoid having complex signatures that make use of it... } else if ( varargs_eligible && ((int_list_overload ? is_int_list(args, param.size, &failed_idx) : is_int_or_symint_list(args, param.size, &failed_idx)))) { // take all positional arguments as this parameter // e.g. permute(1, 2, 3) -> permute((1, 2, 3)) dst[i++] = args; arg_pos = nargs; continue; } else if (raise_exception) { if (is_kwd) { // foo(): argument 'other' must be str, not int throw TypeError( "%s(): argument '%s' must be %s, not %s", name.c_str(), param.name.c_str(), param.type_name().c_str(), Py_TYPE(obj)->tp_name); } else { // foo(): argument 'other' (position 2) must be str, not int if (failed_idx != -1) { if (!(PyTuple_Check(obj) || PyList_Check(obj))) { TORCH_INTERNAL_ASSERT(varargs_eligible); obj = args; } TORCH_INTERNAL_ASSERT(failed_idx < PySequence_Size(obj)); throw TypeError( "%s(): argument '%s' (position %ld) must be %s, but found element of type %s at pos %ld", name.c_str(), param.name.c_str(), static_cast(arg_pos + 1), param.type_name().c_str(), Py_TYPE(py::reinterpret_steal( PySequence_GetItem(obj, failed_idx)) .ptr()) ->tp_name, static_cast(failed_idx)); } throw TypeError( "%s(): argument '%s' (position %ld) must be %s, not %s", name.c_str(), param.name.c_str(), static_cast(arg_pos + 1), param.type_name().c_str(), Py_TYPE(obj)->tp_name); } } else { return false; } if (!is_kwd) { arg_pos++; } else if (obj) { remaining_kwargs--; } } if (remaining_kwargs > 0) { if (raise_exception) { // foo() got an unexpected keyword argument "b" extra_kwargs(*this, kwargs, nargs); } return false; } return true; } PythonArgParser::PythonArgParser(std::vector fmts, bool traceable) : max_args(0), traceable(traceable) { int index = 0; for (auto& fmt : fmts) { signatures_.emplace_back(fmt, index); ++index; } for (auto& signature : signatures_) { if (signature.max_args > max_args) { max_args = signature.max_args; } } if (!signatures_.empty()) { function_name = signatures_[0].name; } // Check deprecated signatures last std::stable_partition( signatures_.begin(), signatures_.end(), [](const FunctionSignature& sig) { return !sig.deprecated; }); } void PythonArgParser::check_deprecated(const FunctionSignature& signature) { if (signature.deprecated) { auto msg = c10::str( "This overload of ", signature.name, " is deprecated:\n\t", signature.name, signature.toString()); auto signatures = get_signatures(); if (!signatures.empty()) { msg += "\nConsider using one of the following signatures instead:"; for (const auto& sig : signatures) { msg += "\n\t"; msg += signature.name; msg += sig; } } TORCH_WARN_ONCE(msg); } } PythonArgs PythonArgParser::raw_parse( PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT if (signatures_.size() == 1) { auto& signature = signatures_[0]; signature.parse(self, args, kwargs, parsed_args, true); check_deprecated(signature); return PythonArgs(traceable, signature, parsed_args); } for (auto& signature : signatures_) { if (signature.parse(self, args, kwargs, parsed_args, false)) { check_deprecated(signature); return PythonArgs(traceable, signature, parsed_args); } } print_error(self, args, kwargs, parsed_args); } void PythonArgParser::print_error( PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) size_t num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0); std::vector plausible_idxs; unsigned i = 0; for (auto& signature : signatures_) { if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) { plausible_idxs.push_back(i); } i++; } if (plausible_idxs.size() == 1) { auto& signature = signatures_[plausible_idxs[0]]; signature.parse(self, args, kwargs, parsed_args, true); } auto options = get_signatures(); auto msg = torch::format_invalid_args(args, kwargs, function_name + "()", options); throw TypeError("%s", msg.c_str()); } std::vector PythonArgParser::get_signatures() const { std::vector options; for (auto& signature : signatures_) { if (!signature.hidden) { options.push_back(signature.toString()); } } return options; } at::Tensor PythonArgs::tensor_slow(int i) { PyObject* obj = args[i]; if (!obj) { return at::Tensor(); } if (THPVariable_Check(obj)) { return THPVariable_Unpack(obj); } bool save_symint = false; at::Scalar scalar; if (PyBool_Check(obj)) { scalar = at::Scalar(THPUtils_unpackBool(obj)); } else if (THPUtils_checkLong(obj)) { scalar = at::Scalar(THPUtils_unpackLong(obj)); } else if (PyComplex_Check(obj)) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj)); } else if (THPUtils_checkDouble(obj)) { scalar = at::Scalar(THPUtils_unpackDouble(obj)); // NB: we DO NOT put symbolic ints/floats into the Scalar itself, // because although Scalar supports SymInt/SymFloat, the subsequent // conversion to Tensor does not. Instead, do it out of band. } else if (torch::is_symint(py::handle(obj))) { save_symint = true; // This scalar value doesn't matter, it shouldn't ever actually // get read out. Make it a big and weird looking number to help // people figure out if there's aproblem. scalar = at::Scalar(7777777); } else if (torch::is_symfloat(py::handle(obj))) { save_symint = true; scalar = at::Scalar(std::numeric_limits::quiet_NaN()); } else { // NB: Are you here because you passed None to a Variable method, // and you expected an undefined tensor to be returned? Don't add // a test for Py_None here; instead, you need to mark the argument // as *allowing none*; you can do this by writing 'Tensor?' instead // of 'Tensor' in the ATen metadata. throw TypeError( "expected Tensor as argument %d, but got %s", i, Py_TYPE(obj)->tp_name); } at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove at::tracer::impl::NoTracerDispatchMode tracer_guard; at::Tensor tensor = scalar_to_tensor(scalar); tensor.unsafeGetTensorImpl()->set_wrapped_number(true); if (save_symint) { auto py_tensor = py::cast(tensor); if (PyObject_SetAttrString(py_tensor.ptr(), "_wrapped_number", obj) < 0) { throw python_error(); } } return tensor; } at::Scalar PythonArgs::scalar_slow(int i) { if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) { auto& var = THPVariable_Unpack(args[i]); jit::tracer::ArgumentStash::stashValue( signature.params[i].name, idx, var, c10::NumberType::get()); } return scalar_slow(args[i]); } at::Scalar PythonArgs::scalar_slow(PyObject* arg) { // Zero-dim tensors are converted to Scalars as-is. Note this doesn't // currently handle most NumPy scalar types except np.float64. if (THPVariable_Check(arg)) { return THPVariable_Unpack(arg).item(); } if (THPUtils_checkLong(arg)) { return at::Scalar(static_cast(THPUtils_unpackLong(arg))); } if (PyBool_Check(arg)) { return at::Scalar(THPUtils_unpackBool(arg)); } if (PyComplex_Check(arg)) { return at::Scalar(THPUtils_unpackComplexDouble(arg)); } if (torch::is_symint(arg)) { return at::Scalar(py::cast(arg)); } if (torch::is_symfloat(arg)) { return at::Scalar(py::cast(arg)); } return at::Scalar(THPUtils_unpackDouble(arg)); } } // namespace torch