mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Python Dispatcher integration with C++ dispatcher (#84826)
Signed-off-by: Edward Z. Yang <ezyangfb.com> From @ezyang's original PR: There are a number of situations where we have non-backend kernels (e.g., CompositeImplicitAutograd, batching rules) which we would like to port to Python, but we have no way to integrate these ports with the overall system while using preexisting C++ registrations otherwise. This PR changes that by introducing a Python dispatcher (which can have its own kernels directly in Python), which can be interpose over ordinary C++ dispatch. The ingredients: We introduce a new PythonDispatcher dispatch key, that has the same tenor as FuncTorchDynamicLayerFrontMode: it works by getting triggered before every other dispatch key in the dispatch key, and shunting to a Python implementation The Python dispatcher is a per-interpreter global object that is enabled/disabled via the guard EnablePythonDispatcher/DisablePythonDispatcher. We don't make it compositional as I have no idea what a compositional version of this feature would look like. Because it is global, we don't need to memory manage it and so I use a simpler SafePyHandle (newly added) to control access to this pointer from non-Python C++. Like __torch_dispatch__, we use PyInterpreter to get to the Python interpreter to handle the dispatch. I need to reimplement dispatch table computation logic in Python. To do this, I expose a lot more helper functions for doing computations on alias dispatch keys and similar. I also improve the pybind11 handling for DispatchKey so that you can either accept the pybind11 bound enum or a string; this simplifies our binding code. See https://github.com/pybind/pybind11/issues/483#issuecomment-1237418106 for how this works; the technique is generally useful. I need to be able to call backend fallbacks. I do this by permitting you to call at a dispatch key which doesn't have a kernel for the operator; if the kernel doesn't exist, we check the backend fallback table instead. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84826 Approved by: https://github.com/ezyang
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							44c30c5d1c
						
					
				
				
					commit
					35f6a69191
				
			| @ -14,6 +14,7 @@ ThreadLocalState::ThreadLocalState() | ||||
|       debug_info_(c10::ThreadLocalDebugInfo::current()), | ||||
|       functorch_tls_(functorch::getCopyOfFuncTorchTLS()), | ||||
|       autograd_tls_(c10::AutogradState::get_tls_state()), | ||||
|       python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()), | ||||
|       python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) { | ||||
|   rf_tls_ = at::get_record_function_tls_(); | ||||
|  | ||||
| @ -41,6 +42,8 @@ void ThreadLocalState::setThreadLocalState( | ||||
|  | ||||
|   at::SavedTensorDefaultHooks::set_stack(state.saved_tensors_default_hooks_); | ||||
|  | ||||
|   c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_); | ||||
|  | ||||
|   c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_); | ||||
|  | ||||
|   c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_); | ||||
|  | ||||
| @ -10,6 +10,7 @@ | ||||
| #include <ATen/FuncTorchTLS.h> | ||||
| #include <ATen/PythonTorchFunctionTLS.h> | ||||
| #include <ATen/record_function.h> | ||||
| #include <c10/core/impl/PythonDispatcherTLS.h> | ||||
| #include <c10/core/impl/TorchDispatchModeTLS.h> | ||||
|  | ||||
| namespace at { | ||||
| @ -57,6 +58,9 @@ class TORCH_API ThreadLocalState { | ||||
|   // TLS for enable_torch_dispatch_mode | ||||
|   std::shared_ptr<SafePyObject> torch_dispatch_mode_state_; | ||||
|  | ||||
|   // TLS for enable_python_dispatcher | ||||
|   SafePyHandle python_dispatcher_state_; | ||||
|  | ||||
|   // TLS for __torch_function__ (mode and disable_torch_function) | ||||
|   at::impl::PythonTorchFunctionTLS python_torch_function_state_; | ||||
|  | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| #include <c10/core/impl/TorchDispatchModeTLS.h> | ||||
| #include <c10/core/impl/PythonDispatcherTLS.h> | ||||
| #include <ATen/core/PythonFallbackKernel.h> | ||||
| #include <c10/core/SafePyObject.h> | ||||
|  | ||||
| @ -87,6 +88,12 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { | ||||
|   TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)"); | ||||
| } | ||||
|  | ||||
| void pythonDispatcherFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { | ||||
|   auto state = c10::impl::PythonDispatcherTLS::get_state(); | ||||
|   TORCH_INTERNAL_ASSERT(state, "Hit PythonDispatcher dispatch key but PythonDispatcherTLS was not set"); | ||||
|   state.pyinterpreter()->python_dispatcher(op, dispatch_keys.remove(c10::DispatchKey::PythonDispatcher), stack); | ||||
| } | ||||
|  | ||||
| void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { | ||||
|   // It is ok for the tls to be already set here. | ||||
|   // It means that there are multiple calls into the dispatcher not originating from python code. | ||||
| @ -134,6 +141,10 @@ TORCH_LIBRARY_IMPL(_, Python, m) { | ||||
|   m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>()); | ||||
| } | ||||
|  | ||||
| TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) { | ||||
|   m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonDispatcherFallback>()); | ||||
| } | ||||
|  | ||||
| TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) { | ||||
|   m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>()); | ||||
| } | ||||
|  | ||||
| @ -168,6 +168,12 @@ public: | ||||
|   // See Note [Plumbing Keys Through The Dispatcher] | ||||
|   void redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const; | ||||
|  | ||||
|   bool hasBackendFallbackForDispatchKey(DispatchKey dk) { | ||||
|     auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk); | ||||
|     if (dispatch_ix < 0) return false; | ||||
|     return backendFallbackKernels_[dispatch_ix].kernel.isValid(); | ||||
|   } | ||||
|  | ||||
|  | ||||
|   // ------------------------------------------------------------------------ | ||||
|   // | ||||
| @ -333,6 +339,10 @@ public: | ||||
|     return operatorDef_->op.hasKernelForDispatchKey(k); | ||||
|   } | ||||
|  | ||||
|   bool hasKernelForAnyDispatchKey(DispatchKeySet k) const { | ||||
|     return operatorDef_->op.hasKernelForAnyDispatchKey(k); | ||||
|   } | ||||
|  | ||||
|   bool hasComputedKernelForDispatchKey(DispatchKey k) const { | ||||
|     return operatorDef_->op.hasComputedKernelForDispatchKey(k); | ||||
|   } | ||||
| @ -635,11 +645,18 @@ inline void Dispatcher::callBoxedForDispatchKey(const OperatorHandle& op, Dispat | ||||
|   // We still compute this as we're obligated to pass it on to the internal | ||||
|   // kernel, if it is a boxed fallback | ||||
|   auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); | ||||
|   const auto& kernel = entry.kernelForDispatchKey(dk); | ||||
|   const auto& kernel = ([&]() { | ||||
|     if (op.hasKernelForDispatchKey(dk)) { | ||||
|       return entry.kernelForDispatchKey(dk); | ||||
|     } else { | ||||
|       auto idx = getDispatchTableIndexForDispatchKey(dk); | ||||
|       TORCH_INTERNAL_ASSERT(idx >= 0); | ||||
|       return backendFallbackKernels_[idx].kernel; | ||||
|     } | ||||
|   })(); | ||||
|   kernel.callBoxed(op, dispatchKeySet, stack); | ||||
| } | ||||
|  | ||||
|  | ||||
| inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const { | ||||
|   // note: this doesn't need the mutex because write operations on the list keep iterators intact. | ||||
|   const auto& entry = op.operatorDef_->op; | ||||
|  | ||||
| @ -172,6 +172,9 @@ const char* toString(DispatchKey t) { | ||||
|     case DispatchKey::TESTING_ONLY_GenericMode: | ||||
|       return "TESTING_ONLY_GenericMode"; | ||||
|  | ||||
|     case DispatchKey::PythonDispatcher: | ||||
|       return "PythonDispatcher"; | ||||
|  | ||||
|       // Aliases | ||||
|  | ||||
|     case DispatchKey::Autograd: | ||||
| @ -283,6 +286,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { | ||||
|       {"TESTING_ONLY_GenericWrapper", | ||||
|        c10::DispatchKey::TESTING_ONLY_GenericWrapper}, | ||||
|       {"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode}, | ||||
|       {"PythonDispatcher", c10::DispatchKey::PythonDispatcher}, | ||||
|  | ||||
|       {"CPU", c10::DispatchKey::CPU}, | ||||
|       {"CUDA", c10::DispatchKey::CUDA}, | ||||
|  | ||||
| @ -401,6 +401,10 @@ enum class DispatchKey : uint16_t { | ||||
|   // for a usage example | ||||
|   TESTING_ONLY_GenericMode, | ||||
|  | ||||
|   // This is a bypass that allows you to skip running the C++ dispatcher | ||||
|   // entirely | ||||
|   PythonDispatcher, | ||||
|  | ||||
|   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // | ||||
|   EndOfFunctionalityKeys, // End of functionality keys. | ||||
|  | ||||
|  | ||||
| @ -172,7 +172,9 @@ class DispatchKeySet final { | ||||
|             (1ULL | ||||
|              << (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) - | ||||
|                  1)) - | ||||
|             1) {} | ||||
|             1) { | ||||
|     *this = add(DispatchKey::PythonDispatcher); | ||||
|   } | ||||
|  | ||||
|   // Public version of DispatchKeySet(uint64_t) API; external users | ||||
|   // must be explicit when they do this! | ||||
|  | ||||
| @ -8,4 +8,9 @@ PyObject* SafePyObject::ptr(const c10::impl::PyInterpreter* interpreter) const { | ||||
|   return data_; | ||||
| } | ||||
|  | ||||
| PyObject* SafePyHandle::ptr(const c10::impl::PyInterpreter* interpreter) const { | ||||
|   TORCH_INTERNAL_ASSERT(interpreter == pyinterpreter_); | ||||
|   return data_; | ||||
| } | ||||
|  | ||||
| } // namespace c10 | ||||
|  | ||||
| @ -42,4 +42,29 @@ struct C10_API SafePyObject { | ||||
|   c10::impl::PyInterpreter* pyinterpreter_; | ||||
| }; | ||||
|  | ||||
| // Like SafePyObject, but non-owning.  Good for references to global PyObjects | ||||
| // that will be leaked on interpreter exit.  You get a copy constructor/assign | ||||
| // this way. | ||||
| struct C10_API SafePyHandle { | ||||
|   SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {} | ||||
|   SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter) | ||||
|       : data_(data), pyinterpreter_(pyinterpreter) {} | ||||
|  | ||||
|   c10::impl::PyInterpreter& pyinterpreter() const { | ||||
|     return *pyinterpreter_; | ||||
|   } | ||||
|   PyObject* ptr(const c10::impl::PyInterpreter*) const; | ||||
|   void reset() { | ||||
|     data_ = nullptr; | ||||
|     pyinterpreter_ = nullptr; | ||||
|   } | ||||
|   operator bool() { | ||||
|     return data_; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   PyObject* data_; | ||||
|   c10::impl::PyInterpreter* pyinterpreter_; | ||||
| }; | ||||
|  | ||||
| } // namespace c10 | ||||
|  | ||||
| @ -27,6 +27,13 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { | ||||
|     PANIC(dispatch); | ||||
|   } | ||||
|  | ||||
|   void python_dispatcher( | ||||
|       const c10::OperatorHandle& op, | ||||
|       c10::DispatchKeySet, | ||||
|       torch::jit::Stack* stack) const override { | ||||
|     PANIC(python_dispatcher); | ||||
|   } | ||||
|  | ||||
|   bool is_contiguous(const TensorImpl* self) const override { | ||||
|     PANIC(is_contiguous); | ||||
|   } | ||||
|  | ||||
| @ -140,6 +140,12 @@ struct C10_API PyInterpreterVTable { | ||||
|   virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) | ||||
|       const = 0; | ||||
|  | ||||
|   // Invoke the Python dispatcher to handle this call | ||||
|   virtual void python_dispatcher( | ||||
|       const c10::OperatorHandle& op, | ||||
|       c10::DispatchKeySet, | ||||
|       torch::jit::Stack* stack) const = 0; | ||||
|  | ||||
|   virtual bool is_contiguous(const TensorImpl* self) const = 0; | ||||
|   virtual c10::Device device(const TensorImpl* self) const = 0; | ||||
|   virtual int64_t dim(const TensorImpl* self) const = 0; | ||||
|  | ||||
							
								
								
									
										32
									
								
								c10/core/impl/PythonDispatcherTLS.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								c10/core/impl/PythonDispatcherTLS.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,32 @@ | ||||
| #include <c10/core/DispatchKeySet.h> | ||||
| #include <c10/core/SafePyObject.h> | ||||
| #include <c10/core/impl/LocalDispatchKeySet.h> | ||||
| #include <c10/core/impl/PythonDispatcherTLS.h> | ||||
|  | ||||
| namespace c10 { | ||||
| namespace impl { | ||||
|  | ||||
| thread_local SafePyHandle pythonDispatcherState; | ||||
|  | ||||
| void PythonDispatcherTLS::set_state(SafePyHandle state) { | ||||
|   if (state) { | ||||
|     c10::impl::tls_set_dispatch_key_included( | ||||
|         DispatchKey::PythonDispatcher, true); | ||||
|   } else { | ||||
|     PythonDispatcherTLS::reset_state(); | ||||
|   } | ||||
|   pythonDispatcherState = state; | ||||
| } | ||||
|  | ||||
| SafePyHandle PythonDispatcherTLS::get_state() { | ||||
|   return pythonDispatcherState; | ||||
| } | ||||
|  | ||||
| void PythonDispatcherTLS::reset_state() { | ||||
|   pythonDispatcherState.reset(); | ||||
|   c10::impl::tls_set_dispatch_key_included( | ||||
|       DispatchKey::PythonDispatcher, false); | ||||
| } | ||||
|  | ||||
| } // namespace impl | ||||
| } // namespace c10 | ||||
							
								
								
									
										27
									
								
								c10/core/impl/PythonDispatcherTLS.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								c10/core/impl/PythonDispatcherTLS.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <c10/core/SafePyObject.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/Optional.h> | ||||
|  | ||||
| namespace c10 { | ||||
| namespace impl { | ||||
|  | ||||
| struct C10_API PythonDispatcherTLS { | ||||
|   static void set_state(SafePyHandle state); | ||||
|   static SafePyHandle get_state(); | ||||
|   static void reset_state(); | ||||
| }; | ||||
|  | ||||
| struct C10_API DisablePythonDispatcher { | ||||
|   DisablePythonDispatcher() : old_(PythonDispatcherTLS::get_state()) { | ||||
|     PythonDispatcherTLS::set_state({}); | ||||
|   } | ||||
|   ~DisablePythonDispatcher() { | ||||
|     PythonDispatcherTLS::set_state(old_); | ||||
|   } | ||||
|   c10::SafePyHandle old_; | ||||
| }; | ||||
|  | ||||
| } // namespace impl | ||||
| } // namespace c10 | ||||
| @ -1,11 +1,12 @@ | ||||
| import torch | ||||
| from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard | ||||
| from functorch.experimental.ops import PyOperator, fallthrough_fn | ||||
| from torch._ops import PyOperator | ||||
| from torch.utils._pytree import tree_flatten | ||||
| from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot | ||||
| import torch.utils._pytree as pytree | ||||
| from torch.utils._python_dispatch import TorchDispatchMode | ||||
| from torch.fx.experimental.proxy_tensor import track_tensor_tree | ||||
| from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode | ||||
|  | ||||
|  | ||||
| """ | ||||
| @ -14,6 +15,9 @@ In order to do this, we need implementations for each of the dispatch keys. | ||||
| """ | ||||
| from contextlib import contextmanager | ||||
|  | ||||
| cond = PyOperator('cond') | ||||
|  | ||||
|  | ||||
| # TODO(voz): Move out somewhere else once other py dispatched ops need it | ||||
| @contextmanager | ||||
| def suspend_mode(mode): | ||||
| @ -98,6 +102,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): | ||||
|     return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) | ||||
|  | ||||
|  | ||||
| @cond.py_impl(DispatchKey.CPU) | ||||
| def cond_dense(pred, true_fn, false_fn, operands): | ||||
|     mode = torch._C._get_torch_dispatch_mode() | ||||
|     assert (mode is None), "Mode should never be enabled for CPU key" | ||||
| @ -107,6 +112,7 @@ def cond_dense(pred, true_fn, false_fn, operands): | ||||
|         return false_fn(*operands) | ||||
|  | ||||
|  | ||||
| @cond.py_impl(DispatchKey.AutogradCPU) | ||||
| def cond_autograd(pred, true_fn, false_fn, *operands): | ||||
|     # TODO: support autograd | ||||
|     flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands]) | ||||
| @ -117,21 +123,16 @@ def cond_autograd(pred, true_fn, false_fn, *operands): | ||||
|     return cond(pred, true_fn, false_fn, *operands) | ||||
|  | ||||
|  | ||||
| def python_fallback(op): | ||||
| @cond.py_impl(ProxyTorchDispatchMode) | ||||
| def inner(pred, true_fn, false_fn, operands): | ||||
|     mode = torch._C._get_torch_dispatch_mode() | ||||
|     assert (mode is not None), "Mode should always be enabled for python fallback key" | ||||
|     with suspend_mode(mode): | ||||
|             res = trace_cond(mode, op, pred, true_fn, false_fn, operands) | ||||
|         res = trace_cond(mode, cond, pred, true_fn, false_fn, operands) | ||||
|     return res | ||||
|  | ||||
|     return inner | ||||
|  | ||||
|  | ||||
| cond = PyOperator('cond') | ||||
| cond.impl(DispatchKey.CPU, cond_dense) | ||||
| cond.impl(DispatchKey.Python, python_fallback(cond)) | ||||
| cond.impl(DispatchKey.PythonTLSSnapshot, fallthrough_fn) | ||||
| cond.impl(DispatchKey.AutogradCPU, cond_autograd) | ||||
| cond.impl(DispatchKey.ADInplaceOrView, fallthrough_fn) | ||||
| cond.impl(DispatchKey.BackendSelect, fallthrough_fn) | ||||
| # TODO(voz): Make this automatic for keys, this is very ugly atm | ||||
| cond.fallthrough(DispatchKey.PythonTLSSnapshot) | ||||
| cond.fallthrough(DispatchKey.ADInplaceOrView) | ||||
| cond.fallthrough(DispatchKey.BackendSelect) | ||||
|  | ||||
| @ -1,36 +1 @@ | ||||
| from torch._dispatch._dispatcher import PyDispatcher, to_flat_tuple, compute_keyset | ||||
| from torch.nn.functional import handle_torch_function | ||||
| from torch.overrides import has_torch_function | ||||
| import torch._C as _C | ||||
|  | ||||
| class PyOperator: | ||||
|     def __init__(self, name): | ||||
|         self.name = name | ||||
|         self.table = {} | ||||
|  | ||||
|         self.__name__ = name | ||||
|  | ||||
|     def impl(self, dispatch_key, fn): | ||||
|         assert dispatch_key not in self.table | ||||
|         if fn is fallthrough_fn: | ||||
|             self.table[dispatch_key] = fn(self, dispatch_key) | ||||
|         else: | ||||
|             self.table[dispatch_key] = fn | ||||
|  | ||||
|     def lookup(self, keyset): | ||||
|         dispatch_key = keyset.highestPriorityTypeId() | ||||
|         return self.table[dispatch_key] | ||||
|  | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         flat_args = to_flat_tuple(args, kwargs) | ||||
|         if has_torch_function(flat_args): | ||||
|             return handle_torch_function(self, flat_args, *args, **kwargs) | ||||
|  | ||||
|         return PyDispatcher.call(self, *args, **kwargs) | ||||
|  | ||||
| def fallthrough_fn(operator, dispatch_key): | ||||
|     def inner(*args, **kwargs): | ||||
|         all_keys_after_current = _C._dispatch_keyset_full_after(dispatch_key) | ||||
|         all_keys_after_current_masked = all_keys_after_current & compute_keyset(args, kwargs) | ||||
|         return PyDispatcher.redispatch(operator, all_keys_after_current_masked, *args, **kwargs) | ||||
|     return inner | ||||
| from torch._ops import PyOperator  # noqa: F401 | ||||
|  | ||||
| @ -19,7 +19,7 @@ class TestControlFlow(TestCase): | ||||
|  | ||||
|  | ||||
| class TestControlFlowTraced(TestCase): | ||||
|     def test_cond_traced(self): | ||||
|     def test_cond_traced_not_nested(self): | ||||
|         def true_fn(x): | ||||
|             return x.sin() | ||||
|  | ||||
| @ -126,14 +126,12 @@ class TestControlFlowTraced(TestCase): | ||||
|         def forward(self, x_1, pred_1, pred2_1): | ||||
|             true_graph_0 = self.true_graph_0 | ||||
|             false_graph_0 = self.false_graph_0 | ||||
|             conditional = functorch_experimental_ops_cond(pred_1, | ||||
|             true_graph_0, false_graph_0, [[x_1]]);  pred_1 = true_graph_0 = false_graph_0 = None | ||||
|             conditional = torch.ops.pyop.cond(pred_1, true_graph_0, false_graph_0, [[x_1]]);  pred_1 = true_graph_0 = false_graph_0 = None  # noqa: B950 | ||||
|             true_graph_1 = self.true_graph_1 | ||||
|             false_graph_1 = self.false_graph_1 | ||||
|             conditional_1 = functorch_experimental_ops_cond(pred2_1, | ||||
|             true_graph_1, false_graph_1, [[x_1, x_1]]);  pred2_1 = true_graph_1 = false_graph_1 = x_1 = None | ||||
|             add_tensor = torch.ops.aten.add.Tensor(conditional, conditional_1);  conditional = conditional_1 = None | ||||
|             return add_tensor | ||||
|             conditional_1 = torch.ops.pyop.cond(pred2_1, true_graph_1, false_graph_1, [[x_1, x_1]]);  pred2_1 = true_graph_1 = false_graph_1 = x_1 = None  # noqa: B950 | ||||
|             add = torch.ops.aten.add.Tensor(conditional, conditional_1);  conditional = conditional_1 = None | ||||
|             return add | ||||
|         """ | ||||
|         code = graph.code | ||||
|         # Normalization hack, cause .code makes some weird whitespace | ||||
| @ -145,8 +143,8 @@ class TestControlFlowTraced(TestCase): | ||||
|         out = """ | ||||
|         def forward(self, flat_args): | ||||
|             flat_args_1, = fx_pytree.tree_flatten_spec([flat_args], self._in_spec) | ||||
|             mul_tensor = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1);  flat_args_1 = None | ||||
|             return pytree.tree_unflatten([mul_tensor], self._out_spec) | ||||
|             mul = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1);  flat_args_1 = None | ||||
|             return pytree.tree_unflatten([mul], self._out_spec) | ||||
|         """ | ||||
|         # Normalization hack, cause .code makes some weird whitespace | ||||
|         code = "".join(code.split()) | ||||
|  | ||||
| @ -327,7 +327,7 @@ class TestPrims(TestCase): | ||||
|         # Check that all call_function nodes are prims | ||||
|         call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) | ||||
|         all_prims_namespace = all( | ||||
|             node.target.name.startswith("prims") for node in call_function_nodes | ||||
|             node.target.name().startswith("prims") for node in call_function_nodes | ||||
|         ) | ||||
|         self.assertTrue(all_prims_namespace) | ||||
|  | ||||
|  | ||||
| @ -952,7 +952,6 @@ symbolic_tensor_failures = { | ||||
|     xfail('bfloat16', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('bmm', ''),  # aten.bmm.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('bool', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('broadcast_tensors', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('bucketize', ''),  # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition | ||||
|     xfail('byte', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('cartesian_prod', ''),  # Tensors of type TensorImpl do not have numel | ||||
| @ -1120,7 +1119,6 @@ symbolic_tensor_failures = { | ||||
|     xfail('nn.functional.grid_sample', ''),  # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos... | ||||
|     xfail('nn.functional.group_norm', ''),  # 'torch._C.SymIntNode' and 'int' | ||||
|     xfail('nn.functional.hinge_embedding_loss', ''),  # aten.empty_like.default - couldn't find symbolic meta function/deco... | ||||
|     xfail('nn.functional.huber_loss', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('nn.functional.instance_norm', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('nn.functional.interpolate', 'area'),  # aten.size.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('nn.functional.interpolate', 'bicubic'),  # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d... | ||||
| @ -1128,7 +1126,6 @@ symbolic_tensor_failures = { | ||||
|     xfail('nn.functional.interpolate', 'linear'),  # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... | ||||
|     xfail('nn.functional.interpolate', 'nearest'),  # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d... | ||||
|     xfail('nn.functional.interpolate', 'trilinear'),  # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... | ||||
|     xfail('nn.functional.l1_loss', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('nn.functional.linear', ''),  # aten.mv.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('nn.functional.local_response_norm', ''),  # Tensors of type TensorImpl do not have numel | ||||
|     xfail('nn.functional.margin_ranking_loss', ''),  # The underlying op of 'aten.stride' has no overload name '_schema' | ||||
| @ -1138,7 +1135,6 @@ symbolic_tensor_failures = { | ||||
|     xfail('nn.functional.max_unpool1d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... | ||||
|     xfail('nn.functional.max_unpool2d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... | ||||
|     xfail('nn.functional.max_unpool3d', 'grad'),  # aten.max_unpool3d.default - couldn't find symbolic meta function/decom... | ||||
|     xfail('nn.functional.mse_loss', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | ||||
|     xfail('nn.functional.multi_margin_loss', ''),  # Could not run 'aten::multi_margin_loss' with arguments from the... | ||||
|     xfail('nn.functional.multilabel_margin_loss', ''),  # Could not run 'aten::multilabel_margin_loss_forward' with ... | ||||
|     xfail('nn.functional.normalize', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | ||||
|  | ||||
| @ -1833,5 +1833,11 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""") | ||||
|             e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass) | ||||
|             self.assertEqual(e.layout, torch.strided) | ||||
|  | ||||
| class TestPythonDispatcher(TestCase): | ||||
|     def test_basic(self): | ||||
|         x = torch.randn(2, requires_grad=True) | ||||
|         r = torch._C._EnablePythonDispatcher() | ||||
|         torch.add(x, x) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     run_tests() | ||||
|  | ||||
| @ -136,6 +136,9 @@ blocklist = [ | ||||
|     "floor_divide", | ||||
|     "floor_divide_", | ||||
|     "floor_divide_out", | ||||
|     "to", | ||||
|     "_to_copy", | ||||
|     "copy_", | ||||
| ] | ||||
|  | ||||
| binary_ops = ( | ||||
|  | ||||
| @ -997,11 +997,26 @@ class Generator(object): | ||||
| def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ... | ||||
| def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ... | ||||
| def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ... | ||||
| def _dispatch_has_kernel_for_any_dispatch_key(name: str, dispatch_key_set: object) -> _bool: ... | ||||
| def _dispatch_has_kernel(name: str) -> _bool: ... | ||||
| def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ... | ||||
| def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ... | ||||
| def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ... | ||||
| def _dispatch_dump(dispatch: str) -> str: ... | ||||
| def _dispatch_is_included_in_alias(dispatch_a: str, dispatch_b: str) -> _bool: ... | ||||
| def _dispatch_get_backend_keyset_from_autograd(dispatch: str) -> Any: ... | ||||
| def _dispatch_has_backend_fallback(dispatch: str) -> _bool: ... | ||||
| def _dispatch_key_name(dispatch: str) -> str: ... | ||||
|  | ||||
| # Define in torch/csrc/autograd/init.cpp | ||||
| class _DisablePythonDispatcher(object): | ||||
|     pass | ||||
|  | ||||
| class _EnablePythonDispatcher(object): | ||||
|     pass | ||||
|  | ||||
| def _set_python_dispatcher(dispatcher: object) -> None: ... | ||||
|  | ||||
|  | ||||
| class _AutoDispatchBelowAutograd: | ||||
|     pass | ||||
|  | ||||
| @ -962,3 +962,5 @@ if 'TORCH_CUDA_SANITIZER' in os.environ: | ||||
|     import torch.cuda._sanitizer as csan | ||||
|  | ||||
|     csan.enable_cuda_sanitizer() | ||||
|  | ||||
| from ._dispatch import python | ||||
|  | ||||
| @ -1,50 +0,0 @@ | ||||
| import torch | ||||
| import torch._C as _C | ||||
| from torch.utils._pytree import tree_flatten | ||||
|  | ||||
| """ | ||||
| This is a dispatcher (in Python) | ||||
| - You can define new operations (in Python) without schemas | ||||
| - It interfaces with the PyTorch dispatcher | ||||
| """ | ||||
|  | ||||
| class PyDispatcher: | ||||
|     # operator is a PyOperator | ||||
|     @staticmethod | ||||
|     def call(operator, *args, **kwargs): | ||||
|         dispatch_key_set = compute_keyset(args, kwargs) | ||||
|         kernel = operator.lookup(dispatch_key_set) | ||||
|         return kernel(*args, **kwargs) | ||||
|  | ||||
|     # operator is a PyOperator | ||||
|     @staticmethod | ||||
|     def redispatch(operator, dispatch_key_set, *args, **kwargs): | ||||
|         kernel = operator.lookup(dispatch_key_set) | ||||
|         return kernel(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| def compute_keyset(args, kwargs): | ||||
|     tensors = get_tensors(args, kwargs) | ||||
|     return key_extractor(tensors) | ||||
|  | ||||
|  | ||||
| # Note - this should maintain identical impl to the C++ dispatcher key extraction logic | ||||
| # at ATen/core/dispatch/DispatchKeyExtractor.h | ||||
| def key_extractor(tensors): | ||||
|     key_set = _C._dispatch_tls_local_include_set()  # type: ignore[attr-defined] | ||||
|     for tensor in tensors: | ||||
|         key_set = key_set | _C._dispatch_keys(tensor)  # type: ignore[attr-defined] | ||||
|     key_set = key_set - _C._dispatch_tls_local_exclude_set()  # type: ignore[attr-defined] | ||||
|     return key_set | ||||
|  | ||||
|  | ||||
| def to_flat_tuple(args, kwargs): | ||||
|     flat_args, _ = tree_flatten(args) | ||||
|     flat_kwargs, _ = tree_flatten(kwargs) | ||||
|     flat_all = flat_args + flat_kwargs | ||||
|     return flat_all | ||||
|  | ||||
| def get_tensors(args, kwargs): | ||||
|     flat_all = to_flat_tuple(args, kwargs) | ||||
|     tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)] | ||||
|     return tuple(tensor_args) | ||||
							
								
								
									
										88
									
								
								torch/_dispatch/python.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								torch/_dispatch/python.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,88 @@ | ||||
