pre_dispatch tracing: support autocast and no_grad/enable_grad ctx managers, add a pre_dispatch_eager dynamo backend (#103024)

This PR adds support for `enable_grad`/`no_grad`/`autocast` context managers getting properly traced in `pre_dispatch` tracing. The stuff in this PR includes:
- I added a torch function mode that runs during make_fx pre_dispatch tracing, `ProxyTorchFunctionMode`. It directly intercepts the torch ops that run during the above context managers, and adds them to the current graph instead of executing them
- `enable_grad` and `no_grad` currently desugar into `torch._C.set_grad_enabled(bool)`, but this API isn't currently overrideable by torch function so I added the ability to interpose there
- the `torch.amp` context managers don't currently have a nice equivalent, like `set_autocast_enabled(state)`, so I ended up adding two new API's: `torch.amp._set_autocast_enabled` and `torch.amp._set_autocast_disabled`. If you look at how the context manager is implemented, it ends up calling several different state-changing functions, some of which depend on the backend - so I figured that it would be cleaner just to add a new API (that should probably only be used by tracing) - but open to feedback
- I added a new dynamo backend, `compile(backend="pre_dispatch_eager")`. When pre_dispatch tracing becomes always-on in inductor, it will be another potential surface for bugs. I also added a test file for it (`test/dynamo/test_pre_dispatch.py`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103024
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2023-06-28 18:01:10 +00:00
committed by PyTorch MergeBot
parent ebb8aa9c0b
commit 875f60399e
8 changed files with 175 additions and 21 deletions

View File

@ -0,0 +1,76 @@
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo
import torch._dynamo.test_case
class PreDispatchTests(torch._dynamo.test_case.TestCase):
def test_no_grad_simple(self):
def f(a):
b = a.sin()
with torch.no_grad():
c = b.cos()
return b * c.sin()
f_compiled = torch.compile(f, backend="pre_dispatch_eager")
a_ref = torch.randn(4, requires_grad=True)
a_test = a_ref.clone().detach().requires_grad_(True)
out_ref = f(a_ref)
out_test = f_compiled(a_test)
self.assertEqual(out_ref, out_test)
out_ref.sum().backward()
out_test.sum().backward()
self.assertEqual(a_ref.grad, a_test.grad)
def test_enable_grad_and_no_grad(self):
def f(a):
b = a * 2
with torch.no_grad():
c = b * 3
with torch.enable_grad():
d = c * 4
e = d * 5
return b + c + d + e
f_compiled = torch.compile(f, backend="pre_dispatch_eager")
a_ref = torch.randn(4, requires_grad=True)
a_test = a_ref.clone().detach().requires_grad_(True)
out_ref = f(a_ref)
out_test = f_compiled(a_test)
self.assertEqual(out_ref, out_test)
out_ref.sum().backward()
out_test.sum().backward()
self.assertEqual(a_ref.grad, a_test.grad)
def test_autocast_simple(self):
def f(a):
b = a * 2
with torch.amp.autocast(device_type="cpu"):
c = torch.matmul(b, b)
return b + c
f_compiled = torch.compile(f, backend="pre_dispatch_eager")
a_ref = torch.randn(4, device="cpu", requires_grad=True)
a_test = a_ref.clone().detach().requires_grad_(True)
out_ref = f(a_ref)
out_test = f_compiled(a_test)
self.assertEqual(out_ref, out_test)
out_ref.sum().backward()
out_test.sum().backward()
self.assertEqual(a_ref.grad, a_test.grad)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -178,6 +178,23 @@ def forward(self, a_1):
out2 = fx_g(a, b, c)
self.assertEqual(out1, out2)
def test_pre_dispatch_no_grad(self):
def f(a):
b = a.sin()
torch.set_grad_enabled(False)
c = b.cos()
torch.set_grad_enabled(True)
return b + c.sin()
a1 = torch.randn(4, requires_grad=True)
a2 = a1.clone().detach().requires_grad_(True)
a_tmp = a1.clone().detach().requires_grad_(True)
fx_g = make_fx(f, pre_dispatch=True)(a_tmp)
out1 = f(a1)
out2 = fx_g(a2)
self.assertEqual(out1, out2)
out1.sum().backward()
out2.sum().backward()
self.assertEqual(a1.grad, a2.grad)
def test_make_fx_simple(self):
def f(x):

View File

@ -21,6 +21,19 @@ def eager(gm, fake_tensor_inputs):
return gm
@register_backend
def pre_dispatch_eager(gm, fake_tensor_inputs):
from torch.fx.experimental.proxy_tensor import make_fx
def runnable_gm(*args):
return torch.fx.Interpreter(gm).run(*args)
pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
pre_dispatch_gm.print_readable()
return pre_dispatch_gm
@register_backend
def eager_debug(gm, fake_tensor_inputs):
from torch._subclasses.schema_check_mode import SchemaCheckMode

View File

@ -231,17 +231,17 @@ class AutocastModeVariable(ContextWrappingVariable):
def exit(self, tx, *args):
self.mode = (
exit_functional_autocast(self.mode[0]),
torch.amp._exit_autocast(self.mode[0]),
tx.output.create_node(
"call_function", exit_functional_autocast, (self.mode[1],), {}
"call_function", torch.amp._exit_autocast, (self.mode[1],), {}
),
)
def enter(self, tx):
self.mode = (
enter_functional_autocast(*self.target_values),
torch.amp._enter_autocast(*self.target_values),
tx.output.create_node(
"call_function", enter_functional_autocast, (*self.target_values,), {}
"call_function", torch.amp._enter_autocast, (*self.target_values,), {}
),
)
@ -252,16 +252,6 @@ class AutocastModeVariable(ContextWrappingVariable):
return "autocast"
def enter_functional_autocast(*vals):
mode = torch.amp.autocast(*vals)
mode.__enter__()
return mode
def exit_functional_autocast(mode):
mode.__exit__(None, None, None)
class NullContextVariable(ContextWrappingVariable):
"""
This class represents Python contextlib.nullcontext.

View File

@ -1 +1 @@
from .autocast_mode import autocast
from .autocast_mode import autocast, _enter_autocast, _exit_autocast

View File

@ -374,3 +374,20 @@ class autocast:
if torch._jit_internal.is_scripting():
return func
return autocast_decorator(self, func)
# These functions aren't meant for public usage.
# They are what we trace into a graph during pre_dispatch tracing
# when we encounter an autocast context manager.
def _enter_autocast(*vals):
# For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph.
if torch._C._is_torch_function_mode_enabled():
return torch.overrides.handle_torch_function(torch.amp._enter_autocast, [], *vals)
mode = torch.amp.autocast(*vals)
mode.__enter__()
return mode
def _exit_autocast(mode):
if torch._C._is_torch_function_mode_enabled():
return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode)
mode.__exit__(None, None, None)

View File

@ -670,12 +670,24 @@ static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
static PyObject* set_grad_enabled(PyObject* _unused, PyObject* arg) {
static PyObject* set_grad_enabled(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
if (!PyBool_Check(arg)) {
throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
static PythonArgParser parser({
"set_grad_enabled(bool enabled)",
});
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (at::impl::torch_function_mode_enabled()) {
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
return handle_torch_function(
r, args, kwargs, torch_C_module, "torch._C", "_set_grad_enabled");
}
GradMode::set_enabled(arg == Py_True);
auto grad_enabled = r.toBool(0);
GradMode::set_enabled(grad_enabled);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -908,7 +920,10 @@ PyObject* THPModule_increment_version(PyObject* _unused, PyObject* tensor) {
// autograd methods on torch._C
static PyMethodDef methods[] = { // NOLINT
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
{"_set_grad_enabled",
castPyCFunctionWithKeywords(set_grad_enabled),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
{"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr},
{"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr},

View File

@ -21,6 +21,8 @@ import operator
from torch.utils._stats import count
import logging
from torch.overrides import TorchFunctionMode
from torch.utils._python_dispatch import (
TorchDispatchMode,
_pop_mode_temporarily,
@ -512,6 +514,26 @@ def set_original_aten_op(func):
# This mode is **only** used for pre_dispatch tracing.
# In particular, we need to make sure that autograd/autocast API's
# that do not desugar into dispatcher operators stay in the graph.
class PreDispatchTorchFunctionMode(TorchFunctionMode):
def __init__(self, tracer):
self.tracer = tracer
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
pre_dispatch_ops = [
torch._C._set_grad_enabled,
torch.amp._enter_autocast,
torch.amp._exit_autocast,
]
if func in pre_dispatch_ops:
return self.tracer.create_node("call_function", func, args, {})
# Don't actually run the function! We just want to trace the calls
# into a graph. We don't actualy want to change global autograd state.
return func(*args, **kwargs)
class ProxyTorchDispatchMode(TorchDispatchMode):
def __init__(self, tracer, tracing_mode, pre_dispatch=False, _allow_fake_constant=False):
@ -735,6 +757,10 @@ def make_fx(f,
if pre_dispatch:
pre_dispatch_mode = enable_pre_dispatch()
proxy_function_mode: Any = nullcontext()
if pre_dispatch:
proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer)
proxy_mode = ProxyTorchDispatchMode(fx_tracer,
tracing_mode,
pre_dispatch=pre_dispatch,
@ -778,7 +804,7 @@ def make_fx(f,
# We also disable tracing by any other tensor proxy-based tracers except the current. The
# purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
# thus irrelevant to any external functional trace.
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, \
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, proxy_function_mode, \
sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True):
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))