[custom_ops] torch.library.{custom_op, register_kernel} disable Dynamo (#133125)

We promise the user that these custom ops (and their kernels) are black
boxes w.r.t. torch.compile. Unfortunately Dynamo can turn itself back
on in the implementation of the custom operator, so we force it off by
disabling Dynamo

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133125
Approved by: https://github.com/ezyang
This commit is contained in:
rzou
2024-08-12 07:17:02 -07:00
committed by PyTorch MergeBot
parent d53dfa4680
commit afb73d253c
3 changed files with 84 additions and 3 deletions

View File

@ -217,6 +217,71 @@ class MiscTests(torch._inductor.test_case.TestCase):
self.assertTrue(same(val4, correct1))
self.assertEqual(counter.frame_count, 3)
@torch._dynamo.config.patch(accumulated_cache_size_limit=1)
def test_dynamo_disabled_in_custom_op_kernels(self):
counters.clear()
@torch.library.custom_op("mylib::foo9", mutates_args={})
def foo(x: torch.Tensor) -> torch.Tensor:
torch._dynamo.graph_break()
return x.clone()
foo.register_fake(torch.clone)
@torch.compile(backend="eager")
def f(x):
return foo._opoverload(x)
x = torch.randn(2)
f(x)
x = torch.randn(3)
# Recompile hits the cache size limit, which will cause Dynamo to
# recurse into the frames. The only frame is the implementation
# of foo. If Dynamo was not turned off correctly, then
# we'll see a graph break
f(x)
self.assertEqual(len(counters["graph_break"]), 0)
counters.clear()
called = 0
# test register_kernel
@foo.register_kernel("cpu")
def _(x):
nonlocal called
called += 1
torch._dynamo.graph_break()
return x.clone()
f(x)
self.assertEqual(called, 1)
self.assertEqual(len(counters["graph_break"]), 0)
# test torch.library.register_kernel
counters.clear()
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
m.define("foo2(Tensor x) -> Tensor")
@torch.library.register_fake("mylib::foo2", lib=m)
def _(x):
return x.clone()
@torch.library.register_kernel("mylib::foo2", "cpu", lib=m)
def _(x):
torch._dynamo.graph_break()
return x.clone()
@torch.compile(backend="eager")
def g(x):
return torch.ops.mylib.foo2.default(x)
x = torch.randn(2)
g(x) # compiles
x = torch.randn(3)
g(x) # dynamo falls back on the outermost frame
self.assertEqual(len(counters["graph_break"]), 0)
def test_invalid_args_builtin(self):
@torch.compile(backend="eager")
def fn(x):

View File

@ -355,6 +355,7 @@ class CustomOpDef:
# Wrap function to choose between the default implementation or the device-specific
# implementation depending on if the kernel is disabled.
@torch._disable_dynamo
def wrapped_fn(*args, **kwargs):
if device_type in self._disabled_kernel:
return self._init_fn(*args, **kwargs)

View File

@ -531,6 +531,10 @@ def impl(qualname, types, func=None, *, lib=None):
>>> y = torch.ops.mylib.mysin(x)
>>> assert torch.allclose(y, x.sin())
"""
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
if isinstance(types, str):
types = (types,)
keys = set({})
@ -549,13 +553,23 @@ def impl(qualname, types, func=None, *, lib=None):
def register(func):
namespace, _ = torch._library.utils.parse_namespace(qualname)
if lib is None:
use_lib = Library(namespace, "FRAGMENT")
_keep_alive.append(use_lib)
else:
use_lib = lib
for key in keys:
use_lib.impl(qualname, func, key)
if disable_dynamo:
@torch._disable_dynamo
def func_no_dynamo(*args, **kwargs):
return func(*args, **kwargs)
for key in keys:
use_lib.impl(qualname, func_no_dynamo, key)
else:
for key in keys:
use_lib.impl(qualname, func, key)
if func is None:
return register
@ -663,7 +677,8 @@ def register_kernel(
assert isinstance(op, str)
if device_types is None:
device_types = "CompositeExplicitAutograd"
return impl(op, device_types, func, lib=lib)
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
def register_fake(