Python Dispatcher integration with C++ dispatcher (#85050)

#84826 but without ghstack
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85050
Approved by: https://github.com/malfet
This commit is contained in:
Michael Voznesensky
2022-09-15 00:43:36 +00:00
committed by PyTorch MergeBot
parent 3a107bc9be
commit 8ca1839d32
33 changed files with 685 additions and 161 deletions

View File

@ -14,6 +14,7 @@ ThreadLocalState::ThreadLocalState()
debug_info_(c10::ThreadLocalDebugInfo::current()),
functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
autograd_tls_(c10::AutogradState::get_tls_state()),
python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) {
rf_tls_ = at::get_record_function_tls_();
@ -41,6 +42,8 @@ void ThreadLocalState::setThreadLocalState(
at::SavedTensorDefaultHooks::set_stack(state.saved_tensors_default_hooks_);
c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);

View File

@ -10,6 +10,7 @@
#include <ATen/FuncTorchTLS.h>
#include <ATen/PythonTorchFunctionTLS.h>
#include <ATen/record_function.h>
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
namespace at {
@ -57,6 +58,9 @@ class TORCH_API ThreadLocalState {
// TLS for enable_torch_dispatch_mode
std::shared_ptr<SafePyObject> torch_dispatch_mode_state_;
// TLS for enable_python_dispatcher
SafePyHandle python_dispatcher_state_;
// TLS for __torch_function__ (mode and disable_torch_function)
at::impl::PythonTorchFunctionTLS python_torch_function_state_;

View File

@ -1,4 +1,5 @@
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <ATen/core/PythonFallbackKernel.h>
#include <c10/core/SafePyObject.h>
@ -87,6 +88,12 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
}
void pythonDispatcherFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
auto state = c10::impl::PythonDispatcherTLS::get_state();
TORCH_INTERNAL_ASSERT(state, "Hit PythonDispatcher dispatch key but PythonDispatcherTLS was not set");
state.pyinterpreter()->python_dispatcher(op, dispatch_keys.remove(c10::DispatchKey::PythonDispatcher), stack);
}
void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
// It is ok for the tls to be already set here.
// It means that there are multiple calls into the dispatcher not originating from python code.
@ -134,6 +141,10 @@ TORCH_LIBRARY_IMPL(_, Python, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
}
TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonDispatcherFallback>());
}
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
}

View File

@ -168,6 +168,12 @@ public:
// See Note [Plumbing Keys Through The Dispatcher]
void redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const;
bool hasBackendFallbackForDispatchKey(DispatchKey dk) {
auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk);
if (dispatch_ix < 0) return false;
return backendFallbackKernels_[dispatch_ix].kernel.isValid();
}
// ------------------------------------------------------------------------
//
@ -333,6 +339,10 @@ public:
return operatorDef_->op.hasKernelForDispatchKey(k);
}
bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
return operatorDef_->op.hasKernelForAnyDispatchKey(k);
}
bool hasComputedKernelForDispatchKey(DispatchKey k) const {
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
}
@ -635,11 +645,18 @@ inline void Dispatcher::callBoxedForDispatchKey(const OperatorHandle& op, Dispat
// We still compute this as we're obligated to pass it on to the internal
// kernel, if it is a boxed fallback
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
const auto& kernel = entry.kernelForDispatchKey(dk);
const auto& kernel = ([&]() {
if (op.hasKernelForDispatchKey(dk)) {
return entry.kernelForDispatchKey(dk);
} else {
auto idx = getDispatchTableIndexForDispatchKey(dk);
TORCH_INTERNAL_ASSERT(idx >= 0);
return backendFallbackKernels_[idx].kernel;
}
})();
kernel.callBoxed(op, dispatchKeySet, stack);
}
inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorDef_->op;

View File

@ -172,6 +172,9 @@ const char* toString(DispatchKey t) {
case DispatchKey::TESTING_ONLY_GenericMode:
return "TESTING_ONLY_GenericMode";
case DispatchKey::PythonDispatcher:
return "PythonDispatcher";
// Aliases
case DispatchKey::Autograd:
@ -283,6 +286,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"TESTING_ONLY_GenericWrapper",
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
{"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
{"CPU", c10::DispatchKey::CPU},
{"CUDA", c10::DispatchKey::CUDA},

View File

@ -401,6 +401,10 @@ enum class DispatchKey : uint16_t {
// for a usage example
TESTING_ONLY_GenericMode,
// This is a bypass that allows you to skip running the C++ dispatcher
// entirely
PythonDispatcher,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
EndOfFunctionalityKeys, // End of functionality keys.

View File

@ -172,7 +172,9 @@ class DispatchKeySet final {
(1ULL
<< (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
1)) -
1) {}
1) {
*this = add(DispatchKey::PythonDispatcher);
}
// Public version of DispatchKeySet(uint64_t) API; external users
// must be explicit when they do this!

View File

@ -8,4 +8,9 @@ PyObject* SafePyObject::ptr(const c10::impl::PyInterpreter* interpreter) const {
return data_;
}
PyObject* SafePyHandle::ptr(const c10::impl::PyInterpreter* interpreter) const {
TORCH_INTERNAL_ASSERT(interpreter == pyinterpreter_);
return data_;
}
} // namespace c10

View File

@ -42,4 +42,29 @@ struct C10_API SafePyObject {
c10::impl::PyInterpreter* pyinterpreter_;
};
// Like SafePyObject, but non-owning. Good for references to global PyObjects
// that will be leaked on interpreter exit. You get a copy constructor/assign
// this way.
struct C10_API SafePyHandle {
SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {}
SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
: data_(data), pyinterpreter_(pyinterpreter) {}
c10::impl::PyInterpreter& pyinterpreter() const {
return *pyinterpreter_;
}
PyObject* ptr(const c10::impl::PyInterpreter*) const;
void reset() {
data_ = nullptr;
pyinterpreter_ = nullptr;
}
operator bool() {
return data_;
}
private:
PyObject* data_;
c10::impl::PyInterpreter* pyinterpreter_;
};
} // namespace c10

