mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[custom_ops] torch.library.{custom_op, register_kernel} disable Dynamo (#133125)
We promise the user that these custom ops (and their kernels) are black boxes w.r.t. torch.compile. Unfortunately Dynamo can turn itself back on in the implementation of the custom operator, so we force it off by disabling Dynamo Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/133125 Approved by: https://github.com/ezyang
This commit is contained in:
@ -217,6 +217,71 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
self.assertTrue(same(val4, correct1))
|
||||
self.assertEqual(counter.frame_count, 3)
|
||||
|
||||
@torch._dynamo.config.patch(accumulated_cache_size_limit=1)
|
||||
def test_dynamo_disabled_in_custom_op_kernels(self):
|
||||
counters.clear()
|
||||
|
||||
@torch.library.custom_op("mylib::foo9", mutates_args={})
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
torch._dynamo.graph_break()
|
||||
return x.clone()
|
||||
|
||||
foo.register_fake(torch.clone)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f(x):
|
||||
return foo._opoverload(x)
|
||||
|
||||
x = torch.randn(2)
|
||||
f(x)
|
||||
x = torch.randn(3)
|
||||
# Recompile hits the cache size limit, which will cause Dynamo to
|
||||
# recurse into the frames. The only frame is the implementation
|
||||
# of foo. If Dynamo was not turned off correctly, then
|
||||
# we'll see a graph break
|
||||
f(x)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
counters.clear()
|
||||
|
||||
called = 0
|
||||
|
||||
# test register_kernel
|
||||
@foo.register_kernel("cpu")
|
||||
def _(x):
|
||||
nonlocal called
|
||||
called += 1
|
||||
torch._dynamo.graph_break()
|
||||
return x.clone()
|
||||
|
||||
f(x)
|
||||
self.assertEqual(called, 1)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
# test torch.library.register_kernel
|
||||
counters.clear()
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
||||
m.define("foo2(Tensor x) -> Tensor")
|
||||
|
||||
@torch.library.register_fake("mylib::foo2", lib=m)
|
||||
def _(x):
|
||||
return x.clone()
|
||||
|
||||
@torch.library.register_kernel("mylib::foo2", "cpu", lib=m)
|
||||
def _(x):
|
||||
torch._dynamo.graph_break()
|
||||
return x.clone()
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def g(x):
|
||||
return torch.ops.mylib.foo2.default(x)
|
||||
|
||||
x = torch.randn(2)
|
||||
g(x) # compiles
|
||||
x = torch.randn(3)
|
||||
g(x) # dynamo falls back on the outermost frame
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
def test_invalid_args_builtin(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
|
@ -355,6 +355,7 @@ class CustomOpDef:
|
||||
|
||||
# Wrap function to choose between the default implementation or the device-specific
|
||||
# implementation depending on if the kernel is disabled.
|
||||
@torch._disable_dynamo
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
if device_type in self._disabled_kernel:
|
||||
return self._init_fn(*args, **kwargs)
|
||||
|
@ -531,6 +531,10 @@ def impl(qualname, types, func=None, *, lib=None):
|
||||
>>> y = torch.ops.mylib.mysin(x)
|
||||
>>> assert torch.allclose(y, x.sin())
|
||||
"""
|
||||
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
|
||||
|
||||
|
||||
def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
|
||||
if isinstance(types, str):
|
||||
types = (types,)
|
||||
keys = set({})
|
||||
@ -549,13 +553,23 @@ def impl(qualname, types, func=None, *, lib=None):
|
||||
|
||||
def register(func):
|
||||
namespace, _ = torch._library.utils.parse_namespace(qualname)
|
||||
|
||||
if lib is None:
|
||||
use_lib = Library(namespace, "FRAGMENT")
|
||||
_keep_alive.append(use_lib)
|
||||
else:
|
||||
use_lib = lib
|
||||
for key in keys:
|
||||
use_lib.impl(qualname, func, key)
|
||||
if disable_dynamo:
|
||||
|
||||
@torch._disable_dynamo
|
||||
def func_no_dynamo(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
for key in keys:
|
||||
use_lib.impl(qualname, func_no_dynamo, key)
|
||||
else:
|
||||
for key in keys:
|
||||
use_lib.impl(qualname, func, key)
|
||||
|
||||
if func is None:
|
||||
return register
|
||||
@ -663,7 +677,8 @@ def register_kernel(
|
||||
assert isinstance(op, str)
|
||||
if device_types is None:
|
||||
device_types = "CompositeExplicitAutograd"
|
||||
return impl(op, device_types, func, lib=lib)
|
||||
|
||||
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
|
||||
|
||||
|
||||
def register_fake(
|
||||
|
Reference in New Issue
Block a user