| import torch | ||||
| from contextlib import contextmanager | ||||
|  | ||||
| __all__ = ['enable_python_dispatcher', 'no_python_dispatcher'] | ||||
|  | ||||
| DispatchKey = torch._C.DispatchKey  # type: ignore[attr-defined] | ||||
|  | ||||
| def has_key(op, k): | ||||
|     return ( | ||||
|         torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), k) | ||||
|         or k in op.py_kernels | ||||
|     ) | ||||
|  | ||||
| is_included_in_alias = torch._C._dispatch_is_included_in_alias | ||||
|  | ||||
| # Equivalent to computeDispatchTableEntryWithDebug | ||||
| # TODO: memoize this or something | ||||
| def resolve_key(op: torch._ops.PyOperatorABC, k: DispatchKey):  # type: ignore[valid-type] | ||||
|     # 1. (Direct) operator registration | ||||
|     if has_key(op, k): | ||||
|         return k | ||||
|     # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available | ||||
|     cand = DispatchKey.CompositeExplicitAutogradNonFunctional | ||||
|     if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(op, cand): | ||||
|         return cand | ||||
|     # 2.2 Use CompositeExplicitAutograd kernel if available | ||||
|     cand = DispatchKey.CompositeExplicitAutograd | ||||
|     if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(op, cand): | ||||
|         return cand | ||||
|     has_backend_kernel = ( | ||||
|         torch._C._dispatch_has_kernel_for_any_dispatch_key(op.name(), torch._C._dispatch_get_backend_keyset_from_autograd(k)) | ||||
|         or has_key(op, DispatchKey.CompositeExplicitAutograd) | ||||
|     ) | ||||
|     # 2.3. Use CompositeImplicitAutograd kernel if available | ||||
|     cand = DispatchKey.CompositeImplicitAutogradNestedTensor | ||||
|     if ( | ||||
|         (k != DispatchKey.Undefined and is_included_in_alias(k, cand))  # type: ignore[attr-defined] | ||||
|             and has_key(op, cand) and not has_backend_kernel): | ||||
|         return cand | ||||
|     cand = DispatchKey.CompositeImplicitAutograd | ||||
|     if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(op, cand): | ||||
|         if ( | ||||
|             k == DispatchKey.AutogradOther | ||||
|             and torch._C._dispatch_has_kernel_for_any_dispatch_key(op.name(), torch._C._dispatch_autogradother_backends)  # type: ignore[attr-defined] # noqa: B950 | ||||
|         ): | ||||
|             raise RuntimeError("ambiguous autogradother kernel") | ||||
|         elif not has_backend_kernel: | ||||
|             return cand | ||||
|     # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available | ||||
|     cand = DispatchKey.Autograd | ||||
|     if is_included_in_alias(k, cand) and has_key(op, cand): | ||||
|         return cand | ||||
|     # Backend fallback | ||||
|     if torch._C._dispatch_has_backend_fallback(k): | ||||
|         # The dispatch key itself will implicitly route to backend fallback. | ||||
|         # This is probably not great for the pure Python implementation. | ||||
|         return k | ||||
|     raise RuntimeError("could not find kernel") | ||||
|  | ||||
| @contextmanager | ||||
| def no_python_dispatcher(): | ||||
|     g = torch._C._DisablePythonDispatcher() | ||||
|     try: | ||||
|         yield | ||||
|     finally: | ||||
|         del g | ||||
|  | ||||
| @contextmanager | ||||
| def enable_python_dispatcher(): | ||||
|     g = torch._C._EnablePythonDispatcher() | ||||
|     try: | ||||
|         yield | ||||
|     finally: | ||||
|         del g | ||||
|  | ||||
| # The Python dispatcher | ||||
| def python_dispatcher(op, ks, args, kwargs): | ||||
|     """ | ||||
|     with no_python_dispatcher(): | ||||
|         print(op, ks, args, kwargs) | ||||
|     """ | ||||
|     k = resolve_key(op, ks.highestPriorityTypeId()) | ||||
|     source = f'torch.ops.{op}.dispatch(k, *args, **kwargs)' | ||||
|     filename = f'{op}[{torch._C._dispatch_key_name(k)}]' | ||||
|     compiled = compile(source, filename, 'eval')  # TODO: maybe cache? | ||||
|     return eval(compiled, {'torch': torch, 'k': k, 'args': args, 'kwargs': kwargs}) | ||||
|  | ||||
| torch._C._set_python_dispatcher(python_dispatcher) | ||||
							
								
								
									
										202
									
								
								torch/_ops.py
									
									
									
									
									
								
							
							
						
						
									
										202
									
								
								torch/_ops.py
									
									
									
									
									
								
							| @ -1,12 +1,19 @@ | ||||
