mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
change pre_autograd to pre_dispatch tracing (#101818)
We discussed in a composability meeting a few weeks ago that `pre_autograd` should probably be renamed to `pre_dispatch`. One question in this PR was: should I re-use a dispatch key? Or should I create a new dispatch key (that yet again corresponds to "top of the dispatcher")? ~~For now, I ended up sticking our proxy mode on the mode stack corresponding to `PythonTLSSnapshot`, because it was simple and it works. It looks like one of the functorch dispatch keys has higher priority though, so it's possible that functorch will end up running first. Open to options, but we can consider adding a new dispatch key later if that becomes a problem~~ Update: I added a dedicated dispatch key, `PreDispatch`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/101818 Approved by: https://github.com/ezyang, https://github.com/Neilblaze, https://github.com/albanD, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c3a7d9a7e
commit
b0392de2c3
@ -107,6 +107,15 @@ void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySe
|
||||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
|
||||
}
|
||||
|
||||
// The PreDispatch key gets a no-op fallback that just redispatches.
|
||||
// The main way this key is used is that we can register a mode to it from python (e.g. TorchProxyDispatchMode, for pre_dispatch tracing)
|
||||
// Can't this be a fallthrough kernel, instead of a fallback that just no-ops and redispatches?
|
||||
// Unfortunately, no: we need a real kernel that is not a fallthrough, in order for the PythonDispatcher to interpose on it.
|
||||
// Alternatively, we could have hardcoded this kernel (in C++) to directly call in TorchProxyDispatchMode.
|
||||
// Doing that in C++ is a pain though, so it's done in python using the PythonDispatcher for convenience.
|
||||
void preDispatchFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
||||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PreDispatch), stack);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
@ -152,3 +161,7 @@ TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
|
||||
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, PreDispatch, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&preDispatchFallback>());
|
||||
}
|
||||
|
@ -180,6 +180,9 @@ const char* toString(DispatchKey t) {
|
||||
case DispatchKey::TESTING_ONLY_GenericMode:
|
||||
return "TESTING_ONLY_GenericMode";
|
||||
|
||||
case DispatchKey::PreDispatch:
|
||||
return "PreDispatch";
|
||||
|
||||
case DispatchKey::PythonDispatcher:
|
||||
return "PythonDispatcher";
|
||||
|
||||
@ -300,6 +303,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
|
||||
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
|
||||
{"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
|
||||
{"PreDispatch", c10::DispatchKey::PreDispatch},
|
||||
|
||||
{"CPU", c10::DispatchKey::CPU},
|
||||
{"CUDA", c10::DispatchKey::CUDA},
|
||||
|
@ -406,6 +406,12 @@ enum class DispatchKey : uint16_t {
|
||||
// for a usage example
|
||||
TESTING_ONLY_GenericMode,
|
||||
|
||||
// This key is used for pre-dispatch tracing in make_fx.
|
||||
// It has lower priority than the PythonDispatcher key
|
||||
// because we use the PythonDispatcher to intercept the key from python,
|
||||
// and avoid having to implement it in C++.
|
||||
PreDispatch,
|
||||
|
||||
// This is a bypass that allows you to skip running the C++ dispatcher
|
||||
// entirely
|
||||
PythonDispatcher,
|
||||
|
@ -1893,9 +1893,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
||||
inp = torch.randn(6, 7)
|
||||
self.assertEqual(gm(inp), f(inp))
|
||||
|
||||
# pre_autograd seems to violate new fake tensor invariants
|
||||
@unittest.expectedFailure
|
||||
def test_pre_autograd_simple(self):
|
||||
def test_pre_dispatch_simple(self):
|
||||
def f(x):
|
||||
y = torch.ones_like(x)
|
||||
return torch.matmul(x, y)
|
||||
@ -1904,7 +1902,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
||||
f,
|
||||
torch.randn(5, 5),
|
||||
aten_graph=True,
|
||||
pre_autograd=True,
|
||||
pre_dispatch=True,
|
||||
tracing_mode="fake",
|
||||
)
|
||||
|
||||
|
@ -146,7 +146,7 @@ class TestGenericProxyTensor(TestCase):
|
||||
r2 = f(*new_inps)
|
||||
self.assertEqual(r1, r2)
|
||||
|
||||
def test_pre_autograd_mode_stack(self):
|
||||
def test_pre_dispatch_mode_stack(self):
|
||||
def f(a):
|
||||
b = torch.ones(4, 4)
|
||||
return torch.matmul(a, b)
|
||||
@ -155,17 +155,28 @@ class TestGenericProxyTensor(TestCase):
|
||||
# This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
|
||||
# so our mode never sees it - it goes directly to the BackendSelect key.
|
||||
inp = torch.ones(4, 4)
|
||||
# Test that make_fx(pre_autograd=True) clears caches properly.
|
||||
# Test that make_fx(pre_dispatch=True) clears caches properly.
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
with enable_python_dispatcher():
|
||||
out1 = f(inp)
|
||||
fx_g = make_fx(f, pre_autograd=True)(inp)
|
||||
fx_g = make_fx(f, pre_dispatch=True)(inp)
|
||||
self.assertExpectedInline(fx_g.code.strip(), """\
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
|
||||
matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None
|
||||
return matmul""")
|
||||
|
||||
def test_pre_dispatch_linear(self):
|
||||
def f(a, b, c):
|
||||
return torch.nn.functional.linear(a, b, c)
|
||||
a = torch.ones(4, 4)
|
||||
b = torch.ones(4, 4)
|
||||
c = torch.ones(4)
|
||||
fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
|
||||
out1 = f(a, b, c)
|
||||
out2 = fx_g(a, b, c)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
|
||||
def test_make_fx_simple(self):
|
||||
def f(x):
|
||||
|
@ -1191,6 +1191,11 @@ class _DisablePythonDispatcher:
|
||||
def __enter__(self): ...
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
class _EnablePreDispatch:
|
||||
def __init__(self): ...
|
||||
def __enter__(self): ...
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
class _DisableFuncTorch:
|
||||
def __init__(self): ...
|
||||
def __enter__(self): ...
|
||||
|
@ -7,10 +7,11 @@ import itertools
|
||||
from typing import Iterator
|
||||
import torch._ops
|
||||
|
||||
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
|
||||
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher', 'enable_pre_dispatch']
|
||||
|
||||
no_python_dispatcher = torch._C._DisablePythonDispatcher
|
||||
enable_python_dispatcher = torch._C._EnablePythonDispatcher
|
||||
enable_pre_dispatch = torch._C._EnablePreDispatch
|
||||
|
||||
CROSSREF_FUNCTIONALIZE = False
|
||||
|
||||
|
@ -833,7 +833,7 @@ def export(
|
||||
f: Callable[..., Any],
|
||||
*args,
|
||||
aten_graph: bool = False,
|
||||
pre_autograd: bool = False,
|
||||
pre_dispatch: bool = False,
|
||||
decomposition_table: Optional[
|
||||
Dict[torch._ops.OpOverload, Callable[..., Any]]
|
||||
] = None,
|
||||
@ -853,9 +853,10 @@ def export(
|
||||
aten_graph (bool): If True, exports a graph with ATen operators.
|
||||
If False, exports a graph with Python operators. Default is False.
|
||||
|
||||
pre_autograd (bool): If True, exports a graph with ATen operators,
|
||||
but before autograd has run. This can be useful if you want to apply further tranformations
|
||||
on a graph before running it through autograd.
|
||||
pre_dispatch (bool): If True, exports a graph with ATen operators,
|
||||
but before any logic in the PyTorch dispatcher has run.
|
||||
This can be useful if you want to apply further tranformations on a graph before running it
|
||||
through autograd, autocast, or any other functionalities that are integrated into the dispatcher.
|
||||
This flag is only valid if aten_graph=True is set.
|
||||
Default is False.
|
||||
|
||||
@ -885,8 +886,8 @@ def export(
|
||||
assert (
|
||||
aten_graph
|
||||
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
|
||||
if pre_autograd:
|
||||
assert aten_graph, "pre_autograd=True can only be used when aten_graph=True"
|
||||
if pre_dispatch:
|
||||
assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
|
||||
f = innermost_fn(f)
|
||||
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
|
||||
original_signature = inspect.signature(call_to_inspect)
|
||||
@ -1041,7 +1042,7 @@ def export(
|
||||
decomposition_table=decomposition_table,
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=True,
|
||||
pre_autograd=pre_autograd,
|
||||
pre_dispatch=pre_dispatch,
|
||||
)(*example_fake_inputs)
|
||||
except CondOpArgsMismatchError as e:
|
||||
# Wrap the internal error to the user-facing error
|
||||
|
@ -98,6 +98,11 @@ struct EnablePythonDispatcher {
|
||||
c10::impl::PyInterpreter* old_;
|
||||
};
|
||||
|
||||
struct EnablePreDispatch {
|
||||
EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
|
||||
c10::impl::IncludeDispatchKeyGuard guard_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
@ -419,6 +424,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
_C_m, "_EnablePythonDispatcher");
|
||||
py_context_manager<c10::impl::DisablePythonDispatcher>(
|
||||
_C_m, "_DisablePythonDispatcher");
|
||||
py_context_manager<EnablePreDispatch>(_C_m, "_EnablePreDispatch");
|
||||
py_context_manager_DEPRECATED<DisableFuncTorch>(_C_m, "_DisableFuncTorch");
|
||||
py_context_manager_DEPRECATED<MultithreadingEnabled, bool>(
|
||||
_C_m, "_MultithreadingEnabled");
|
||||
|
@ -561,6 +561,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
DEF_ONE(FuncTorchVmapMode)
|
||||
DEF_ONE(FuncTorchGradWrapper)
|
||||
DEF_ONE(PythonDispatcher)
|
||||
DEF_ONE(PreDispatch)
|
||||
DEF_ONE(Functionalize)
|
||||
DEF_ONE(AutocastCPU)
|
||||
DEF_ONE(AutocastXPU)
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx import Tracer, GraphModule
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch
|
||||
import torch.fx as fx
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from contextlib import contextmanager, nullcontext
|
||||
@ -246,16 +246,7 @@ def fetch_tensor_proxy(tracer):
|
||||
|
||||
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inside_mode(proxy_mode):
|
||||
old = proxy_mode.is_inside_mode
|
||||
proxy_mode.is_inside_mode = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
proxy_mode.is_inside_mode = old
|
||||
|
||||
def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
||||
def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
|
||||
unrecognized_types = []
|
||||
|
||||
def can_handle_tensor(x):
|
||||
@ -277,7 +268,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
||||
return r
|
||||
|
||||
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
|
||||
if not pre_autograd:
|
||||
if not pre_dispatch:
|
||||
with proxy_mode:
|
||||
r = func.decompose(*args, **kwargs)
|
||||
if r is not NotImplemented:
|
||||
@ -356,12 +347,6 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
||||
if func is torch.ops.aten.lift_fresh.default:
|
||||
func = torch.ops.aten.lift_fresh_copy.default
|
||||
|
||||
# See Note [Per-Dispatch-Key Modes Must Be Reentrant]
|
||||
# If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks)
|
||||
# then we only want it to trace out proxies the first time that we hit an op.
|
||||
if proxy_mode.is_inside_mode:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
|
||||
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
|
||||
|
||||
@ -377,8 +362,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
||||
else:
|
||||
args[0].proxy = proxy_out
|
||||
|
||||
with inside_mode(proxy_mode):
|
||||
out = func(*args, **kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# In some circumstances, we will be tracing in a situation where a tensor
|
||||
# is *statically* known to be a constant (currently, this only happens if
|
||||
@ -426,12 +410,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
||||
else:
|
||||
constant = None
|
||||
|
||||
# See Note [Per-Dispatch-Key Modes Must Be Reentrant]
|
||||
# If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks)
|
||||
# then we only want it to trace out proxies the first time that we hit an op.
|
||||
# In particular, track_tensor_tree can call detach().
|
||||
with inside_mode(proxy_mode):
|
||||
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
|
||||
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
|
||||
return out
|
||||
|
||||
|
||||
@ -486,9 +465,9 @@ def dispatch_trace(
|
||||
return GraphModule(tracer.root, graph, name)
|
||||
|
||||
|
||||
def wrap_key(f, tensors, tracer, pre_autograd: bool):
|
||||
def wrap_key(f, tensors, tracer, pre_dispatch: bool):
|
||||
flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
|
||||
dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None
|
||||
dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*proxies):
|
||||
@ -532,14 +511,13 @@ def set_original_aten_op(func):
|
||||
|
||||
|
||||
class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
def __init__(self, tracer, tracing_mode, pre_autograd=False):
|
||||
dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None
|
||||
def __init__(self, tracer, tracing_mode, pre_dispatch=False):
|
||||
dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
|
||||
super().__init__(dk)
|
||||
self.tracer = tracer
|
||||
self.tracing_mode = tracing_mode
|
||||
self.enable_tracing = True
|
||||
self.pre_autograd = pre_autograd
|
||||
self.is_inside_mode = False
|
||||
self.pre_dispatch = pre_dispatch
|
||||
self.sym_mode = ProxySymDispatchMode(tracer)
|
||||
self.trace_state = {}
|
||||
self._managers = []
|
||||
@ -572,7 +550,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
if func in [prim.device.default]:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return proxy_call(self, func, self.pre_autograd, args, kwargs)
|
||||
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
|
||||
|
||||
|
||||
class ProxySymDispatchMode(SymDispatchMode):
|
||||
@ -701,7 +679,7 @@ def disable_autocast_cache():
|
||||
torch.set_autocast_cache_enabled(old_value)
|
||||
|
||||
|
||||
def make_fx(f, decomposition_table=None, tracing_mode="real", _allow_non_fake_inputs=False, *, pre_autograd=False):
|
||||
def make_fx(f, decomposition_table=None, tracing_mode="real", _allow_non_fake_inputs=False, *, pre_dispatch=False):
|
||||
assert tracing_mode in ["real", "fake", "symbolic"]
|
||||
|
||||
if decomposition_table is None:
|
||||
@ -738,12 +716,15 @@ def make_fx(f, decomposition_table=None, tracing_mode="real", _allow_non_fake_in
|
||||
raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
|
||||
|
||||
python_dispatcher_mode: Any = nullcontext()
|
||||
pre_dispatch_mode: Any = nullcontext()
|
||||
# pre-autograd tracing uses per-dispatch-key modes,
|
||||
# which requires the python dispatcher
|
||||
if tracing_mode == "symbolic" or pre_autograd:
|
||||
if tracing_mode == "symbolic" or pre_dispatch:
|
||||
python_dispatcher_mode = enable_python_dispatcher()
|
||||
if pre_dispatch:
|
||||
pre_dispatch_mode = enable_pre_dispatch()
|
||||
|
||||
proxy_mode = ProxyTorchDispatchMode(fx_tracer, tracing_mode, pre_autograd=pre_autograd)
|
||||
proxy_mode = ProxyTorchDispatchMode(fx_tracer, tracing_mode, pre_dispatch=pre_dispatch)
|
||||
|
||||
arg_count = 0
|
||||
|
||||
@ -783,9 +764,9 @@ def make_fx(f, decomposition_table=None, tracing_mode="real", _allow_non_fake_in
|
||||
# We also disable tracing by any other tensor proxy-based tracers except the current. The
|
||||
# purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
|
||||
# thus irrelevant to any external functional trace.
|
||||
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \
|
||||
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, \
|
||||
sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True):
|
||||
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_autograd), tracer=fx_tracer, concrete_args=tuple(phs))
|
||||
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
|
||||
|
||||
# TODO: kind of a bad way to do it, should maybe figure out a better way
|
||||
if tracing_mode == "symbolic":
|
||||
|
@ -83,24 +83,16 @@ def _push_mode(mode, k: Optional[DispatchKey] = None):
|
||||
for key in ks:
|
||||
op._uncache_dispatch(key)
|
||||
push_mode_for_key(k, mode)
|
||||
# Note [Per-Dispatch-Key Modes Must Be Reentrant]
|
||||
# The idea here is that we are allowed to push modes onto any dispatch key's mode stack, but:
|
||||
# (1) We **always** push the mode onto the python mode stack. Operators can have fallthrough
|
||||
# kernels registered to any dispatch key, so we use the Python mode stack as a catchall,
|
||||
# to guarantee that every op will be seen by our mode.
|
||||
# (2) We expect the mode that you push to handle being re-entrant: If we end up invoking the mode
|
||||
# at both the Autograd key and the Python key, nothing bad should happen.
|
||||
# The main use case for this is pre-autograd tracing with TorchProxyDispatchMode.
|
||||
_push_on_torch_dispatch_stack(mode)
|
||||
else:
|
||||
_push_on_torch_dispatch_stack(mode)
|
||||
|
||||
|
||||
def _pop_mode(k: Optional[DispatchKey] = None):
|
||||
m = _pop_torch_dispatch_stack()
|
||||
if k is not None:
|
||||
from torch._ops import pop_mode_for_key
|
||||
tmp = pop_mode_for_key(k)
|
||||
assert m is tmp
|
||||
return m
|
||||
return pop_mode_for_key(k)
|
||||
else:
|
||||
return _pop_torch_dispatch_stack()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
Reference in New Issue
Block a user