mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[custom_ops] torch.library.{custom_op, register_kernel} disable Dynamo (#133125)
We promise the user that these custom ops (and their kernels) are black boxes w.r.t. torch.compile. Unfortunately Dynamo can turn itself back on in the implementation of the custom operator, so we force it off by disabling Dynamo Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/133125 Approved by: https://github.com/ezyang
This commit is contained in:
@ -217,6 +217,71 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
|||||||
self.assertTrue(same(val4, correct1))
|
self.assertTrue(same(val4, correct1))
|
||||||
self.assertEqual(counter.frame_count, 3)
|
self.assertEqual(counter.frame_count, 3)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(accumulated_cache_size_limit=1)
|
||||||
|
def test_dynamo_disabled_in_custom_op_kernels(self):
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
@torch.library.custom_op("mylib::foo9", mutates_args={})
|
||||||
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x.clone()
|
||||||
|
|
||||||
|
foo.register_fake(torch.clone)
|
||||||
|
|
||||||
|
@torch.compile(backend="eager")
|
||||||
|
def f(x):
|
||||||
|
return foo._opoverload(x)
|
||||||
|
|
||||||
|
x = torch.randn(2)
|
||||||
|
f(x)
|
||||||
|
x = torch.randn(3)
|
||||||
|
# Recompile hits the cache size limit, which will cause Dynamo to
|
||||||
|
# recurse into the frames. The only frame is the implementation
|
||||||
|
# of foo. If Dynamo was not turned off correctly, then
|
||||||
|
# we'll see a graph break
|
||||||
|
f(x)
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
|
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
called = 0
|
||||||
|
|
||||||
|
# test register_kernel
|
||||||
|
@foo.register_kernel("cpu")
|
||||||
|
def _(x):
|
||||||
|
nonlocal called
|
||||||
|
called += 1
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x.clone()
|
||||||
|
|
||||||
|
f(x)
|
||||||
|
self.assertEqual(called, 1)
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
|
|
||||||
|
# test torch.library.register_kernel
|
||||||
|
counters.clear()
|
||||||
|
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
||||||
|
m.define("foo2(Tensor x) -> Tensor")
|
||||||
|
|
||||||
|
@torch.library.register_fake("mylib::foo2", lib=m)
|
||||||
|
def _(x):
|
||||||
|
return x.clone()
|
||||||
|
|
||||||
|
@torch.library.register_kernel("mylib::foo2", "cpu", lib=m)
|
||||||
|
def _(x):
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x.clone()
|
||||||
|
|
||||||
|
@torch.compile(backend="eager")
|
||||||
|
def g(x):
|
||||||
|
return torch.ops.mylib.foo2.default(x)
|
||||||
|
|
||||||
|
x = torch.randn(2)
|
||||||
|
g(x) # compiles
|
||||||
|
x = torch.randn(3)
|
||||||
|
g(x) # dynamo falls back on the outermost frame
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
|
|
||||||
def test_invalid_args_builtin(self):
|
def test_invalid_args_builtin(self):
|
||||||
@torch.compile(backend="eager")
|
@torch.compile(backend="eager")
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
@ -355,6 +355,7 @@ class CustomOpDef:
|
|||||||
|
|
||||||
# Wrap function to choose between the default implementation or the device-specific
|
# Wrap function to choose between the default implementation or the device-specific
|
||||||
# implementation depending on if the kernel is disabled.
|
# implementation depending on if the kernel is disabled.
|
||||||
|
@torch._disable_dynamo
|
||||||
def wrapped_fn(*args, **kwargs):
|
def wrapped_fn(*args, **kwargs):
|
||||||
if device_type in self._disabled_kernel:
|
if device_type in self._disabled_kernel:
|
||||||
return self._init_fn(*args, **kwargs)
|
return self._init_fn(*args, **kwargs)
|
||||||
|
@ -531,6 +531,10 @@ def impl(qualname, types, func=None, *, lib=None):
|
|||||||
>>> y = torch.ops.mylib.mysin(x)
|
>>> y = torch.ops.mylib.mysin(x)
|
||||||
>>> assert torch.allclose(y, x.sin())
|
>>> assert torch.allclose(y, x.sin())
|
||||||
"""
|
"""
|
||||||
|
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
|
||||||
if isinstance(types, str):
|
if isinstance(types, str):
|
||||||
types = (types,)
|
types = (types,)
|
||||||
keys = set({})
|
keys = set({})
|
||||||
@ -549,13 +553,23 @@ def impl(qualname, types, func=None, *, lib=None):
|
|||||||
|
|
||||||
def register(func):
|
def register(func):
|
||||||
namespace, _ = torch._library.utils.parse_namespace(qualname)
|
namespace, _ = torch._library.utils.parse_namespace(qualname)
|
||||||
|
|
||||||
if lib is None:
|
if lib is None:
|
||||||
use_lib = Library(namespace, "FRAGMENT")
|
use_lib = Library(namespace, "FRAGMENT")
|
||||||
_keep_alive.append(use_lib)
|
_keep_alive.append(use_lib)
|
||||||
else:
|
else:
|
||||||
use_lib = lib
|
use_lib = lib
|
||||||
for key in keys:
|
if disable_dynamo:
|
||||||
use_lib.impl(qualname, func, key)
|
|
||||||
|
@torch._disable_dynamo
|
||||||
|
def func_no_dynamo(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
use_lib.impl(qualname, func_no_dynamo, key)
|
||||||
|
else:
|
||||||
|
for key in keys:
|
||||||
|
use_lib.impl(qualname, func, key)
|
||||||
|
|
||||||
if func is None:
|
if func is None:
|
||||||
return register
|
return register
|
||||||
@ -663,7 +677,8 @@ def register_kernel(
|
|||||||
assert isinstance(op, str)
|
assert isinstance(op, str)
|
||||||
if device_types is None:
|
if device_types is None:
|
||||||
device_types = "CompositeExplicitAutograd"
|
device_types = "CompositeExplicitAutograd"
|
||||||
return impl(op, device_types, func, lib=lib)
|
|
||||||
|
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
|
||||||
|
|
||||||
|
|
||||||
def register_fake(
|
def register_fake(
|
||||||
|
Reference in New Issue
Block a user