mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Most typos were introduced in #131077 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131890 Approved by: https://github.com/Skylion007
		
			
				
	
	
		
			394 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			394 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #pragma once
 | |
| 
 | |
| #include <exception>
 | |
| #include <memory>
 | |
| #include <string>
 | |
| #include <system_error>
 | |
| 
 | |
| #include <ATen/detail/FunctionTraits.h>
 | |
| #include <c10/util/C++17.h>
 | |
| #include <c10/util/Exception.h>
 | |
| #include <c10/util/StringUtil.h>
 | |
| #include <pybind11/pybind11.h>
 | |
| #include <torch/csrc/Export.h>
 | |
| #include <torch/csrc/jit/runtime/jit_exception.h>
 | |
| #include <torch/csrc/utils/cpp_stacktraces.h>
 | |
| #include <torch/csrc/utils/pybind.h>
 | |
| 
 | |
| #if defined(USE_DISTRIBUTED)
 | |
| #include <torch/csrc/distributed/c10d/exception.h>
 | |
| #endif
 | |
| 
 | |
| inline void PyErr_SetString(PyObject* type, const std::string& message) {
 | |
|   PyErr_SetString(type, message.c_str());
 | |
| }
 | |
| /// NOTE [ Conversion Cpp Python Warning ]
 | |
| /// The warning handler cannot set python warnings immediately
 | |
| /// as it requires acquiring the GIL (potential deadlock)
 | |
| /// and would need to cleanly exit if the warning raised a
 | |
| /// python error. To solve this, we buffer the warnings and
 | |
| /// process them when we go back to python.
 | |
| /// This requires the two try/catch blocks below to handle the
 | |
| /// following cases:
 | |
| ///   - If there is no Error raised in the inner try/catch, the
 | |
| ///     buffered warnings are processed as python warnings.
 | |
| ///     - If they don't raise an error, the function process with the
 | |
| ///       original return code.
 | |
| ///     - If any of them raise an error, the error is set (PyErr_*) and
 | |
| ///       the destructor will raise a cpp exception python_error() that
 | |
| ///       will be caught by the outer try/catch that will be able to change
 | |
| ///       the return value of the function to reflect the error.
 | |
| ///   - If an Error was raised in the inner try/catch, the inner try/catch
 | |
| ///     must set the python error. The buffered warnings are then
 | |
| ///     processed as cpp warnings as we cannot predict before hand
 | |
| ///     whether a python warning will raise an error or not and we
 | |
| ///     cannot handle two errors at the same time.
 | |
| /// This advanced handler will only be used in the current thread.
 | |
| /// If any other thread is used, warnings will be processed as
 | |
| /// cpp warnings.
 | |
| #define HANDLE_TH_ERRORS                              \
 | |
|   try {                                               \
 | |
|     torch::PyWarningHandler __enforce_warning_buffer; \
 | |
|     try {
 | |
| #define _CATCH_GENERIC_ERROR(ErrorType, PythonErrorType, retstmnt) \
 | |
|   catch (const c10::ErrorType& e) {                                \
 | |
|     auto msg = torch::get_cpp_stacktraces_enabled()                \
 | |
|         ? e.what()                                                 \
 | |
|         : e.what_without_backtrace();                              \
 | |
|     PyErr_SetString(PythonErrorType, torch::processErrorMsg(msg)); \
 | |
|     retstmnt;                                                      \
 | |
|   }
 | |
| 
 | |
| // Only catch torch-specific exceptions
 | |
| #define CATCH_CORE_ERRORS(retstmnt)                                           \
 | |
|   catch (python_error & e) {                                                  \
 | |
|     e.restore();                                                              \
 | |
|     retstmnt;                                                                 \
 | |
|   }                                                                           \
 | |
|   catch (py::error_already_set & e) {                                         \
 | |
|     e.restore();                                                              \
 | |
|     retstmnt;                                                                 \
 | |
|   }                                                                           \
 | |
|   _CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt)                \
 | |
|   _CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt)                \
 | |
|   _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt)                  \
 | |
|   _CATCH_GENERIC_ERROR(                                                       \
 | |
|       NotImplementedError, PyExc_NotImplementedError, retstmnt)               \
 | |
|   _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt)       \
 | |
|   _CATCH_GENERIC_ERROR(                                                       \
 | |
|       OutOfMemoryError, THPException_OutOfMemoryError, retstmnt)              \
 | |
|   _CATCH_GENERIC_ERROR(                                                       \
 | |
|       DistBackendError, THPException_DistBackendError, retstmnt)              \
 | |
