mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	This PR removes unused header inclusion in C++ files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165763 Approved by: https://github.com/Skylion007
		
			
				
	
	
		
			196 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			196 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <c10/core/DispatchKey.h>
 | |
| #include <c10/core/impl/LocalDispatchKeySet.h>
 | |
| #include <c10/core/impl/TorchDispatchModeTLS.h>
 | |
| #include <c10/util/irange.h>
 | |
| 
 | |
| #include <utility>
 | |
| 
 | |
| namespace c10::impl {
 | |
| 
 | |
| thread_local static TorchDispatchModeTLS torchDispatchModeState;
 | |
| 
 | |
| bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) {
 | |
|   if (!torchDispatchModeState.stack_.empty())
 | |
|     return true;
 | |
|   if (!skip_infra_modes) {
 | |
|     for (const auto i : c10::irange(
 | |
|              static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
 | |
|       if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
 | |
|         return true;
 | |
|       }
 | |
|     }
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| void TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
 | |
|     std::shared_ptr<PyObject_TorchDispatchMode> mode) {
 | |
|   if (!any_modes_set()) {
 | |
|     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
 | |
|     c10::impl::tls_set_dispatch_key_included(
 | |
|         DispatchKey::PythonTLSSnapshot, true);
 | |
|   }
 | |
|   torchDispatchModeState.stack_.push_back(std::move(mode));
 | |
| }
 | |
| 
 | |
| const std::shared_ptr<PyObject_TorchDispatchMode> TorchDispatchModeTLS::
 | |
|     pop_stack() {
 | |
|   std::shared_ptr<PyObject_TorchDispatchMode> out;
 | |
|   if (!torchDispatchModeState.stack_.empty()) {
 | |
|     out = torchDispatchModeState.stack_.back();
 | |
|     torchDispatchModeState.stack_.pop_back();
 | |
|   } else {
 | |
|     for (int64_t i =
 | |
|              static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
 | |
|          i >= 0;
 | |
|          --i) {
 | |
|       if (torchDispatchModeState.infra_modes_[i].has_value()) {
 | |
|         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
 | |
|         out = std::move(torchDispatchModeState.infra_modes_[i].value());
 | |
|         torchDispatchModeState.infra_modes_[i] = std::nullopt;
 | |
|         break;
 | |
|       }
 | |
|     }
 | |
|   }
 | |
|   TORCH_CHECK(out, "trying to pop from empty mode stack");
 | |
|   if (!any_modes_set()) {
 | |
|     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
 | |
|     c10::impl::tls_set_dispatch_key_included(
 | |
|         DispatchKey::PythonTLSSnapshot, false);
 | |
|   }
 | |
|   return out;
 | |
| }
 | |
| const std::
 | |
|     tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey>
 | |
|     TorchDispatchModeTLS::pop_highest_infra_mode() {
 | |
|   for (int64_t i = static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
 | |
|        i >= 0;
 | |
|        --i) {
 | |
|     if (torchDispatchModeState.infra_modes_[i].has_value()) {
 | |
|       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
 | |
|       auto out_mode = torchDispatchModeState.infra_modes_[i].value();
 | |
|       torchDispatchModeState.infra_modes_[i] = std::nullopt;
 | |
|       if (!any_modes_set()) {
 | |
|         c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
 | |
|         c10::impl::tls_set_dispatch_key_included(
 | |
|             DispatchKey::PythonTLSSnapshot, false);
 | |
|       }
 | |
|       return std::make_tuple(
 | |
|           std::move(out_mode), static_cast<TorchDispatchModeKey>(i));
 | |
|     }
 | |
|   }
 | |
|   TORCH_CHECK(
 | |
|       false, "Called pop_highest_infra_mode, but no infra modes were active.")
 | |
| }
 | |
| 
 | |
| const std::shared_ptr<PyObject_TorchDispatchMode>& TorchDispatchModeTLS::
 | |
|     get_stack_at(int64_t idx) {
 | |
|   TORCH_CHECK(idx < stack_len(), "Tried to get stack at idx that's too big");
 | |
|   // Our "logical" stack includes both:
 | |
|   // - any user modes (the entire torchDispatchModeState.stack_)
 | |
|   // - any infra modes (members of torchDispatchModeState.infra_modes_ that are
 | |
|   // not None)
 | |
| 
 | |
|   // idx == 0 means the "bottom" of the stack, which starts with any infra
 | |
|   // modes (iterating from lowest-priority to highest-priority).
 | |
|   auto curr_idx = idx;
 | |
|   for (const auto i :
 | |
|        c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
 | |
|     if (torchDispatchModeState.infra_modes_[i].has_value()) {
 | |
|       if (curr_idx == 0) {
 | |
|         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
 | |
|         return torchDispatchModeState.infra_modes_[i].value();
 | |
|       }
 | |
|       curr_idx -= 1;
 | |
|     }
 | |
|   }
 | |
|   // At this point, we're guaranteed that curr_idx < stack_.size()
 | |
|   return torchDispatchModeState.stack_[curr_idx];
 | |
| }
 | |
| 
 | |
| int64_t TorchDispatchModeTLS::stack_len() {
 | |
|   auto stack_len = static_cast<int64_t>(torchDispatchModeState.stack_.size());
 | |
|   int64_t infra_modes_len = 0;
 | |
|   for (const auto i :
 | |
|        c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
 | |
|     if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
 | |
|       infra_modes_len += 1;
 | |
|     }
 | |
|   }
 | |
|   return stack_len + infra_modes_len;
 | |
| }
 | |
