change pre_autograd to pre_dispatch tracing (#101818)

We discussed in a composability meeting a few weeks ago that `pre_autograd` should probably be renamed to `pre_dispatch`.

One question in this PR was: should I re-use a dispatch key? Or should I create a new dispatch key (that yet again corresponds to "top of the dispatcher")?

~~For now, I ended up sticking our proxy mode on the mode stack corresponding to `PythonTLSSnapshot`, because it was simple and it works. It looks like one of the functorch dispatch keys has higher priority though, so it's possible that functorch will end up running first. Open to options, but we can consider adding a new dispatch key later if that becomes a problem~~

Update: I added a dedicated dispatch key, `PreDispatch`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101818
Approved by: https://github.com/ezyang, https://github.com/Neilblaze, https://github.com/albanD, https://github.com/zou3519
This commit is contained in:
Brian Hirsh
2023-06-08 23:05:52 +00:00
committed by PyTorch MergeBot
parent 1c3a7d9a7e
commit b0392de2c3
12 changed files with 85 additions and 66 deletions

View File

@ -107,6 +107,15 @@ void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySe
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
}
// The PreDispatch key gets a no-op fallback that just redispatches.
// The main way this key is used is that we can register a mode to it from python (e.g. TorchProxyDispatchMode, for pre_dispatch tracing)
// Can't this be a fallthrough kernel, instead of a fallback that just no-ops and redispatches?
// Unfortunately, no: we need a real kernel that is not a fallthrough, in order for the PythonDispatcher to interpose on it.
// Alternatively, we could have hardcoded this kernel (in C++) to directly call in TorchProxyDispatchMode.
// Doing that in C++ is a pain though, so it's done in python using the PythonDispatcher for convenience.
void preDispatchFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PreDispatch), stack);
}
} // anonymous namespace
@ -152,3 +161,7 @@ TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
}
TORCH_LIBRARY_IMPL(_, PreDispatch, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&preDispatchFallback>());
}

View File