View File

@ -27,6 +27,13 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
PANIC(dispatch);
}
void python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet,
torch::jit::Stack* stack) const override {
PANIC(python_dispatcher);
}
bool is_contiguous(const TensorImpl* self) const override {
PANIC(is_contiguous);
}

View File

@ -140,6 +140,12 @@ struct C10_API PyInterpreterVTable {
virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const = 0;
// Invoke the Python dispatcher to handle this call
virtual void python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet,
torch::jit::Stack* stack) const = 0;
virtual bool is_contiguous(const TensorImpl* self) const = 0;
virtual c10::Device device(const TensorImpl* self) const = 0;
virtual int64_t dim(const TensorImpl* self) const = 0;

View File

@ -0,0 +1,32 @@
#include <c10/core/DispatchKeySet.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/impl/PythonDispatcherTLS.h>
namespace c10 {
namespace impl {
thread_local SafePyHandle pythonDispatcherState;
void PythonDispatcherTLS::set_state(SafePyHandle state) {
if (state) {
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonDispatcher, true);
} else {
PythonDispatcherTLS::reset_state();
}
pythonDispatcherState = state;
}
SafePyHandle PythonDispatcherTLS::get_state() {
return pythonDispatcherState;
}
void PythonDispatcherTLS::reset_state() {
pythonDispatcherState.reset();
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonDispatcher, false);
}
} // namespace impl
} // namespace c10

View File

@ -0,0 +1,27 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Macros.h>
#include <c10/util/Optional.h>
namespace c10 {
namespace impl {
struct C10_API PythonDispatcherTLS {
static void set_state(SafePyHandle state);
static SafePyHandle get_state();
static void reset_state();
};
struct C10_API DisablePythonDispatcher {
DisablePythonDispatcher() : old_(PythonDispatcherTLS::get_state()) {
PythonDispatcherTLS::set_state({});
}
~DisablePythonDispatcher() {
PythonDispatcherTLS::set_state(old_);
}
c10::SafePyHandle old_;
};
} // namespace impl
} // namespace c10

View File

@ -1,11 +1,12 @@
import torch
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from functorch.experimental.ops import PyOperator, fallthrough_fn
from torch._ops import PyOperator
from torch.utils._pytree import tree_flatten
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.fx.experimental.proxy_tensor import track_tensor_tree
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
"""
@ -14,6 +15,9 @@ In order to do this, we need implementations for each of the dispatch keys.
"""
from contextlib import contextmanager
cond = PyOperator('cond')
# TODO(voz): Move out somewhere else once other py dispatched ops need it
@contextmanager
def suspend_mode(mode):
@ -98,6 +102,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@cond.py_impl(DispatchKey.CPU)
def cond_dense(pred, true_fn, false_fn, operands):
mode = torch._C._get_torch_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU key"
@ -107,6 +112,7 @@ def cond_dense(pred, true_fn, false_fn, operands):
return false_fn(*operands)
@cond.py_impl(DispatchKey.AutogradCPU)
def cond_autograd(pred, true_fn, false_fn, *operands):
# TODO: support autograd
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
@ -117,21 +123,16 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
return cond(pred, true_fn, false_fn, *operands)
def python_fallback(op):
def inner(pred, true_fn, false_fn, operands):
@cond.py_impl(ProxyTorchDispatchMode)
def inner(pred, true_fn, false_fn, operands):
mode = torch._C._get_torch_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with suspend_mode(mode):
res = trace_cond(mode, op, pred, true_fn, false_fn, operands)
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
return res
return inner
cond = PyOperator('cond')
cond.impl(DispatchKey.CPU, cond_dense)
cond.impl(DispatchKey.Python, python_fallback(cond))
cond.impl(DispatchKey.PythonTLSSnapshot, fallthrough_fn)
cond.impl(DispatchKey.AutogradCPU, cond_autograd)
cond.impl(DispatchKey.ADInplaceOrView, fallthrough_fn)
cond.impl(DispatchKey.BackendSelect, fallthrough_fn)
# TODO(voz): Make this automatic for keys, this is very ugly atm
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
cond.fallthrough(DispatchKey.ADInplaceOrView)
cond.fallthrough(DispatchKey.BackendSelect)

View File