|   _CATCH_GENERIC_ERROR(                                                       \
 | |
|       DistNetworkError, THPException_DistNetworkError, retstmnt)              \
 | |
|   _CATCH_GENERIC_ERROR(DistStoreError, THPException_DistStoreError, retstmnt) \
 | |
|   _CATCH_GENERIC_ERROR(DistError, THPException_DistError, retstmnt)           \
 | |
|   _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt)                   \
 | |
|   catch (torch::PyTorchError & e) {                                           \
 | |
|     auto msg = torch::processErrorMsg(e.what());                              \
 | |
|     PyErr_SetString(e.python_type(), msg);                                    \
 | |
|     retstmnt;                                                                 \
 | |
|   }
 | |
| 
 | |
| #define CATCH_TH_ERRORS(retstmnt) CATCH_CORE_ERRORS(retstmnt)
 | |
| 
 | |
| #define CATCH_ALL_ERRORS(retstmnt)               \
 | |
|   CATCH_TH_ERRORS(retstmnt)                      \
 | |
|   catch (const std::exception& e) {              \
 | |
|     auto msg = torch::processErrorMsg(e.what()); \
 | |
|     PyErr_SetString(PyExc_RuntimeError, msg);    \
 | |
|     retstmnt;                                    \
 | |
|   }
 | |
| 
 | |
| #define END_HANDLE_TH_ERRORS_PYBIND                                 \
 | |
|   }                                                                 \
 | |
|   catch (...) {                                                     \
 | |
|     __enforce_warning_buffer.set_in_exception();                    \
 | |
|     throw;                                                          \
 | |
|   }                                                                 \
 | |
|   }                                                                 \
 | |
|   catch (py::error_already_set & e) {                               \
 | |
|     throw;                                                          \
 | |
|   }                                                                 \
 | |
|   catch (py::builtin_exception & e) {                               \
 | |
|     throw;                                                          \
 | |
|   }                                                                 \
 | |
|   catch (torch::jit::JITException & e) {                            \
 | |
|     throw;                                                          \
 | |
|   }                                                                 \
 | |
|   catch (const std::exception& e) {                                 \
 | |
|     torch::translate_exception_to_python(std::current_exception()); \
 | |
|     throw py::error_already_set();                                  \
 | |
|   }
 | |
| 
 | |
| #define END_HANDLE_TH_ERRORS_RET(retval)                            \
 | |
|   }                                                                 \
 | |
|   catch (...) {                                                     \
 | |
|     __enforce_warning_buffer.set_in_exception();                    \
 | |
|     throw;                                                          \
 | |
|   }                                                                 \
 | |
|   }                                                                 \
 | |
|   catch (const std::exception& e) {                                 \
 | |
|     torch::translate_exception_to_python(std::current_exception()); \
 | |
|     return retval;                                                  \
 | |
|   }
 | |
| 
 | |
| #define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
 | |
| 
 | |
| extern PyObject *THPException_FatalError, *THPException_LinAlgError,
 | |
|     *THPException_OutOfMemoryError, *THPException_DistError,
 | |
|     *THPException_DistBackendError, *THPException_DistNetworkError,
 | |
|     *THPException_DistStoreError;
 | |
| 
 | |
| // Throwing this exception means that the python error flags have been already
 | |
| // set and control should be immediately returned to the interpreter.
 | |