| import contextlib | ||||
| import ctypes | ||||
| import inspect | ||||
| import sys | ||||
| import types | ||||
| from abc import ABC | ||||
| from typing import Any, Dict | ||||
|  | ||||
| import torch._C | ||||
|  | ||||
| import torch.jit | ||||
| from torch import _utils_internal | ||||
| from torch._C import DispatchKey  # type: ignore[attr-defined] | ||||
| from torch.overrides import handle_torch_function, has_torch_function | ||||
| from torch.utils._python_dispatch import TorchDispatchMode | ||||
| from torch.utils._pytree import tree_flatten | ||||
|  | ||||
| # Query `hasattr` only once. | ||||
| _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") | ||||
| @ -26,9 +33,129 @@ def dl_open_guard(): | ||||
|         sys.setdlopenflags(old_flags) | ||||
|  | ||||
|  | ||||
| # TODO(voz) We are missing an entire axis of registration - Modes for the python key | ||||
| class PyOperatorABC(ABC): | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         pass | ||||
|  | ||||
|     def py_impl(self, dispatch_key, fn): | ||||
|         pass | ||||
|  | ||||
|     def name(self): | ||||
|         pass | ||||
|  | ||||
|     def dispatch(self, dispatch_key, *args, **kwargs): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class PyOperator(PyOperatorABC): | ||||
|     def __init__(self, name): | ||||
|         self._name = name | ||||
|         self.table = {} | ||||
|         self.python_key_mode_table = {} | ||||
|  | ||||
|         # Make _OPNamespace not scream, this whole name based association needs a good hard look | ||||
|         self.__name__ = "pyop." + name | ||||
|         pyop_namespace.py_ops[name] = self | ||||
|  | ||||
|     def fallthrough(self, dispatch_key): | ||||
|         self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key) | ||||
|  | ||||
|     def py_impl(self, dispatch_key_or_mode): | ||||
|         def inner(fn): | ||||
|             if inspect.isclass(dispatch_key_or_mode) and issubclass( | ||||
|                 dispatch_key_or_mode, TorchDispatchMode | ||||
|             ): | ||||
|                 mode = dispatch_key_or_mode | ||||
|                 assert mode not in self.python_key_mode_table | ||||
|                 # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys? | ||||
|                 self.python_key_mode_table[mode] = fn | ||||
|                 return fn | ||||
|  | ||||
|             dispatch_key = dispatch_key_or_mode | ||||
|             assert ( | ||||
|                 dispatch_key != torch._C.DispatchKey.Python  # type: ignore[attr-defined] | ||||
|             ), "Please register a mode for the torch._C.DispatchKey.Python key instead." | ||||
|             assert isinstance(dispatch_key, torch._C.DispatchKey)  # type: ignore[attr-defined] | ||||
|             assert dispatch_key not in self.table | ||||
|             self.table[dispatch_key] = fn | ||||
|             return fn | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|     def dispatch(self, dispatch_key, *args, **kwargs): | ||||
|         if dispatch_key == torch._C.DispatchKey.Python:  # type: ignore[attr-defined] | ||||
|             # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. | ||||
|             curr_mode = type(torch._C._get_torch_dispatch_mode()) | ||||
|             assert ( | ||||
|                 curr_mode is not None | ||||
|             ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." | ||||
|             assert ( | ||||
|                 curr_mode in self.python_key_mode_table | ||||
|             ), f"Current active mode {curr_mode} not registered" | ||||
|             # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key. | ||||
|             return self.python_key_mode_table[curr_mode](*args, **kwargs) | ||||
|  | ||||
|         assert dispatch_key in self.table | ||||
|         return self.table[dispatch_key](*args, **kwargs) | ||||
|  | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         flat_args = _to_flat_tuple(args, kwargs) | ||||
|         if has_torch_function(flat_args): | ||||
|             return handle_torch_function(self, flat_args, *args, **kwargs) | ||||
|  | ||||
|         dispatch_key_set = _compute_keyset(args, kwargs) | ||||
|         return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs) | ||||
|  | ||||
|     def name(self): | ||||
|         return self.name | ||||
|  | ||||
|     # TODO(voz): Should rewrite fallthrough register as the impl for keys we do not specify | ||||
|     # as opposed to being this sort of explicit thing where ops are a little too key aware... | ||||
|     def _fallthrough_fn(self, operator, dispatch_key): | ||||
|         def inner(*args, **kwargs): | ||||
|             all_keys_after_current = torch._C._dispatch_keyset_full_after(dispatch_key)  # type: ignore[attr-defined] | ||||
|             all_keys_after_current_masked = all_keys_after_current & _compute_keyset( | ||||
|                 args, kwargs | ||||
|             )  # type: ignore[attr-defined] | ||||
|             return self.dispatch( | ||||
|                 all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs | ||||
|             ) | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|  | ||||
| def _to_flat_tuple(args, kwargs): | ||||
|     flat_args, _ = tree_flatten(args) | ||||
|     flat_kwargs, _ = tree_flatten(kwargs) | ||||
|     flat_all = flat_args + flat_kwargs | ||||
|     return flat_all | ||||
|  | ||||
|  | ||||
| def _compute_keyset(args, kwargs): | ||||
|     tensors = _get_tensors(args, kwargs) | ||||
|     return key_extractor(tensors) | ||||
|  | ||||
|  | ||||
| def _get_tensors(args, kwargs): | ||||
|     flat_all = _to_flat_tuple(args, kwargs) | ||||
|     tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)] | ||||
|     return tuple(tensor_args) | ||||
|  | ||||
|  | ||||
| # Note - this should maintain identical impl to the C++ dispatcher key extraction logic | ||||
| # at ATen/core/dispatch/DispatchKeyExtractor.h | ||||
| def key_extractor(tensors): | ||||
|     key_set = torch._C._dispatch_tls_local_include_set()  # type: ignore[attr-defined] | ||||
|     for tensor in tensors: | ||||
|         key_set = key_set | torch._C._dispatch_keys(tensor)  # type: ignore[attr-defined] | ||||
|     key_set = key_set - torch._C._dispatch_tls_local_exclude_set()  # type: ignore[attr-defined] | ||||
|     return key_set | ||||
|  | ||||
|  | ||||
| # Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object. | ||||
| # You can obtain an OpOverload object through attribute query on OpOverloadPacket. | ||||
| class OpOverload: | ||||
| class OpOverload(PyOperatorABC): | ||||
|     def __init__(self, overloadpacket, op, op_dk, schema, tags): | ||||
|         self._op = op | ||||
|         self._op_dk = op_dk | ||||
| @ -38,12 +165,15 @@ class OpOverload: | ||||
|         self._overloadname = ( | ||||
|             "default" if schema.overload_name == "" else schema.overload_name | ||||
|         ) | ||||
|         self.name = self._schema.name | ||||
|         self._name = self._schema.name | ||||
|         if schema.overload_name: | ||||
|             self.name += "." + schema.overload_name | ||||
|             self._name += "." + schema.overload_name | ||||
|         self.py_kernels: Dict[DispatchKey, Any] = {} | ||||
|         self.__name__ = "{}.{}".format( | ||||
|             self._schema.name.split("::")[1], self._overloadname | ||||
|         ) | ||||
|         # TODO(voz): Lots of shared logic around python_key_mode_table, maybe pull into base... | ||||
|         self.python_key_mode_table = {} | ||||
|         self.__module__ = overloadpacket.__module__ | ||||
|         op.__module__ = overloadpacket.__module__ | ||||
|  | ||||
| @ -70,12 +200,56 @@ class OpOverload: | ||||
|         return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname) | ||||
|  | ||||
|     def decompose(self, *args, **kwargs): | ||||
|         dk = "CompositeImplicitAutograd" | ||||
|         if torch._C._dispatch_has_kernel_for_dispatch_key(self.name, dk): | ||||
|             return self._op_dk(dk, *args, **kwargs) | ||||
|         dk = torch._C.DispatchKey.CompositeImplicitAutograd  # type: ignore[attr-defined] | ||||
|         if ( | ||||
|             torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk) | ||||
|             or dk in self.py_kernels | ||||
|         ): | ||||
|             return self.dispatch(dk, *args, **kwargs) | ||||
|         else: | ||||
|             return NotImplemented | ||||
|  | ||||
|     def py_impl(self, dispatch_key_or_mode): | ||||
|         def inner(fn): | ||||
|             if inspect.isclass(dispatch_key_or_mode) and issubclass( | ||||
|                 dispatch_key_or_mode, TorchDispatchMode | ||||
|             ): | ||||
|                 mode = dispatch_key_or_mode | ||||
|                 assert mode not in self.python_key_mode_table | ||||
|                 # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys? | ||||
|                 self.python_key_mode_table[mode] = fn | ||||
|                 return fn | ||||
|  | ||||
|             assert isinstance(dispatch_key_or_mode, torch._C.DispatchKey)  # type: ignore[attr-defined] | ||||
|             assert ( | ||||
|                 dispatch_key_or_mode != torch._C.DispatchKey.Python  # type: ignore[attr-defined] | ||||
|             ), "Please register a mode for the torch._C.DispatchKey.Python key instead." | ||||
|  | ||||
|             self.py_kernels[dispatch_key_or_mode] = fn | ||||
|             return fn | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|     def dispatch(self, dispatch_key, *args, **kwargs): | ||||
|         if dispatch_key == torch._C.DispatchKey.Python:  # type: ignore[attr-defined] | ||||
|             # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. | ||||
|             curr_mode = type(torch._C._get_torch_dispatch_mode()) | ||||
|             assert ( | ||||
|                 curr_mode is not None | ||||
|             ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." | ||||
|             if curr_mode not in self.python_key_mode_table: | ||||
|                 return self._op_dk(dispatch_key, *args, **kwargs) | ||||
|             # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key. | ||||
|             return self.python_key_mode_table[curr_mode](*args, **kwargs) | ||||
|  | ||||
|         if dispatch_key in self.py_kernels: | ||||
|             return self.py_kernels[dispatch_key](*args, **kwargs) | ||||
|         else: | ||||
|             return self._op_dk(dispatch_key, *args, **kwargs) | ||||
|  | ||||
|     def name(self): | ||||
|         return self._name | ||||
|  | ||||
|     @property | ||||
|     def overloadpacket(self): | ||||
|         return self._overloadpacket | ||||
| @ -218,8 +392,15 @@ class _OpNamespace(types.ModuleType): | ||||
|     def __init__(self, name): | ||||
|         super(_OpNamespace, self).__init__("torch.ops." + name) | ||||
|         self.name = name | ||||
|         if self.name == "pyop": | ||||
|             self.pyops = pyop_namespace | ||||
|         else: | ||||
|             self.pyops = None  # type: ignore[assignment] | ||||
|  | ||||
|     def __getattr__(self, op_name): | ||||
|         pyops = object.__getattribute__(self, "pyops") | ||||
|         if pyops is not None: | ||||
|             return pyops.py_ops[op_name] | ||||
|         # It is not a valid op_name when __file__ is passed in | ||||
|         if op_name == "__file__": | ||||
|             return "torch.ops" | ||||
| @ -253,6 +434,15 @@ class _OpNamespace(types.ModuleType): | ||||
|         return opoverloadpacket | ||||
|  | ||||
|  | ||||
| class _PyOpNamespace(_OpNamespace): | ||||
|     def __init__(self): | ||||
|         super(_PyOpNamespace, self).__init__("torch.ops.pyop") | ||||
|         self.py_ops = {} | ||||
|  | ||||
|  | ||||
| pyop_namespace = _PyOpNamespace() | ||||
|  | ||||
|  | ||||
| class _Ops(types.ModuleType): | ||||
|     __file__ = "_ops.py" | ||||
|  | ||||
|  | ||||
| @ -271,6 +271,7 @@ __all__ = [ | ||||
| ] | ||||
|  | ||||
| Tensor = torch.Tensor | ||||
| DispatchKey = torch._C.DispatchKey  # type: ignore[attr-defined] | ||||
|  | ||||
|  | ||||
| def _broadcast_shapes(*_shapes): | ||||
| @ -2190,7 +2191,11 @@ def broadcast_shapes(*shapes) -> ShapeType: | ||||
|     return torch.Size(_broadcast_shapes(*shapes)) | ||||
|  | ||||
|  | ||||
| @torch.ops.aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd) | ||||
| @torch.ops.aten.broadcast_tensors.default.py_impl(DispatchKey.Meta) | ||||
| def broadcast_tensors(*tensors) -> List[TensorLikeType]: | ||||
|     if len(tensors) == 1 and not isinstance(tensors[0], Tensor): | ||||
|         tensors = tensors[0] | ||||
|     return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -7,6 +7,7 @@ | ||||
| #include <c10/core/DeviceType.h> | ||||
| #include <c10/core/InferenceMode.h> | ||||
| #include <c10/core/ScalarType.h> | ||||
| #include <c10/core/impl/PythonDispatcherTLS.h> | ||||
| #include <torch/csrc/Exceptions.h> | ||||
| #include <torch/csrc/autograd/autograd.h> | ||||
| #include <torch/csrc/autograd/function.h> | ||||
| @ -52,6 +53,19 @@ struct EnableTorchFunction { | ||||
|   bool old_; | ||||
| }; | ||||
|  | ||||
| PyObject* globalPythonDispatcher = nullptr; | ||||
|  | ||||
| struct EnablePythonDispatcher { | ||||
|   EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) { | ||||
|     c10::impl::PythonDispatcherTLS::set_state( | ||||
|         {globalPythonDispatcher, getPyInterpreter()}); | ||||
|   } | ||||
|   ~EnablePythonDispatcher() { | ||||
|     c10::impl::PythonDispatcherTLS::set_state(old_); | ||||
|   } | ||||
|   c10::SafePyHandle old_; | ||||
| }; | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { | ||||
| @ -328,6 +342,18 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { | ||||
|       .def(py::init<>()); | ||||
|   py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction") | ||||
|       .def(py::init<>()); | ||||
|   py::class_<EnablePythonDispatcher>(_C_m, "_EnablePythonDispatcher") | ||||
|       .def(py::init<>()); | ||||
|   py::class_<c10::impl::DisablePythonDispatcher>( | ||||
|       _C_m, "_DisablePythonDispatcher") | ||||
|       .def(py::init<>()); | ||||
|   _C_m.def("_set_python_dispatcher", [](py::object dispatcher) { | ||||
|     TORCH_CHECK( | ||||
|         !globalPythonDispatcher, | ||||
|         "overwriting the global python dispatcher is not supported; if you need this file an issue"); | ||||
|     // NB: intentionally leak | ||||
|     globalPythonDispatcher = dispatcher.release().ptr(); | ||||
|   }); | ||||
|   py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>()); | ||||
|  | ||||
|   py::class_<torch::autograd::SavedVariable>(m, "SavedTensor") | ||||
|  | ||||
| @ -3,6 +3,7 @@ | ||||
| #include <c10/core/DeviceType.h> | ||||
| #include <c10/core/SafePyObject.h> | ||||
| #include <c10/core/impl/GPUTrace.h> | ||||
| #include <c10/core/impl/PythonDispatcherTLS.h> | ||||
| #include <c10/util/DeadlockDetection.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <pybind11/pytypes.h> | ||||
| @ -235,6 +236,10 @@ struct ConcretePyInterpreterVTable final | ||||
|  | ||||
|   void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) | ||||
|       const override; | ||||
|   void python_dispatcher( | ||||
|       const c10::OperatorHandle& op, | ||||
|       c10::DispatchKeySet, | ||||
|       torch::jit::Stack* stack) const override; | ||||
|  | ||||
|   bool is_contiguous(const TensorImpl* self) const override; | ||||
|   c10::Device device(const TensorImpl* self) const override; | ||||
| @ -701,6 +706,7 @@ static PyObject* THPVariable_make_subclass( | ||||
|         "cls must be a type (got %s)", Py_TYPE(cls)->tp_name); | ||||
|   } | ||||
|   torch_dispatch_mode::StashTorchDispatchModeGuard td_g; | ||||
|   c10::impl::DisablePythonDispatcher dpd_g; | ||||
|   auto data = | ||||
|       r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED) | ||||
|   // We set `data`'s `allow_tensor_metadata_change` to true here, because we | ||||
| @ -2265,6 +2271,67 @@ void ConcretePyInterpreterVTable::dispatch( | ||||
|       op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__"); | ||||
| } | ||||
|  | ||||
| void ConcretePyInterpreterVTable::python_dispatcher( | ||||
|     const c10::OperatorHandle& op, | ||||
|     c10::DispatchKeySet ks, | ||||
|     torch::jit::Stack* stack) const { | ||||
|   const auto& schema = op.schema(); | ||||
|   const auto num_arguments = schema.arguments().size(); | ||||
|   auto arguments = torch::jit::pop(*stack, num_arguments); | ||||
|  | ||||
|   // Parse the name into namespace and name (no overload_name) | ||||
|   // TODO: put this into the library | ||||
|   const auto& qualified_name = op.operator_name().name; | ||||
|   const auto& overload_name = schema.overload_name(); | ||||
|   auto pos = qualified_name.find("::"); | ||||
|   TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); | ||||
|   // Make me some null terminated strings | ||||
|   std::string ns_str = qualified_name.substr(0, pos); | ||||
|   const char* ns = ns_str.c_str(); | ||||
|   const char* func_name = qualified_name.c_str() + pos + strlen("::"); | ||||
|  | ||||
|   // The plan: convert all the arguments back into PyObjects, | ||||
|   // extracting out the tensor handles, then call | ||||
|   // handle_torch_function_no_python_arg_parser | ||||
|   // NB: at the point arguments are pushed to the stack, ALL defaults | ||||
|   // are already present | ||||
|  | ||||
|   py::gil_scoped_acquire g; | ||||
|  | ||||
|   std::vector<py::handle> overloaded_args; | ||||
|   py::handle torch_api_function = | ||||
|       py::module::import("torch").attr("ops").attr(ns).attr(func_name); | ||||
|   py::handle torch_api_function_overload; | ||||
|   if (overload_name == "") { | ||||
|     torch_api_function_overload = torch_api_function.attr("default"); | ||||
|   } else { | ||||
|     torch_api_function_overload = | ||||
|         torch_api_function.attr(overload_name.c_str()); | ||||
|   } | ||||
|   std::string module_name_str = "torch.ops." + ns_str; | ||||
|  | ||||
|   auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); | ||||
|   auto args = std::move(args_kwargs.first); | ||||
|   auto kwargs = std::move(args_kwargs.second); | ||||
|  | ||||
|   auto python_dispatcher = | ||||
|       c10::impl::PythonDispatcherTLS::get_state().ptr(getPyInterpreter()); | ||||
|   TORCH_INTERNAL_ASSERT(python_dispatcher); | ||||
|  | ||||
|   py::object obj = py::reinterpret_steal<py::object>(PyObject_CallFunction( | ||||
|       python_dispatcher, | ||||
|       "OOOO", | ||||
|       torch_api_function_overload, | ||||
|       py::cast(ks).ptr(), | ||||
|       args.ptr(), | ||||
|       kwargs.ptr())); | ||||
|  | ||||
|   if (obj == nullptr) | ||||
|     throw python_error(); | ||||
|  | ||||
|   pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher"); | ||||
| } | ||||
|  | ||||
| c10::intrusive_ptr<TensorImpl> ConcretePyInterpreterVTable::detach( | ||||
|     const c10::TensorImpl* self) const { | ||||
|   pybind11::gil_scoped_acquire gil; | ||||
| @ -2377,7 +2444,7 @@ c10::IntArrayRef ConcretePyInterpreterVTable::strides( | ||||
|   if (out == Py_None) { | ||||
|     TORCH_CHECK( | ||||
|         !self->has_symbolic_sizes_strides(), | ||||
|         "Cannot call sizes on a tensor with symbolic shapes/strides"); | ||||
|         "Cannot call strides on a tensor with symbolic shapes/strides"); | ||||
|     return self->strides_default(); | ||||
|   } | ||||
|  | ||||
|  | ||||
| @ -1713,13 +1713,11 @@ void initJITBindings(PyObject* module) { | ||||
|                     return _get_operation_for_overload_or_packet( | ||||
|                         {op}, symbol, args, kwargs, /*is_overload*/ true); | ||||
|                   }); | ||||
|               auto func_dk = | ||||
|                   py::cpp_function([op, symbol, allow_numbers_as_tensors]( | ||||
|                                        const std::string& str_dk, | ||||
|                                        py::args args, | ||||
|                                        py::kwargs kwargs) { | ||||
|               auto func_dk = py::cpp_function( | ||||
|                   [op, symbol, allow_numbers_as_tensors]( | ||||
|                       c10::DispatchKey dk_, py::args args, py::kwargs kwargs) { | ||||
|                     c10::optional<c10::DispatchKey> dk = | ||||
|                         c10::make_optional(c10::parseDispatchKey(str_dk)); | ||||
|                         c10::make_optional(dk_); | ||||
|                     ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); | ||||
|                     return _get_operation_for_overload_or_packet( | ||||
|                         {op}, symbol, args, kwargs, /*is_overload*/ true, dk); | ||||
|  | ||||
| @ -136,6 +136,33 @@ struct type_caster<at::Device> { | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct type_caster<c10::DispatchKey> | ||||
|     : public type_caster_base<c10::DispatchKey> { | ||||
|   using base = type_caster_base<c10::DispatchKey>; | ||||
|   c10::DispatchKey tmp; | ||||
|  | ||||
|  public: | ||||
|   bool load(handle src, bool convert) { | ||||
|     if (base::load(src, convert)) { | ||||
|       return true; | ||||
|     } else if (py::isinstance( | ||||
|                    src, py::module_::import("builtins").attr("str"))) { | ||||
|       tmp = c10::parseDispatchKey(py::cast<std::string>(src)); | ||||
|       value = &tmp; | ||||
|       return true; | ||||
|     } | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   static handle cast( | ||||
|       c10::DispatchKey src, | ||||
|       return_value_policy policy, | ||||
|       handle parent) { | ||||
|     return base::cast(src, policy, parent); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // Pybind11 bindings for our optional and variant types. | ||||
| // http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers | ||||
| template <typename T> | ||||
|  | ||||
| @ -92,6 +92,7 @@ bool should_allow_numbers_as_tensors(const std::string& name) { | ||||
|       "sub",          "sub_",          "sub_out", | ||||
|       "subtract",     "subtract_",     "subtract_out", // alias of sub | ||||
|       "true_divide",  "true_divide_",  "true_divide_out", | ||||
|       "to",           "_to_copy",      "copy_", | ||||
|       "floor_divide", "floor_divide_", "floor_divide_out"}; | ||||
|   return allowed.find(name) != allowed.end(); | ||||
| } | ||||
|  | ||||
| @ -295,11 +295,20 @@ void initDispatchBindings(PyObject* module) { | ||||
|       // Returns whether or not a direct kernel registration exists | ||||
|       // for this <op_name, dispatch_key> pair. | ||||
|       "_dispatch_has_kernel_for_dispatch_key", | ||||
|       [](const char* name, const char* dispatch) -> bool { | ||||
|       [](const char* name, c10::DispatchKey dispatch) -> bool { | ||||
|         auto op = | ||||
|             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); | ||||
|         TORCH_CHECK(op, "operator ", name, " does not exist"); | ||||
|         return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch)); | ||||
|         return op->hasKernelForDispatchKey(dispatch); | ||||
|       }); | ||||
|  | ||||
|   m.def( | ||||
|       "_dispatch_has_kernel_for_any_dispatch_key", | ||||
|       [](const char* name, c10::DispatchKeySet ks) -> bool { | ||||
|         auto op = | ||||
|             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); | ||||
|         TORCH_CHECK(op, "operator ", name, " does not exist"); | ||||
|         return op->hasKernelForAnyDispatchKey(ks); | ||||
|       }); | ||||
|  | ||||
|   m.def( | ||||
| @ -332,38 +341,48 @@ void initDispatchBindings(PyObject* module) { | ||||
|  | ||||
|   m.def( | ||||
|       "_dispatch_tls_set_dispatch_key_excluded", | ||||
|       [](const char* dispatch_key, bool desired_state) { | ||||
|         c10::impl::tls_set_dispatch_key_excluded( | ||||
|             c10::parseDispatchKey(dispatch_key), desired_state); | ||||
|       [](c10::DispatchKey dispatch_key, bool desired_state) { | ||||
|         c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state); | ||||
|       }); | ||||
|   m.def("_dispatch_tls_is_dispatch_key_excluded", [](const char* dispatch_key) { | ||||
|     return c10::impl::tls_is_dispatch_key_excluded( | ||||
|         c10::parseDispatchKey(dispatch_key)); | ||||
|   m.def( | ||||
|       "_dispatch_tls_is_dispatch_key_excluded", | ||||
|       [](c10::DispatchKey dispatch_key) { | ||||
|         return c10::impl::tls_is_dispatch_key_excluded(dispatch_key); | ||||
|       }); | ||||
|  | ||||
|   m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) { | ||||
|     return at::isTensorSubclassLike(tensor); | ||||
|   }); | ||||
|  | ||||
|   m.def("_dispatch_key_name", [](uint64_t dispatch_key) { | ||||
|     auto dt = (c10::DispatchKey)dispatch_key; | ||||
|     return c10::toString(dt); | ||||
|   m.def("_dispatch_key_name", [](c10::DispatchKey k) { | ||||
|     return c10::toString(k); | ||||
|   }); | ||||
|   m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; }); | ||||
|   m.def("_dispatch_num_backends", []() { return c10::num_backends; }); | ||||
|  | ||||
|   py::enum_<c10::DispatchKey>(m, "DispatchKey") | ||||
|       .value("Undefined", c10::DispatchKey::Undefined) | ||||
|       .value("Dense", c10::DispatchKey::Dense) | ||||
|       .value("BackendSelect", c10::DispatchKey::BackendSelect) | ||||
|       .value("CPU", c10::DispatchKey::CPU) | ||||
|       .value("CUDA", c10::DispatchKey::CUDA) | ||||
|       .value("AutocastCPU", c10::DispatchKey::AutocastCPU) | ||||
|       .value("AutocastCUDA", c10::DispatchKey::AutocastCUDA) | ||||
|       .value("AutogradCPU", c10::DispatchKey::AutogradCPU) | ||||
|       .value("ADInplaceOrView", c10::DispatchKey::ADInplaceOrView) | ||||
|       .value("AutogradCUDA", c10::DispatchKey::AutogradCUDA) | ||||
|       .value("PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot) | ||||
|       .value("Python", c10::DispatchKey::Python); | ||||
| #define DEF_ONE(n) .value(#n, c10::DispatchKey::n) | ||||
|  | ||||
|   py::enum_<c10::DispatchKey>(m, "DispatchKey") DEF_ONE(Undefined) | ||||
|       DEF_ONE(CompositeExplicitAutogradNonFunctional) | ||||
|           DEF_ONE(CompositeExplicitAutograd) | ||||
|               DEF_ONE(CompositeImplicitAutogradNestedTensor) | ||||
|                   DEF_ONE(CompositeImplicitAutograd) DEF_ONE(AutogradOther) | ||||
|                       DEF_ONE(Autograd) DEF_ONE(BackendSelect) | ||||
|                           DEF_ONE(ADInplaceOrView) DEF_ONE(PythonTLSSnapshot) | ||||
|                               DEF_ONE(Python) | ||||
|  | ||||
| #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n) | ||||
| #define DEF_MULTIPLE(fullname, prefix)              \ | ||||
|   DEF_SINGLE(, fullname)                            \ | ||||
|   DEF_SINGLE(, StartOf##fullname##Backends)         \ | ||||
|   C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \ | ||||
|   DEF_SINGLE(, EndOf##fullname##Backends) | ||||
|  | ||||
|                                   C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE) | ||||
|  | ||||
| #undef DEF_MULTIPLE | ||||
| #undef DEF_SINGLE | ||||
|                                       ; | ||||
|  | ||||
|   py::class_<c10::DispatchKeySet>(m, "DispatchKeySet") | ||||
|       .def(py::init<c10::DispatchKey>()) | ||||
| @ -371,7 +390,15 @@ void initDispatchBindings(PyObject* module) { | ||||
|       .def("__sub__", &c10::DispatchKeySet::operator-) | ||||
|       .def("__and__", &c10::DispatchKeySet::operator&) | ||||
|       .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId) | ||||
|       .def("has", &c10::DispatchKeySet::has); | ||||
|       .def("has", &c10::DispatchKeySet::has) | ||||
|       .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); }); | ||||
|  | ||||
|   m.attr("_dispatch_autogradother_backends") = | ||||
|       py::cast(c10::autogradother_backends); | ||||
|  | ||||
|   m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) { | ||||
|     return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t); | ||||
|   }); | ||||
|  | ||||
|   m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) { | ||||
|     return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t); | ||||
| @ -381,6 +408,10 @@ void initDispatchBindings(PyObject* module) { | ||||
|     return c10::toString(keyset); | ||||
|   }); | ||||
|  | ||||
|   m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) { | ||||
|     return c10::getBackendKeySetFromAutograd(k); | ||||
|   }); | ||||
|  | ||||
|   m.def("_dispatch_keys", [](const at::Tensor& tensor) { | ||||
|     auto* impl = tensor.unsafeGetTensorImpl(); | ||||
|     return impl->key_set(); | ||||
| @ -391,6 +422,11 @@ void initDispatchBindings(PyObject* module) { | ||||
|   m.def("_dispatch_tls_local_exclude_set", []() { | ||||
|     return c10::impl::tls_local_dispatch_key_set().excluded_; | ||||
|   }); | ||||
|   m.def( | ||||
|       "_dispatch_is_included_in_alias", | ||||
|       [](c10::DispatchKey a, c10::DispatchKey b) { | ||||
|         return c10::isIncludedInAlias(a, b); | ||||
|       }); | ||||
|   py::class_<c10::impl::ExcludeDispatchKeyGuard>(m, "ExcludeDispatchKeyGuard") | ||||
|       .def(py::init<c10::DispatchKeySet>()); | ||||
|  | ||||
|  | ||||
| @ -10,6 +10,7 @@ import torch | ||||
| import torch.utils._pytree as pytree | ||||
| from torch.fx import Tracer, GraphModule | ||||
| from torch._subclasses.fake_tensor import FakeTensorMode | ||||
| from torch._dispatch.python import enable_python_dispatcher | ||||
| import torch.fx as fx | ||||
| from torch.fx.passes.shape_prop import _extract_tensor_metadata | ||||
| from contextlib import contextmanager, nullcontext | ||||
| @ -579,6 +580,10 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"): | ||||
|         else: | ||||
|             raise AssertionError(f"Unexpected tracing type: {tracing_mode}") | ||||
|  | ||||
|         python_dispatcher_mode: Any = nullcontext() | ||||
|         if tracing_mode == "symbolic": | ||||
|             python_dispatcher_mode = enable_python_dispatcher() | ||||
|  | ||||
|         proxy_mode = ProxyTorchDispatchMode(fx_tracer) | ||||
|  | ||||
|         def wrap_fake_concrete(x): | ||||
| @ -617,7 +622,7 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"): | ||||
|  | ||||
|         # We disable the autocast cache as the autocast cache causes type conversions on parameters to | ||||
|         # check a cache, which introduces untracked tensors into the graph | ||||
|         with decompose(decomposition_table), fake_tensor_mode, \ | ||||
|         with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \ | ||||
|              sym_mode, proxy_mode, disable_autocast_cache():  # type: ignore[attr-defined] | ||||
|             t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs)) | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user