| 
 | |
| const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
 | |
| TorchDispatchModeTLS::get_mode(TorchDispatchModeKey mode_key) {
 | |
|   return torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
 | |
| }
 | |
| 
 | |
| void TorchDispatchModeTLS::set_mode(
 | |
|     const std::shared_ptr<PyObject_TorchDispatchMode>& mode,
 | |
|     TorchDispatchModeKey mode_key) {
 | |
|   TORCH_CHECK(
 | |
|       torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] ==
 | |
|           std::nullopt,
 | |
|       "trying to set the current ",
 | |
|       to_string(mode_key),
 | |
|       ", but one already exists");
 | |
| 
 | |
|   if (!any_modes_set()) {
 | |
|     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
 | |
|     c10::impl::tls_set_dispatch_key_included(
 | |
|         DispatchKey::PythonTLSSnapshot, true);
 | |
|   }
 | |
| 
 | |
|   torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] = mode;
 | |
| }
 | |
| 
 | |
| const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
 | |
| TorchDispatchModeTLS::unset_mode(TorchDispatchModeKey mode_key) {
 | |
|   auto out = torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
 | |
|   torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] =
 | |
|       std::nullopt;
 | |
|   if (out.has_value() && !any_modes_set()) {
 | |
|     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
 | |
|     c10::impl::tls_set_dispatch_key_included(
 | |
|         DispatchKey::PythonTLSSnapshot, false);
 | |
|   }
 | |
|   return out;
 | |
| }
 | |
| 
 | |
| const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
 | |
|   return torchDispatchModeState;
 | |
| }
 | |
| 
 | |
| void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {
 | |
|   torchDispatchModeState = std::move(state);
 | |
|   if (!any_modes_set()) {
 | |
|     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
 | |
|     c10::impl::tls_set_dispatch_key_included(
 | |
|         DispatchKey::PythonTLSSnapshot, false);
 | |
|   } else {
 | |
|     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
 | |
|     c10::impl::tls_set_dispatch_key_included(
 | |
|         DispatchKey::PythonTLSSnapshot, true);
 | |
|   }
 | |
| }
 | |
| 
 | |
| // UTIL
 | |
| 
 | |
| bool dispatch_mode_enabled() {
 | |
|   return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) &&
 | |
|       TorchDispatchModeTLS::any_modes_set();
 | |
| }
 | |
| 
 | |
| std::string to_string(TorchDispatchModeKey mode_key) {
 | |
|   switch (mode_key) {
 | |
|     case TorchDispatchModeKey::PROXY:
 | |
|       return "ProxyTorchDispatchMode";
 | |
|     case TorchDispatchModeKey::FAKE:
 | |
|       return "FakeTensorMode";
 | |
|     default:
 | |
|       return "UNKNOWN_MODE";
 | |
|   }
 | |
| }
 | |
| 
 | |
| } // namespace c10::impl
 |