| struct python_error : public std::exception {
 | |
|   python_error() = default;
 | |
| 
 | |
|   python_error(const python_error& other)
 | |
|       : type(other.type),
 | |
|         value(other.value),
 | |
|         traceback(other.traceback),
 | |
|         message(other.message) {
 | |
|     pybind11::gil_scoped_acquire gil;
 | |
|     Py_XINCREF(type);
 | |
|     Py_XINCREF(value);
 | |
|     Py_XINCREF(traceback);
 | |
|   }
 | |
| 
 | |
|   python_error(python_error&& other) noexcept
 | |
|       : type(other.type),
 | |
|         value(other.value),
 | |
|         traceback(other.traceback),
 | |
|         message(std::move(other.message)) {
 | |
|     other.type = nullptr;
 | |
|     other.value = nullptr;
 | |
|     other.traceback = nullptr;
 | |
|   }
 | |
| 
 | |
|   python_error& operator=(const python_error& other) = delete;
 | |
|   python_error& operator=(python_error&& other) = delete;
 | |
| 
 | |
|   // NOLINTNEXTLINE(bugprone-exception-escape)
 | |
|   ~python_error() override {
 | |
|     if (type || value || traceback) {
 | |
|       pybind11::gil_scoped_acquire gil;
 | |
|       Py_XDECREF(type);
 | |
|       Py_XDECREF(value);
 | |
|       Py_XDECREF(traceback);
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   const char* what() const noexcept override {
 | |
|     return message.c_str();
 | |
|   }
 | |
| 
 | |
|   void build_message() {
 | |
|     // Ensure we have the GIL.
 | |
|     pybind11::gil_scoped_acquire gil;
 | |
| 
 | |
|     // No errors should be set when we enter the function since PyErr_Fetch
 | |
|     // clears the error indicator.
 | |
|     TORCH_INTERNAL_ASSERT(!PyErr_Occurred());
 | |
| 
 | |
|     // Default message.
 | |
|     message = "python_error";
 | |
| 
 | |
|     // Try to retrieve the error message from the value.
 | |
|     if (value != nullptr) {
 | |
|       // Reference count should not be zero.
 | |
|       TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0);
 | |
| 
 | |
|       PyObject* pyStr = PyObject_Str(value);
 | |
|       if (pyStr != nullptr) {
 | |
|         PyObject* encodedString =
 | |
|             PyUnicode_AsEncodedString(pyStr, "utf-8", "strict");
 | |
|         if (encodedString != nullptr) {
 | |
|           char* bytes = PyBytes_AS_STRING(encodedString);
 | |
|           if (bytes != nullptr) {
 | |
|             // Set the message.
 | |
|             message = std::string(bytes);
 | |
|           }
 | |
|           Py_XDECREF(encodedString);
 | |
|         }
 | |
|         Py_XDECREF(pyStr);
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     // Clear any errors since we don't want to propagate errors for functions
 | |
|     // that are trying to build a string for the error message.
 | |
|     PyErr_Clear();
 | |
|   }
 | |
| 
 | |
|   /** Saves the exception so that it can be re-thrown on a different thread */
 | |
|   inline void persist() {
 | |
|     if (type)
 | |
|       return; // Don't overwrite exceptions
 | |
|     // PyErr_Fetch overwrites the pointers
 | |
|     pybind11::gil_scoped_acquire gil;
 | |
|     Py_XDECREF(type);
 | |
|     Py_XDECREF(value);
 | |
|     Py_XDECREF(traceback);
 | |
|     PyErr_Fetch(&type, &value, &traceback);
 | |
|     build_message();
 | |
|   }
 | |
| 
 | |
|   /** Sets the current Python error from this exception */
 | |
|   inline void restore() {
 | |
|     if (!type)
 | |
|       return;
 | |
|     // PyErr_Restore steals references
 | |
|     pybind11::gil_scoped_acquire gil;
 | |
|     Py_XINCREF(type);
 | |
|     Py_XINCREF(value);
 | |
|     Py_XINCREF(traceback);
 | |
|     PyErr_Restore(type, value, traceback);
 | |
|   }
 | |
| 
 | |
|   PyObject* type{nullptr};
 | |
|   PyObject* value{nullptr};
 | |
|   PyObject* traceback{nullptr};
 | |
| 
 | |
|   // Message to return to the user when 'what()' is invoked.
 | |
|   std::string message;
 | |
| };
 | |
| 
 | |
| bool THPException_init(PyObject* module);
 | |
| 
 | |
| namespace torch {
 | |
| 
 | |
| // Set python current exception from a C++ exception
 | |
| TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&);
 | |
| 
 | |
| TORCH_PYTHON_API std::string processErrorMsg(std::string str);
 | |
| 
 | |
| // Abstract base class for exceptions which translate to specific Python types
 | |
| struct PyTorchError : public std::exception {
 | |
|   PyTorchError() = default;
 | |
|   PyTorchError(std::string msg_) : msg(std::move(msg_)) {}
 | |
|   virtual PyObject* python_type() = 0;
 | |
|   const char* what() const noexcept override {
 | |
|     return msg.c_str();
 | |
|   }
 | |
|   std::string msg;
 | |
| };
 | |
| 
 | |
| // Declare a printf-like function on gcc & clang
 | |
| // The compiler can then warn on invalid format specifiers
 | |
| #ifdef __GNUC__
 | |
| #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \
 | |
|   __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX)))
 | |
| #else
 | |
| #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX)
 | |
| #endif
 | |
| 
 | |
| // Translates to Python TypeError
 | |
| struct TypeError : public PyTorchError {
 | |
|   using PyTorchError::PyTorchError;
 | |
|   TORCH_PYTHON_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
 | |
|   PyObject* python_type() override {
 | |
|     return PyExc_TypeError;
 | |
|   }
 | |
| };
 | |
