mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
pre_dispatch tracing: support autocast and no_grad/enable_grad ctx managers, add a pre_dispatch_eager dynamo backend (#103024)
This PR adds support for `enable_grad`/`no_grad`/`autocast` context managers getting properly traced in `pre_dispatch` tracing. The stuff in this PR includes: - I added a torch function mode that runs during make_fx pre_dispatch tracing, `ProxyTorchFunctionMode`. It directly intercepts the torch ops that run during the above context managers, and adds them to the current graph instead of executing them - `enable_grad` and `no_grad` currently desugar into `torch._C.set_grad_enabled(bool)`, but this API isn't currently overrideable by torch function so I added the ability to interpose there - the `torch.amp` context managers don't currently have a nice equivalent, like `set_autocast_enabled(state)`, so I ended up adding two new API's: `torch.amp._set_autocast_enabled` and `torch.amp._set_autocast_disabled`. If you look at how the context manager is implemented, it ends up calling several different state-changing functions, some of which depend on the backend - so I figured that it would be cleaner just to add a new API (that should probably only be used by tracing) - but open to feedback - I added a new dynamo backend, `compile(backend="pre_dispatch_eager")`. When pre_dispatch tracing becomes always-on in inductor, it will be another potential surface for bugs. I also added a test file for it (`test/dynamo/test_pre_dispatch.py`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/103024 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ebb8aa9c0b
commit
875f60399e
76
test/dynamo/test_pre_dispatch.py
Normal file
76
test/dynamo/test_pre_dispatch.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
|
||||
|
||||
class PreDispatchTests(torch._dynamo.test_case.TestCase):
|
||||
def test_no_grad_simple(self):
|
||||
def f(a):
|
||||
b = a.sin()
|
||||
with torch.no_grad():
|
||||
c = b.cos()
|
||||
return b * c.sin()
|
||||
|
||||
f_compiled = torch.compile(f, backend="pre_dispatch_eager")
|
||||
|
||||
a_ref = torch.randn(4, requires_grad=True)
|
||||
a_test = a_ref.clone().detach().requires_grad_(True)
|
||||
|
||||
out_ref = f(a_ref)
|
||||
out_test = f_compiled(a_test)
|
||||
self.assertEqual(out_ref, out_test)
|
||||
|
||||
out_ref.sum().backward()
|
||||
out_test.sum().backward()
|
||||
self.assertEqual(a_ref.grad, a_test.grad)
|
||||
|
||||
def test_enable_grad_and_no_grad(self):
|
||||
def f(a):
|
||||
b = a * 2
|
||||
with torch.no_grad():
|
||||
c = b * 3
|
||||
with torch.enable_grad():
|
||||
d = c * 4
|
||||
e = d * 5
|
||||
return b + c + d + e
|
||||
|
||||
f_compiled = torch.compile(f, backend="pre_dispatch_eager")
|
||||
|
||||
a_ref = torch.randn(4, requires_grad=True)
|
||||
a_test = a_ref.clone().detach().requires_grad_(True)
|
||||
|
||||
out_ref = f(a_ref)
|
||||
out_test = f_compiled(a_test)
|
||||
self.assertEqual(out_ref, out_test)
|
||||
|
||||
out_ref.sum().backward()
|
||||
out_test.sum().backward()
|
||||
self.assertEqual(a_ref.grad, a_test.grad)
|
||||
|
||||
def test_autocast_simple(self):
|
||||
def f(a):
|
||||
b = a * 2
|
||||
with torch.amp.autocast(device_type="cpu"):
|
||||
c = torch.matmul(b, b)
|
||||
return b + c
|
||||
|
||||
f_compiled = torch.compile(f, backend="pre_dispatch_eager")
|
||||
|
||||
a_ref = torch.randn(4, device="cpu", requires_grad=True)
|
||||
a_test = a_ref.clone().detach().requires_grad_(True)
|
||||
|
||||
out_ref = f(a_ref)
|
||||
out_test = f_compiled(a_test)
|
||||
self.assertEqual(out_ref, out_test)
|
||||
|
||||
out_ref.sum().backward()
|
||||
out_test.sum().backward()
|
||||
self.assertEqual(a_ref.grad, a_test.grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -178,6 +178,23 @@ def forward(self, a_1):
|
||||
out2 = fx_g(a, b, c)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_pre_dispatch_no_grad(self):
|
||||
def f(a):
|
||||
b = a.sin()
|
||||
torch.set_grad_enabled(False)
|
||||
c = b.cos()
|
||||
torch.set_grad_enabled(True)
|
||||
return b + c.sin()
|
||||
a1 = torch.randn(4, requires_grad=True)
|
||||
a2 = a1.clone().detach().requires_grad_(True)
|
||||
a_tmp = a1.clone().detach().requires_grad_(True)
|
||||
fx_g = make_fx(f, pre_dispatch=True)(a_tmp)
|
||||
out1 = f(a1)
|
||||
out2 = fx_g(a2)
|
||||
self.assertEqual(out1, out2)
|
||||
out1.sum().backward()
|
||||
out2.sum().backward()
|
||||
self.assertEqual(a1.grad, a2.grad)
|
||||
|
||||
def test_make_fx_simple(self):
|
||||
def f(x):
|
||||
|
@ -21,6 +21,19 @@ def eager(gm, fake_tensor_inputs):
|
||||
return gm
|
||||
|
||||
|
||||
@register_backend
|
||||
def pre_dispatch_eager(gm, fake_tensor_inputs):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
def runnable_gm(*args):
|
||||
return torch.fx.Interpreter(gm).run(*args)
|
||||
|
||||
pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
|
||||
pre_dispatch_gm.print_readable()
|
||||
|
||||
return pre_dispatch_gm
|
||||
|
||||
|
||||
@register_backend
|
||||
def eager_debug(gm, fake_tensor_inputs):
|
||||
from torch._subclasses.schema_check_mode import SchemaCheckMode
|
||||
|
@ -231,17 +231,17 @@ class AutocastModeVariable(ContextWrappingVariable):
|
||||
|
||||
def exit(self, tx, *args):
|
||||
self.mode = (
|
||||
exit_functional_autocast(self.mode[0]),
|
||||
torch.amp._exit_autocast(self.mode[0]),
|
||||
tx.output.create_node(
|
||||
"call_function", exit_functional_autocast, (self.mode[1],), {}
|
||||
"call_function", torch.amp._exit_autocast, (self.mode[1],), {}
|
||||
),
|
||||
)
|
||||
|
||||
def enter(self, tx):
|
||||
self.mode = (
|
||||
enter_functional_autocast(*self.target_values),
|
||||
torch.amp._enter_autocast(*self.target_values),
|
||||
tx.output.create_node(
|
||||
"call_function", enter_functional_autocast, (*self.target_values,), {}
|
||||
"call_function", torch.amp._enter_autocast, (*self.target_values,), {}
|
||||
),
|
||||
)
|
||||
|
||||
@ -252,16 +252,6 @@ class AutocastModeVariable(ContextWrappingVariable):
|
||||
return "autocast"
|
||||
|
||||
|
||||
def enter_functional_autocast(*vals):
|
||||
mode = torch.amp.autocast(*vals)
|
||||
mode.__enter__()
|
||||
return mode
|
||||
|
||||
|
||||
def exit_functional_autocast(mode):
|
||||
mode.__exit__(None, None, None)
|
||||
|
||||
|
||||
class NullContextVariable(ContextWrappingVariable):
|
||||
"""
|
||||
This class represents Python contextlib.nullcontext.
|
||||
|
@ -1 +1 @@
|
||||
from .autocast_mode import autocast
|
||||
from .autocast_mode import autocast, _enter_autocast, _exit_autocast
|
||||
|
@ -374,3 +374,20 @@ class autocast:
|
||||
if torch._jit_internal.is_scripting():
|
||||
return func
|
||||
return autocast_decorator(self, func)
|
||||
|
||||
# These functions aren't meant for public usage.
|
||||
# They are what we trace into a graph during pre_dispatch tracing
|
||||
# when we encounter an autocast context manager.
|
||||
def _enter_autocast(*vals):
|
||||
# For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph.
|
||||
if torch._C._is_torch_function_mode_enabled():
|
||||
return torch.overrides.handle_torch_function(torch.amp._enter_autocast, [], *vals)
|
||||
mode = torch.amp.autocast(*vals)
|
||||
mode.__enter__()
|
||||
return mode
|
||||
|
||||
|
||||
def _exit_autocast(mode):
|
||||
if torch._C._is_torch_function_mode_enabled():
|
||||
return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode)
|
||||
mode.__exit__(None, None, None)
|
||||
|
@ -670,12 +670,24 @@ static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_grad_enabled(PyObject* _unused, PyObject* arg) {
|
||||
static PyObject* set_grad_enabled(
|
||||
PyObject* _unused,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (!PyBool_Check(arg)) {
|
||||
throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
|
||||
static PythonArgParser parser({
|
||||
"set_grad_enabled(bool enabled)",
|
||||
});
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
if (at::impl::torch_function_mode_enabled()) {
|
||||
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
|
||||
return handle_torch_function(
|
||||
r, args, kwargs, torch_C_module, "torch._C", "_set_grad_enabled");
|
||||
}
|
||||
GradMode::set_enabled(arg == Py_True);
|
||||
auto grad_enabled = r.toBool(0);
|
||||
GradMode::set_enabled(grad_enabled);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -908,7 +920,10 @@ PyObject* THPModule_increment_version(PyObject* _unused, PyObject* tensor) {
|
||||
|
||||
// autograd methods on torch._C
|
||||
static PyMethodDef methods[] = { // NOLINT
|
||||
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
|
||||
{"_set_grad_enabled",
|
||||
castPyCFunctionWithKeywords(set_grad_enabled),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
|
||||
{"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr},
|
||||
{"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr},
|
||||
|
@ -21,6 +21,8 @@ import operator
|
||||
from torch.utils._stats import count
|
||||
import logging
|
||||
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
from torch.utils._python_dispatch import (
|
||||
TorchDispatchMode,
|
||||
_pop_mode_temporarily,
|
||||
@ -512,6 +514,26 @@ def set_original_aten_op(func):
|
||||
|
||||
|
||||
|
||||
# This mode is **only** used for pre_dispatch tracing.
|
||||
# In particular, we need to make sure that autograd/autocast API's
|
||||
# that do not desugar into dispatcher operators stay in the graph.
|
||||
class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||
|
||||
def __init__(self, tracer):
|
||||
self.tracer = tracer
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
pre_dispatch_ops = [
|
||||
torch._C._set_grad_enabled,
|
||||
torch.amp._enter_autocast,
|
||||
torch.amp._exit_autocast,
|
||||
]
|
||||
if func in pre_dispatch_ops:
|
||||
return self.tracer.create_node("call_function", func, args, {})
|
||||
# Don't actually run the function! We just want to trace the calls
|
||||
# into a graph. We don't actualy want to change global autograd state.
|
||||
return func(*args, **kwargs)
|
||||
|
||||
class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
def __init__(self, tracer, tracing_mode, pre_dispatch=False, _allow_fake_constant=False):
|
||||
@ -735,6 +757,10 @@ def make_fx(f,
|
||||
if pre_dispatch:
|
||||
pre_dispatch_mode = enable_pre_dispatch()
|
||||
|
||||
proxy_function_mode: Any = nullcontext()
|
||||
if pre_dispatch:
|
||||
proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer)
|
||||
|
||||
proxy_mode = ProxyTorchDispatchMode(fx_tracer,
|
||||
tracing_mode,
|
||||
pre_dispatch=pre_dispatch,
|
||||
@ -778,7 +804,7 @@ def make_fx(f,
|
||||
# We also disable tracing by any other tensor proxy-based tracers except the current. The
|
||||
# purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
|
||||
# thus irrelevant to any external functional trace.
|
||||
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, \
|
||||
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, proxy_function_mode, \
|
||||
sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True):
|
||||
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
|
||||
|
||||
|
Reference in New Issue
Block a user