@ -180,6 +180,9 @@ const char* toString(DispatchKey t) {
case DispatchKey::TESTING_ONLY_GenericMode:
return "TESTING_ONLY_GenericMode";
case DispatchKey::PreDispatch:
return "PreDispatch";
case DispatchKey::PythonDispatcher:
return "PythonDispatcher";
@ -300,6 +303,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
{"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
{"PreDispatch", c10::DispatchKey::PreDispatch},
{"CPU", c10::DispatchKey::CPU},
{"CUDA", c10::DispatchKey::CUDA},

View File

@ -406,6 +406,12 @@ enum class DispatchKey : uint16_t {
// for a usage example
TESTING_ONLY_GenericMode,
// This key is used for pre-dispatch tracing in make_fx.
// It has lower priority than the PythonDispatcher key
// because we use the PythonDispatcher to intercept the key from python,
// and avoid having to implement it in C++.
PreDispatch,
// This is a bypass that allows you to skip running the C++ dispatcher
// entirely
PythonDispatcher,

View File

@ -1893,9 +1893,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
inp = torch.randn(6, 7)
self.assertEqual(gm(inp), f(inp))
# pre_autograd seems to violate new fake tensor invariants
@unittest.expectedFailure
def test_pre_autograd_simple(self):
def test_pre_dispatch_simple(self):
def f(x):
y = torch.ones_like(x)
return torch.matmul(x, y)
@ -1904,7 +1902,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
f,
torch.randn(5, 5),
aten_graph=True,
pre_autograd=True,
pre_dispatch=True,
tracing_mode="fake",
)

View File

@ -146,7 +146,7 @@ class TestGenericProxyTensor(TestCase):
r2 = f(*new_inps)
self.assertEqual(r1, r2)
def test_pre_autograd_mode_stack(self):
def test_pre_dispatch_mode_stack(self):
def f(a):
b = torch.ones(4, 4)
return torch.matmul(a, b)
@ -155,17 +155,28 @@ class TestGenericProxyTensor(TestCase):
# This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
# so our mode never sees it - it goes directly to the BackendSelect key.
inp = torch.ones(4, 4)
# Test that make_fx(pre_autograd=True) clears caches properly.
# Test that make_fx(pre_dispatch=True) clears caches properly.
from torch._dispatch.python import enable_python_dispatcher
with enable_python_dispatcher():
out1 = f(inp)
fx_g = make_fx(f, pre_autograd=True)(inp)
fx_g = make_fx(f, pre_dispatch=True)(inp)
self.assertExpectedInline(fx_g.code.strip(), """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None
return matmul""")
def test_pre_dispatch_linear(self):
def f(a, b, c):
return torch.nn.functional.linear(a, b, c)
a = torch.ones(4, 4)
b = torch.ones(4, 4)
c = torch.ones(4)
fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
out1 = f(a, b, c)
out2 = fx_g(a, b, c)
self.assertEqual(out1, out2)
def test_make_fx_simple(self):
def f(x):

View File

@ -1191,6 +1191,11 @@ class _DisablePythonDispatcher:
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
class _EnablePreDispatch:
def __init__(self): ...
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
class _DisableFuncTorch:
def __init__(self): ...
def __enter__(self): ...

View File

@ -7,10 +7,11 @@ import itertools
from typing import Iterator
import torch._ops
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher', 'enable_pre_dispatch']
no_python_dispatcher = torch._C._DisablePythonDispatcher
enable_python_dispatcher = torch._C._EnablePythonDispatcher
enable_pre_dispatch = torch._C._EnablePreDispatch
CROSSREF_FUNCTIONALIZE = False

View File

@ -833,7 +833,7 @@ def export(
f: Callable[..., Any],
*args,
aten_graph: bool = False,
pre_autograd: bool = False,
pre_dispatch: bool = False,
decomposition_table: Optional[
Dict[torch._ops.OpOverload, Callable[..., Any]]
] = None,
@ -853,9 +853,10 @@ def export(
aten_graph (bool): If True, exports a graph with ATen operators.
If False, exports a graph with Python operators. Default is False.
pre_autograd (bool): If True, exports a graph with ATen operators,
but before autograd has run. This can be useful if you want to apply further tranformations
on a graph before running it through autograd.
pre_dispatch (bool): If True, exports a graph with ATen operators,
but before any logic in the PyTorch dispatcher has run.
This can be useful if you want to apply further tranformations on a graph before running it
through autograd, autocast, or any other functionalities that are integrated into the dispatcher.
This flag is only valid if aten_graph=True is set.
Default is False.
@ -885,8 +886,8 @@ def export(
assert (
aten_graph
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
if pre_autograd:
assert aten_graph, "pre_autograd=True can only be used when aten_graph=True"
if pre_dispatch:
assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
f = innermost_fn(f)
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
original_signature = inspect.signature(call_to_inspect)
@ -1041,7 +1042,7 @@ def export(
decomposition_table=decomposition_table,
tracing_mode="real",
_allow_non_fake_inputs=True,
pre_autograd=pre_autograd,
pre_dispatch=pre_dispatch,
)(*example_fake_inputs)
except CondOpArgsMismatchError as e:
# Wrap the internal error to the user-facing error

View File

@ -98,6 +98,11 @@ struct EnablePythonDispatcher {
c10::impl::PyInterpreter* old_;
};
struct EnablePreDispatch {
EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
c10::impl::IncludeDispatchKeyGuard guard_;
};
} // namespace
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
@ -419,6 +424,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
_C_m, "_EnablePythonDispatcher");
py_context_manager<c10::impl::DisablePythonDispatcher>(
_C_m, "_DisablePythonDispatcher");
py_context_manager<EnablePreDispatch>(_C_m, "_EnablePreDispatch");
py_context_manager_DEPRECATED<DisableFuncTorch>(_C_m, "_DisableFuncTorch");
py_context_manager_DEPRECATED<MultithreadingEnabled, bool>(
_C_m, "_MultithreadingEnabled");

View File

@ -561,6 +561,7 @@ void initDispatchBindings(PyObject* module) {
DEF_ONE(FuncTorchVmapMode)
DEF_ONE(FuncTorchGradWrapper)
DEF_ONE(PythonDispatcher)
DEF_ONE(PreDispatch)
DEF_ONE(Functionalize)
DEF_ONE(AutocastCPU)
DEF_ONE(AutocastXPU)

View File

@ -10,7 +10,7 @@ import torch
import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._dispatch.python import enable_python_dispatcher
from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch
import torch.fx as fx
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager, nullcontext
@ -246,16 +246,7 @@ def fetch_tensor_proxy(tracer):
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
@contextlib.contextmanager
def inside_mode(proxy_mode):
old = proxy_mode.is_inside_mode
proxy_mode.is_inside_mode = True
try:
yield
finally:
proxy_mode.is_inside_mode = old
def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
unrecognized_types = []
def can_handle_tensor(x):
@ -277,7 +268,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
return r
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
if not pre_autograd:
if not pre_dispatch:
with proxy_mode:
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:
@ -356,12 +347,6 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
if func is torch.ops.aten.lift_fresh.default:
func = torch.ops.aten.lift_fresh_copy.default
# See Note [Per-Dispatch-Key Modes Must Be Reentrant]
# If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks)
# then we only want it to trace out proxies the first time that we hit an op.
if proxy_mode.is_inside_mode:
return func(*args, **kwargs)
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
@ -377,8 +362,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
else:
args[0].proxy = proxy_out
with inside_mode(proxy_mode):
out = func(*args, **kwargs)
out = func(*args, **kwargs)
# In some circumstances, we will be tracing in a situation where a tensor
# is *statically* known to be a constant (currently, this only happens if
@ -426,12 +410,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
else:
constant = None
# See Note [Per-Dispatch-Key Modes Must Be Reentrant]
# If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks)
# then we only want it to trace out proxies the first time that we hit an op.
# In particular, track_tensor_tree can call detach().
with inside_mode(proxy_mode):
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
return out
@ -486,9 +465,9 @@ def dispatch_trace(
return GraphModule(tracer.root, graph, name)
def wrap_key(f, tensors, tracer, pre_autograd: bool):
def wrap_key(f, tensors, tracer, pre_dispatch: bool):
flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None
dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
@functools.wraps(f)
def wrapped(*proxies):
@ -532,14 +511,13 @@ def set_original_aten_op(func):
class ProxyTorchDispatchMode(TorchDispatchMode):
def __init__(self, tracer, tracing_mode, pre_autograd=False):
dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None
def __init__(self, tracer, tracing_mode, pre_dispatch=False):
dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
super().__init__(dk)
self.tracer = tracer
self.tracing_mode = tracing_mode
self.enable_tracing = True
self.pre_autograd = pre_autograd
self.is_inside_mode = False
self.pre_dispatch = pre_dispatch
self.sym_mode = ProxySymDispatchMode(tracer)
self.trace_state = {}
self._managers = []
@ -572,7 +550,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
if func in [prim.device.default]:
return func(*args, **kwargs)
return proxy_call(self, func, self.pre_autograd, args, kwargs)
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
class ProxySymDispatchMode(SymDispatchMode):
@ -701,7 +679,7 @@ def disable_autocast_cache():
torch.set_autocast_cache_enabled(old_value)
def make_fx(f, decomposition_table=None, tracing_mode="real", _allow_non_fake_inputs=False, *, pre_autograd=False):
def make_fx(f, decomposition_table=None, tracing_mode="real", _allow_non_fake_inputs=False, *, pre_dispatch=False):
assert tracing_mode in ["real", "fake", "symbolic"]
if decomposition_table is None:
@ -738,12 +716,15 @@ def make_fx(f, decomposition_table=None, tracing_mode="real", _allow_non_fake_in
raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
python_dispatcher_mode: Any = nullcontext()
pre_dispatch_mode: Any = nullcontext()
# pre-autograd tracing uses per-dispatch-key modes,
# which requires the python dispatcher
if tracing_mode == "symbolic" or pre_autograd:
if tracing_mode == "symbolic" or pre_dispatch:
python_dispatcher_mode = enable_python_dispatcher()
if pre_dispatch:
pre_dispatch_mode = enable_pre_dispatch()
proxy_mode = ProxyTorchDispatchMode(fx_tracer, tracing_mode, pre_autograd=pre_autograd)
proxy_mode = ProxyTorchDispatchMode(fx_tracer, tracing_mode, pre_dispatch=pre_dispatch)
arg_count = 0
@ -783,9 +764,9 @@ def make_fx(f, decomposition_table=None, tracing_mode="real", _allow_non_fake_in
# We also disable tracing by any other tensor proxy-based tracers except the current. The
# purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
# thus irrelevant to any external functional trace.
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, \
sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True):
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_autograd), tracer=fx_tracer, concrete_args=tuple(phs))
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
# TODO: kind of a bad way to do it, should maybe figure out a better way
if tracing_mode == "symbolic":

View File

@ -83,24 +83,16 @@ def _push_mode(mode, k: Optional[DispatchKey] = None):
for key in ks:
op._uncache_dispatch(key)
push_mode_for_key(k, mode)
# Note [Per-Dispatch-Key Modes Must Be Reentrant]
# The idea here is that we are allowed to push modes onto any dispatch key's mode stack, but:
# (1) We **always** push the mode onto the python mode stack. Operators can have fallthrough
# kernels registered to any dispatch key, so we use the Python mode stack as a catchall,
# to guarantee that every op will be seen by our mode.
# (2) We expect the mode that you push to handle being re-entrant: If we end up invoking the mode
# at both the Autograd key and the Python key, nothing bad should happen.
# The main use case for this is pre-autograd tracing with TorchProxyDispatchMode.
_push_on_torch_dispatch_stack(mode)
else:
_push_on_torch_dispatch_stack(mode)
def _pop_mode(k: Optional[DispatchKey] = None):
m = _pop_torch_dispatch_stack()
if k is not None:
from torch._ops import pop_mode_for_key
tmp = pop_mode_for_key(k)
assert m is tmp
return m
return pop_mode_for_key(k)
else:
return _pop_torch_dispatch_stack()
@contextlib.contextmanager