| 
 | |
| // Translates to Python AttributeError
 | |
| struct AttributeError : public PyTorchError {
 | |
|   AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
 | |
|   PyObject* python_type() override {
 | |
|     return PyExc_AttributeError;
 | |
|   }
 | |
| };
 | |
| 
 | |
| // ATen warning handler for Python
 | |
| struct PyWarningHandler {
 | |
|   // Move actual handler into a separate class with a noexcept
 | |
|   // destructor. Otherwise, we need to force all WarningHandler
 | |
|   // subclasses to have a noexcept(false) destructor.
 | |
|   struct InternalHandler : at::WarningHandler {
 | |
|     ~InternalHandler() override = default;
 | |
|     void process(const c10::Warning& warning) override;
 | |
| 
 | |
|     std::vector<c10::Warning> warning_buffer_;
 | |
|   };
 | |
| 
 | |
|  public:
 | |
|   /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
 | |
|   TORCH_PYTHON_API PyWarningHandler() noexcept(true);
 | |
|   // NOLINTNEXTLINE(bugprone-exception-escape)
 | |
|   TORCH_PYTHON_API ~PyWarningHandler() noexcept(false);
 | |
| 
 | |
|   /** Call if an exception has been thrown
 | |
| 
 | |
|    *  Necessary to determine if it is safe to throw from the desctructor since
 | |
|    *  std::uncaught_exception is buggy on some platforms and generally
 | |
|    *  unreliable across dynamic library calls.
 | |
|    */
 | |
|   void set_in_exception() {
 | |
|     in_exception_ = true;
 | |
|   }
 | |
| 
 | |
|  private:
 | |
|   InternalHandler internal_handler_;
 | |
|   at::WarningHandler* prev_handler_;
 | |
|   bool in_exception_;
 | |
| };
 | |
| 
 | |
| namespace detail {
 | |
| 
 | |
| struct noop_gil_scoped_release {
 | |
|   // user-defined constructor (i.e. not defaulted) to avoid
 | |
|   // unused-variable warnings at usage sites of this class
 | |
|   noop_gil_scoped_release() {}
 | |
| };
 | |
| 
 | |
| template <bool release_gil>
 | |
| using conditional_gil_scoped_release = std::conditional_t<
 | |
|     release_gil,
 | |
|     pybind11::gil_scoped_release,
 | |
|     noop_gil_scoped_release>;
 | |
| 
 | |
| template <typename Func, size_t i>
 | |
| using Arg = typename invoke_traits<Func>::template arg<i>::type;
 | |
| 
 | |
| template <typename Func, size_t... Is, bool release_gil>
 | |
| auto wrap_pybind_function_impl_(
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
 | |
|     Func&& f,
 | |
|     std::index_sequence<Is...>,
 | |
|     std::bool_constant<release_gil>) {
 | |
|   namespace py = pybind11;
 | |
| 
 | |
|   // f=f is needed to handle function references on older compilers
 | |
|   return [f = std::forward<Func>(f)](Arg<Func, Is>... args) {
 | |
|     HANDLE_TH_ERRORS
 | |
|     conditional_gil_scoped_release<release_gil> no_gil;
 | |
|     return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
 | |
|     END_HANDLE_TH_ERRORS_PYBIND
 | |
|   };
 | |
| }
 | |
| } // namespace detail
 | |
| 
 | |
| // Wrap a function with TH error and warning handling.
 | |
| // Returns a function object suitable for registering with pybind11.
 | |
| template <typename Func>
 | |
| auto wrap_pybind_function(Func&& f) {
 | |
|   using traits = invoke_traits<Func>;
 | |
|   return torch::detail::wrap_pybind_function_impl_(
 | |
|       std::forward<Func>(f),
 | |
|       std::make_index_sequence<traits::arity>{},
 | |
|       std::false_type{});
 | |
| }
 | |
| 
 | |
| // Wrap a function with TH error, warning handling and releases the GIL.
 | |
| // Returns a function object suitable for registering with pybind11.
 | |
| template <typename Func>
 | |
| auto wrap_pybind_function_no_gil(Func&& f) {
 | |
|   using traits = invoke_traits<Func>;
 | |
|   return torch::detail::wrap_pybind_function_impl_(
 | |
|       std::forward<Func>(f),
 | |
|       std::make_index_sequence<traits::arity>{},
 | |
|       std::true_type{});
 | |
| }
 | |
| 
 | |
| } // namespace torch
 |