@ -1,36 +1 @@
from torch._dispatch._dispatcher import PyDispatcher, to_flat_tuple, compute_keyset
from torch.nn.functional import handle_torch_function
from torch.overrides import has_torch_function
import torch._C as _C
class PyOperator:
def __init__(self, name):
self.name = name
self.table = {}
self.__name__ = name
def impl(self, dispatch_key, fn):
assert dispatch_key not in self.table
if fn is fallthrough_fn:
self.table[dispatch_key] = fn(self, dispatch_key)
else:
self.table[dispatch_key] = fn
def lookup(self, keyset):
dispatch_key = keyset.highestPriorityTypeId()
return self.table[dispatch_key]
def __call__(self, *args, **kwargs):
flat_args = to_flat_tuple(args, kwargs)
if has_torch_function(flat_args):
return handle_torch_function(self, flat_args, *args, **kwargs)
return PyDispatcher.call(self, *args, **kwargs)
def fallthrough_fn(operator, dispatch_key):
def inner(*args, **kwargs):
all_keys_after_current = _C._dispatch_keyset_full_after(dispatch_key)
all_keys_after_current_masked = all_keys_after_current & compute_keyset(args, kwargs)
return PyDispatcher.redispatch(operator, all_keys_after_current_masked, *args, **kwargs)
return inner
from torch._ops import PyOperator # noqa: F401

View File

@ -19,7 +19,7 @@ class TestControlFlow(TestCase):
class TestControlFlowTraced(TestCase):
def test_cond_traced(self):
def test_cond_traced_not_nested(self):
def true_fn(x):
return x.sin()
@ -126,14 +126,14 @@ class TestControlFlowTraced(TestCase):
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = functorch_experimental_ops_cond(pred_1,
true_graph_0, false_graph_0, [[x_1]]); pred_1 = true_graph_0 = false_graph_0 = None
conditional = torch.ops.cond(pred_1, true_graph_0, false_graph_0, [[x_1]]);
pred_1 = true_graph_0 = false_graph_0 = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
conditional_1 = functorch_experimental_ops_cond(pred2_1,
true_graph_1, false_graph_1, [[x_1, x_1]]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
add_tensor = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
return add_tensor
conditional_1 = torch.ops.cond(pred2_1, true_graph_1, false_graph_1, [[x_1, x_1]]);
pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
add = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
return add
"""
code = graph.code
# Normalization hack, cause .code makes some weird whitespace
@ -145,8 +145,8 @@ class TestControlFlowTraced(TestCase):
out = """
def forward(self, flat_args):
flat_args_1, = fx_pytree.tree_flatten_spec([flat_args], self._in_spec)
mul_tensor = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1); flat_args_1 = None
return pytree.tree_unflatten([mul_tensor], self._out_spec)
mul = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1); flat_args_1 = None
return pytree.tree_unflatten([mul], self._out_spec)
"""
# Normalization hack, cause .code makes some weird whitespace
code = "".join(code.split())

View File

@ -327,7 +327,7 @@ class TestPrims(TestCase):
# Check that all call_function nodes are prims
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
all_prims_namespace = all(
node.target.name.startswith("prims") for node in call_function_nodes
node.target.name().startswith("prims") for node in call_function_nodes
)
self.assertTrue(all_prims_namespace)

View File

@ -963,7 +963,6 @@ symbolic_tensor_failures = {
xfail('bfloat16', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('bmm', ''), # aten.bmm.default - couldn't find symbolic meta function/decomposition
xfail('bool', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('broadcast_tensors', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition
xfail('byte', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel
@ -1131,7 +1130,6 @@ symbolic_tensor_failures = {
xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
xfail('nn.functional.group_norm', ''), # 'torch._C.SymIntNode' and 'int'
xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco...
xfail('nn.functional.huber_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.instance_norm', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.interpolate', 'area'), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.interpolate', 'bicubic'), # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d...
@ -1139,7 +1137,6 @@ symbolic_tensor_failures = {
xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec...
xfail('nn.functional.interpolate', 'nearest'), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d...
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
xfail('nn.functional.l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.linear', ''), # aten.mv.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.local_response_norm', ''), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
@ -1149,7 +1146,6 @@ symbolic_tensor_failures = {
xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.mse_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the...
xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ...
xfail('nn.functional.normalize', ''), # aten.size.default - couldn't find symbolic meta function/decomposition

View File

@ -1833,5 +1833,11 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
self.assertEqual(e.layout, torch.strided)
class TestPythonDispatcher(TestCase):
def test_basic(self):
x = torch.randn(2, requires_grad=True)
r = torch._C._EnablePythonDispatcher()
torch.add(x, x)
if __name__ == '__main__':
run_tests()

View File

@ -136,6 +136,9 @@ blocklist = [
"floor_divide",
"floor_divide_",
"floor_divide_out",
"to",
"_to_copy",
"copy_",
]
binary_ops = (

View File

@ -997,11 +997,26 @@ class Generator(object):
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
def _dispatch_has_kernel_for_any_dispatch_key(name: str, dispatch_key_set: object) -> _bool: ...
def _dispatch_has_kernel(name: str) -> _bool: ...
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...
def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ...
def _dispatch_dump(dispatch: str) -> str: ...
def _dispatch_is_included_in_alias(dispatch_a: str, dispatch_b: str) -> _bool: ...
def _dispatch_get_backend_keyset_from_autograd(dispatch: str) -> Any: ...
def _dispatch_has_backend_fallback(dispatch: str) -> _bool: ...
def _dispatch_key_name(dispatch: str) -> str: ...
# Define in torch/csrc/autograd/init.cpp
class _DisablePythonDispatcher(object):
pass
class _EnablePythonDispatcher(object):
pass
def _set_python_dispatcher(dispatcher: object) -> None: ...
class _AutoDispatchBelowAutograd:
pass

View File

@ -29,7 +29,7 @@ else:
from ._six import string_classes as _string_classes
from typing import Set, Type, TYPE_CHECKING, Union, Callable
from typing import Set, Type, TYPE_CHECKING, Union, Callable, Any
import builtins
__all__ = [
@ -962,3 +962,5 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
import torch.cuda._sanitizer as csan
csan.enable_cuda_sanitizer()
from ._dispatch import python

View File

@ -1,50 +0,0 @@
import torch
import torch._C as _C
from torch.utils._pytree import tree_flatten
"""
This is a dispatcher (in Python)
- You can define new operations (in Python) without schemas
- It interfaces with the PyTorch dispatcher
"""
class PyDispatcher:
# operator is a PyOperator
@staticmethod
def call(operator, *args, **kwargs):
dispatch_key_set = compute_keyset(args, kwargs)
kernel = operator.lookup(dispatch_key_set)
return kernel(*args, **kwargs)
# operator is a PyOperator
@staticmethod
def redispatch(operator, dispatch_key_set, *args, **kwargs):
kernel = operator.lookup(dispatch_key_set)
return kernel(*args, **kwargs)
def compute_keyset(args, kwargs):
tensors = get_tensors(args, kwargs)
return key_extractor(tensors)
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
# at ATen/core/dispatch/DispatchKeyExtractor.h
def key_extractor(tensors):
key_set = _C._dispatch_tls_local_include_set() # type: ignore[attr-defined]
for tensor in tensors:
key_set = key_set | _C._dispatch_keys(tensor) # type: ignore[attr-defined]
key_set = key_set - _C._dispatch_tls_local_exclude_set() # type: ignore[attr-defined]
return key_set
def to_flat_tuple(args, kwargs):
flat_args, _ = tree_flatten(args)
flat_kwargs, _ = tree_flatten(kwargs)
flat_all = flat_args + flat_kwargs
return flat_all
def get_tensors(args, kwargs):
flat_all = to_flat_tuple(args, kwargs)
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
return tuple(tensor_args)

88
torch/_dispatch/python.py Normal file
View File

@ -0,0 +1,88 @@
import torch
from contextlib import contextmanager
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
def has_key(op, k):
return (
torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), k)
or k in op.py_kernels
)
is_included_in_alias = torch._C._dispatch_is_included_in_alias
# Equivalent to computeDispatchTableEntryWithDebug
# TODO: memoize this or something
def resolve_key(op: torch._ops.PyOperatorABC, k: DispatchKey): # type: ignore[valid-type]
# 1. (Direct) operator registration
if has_key(op, k):
return k
# 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
cand = DispatchKey.CompositeExplicitAutogradNonFunctional
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(op, cand):
return cand
# 2.2 Use CompositeExplicitAutograd kernel if available
cand = DispatchKey.CompositeExplicitAutograd
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(op, cand):
return cand
has_backend_kernel = (
torch._C._dispatch_has_kernel_for_any_dispatch_key(op.name(), torch._C._dispatch_get_backend_keyset_from_autograd(k))
or has_key(op, DispatchKey.CompositeExplicitAutograd)
)
# 2.3. Use CompositeImplicitAutograd kernel if available
cand = DispatchKey.CompositeImplicitAutogradNestedTensor
if (
(k != DispatchKey.Undefined and is_included_in_alias(k, cand)) # type: ignore[attr-defined]
and has_key(op, cand) and not has_backend_kernel):
return cand
cand = DispatchKey.CompositeImplicitAutograd
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(op, cand):
if (
k == DispatchKey.AutogradOther
and torch._C._dispatch_has_kernel_for_any_dispatch_key(op.name(), torch._C._dispatch_autogradother_backends) # type: ignore[attr-defined] # noqa: B950
):
raise RuntimeError("ambiguous autogradother kernel")
elif not has_backend_kernel:
return cand
# 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
cand = DispatchKey.Autograd
if is_included_in_alias(k, cand) and has_key(op, cand):
return cand
# Backend fallback
if torch._C._dispatch_has_backend_fallback(k):
# The dispatch key itself will implicitly route to backend fallback.
# This is probably not great for the pure Python implementation.
return k
raise RuntimeError("could not find kernel")
@contextmanager
def no_python_dispatcher():
g = torch._C._DisablePythonDispatcher()
try:
yield
finally:
del g
@contextmanager
def enable_python_dispatcher():
g = torch._C._EnablePythonDispatcher()
try:
yield
finally:
del g
# The Python dispatcher
def python_dispatcher(op, ks, args, kwargs):
"""
with no_python_dispatcher():
print(op, ks, args, kwargs)
"""
k = resolve_key(op, ks.highestPriorityTypeId())
source = f'torch.ops.{op}.dispatch(k, *args, **kwargs)'
filename = f'{op}[{torch._C._dispatch_key_name(k)}]'
compiled = compile(source, filename, 'eval') # TODO: maybe cache?
return eval(compiled, {'torch': torch, 'k': k, 'args': args, 'kwargs': kwargs})
torch._C._set_python_dispatcher(python_dispatcher)

View File

@ -1,7 +1,10 @@
import contextlib
import ctypes
import inspect
import sys
import types
from abc import ABC
from typing import Any, Dict
import torch._C
@ -26,9 +29,134 @@ def dl_open_guard():
sys.setdlopenflags(old_flags)
# TODO(voz) We are missing an entire axis of registration - Modes for the python key
class PyOperatorABC(ABC):
def __call__(self, *args, **kwargs):
pass
def py_impl(self, dispatch_key, fn):
pass
def name(self):
pass
def dispatch(self, dispatch_key, *args, **kwargs):
pass
pyop_namespace = {}
class PyOperator(PyOperatorABC):
def __init__(self, name):
self._name = name
self.table = {}
self.python_key_mode_table = {}
# Make _OPNamespace not scream, this whole name based association needs a good hard look
self.__name__ = name
pyop_namespace[name] = self
def fallthrough(self, dispatch_key):
self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key)
def py_impl(self, dispatch_key_or_mode):
def inner(fn):
if inspect.isclass(dispatch_key_or_mode) and issubclass(
dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode
):
mode = dispatch_key_or_mode
assert mode not in self.python_key_mode_table
# TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
self.python_key_mode_table[mode] = fn
return fn
dispatch_key = dispatch_key_or_mode
assert (
dispatch_key != torch._C.DispatchKey.Python # type: ignore[attr-defined]
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
assert isinstance(dispatch_key, torch._C.DispatchKey) # type: ignore[attr-defined]
assert dispatch_key not in self.table
self.table[dispatch_key] = fn
return fn
return inner
def dispatch(self, dispatch_key, *args, **kwargs):
if dispatch_key == torch._C.DispatchKey.Python: # type: ignore[attr-defined]
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
curr_mode = type(torch._C._get_torch_dispatch_mode())
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
assert (
curr_mode in self.python_key_mode_table
), f"Current active mode {curr_mode} not registered"
# TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
return self.python_key_mode_table[curr_mode](*args, **kwargs)
assert dispatch_key in self.table
return self.table[dispatch_key](*args, **kwargs)
def __call__(self, *args, **kwargs):
flat_args = _to_flat_tuple(args, kwargs)
if torch.overrides.has_torch_function(flat_args):
return torch.overrides.handle_torch_function(
self, flat_args, *args, **kwargs
)
dispatch_key_set = _compute_keyset(args, kwargs)
return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
def name(self):
return self.name
# TODO(voz): Should rewrite fallthrough register as the impl for keys we do not specify
# as opposed to being this sort of explicit thing where ops are a little too key aware...
def _fallthrough_fn(self, operator, dispatch_key):
def inner(*args, **kwargs):
all_keys_after_current = torch._C._dispatch_keyset_full_after(dispatch_key) # type: ignore[attr-defined]
all_keys_after_current_masked = all_keys_after_current & _compute_keyset(
args, kwargs
) # type: ignore[attr-defined]
return self.dispatch(
all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
)
return inner
def _to_flat_tuple(args, kwargs):
flat_args, _ = torch.utils._pytree.tree_flatten(args)
flat_kwargs, _ = torch.utils._pytree.tree_flatten(kwargs)
flat_all = flat_args + flat_kwargs
return flat_all
def _compute_keyset(args, kwargs):
tensors = _get_tensors(args, kwargs)
return key_extractor(tensors)
def _get_tensors(args, kwargs):
flat_all = _to_flat_tuple(args, kwargs)
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
return tuple(tensor_args)
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
# at ATen/core/dispatch/DispatchKeyExtractor.h
def key_extractor(tensors):
key_set = torch._C._dispatch_tls_local_include_set() # type: ignore[attr-defined]
for tensor in tensors:
key_set = key_set | torch._C._dispatch_keys(tensor) # type: ignore[attr-defined]
key_set = key_set - torch._C._dispatch_tls_local_exclude_set() # type: ignore[attr-defined]
return key_set
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
class OpOverload:
class OpOverload(PyOperatorABC):
def __init__(self, overloadpacket, op, op_dk, schema, tags):
self._op = op
self._op_dk = op_dk
@ -38,12 +166,15 @@ class OpOverload:
self._overloadname = (
"default" if schema.overload_name == "" else schema.overload_name
)
self.name = self._schema.name
self._name = self._schema.name
if schema.overload_name:
self.name += "." + schema.overload_name
self._name += "." + schema.overload_name
self.py_kernels: Dict[torch._C.DispatchKey, Any] = {} # type: ignore[name-defined]
self.__name__ = "{}.{}".format(
self._schema.name.split("::")[1], self._overloadname
)
# TODO(voz): Lots of shared logic around python_key_mode_table, maybe pull into base...
self.python_key_mode_table = {}
self.__module__ = overloadpacket.__module__
op.__module__ = overloadpacket.__module__
@ -70,12 +201,56 @@ class OpOverload:
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
def decompose(self, *args, **kwargs):
dk = "CompositeImplicitAutograd"
if torch._C._dispatch_has_kernel_for_dispatch_key(self.name, dk):
return self._op_dk(dk, *args, **kwargs)
dk = torch._C.DispatchKey.CompositeImplicitAutograd # type: ignore[attr-defined]
if (
torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk)
or dk in self.py_kernels
):
return self.dispatch(dk, *args, **kwargs)
else:
return NotImplemented
def py_impl(self, dispatch_key_or_mode):
def inner(fn):
if inspect.isclass(dispatch_key_or_mode) and issubclass(
dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode
):
mode = dispatch_key_or_mode
assert mode not in self.python_key_mode_table
# TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
self.python_key_mode_table[mode] = fn
return fn
assert isinstance(dispatch_key_or_mode, torch._C.DispatchKey) # type: ignore[attr-defined]
assert (
dispatch_key_or_mode != torch._C.DispatchKey.Python # type: ignore[attr-defined]
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
self.py_kernels[dispatch_key_or_mode] = fn
return fn
return inner
def dispatch(self, dispatch_key, *args, **kwargs):
if dispatch_key == torch._C.DispatchKey.Python: # type: ignore[attr-defined]
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
curr_mode = type(torch._C._get_torch_dispatch_mode())
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
if curr_mode not in self.python_key_mode_table:
return self._op_dk(dispatch_key, *args, **kwargs)
# TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
return self.python_key_mode_table[curr_mode](*args, **kwargs)
if dispatch_key in self.py_kernels:
return self.py_kernels[dispatch_key](*args, **kwargs)
else:
return self._op_dk(dispatch_key, *args, **kwargs)
def name(self):
return self._name
@property
def overloadpacket(self):
return self._overloadpacket
@ -253,14 +428,25 @@ class _OpNamespace(types.ModuleType):
return opoverloadpacket
class _PyOpNamespace(_OpNamespace):
def __init__(self):
super(_PyOpNamespace, self).__init__("torch.ops")
self.pyop_namespace = pyop_namespace
class _Ops(types.ModuleType):
__file__ = "_ops.py"
def __init__(self):
super(_Ops, self).__init__("torch.ops")
self.loaded_libraries = set()
self.pyops = _PyOpNamespace()
def __getattr__(self, name):
# Check if the name is a pyop
if name in self.pyops.pyop_namespace:
return self.pyops.pyop_namespace[name]
# Here we are creating `torch.ops.my_namespace`
namespace = _OpNamespace(name)
setattr(self, name, namespace)

View File

@ -271,6 +271,7 @@ __all__ = [
]
Tensor = torch.Tensor
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
def _broadcast_shapes(*_shapes):
@ -2190,7 +2191,11 @@ def broadcast_shapes(*shapes) -> ShapeType:
return torch.Size(_broadcast_shapes(*shapes))
@torch.ops.aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@torch.ops.aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
def broadcast_tensors(*tensors) -> List[TensorLikeType]:
if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
tensors = tensors[0]
return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))

View File

@ -7,6 +7,7 @@
#include <c10/core/DeviceType.h>
#include <c10/core/InferenceMode.h>
#include <c10/core/ScalarType.h>
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/function.h>
@ -52,6 +53,19 @@ struct EnableTorchFunction {
bool old_;
};
PyObject* globalPythonDispatcher = nullptr;
struct EnablePythonDispatcher {
EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) {
c10::impl::PythonDispatcherTLS::set_state(
{globalPythonDispatcher, getPyInterpreter()});
}
~EnablePythonDispatcher() {
c10::impl::PythonDispatcherTLS::set_state(old_);
}
c10::SafePyHandle old_;
};
} // namespace
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
@ -328,6 +342,18 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
.def(py::init<>());
py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction")
.def(py::init<>());
py::class_<EnablePythonDispatcher>(_C_m, "_EnablePythonDispatcher")
.def(py::init<>());
py::class_<c10::impl::DisablePythonDispatcher>(
_C_m, "_DisablePythonDispatcher")
.def(py::init<>());
_C_m.def("_set_python_dispatcher", [](py::object dispatcher) {
TORCH_CHECK(
!globalPythonDispatcher,
"overwriting the global python dispatcher is not supported; if you need this file an issue");
// NB: intentionally leak
globalPythonDispatcher = dispatcher.release().ptr();
});
py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")

View File

@ -3,6 +3,7 @@
#include <c10/core/DeviceType.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <c10/util/DeadlockDetection.h>
#include <c10/util/irange.h>
#include <pybind11/pytypes.h>
@ -235,6 +236,10 @@ struct ConcretePyInterpreterVTable final
void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const override;
void python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet,
torch::jit::Stack* stack) const override;
bool is_contiguous(const TensorImpl* self) const override;
c10::Device device(const TensorImpl* self) const override;
@ -701,6 +706,7 @@ static PyObject* THPVariable_make_subclass(
"cls must be a type (got %s)", Py_TYPE(cls)->tp_name);
}
torch_dispatch_mode::StashTorchDispatchModeGuard td_g;
c10::impl::DisablePythonDispatcher dpd_g;
auto data =
r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED)
// We set `data`'s `allow_tensor_metadata_change` to true here, because we
@ -2265,6 +2271,67 @@ void ConcretePyInterpreterVTable::dispatch(
op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
}
void ConcretePyInterpreterVTable::python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet ks,
torch::jit::Stack* stack) const {
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
// The plan: convert all the arguments back into PyObjects,
// extracting out the tensor handles, then call
// handle_torch_function_no_python_arg_parser
// NB: at the point arguments are pushed to the stack, ALL defaults
// are already present
py::gil_scoped_acquire g;
std::vector<py::handle> overloaded_args;
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
py::handle torch_api_function_overload;
if (overload_name == "") {
torch_api_function_overload = torch_api_function.attr("default");
} else {
torch_api_function_overload =
torch_api_function.attr(overload_name.c_str());
}
std::string module_name_str = "torch.ops." + ns_str;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto args = std::move(args_kwargs.first);
auto kwargs = std::move(args_kwargs.second);
auto python_dispatcher =
c10::impl::PythonDispatcherTLS::get_state().ptr(getPyInterpreter());
TORCH_INTERNAL_ASSERT(python_dispatcher);
py::object obj = py::reinterpret_steal<py::object>(PyObject_CallFunction(
python_dispatcher,
"OOOO",
torch_api_function_overload,
py::cast(ks).ptr(),
args.ptr(),
kwargs.ptr()));
if (obj == nullptr)
throw python_error();
pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher");
}
c10::intrusive_ptr<TensorImpl> ConcretePyInterpreterVTable::detach(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
@ -2377,7 +2444,7 @@ c10::IntArrayRef ConcretePyInterpreterVTable::strides(
if (out == Py_None) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call sizes on a tensor with symbolic shapes/strides");
"Cannot call strides on a tensor with symbolic shapes/strides");
return self->strides_default();
}

View File

@ -1713,13 +1713,11 @@ void initJITBindings(PyObject* module) {
return _get_operation_for_overload_or_packet(
{op}, symbol, args, kwargs, /*is_overload*/ true);
});
auto func_dk =
py::cpp_function([op, symbol, allow_numbers_as_tensors](
const std::string& str_dk,
py::args args,
py::kwargs kwargs) {
auto func_dk = py::cpp_function(
[op, symbol, allow_numbers_as_tensors](
c10::DispatchKey dk_, py::args args, py::kwargs kwargs) {
c10::optional<c10::DispatchKey> dk =
c10::make_optional(c10::parseDispatchKey(str_dk));
c10::make_optional(dk_);
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
return _get_operation_for_overload_or_packet(
{op}, symbol, args, kwargs, /*is_overload*/ true, dk);

View File

@ -136,6 +136,33 @@ struct type_caster<at::Device> {
}
};
template <>
struct type_caster<c10::DispatchKey>
: public type_caster_base<c10::DispatchKey> {
using base = type_caster_base<c10::DispatchKey>;
c10::DispatchKey tmp;
public:
bool load(handle src, bool convert) {
if (base::load(src, convert)) {
return true;
} else if (py::isinstance(
src, py::module_::import("builtins").attr("str"))) {
tmp = c10::parseDispatchKey(py::cast<std::string>(src));
value = &tmp;
return true;
}
return false;
}
static handle cast(
c10::DispatchKey src,
return_value_policy policy,
handle parent) {
return base::cast(src, policy, parent);
}
};
// Pybind11 bindings for our optional and variant types.
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
template <typename T>

View File

@ -92,6 +92,7 @@ bool should_allow_numbers_as_tensors(const std::string& name) {
"sub", "sub_", "sub_out",
"subtract", "subtract_", "subtract_out", // alias of sub
"true_divide", "true_divide_", "true_divide_out",
"to", "_to_copy", "copy_",
"floor_divide", "floor_divide_", "floor_divide_out"};
return allowed.find(name) != allowed.end();
}

View File

@ -295,11 +295,20 @@ void initDispatchBindings(PyObject* module) {
// Returns whether or not a direct kernel registration exists
// for this <op_name, dispatch_key> pair.
"_dispatch_has_kernel_for_dispatch_key",
[](const char* name, const char* dispatch) -> bool {
[](const char* name, c10::DispatchKey dispatch) -> bool {
auto op =
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
TORCH_CHECK(op, "operator ", name, " does not exist");
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
return op->hasKernelForDispatchKey(dispatch);
});
m.def(
"_dispatch_has_kernel_for_any_dispatch_key",
[](const char* name, c10::DispatchKeySet ks) -> bool {
auto op =
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
TORCH_CHECK(op, "operator ", name, " does not exist");
return op->hasKernelForAnyDispatchKey(ks);
});
m.def(
@ -332,38 +341,48 @@ void initDispatchBindings(PyObject* module) {
m.def(
"_dispatch_tls_set_dispatch_key_excluded",
[](const char* dispatch_key, bool desired_state) {
c10::impl::tls_set_dispatch_key_excluded(
c10::parseDispatchKey(dispatch_key), desired_state);
[](c10::DispatchKey dispatch_key, bool desired_state) {
c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
});
m.def("_dispatch_tls_is_dispatch_key_excluded", [](const char* dispatch_key) {
return c10::impl::tls_is_dispatch_key_excluded(
c10::parseDispatchKey(dispatch_key));
m.def(
"_dispatch_tls_is_dispatch_key_excluded",
[](c10::DispatchKey dispatch_key) {
return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
});
m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
return at::isTensorSubclassLike(tensor);
});
m.def("_dispatch_key_name", [](uint64_t dispatch_key) {
auto dt = (c10::DispatchKey)dispatch_key;
return c10::toString(dt);
m.def("_dispatch_key_name", [](c10::DispatchKey k) {
return c10::toString(k);
});
m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
m.def("_dispatch_num_backends", []() { return c10::num_backends; });
py::enum_<c10::DispatchKey>(m, "DispatchKey")
.value("Undefined", c10::DispatchKey::Undefined)
.value("Dense", c10::DispatchKey::Dense)
.value("BackendSelect", c10::DispatchKey::BackendSelect)
.value("CPU", c10::DispatchKey::CPU)
.value("CUDA", c10::DispatchKey::CUDA)
.value("AutocastCPU", c10::DispatchKey::AutocastCPU)
.value("AutocastCUDA", c10::DispatchKey::AutocastCUDA)
.value("AutogradCPU", c10::DispatchKey::AutogradCPU)
.value("ADInplaceOrView", c10::DispatchKey::ADInplaceOrView)
.value("AutogradCUDA", c10::DispatchKey::AutogradCUDA)
.value("PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot)
.value("Python", c10::DispatchKey::Python);
#define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
py::enum_<c10::DispatchKey>(m, "DispatchKey") DEF_ONE(Undefined)
DEF_ONE(CompositeExplicitAutogradNonFunctional)
DEF_ONE(CompositeExplicitAutograd)
DEF_ONE(CompositeImplicitAutogradNestedTensor)
DEF_ONE(CompositeImplicitAutograd) DEF_ONE(AutogradOther)
DEF_ONE(Autograd) DEF_ONE(BackendSelect)
DEF_ONE(ADInplaceOrView) DEF_ONE(PythonTLSSnapshot)
DEF_ONE(Python)
#define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
#define DEF_MULTIPLE(fullname, prefix) \
DEF_SINGLE(, fullname) \
DEF_SINGLE(, StartOf##fullname##Backends) \
C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
DEF_SINGLE(, EndOf##fullname##Backends)
C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
#undef DEF_MULTIPLE
#undef DEF_SINGLE
;
py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
.def(py::init<c10::DispatchKey>())
@ -371,7 +390,15 @@ void initDispatchBindings(PyObject* module) {
.def("__sub__", &c10::DispatchKeySet::operator-)
.def("__and__", &c10::DispatchKeySet::operator&)
.def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
.def("has", &c10::DispatchKeySet::has);
.def("has", &c10::DispatchKeySet::has)
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
m.attr("_dispatch_autogradother_backends") =
py::cast(c10::autogradother_backends);
m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
});
m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
@ -381,6 +408,10 @@ void initDispatchBindings(PyObject* module) {
return c10::toString(keyset);
});
m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
return c10::getBackendKeySetFromAutograd(k);
});
m.def("_dispatch_keys", [](const at::Tensor& tensor) {
auto* impl = tensor.unsafeGetTensorImpl();
return impl->key_set();
@ -391,6 +422,11 @@ void initDispatchBindings(PyObject* module) {
m.def("_dispatch_tls_local_exclude_set", []() {
return c10::impl::tls_local_dispatch_key_set().excluded_;
});
m.def(
"_dispatch_is_included_in_alias",
[](c10::DispatchKey a, c10::DispatchKey b) {
return c10::isIncludedInAlias(a, b);
});
py::class_<c10::impl::ExcludeDispatchKeyGuard>(m, "ExcludeDispatchKeyGuard")
.def(py::init<c10::DispatchKeySet>());

View File

@ -10,6 +10,7 @@ import torch
import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._dispatch.python import enable_python_dispatcher
import torch.fx as fx
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager, nullcontext
@ -579,6 +580,10 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
else:
raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
python_dispatcher_mode: Any = nullcontext()
if tracing_mode == "symbolic":
python_dispatcher_mode = enable_python_dispatcher()
proxy_mode = ProxyTorchDispatchMode(fx_tracer)
def wrap_fake_concrete(x):
@ -617,7 +622,7 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
# We disable the autocast cache as the autocast cache causes type conversions on parameters to
# check a cache, which introduces untracked tensors into the graph
with decompose(decomposition_table), fake_tensor_mode, \
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \
sym_mode, proxy_mode, disable_autocast_cache(): # type: ignore[attr-defined]
t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))