Files
pytorch/test/dynamo/test_hooks.py
Ryan Guo 6765df052c [dynamo] Emit warning on global module hooks when calling using output of torch.compile(module) (#152740)
When we do `torch.compile(module)`, we eventually end up returning a new
`OptimizedModule` instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) for the compiled module.

`OptimizedModule` also inherits `nn.module.__call__`, and thus
has its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.

However, this might create unexpected behavior for global module hooks,
because `torch.compile(module)` causes the hook to fire one extra time
for `OptimizedModule`, when compared to eager.

To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.

Fixes #149502

Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152740
Approved by: https://github.com/anijain2305, https://github.com/zou3519
2025-05-14 17:03:59 +00:00

927 lines
31 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import functools
import unittest
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from functorch.compile import nop
from torch._dynamo import compiled_autograd
from torch._functorch.aot_autograd import aot_module_simplified
from torch.utils.hooks import RemovableHandle
def compiler_fn(gm):
return torch.compile(gm, backend="inductor", fullgraph=True, dynamic=True)
def global_hook_0(grad):
return grad * 4
def global_hook_1(grad):
return grad / 2
def global_hook_2(grad):
return grad * 3
h0 = None
class ClassWithVal:
def __init__(self, val):
self.val = val
class HooksTests(torch._dynamo.test_case.TestCase):
def test_tensor_only_register_hook_in_graph_lambda(self):
def fn(x):
x.register_hook(lambda grad: grad * 2)
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 0)
def test_tensor_register_hook_in_graph_lambda(self):
def fn(x, y, z):
x.register_hook(lambda grad: grad * 2)
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_break_handle_lambda(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
z = z * z
handle.remove()
x.register_hook(lambda grad: grad * 3)
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_multi_handle_return(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
h2 = handle
z = z * z
return x, y * y, z, handle, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(h, None)
self.assertNotEqual(h2, None)
self.assertEqual(h2, h)
def test_tensor_register_hook_repeated_handle_return(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
h2 = handle # noqa: F841
z = z * z
return x, y * y, z, handle, handle
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertIsInstance(h, RemovableHandle)
self.assertIs(h2, h)
def test_removed_handle_return(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
z = z * z
handle.remove()
handle.remove()
return x, y * y, z, handle, handle
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(cnt.frame_count, 1)
self.assertIsInstance(h, RemovableHandle)
self.assertIs(h2, h)
def test_tensor_register_hook_repeated_handle_not_local(self):
def fn(x, y, z, mod):
mod.handle = x.register_hook(lambda grad: grad * 2)
z = z * z
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts, fullgraph=True)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
mod = torch.nn.Module()
mod.handle = None
v, y, z = fn(v, torch.randn([2, 2]), torch.randn([2, 2]), mod)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(mod.handle, None)
def test_tensor_only_register_hook_in_graph_local(self):
def local_hook(grad):
return grad * 2
def fn(x):
x.register_hook(local_hook)
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 0)
def test_tensor_only_register_hook_in_graph_local_inner(self):
def fn(x):
def local_hook(grad):
return grad * 2
z = x * x
x.register_hook(local_hook)
z.register_hook(local_hook)
return x, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v[0].grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_local(self):
def local_hook(grad):
return grad * 2
def fn(x, y, z):
x.register_hook(local_hook)
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_break_handle_local(self):
def local_hook(grad):
return grad * 2
def local_hook2(grad):
return grad * 3
def fn(x, y, z):
handle = x.register_hook(local_hook)
z = z * z
handle.remove()
x.register_hook(local_hook2)
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
def test_tensor_register_global_hook(self):
def fn(x):
x.register_hook(global_hook_0)
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_multiple_hooks(self):
def fn(x):
x.register_hook(global_hook_0) # * 4
x.register_hook(global_hook_1) # / 2
x.register_hook(global_hook_2) # * 3
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_multiple_hooks_handles_in_list(self):
def fn(x):
h0 = x.register_hook(global_hook_0) # * 4
h1 = x.register_hook(global_hook_1) # / 2
h2 = x.register_hook(global_hook_2) # * 3
return x, x * x, h0, h1, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r, handle_0, handle_1, handle_2 = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
handle_0.remove()
handle_1.remove()
handle_2.remove()
v.backward(torch.tensor([1.0, 2.0, 3.0]))
# Handles gone, grad is just applied as is
self.assertEqual(v.grad, torch.tensor([7.0, 14.0, 21.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_global_hooks_handles_in_list(self):
def fn(x):
global h0
h0 = x.register_hook(global_hook_0) # * 4
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r = fn(v)
self.assertIsNotNone(h0)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
h0.remove()
v.backward(torch.tensor([1.0, 2.0, 3.0]))
# Handles gone, grad is just applied as is
self.assertEqual(v.grad, torch.tensor([5.0, 10.0, 15.0]))
# NYI!
self.assertEqual(cnts.frame_count, 0)
def test_intermediary_hooks(self):
# Graph breaks because compiled_autograd is not set
def simple_hook(g):
return g * 2
def f(x):
y = x + 1
y.register_hook(simple_hook)
z = y + 1
return z
out = torch.randn(1, requires_grad=True)
cnts = torch._dynamo.testing.CompileCounter()
fn = torch.compile(f, backend=cnts, fullgraph=False)
res = fn(out)
res.backward()
self.assertEqual(res, f(out))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(out.grad, torch.Tensor([2.0]))
def test_intermediary_hooks_same_on_aot_eager(self):
def my_hook(grad, *, k=0):
return grad + k
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
hook1 = functools.partial(my_hook, k=3)
hook2 = functools.partial(my_hook, k=4)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd._enable(compiler_fn):
dynamo_out = torch.compile(mod, backend="aot_eager", fullgraph=True)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_input_hooks_same(self):
backends = ["eager", "aot_eager", "inductor"]
for backend in backends:
def my_hook(grad, *, k=0):
return grad + k
hook = functools.partial(my_hook, k=3)
class MyMod(torch.nn.Module):
def forward(self, x):
x.register_hook(hook)
y = x.mul(2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
dynamo_out = torch.compile(mod, backend=backend, fullgraph=True)(x2)
with compiled_autograd._enable(compiler_fn):
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_intermediary_hooks_same_on_inductor(self):
def my_hook(grad, *, k=0):
return grad + k
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
hook1 = functools.partial(my_hook, k=3)
hook2 = functools.partial(my_hook, k=4)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd._enable(compiler_fn):
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor(self):
class SomePyClass:
count = 0
def do_stuff(self, grad):
if self.count % 2 == 0:
r = grad * grad
else:
r = grad + grad
self.count += 1
return r
def complex_state_touching_hook(grad, *, obj):
return obj.do_stuff(grad)
class MyMod(torch.nn.Module):
def forward(self, x, obj):
y = x.mul(2)
hook1 = functools.partial(complex_state_touching_hook, obj=obj)
hook2 = functools.partial(complex_state_touching_hook, obj=obj)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
obj = SomePyClass()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0, obj)
eager_out[0].backward(torch.ones(4))
# Eager 2
self.assertEqual(obj.count, 2)
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd._enable(compiler_fn):
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, eager_out)
# Eager 2 + compiled 2
self.assertEqual(obj.count, 4)
self.assertEqual(x0.grad, x2.grad)
def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor_with_graph_break(
self,
):
class SomePyClass:
grad_as_str = "None"
count = 0
def write_grad_as_str_and_do_stuff(self, grad):
self.grad_as_str = str(grad)
if self.count % 2 == 0:
r = grad * grad
else:
r = grad + grad
print("Break!")
self.count += 1
return r
def complex_state_touching_hook(grad, *, obj):
return obj.write_grad_as_str_and_do_stuff(grad)
class MyMod(torch.nn.Module):
def forward(self, x, obj):
y = x.mul(2)
hook1 = functools.partial(complex_state_touching_hook, obj=obj)
hook2 = functools.partial(complex_state_touching_hook, obj=obj)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
obj = SomePyClass()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0, obj)
eager_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd._enable(compiler_fn):
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "Failed to trace builtin operator"
):
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(obj.count, 2)
def test_register_hook_partial_guarding(
self,
):
def some_hook(grad, *, obj):
return grad + obj.val
class MyMod(torch.nn.Module):
def forward(self, x, obj):
y = x.mul(2)
hook1 = functools.partial(some_hook, obj=obj)
y.register_hook(hook1)
z = y.mul(3)
return (z,)
mod = MyMod()
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
obj3 = ClassWithVal(11)
cnt = torch._dynamo.testing.CompileCounter()
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
with compiled_autograd._enable(compiler_fn):
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj1)
torch.compile(mod, backend=cnt, fullgraph=True)(x1, obj1)
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj2)
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj3)
self.assertEqual(cnt.frame_count, 1)
def test_hook_with_closure(self):
def fn(x, obj):
y = x.sin()
x.register_hook(lambda grad: grad + obj.val)
z = y.sin()
return z
cnt_fw = torch._dynamo.testing.CompileCounter()
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd._enable(
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(cnt_fw.frame_count, 1)
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_hook_with_nested_closure(self):
def fn(x):
def run():
y = x.sin()
x.register_hook(lambda grad: grad + y)
z = y.sin()
return z
return run()
cnt_fw = torch._dynamo.testing.CompileCounter()
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
fn(x0).sum().backward()
with compiled_autograd._enable(
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x1).sum().backward()
self.assertEqual(cnt_fw.frame_count, 1)
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x1.grad)
def test_intermediate_hook_with_closure_eager(self):
def fn(x, obj):
y = x.sin()
y.register_hook(lambda grad: grad + obj.val)
z = y.sin()
return z
cnt_fw = torch._dynamo.testing.CompileCounter()
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd._enable(
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(cnt_fw.frame_count, 1)
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_intermediate_hook_with_closure_aot(self):
def fn(x, obj):
y = x.sin()
y.register_hook(lambda grad: grad + obj.val)
z = y.sin()
return z
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd._enable(
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_no_recompile_on_hook_identity_change(self):
def my_hook(grad, k=0):
return grad + k
def my_hook2(grad):
return grad * 2
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
y.register_hook(my_hook)
y.register_hook(my_hook)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
with compiled_autograd._enable(compiler_fn):
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
comp_mod = torch.compile(mod, backend=cnts, fullgraph=True)
comp_out = comp_mod(x1)
comp_out[0].backward(torch.ones(4))
self.assertEqual(cnts.frame_count, 1)
my_hook = my_hook2 # noqa: F811
self.assertEqual(x0.grad, x1.grad)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
comp_out = comp_mod(x1)
self.assertEqual(cnts.frame_count, 1)
comp_out[0].backward(torch.ones(4))
self.assertEqual(x0.grad, x1.grad)
def test_functools_arg_vary(self):
def pre_hook(grad, *, k):
return grad * k
hook = functools.partial(pre_hook, k=1)
@torch.compile(backend="eager", fullgraph=True)
def h(x):
y = x.mul(2)
y.register_hook(hook)
return y.mul(3)
with compiled_autograd._enable(torch.compile(backend="eager", fullgraph=True)):
x = torch.randn(2, requires_grad=True)
h(x).sum().backward()
orig_grad = x.grad
x.grad = None
hook = functools.partial(pre_hook, k=2)
h(x).sum().backward()
self.assertEqual(orig_grad * 2, x.grad)
def test_post_acc_grad_hook(self):
def hook(input_t):
input_t.mul_(input_t.grad)
input_t.grad.mul_(5)
def reg_and_mul(x, y):
x.register_post_accumulate_grad_hook(hook)
return x * y
cnts = None
def test_fn(fn):
fn(x, y)
b = torch.tensor([2.0, 2.0, 2.0], requires_grad=True)
x.backward(b)
if cnts:
self.assertEqual(cnts.frame_count, 1)
# These same exact assertions run on both eager and compiled
# X goes to x*2 becaue of mul_
self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2)
# This test proves grad aliasing works -
self.assertEqual(x.grad, b * 5)
# Eager values
x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True)
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
test_fn(reg_and_mul)
# Compiled
for backend in ["eager", "aot_eager", "inductor"]:
for compiled_bwd in [False, True]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True)
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
cnts = torch._dynamo.testing.CompileCounterWithBackend(backend)
compiled_fn = torch.compile(reg_and_mul, backend=cnts, fullgraph=True)
compiled_bwd_ctx = (
compiled_autograd._enable(
torch.compile(backend=backend, fullgraph=True)
)
if compiled_bwd
else contextlib.nullcontext()
)
with compiled_bwd_ctx:
test_fn(compiled_fn)
def test_recompile(self):
def hook(param):
param.grad *= 2
x = torch.ones(10)
x.requires_grad = True
def run(input):
return x * input
x.register_post_accumulate_grad_hook(hook)
with compiled_autograd._enable(compiler_fn):
for i in range(5):
with unittest.mock.patch(
"torch._dynamo.config.error_on_recompile", True
):
# Mimic optimizer.zero_grad() to clear the gradient
x.grad = None
run(i).sum().backward()
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_no_recompile_on_same_hook(self):
cnts = torch._dynamo.testing.CompileCounter()
def fw_hook(inp):
return (inp[0] + 1,)
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = torch.nn.ModuleList()
for i in range(10):
layer = torch.nn.Linear(16, 16)
layer.register_forward_pre_hook(lambda _, inp: fw_hook(inp))
layer = torch.compile(layer, backend=cnts)
self.layers.append(layer)
def forward(self, x):
for l in self.layers:
x = l(x)
return x
mod = Mod()
x = torch.ones(16, 16, requires_grad=True)
mod(x)
self.assertEqual(cnts.frame_count, 1)
@torch._dynamo.config.patch(skip_nnmodule_hook_guards=False)
def test_nnmodule_hook_guards(self):
# Compile a model and then apply a hook
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(16, 16)
def forward(self, x):
return self.linear(x)
cnts = torch._dynamo.testing.CompileCounter()
mod = Mod()
def fn(x):
return mod(x)
opt_fn = torch.compile(fn, backend=cnts)
x = torch.ones(16, 16)
opt_fn(x)
# Register a hook
def forward_hook(self, inputs, out):
return out * 2
mod.register_forward_hook(forward_hook)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@torch._dynamo.config.patch(wrap_top_frame=True)
def test_wrap_top_frame_with_hooks(self):
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net1 = torch.nn.Linear(18, 18, bias=False)
def forward(self, x):
return self.net1(x)
mod = ToyModel()
mod.register_forward_pre_hook(lambda mod, input: input[0] + 1)
# Case 1: torch.compile(mod)
cnts = torch._dynamo.testing.CompileCounter()
compiled_mod = torch.compile(mod, backend=cnts)
x = torch.rand(18, 18)
ref = mod(x)
res = compiled_mod(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
# Case 2: mod.compile()
cnts = torch._dynamo.testing.CompileCounter()
mod.compile(backend=cnts)
res = mod(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
def test_global_module_forward_pre_hook(self):
class Mod(torch.nn.Module):
def forward(self, x):
return x - 1
counter = 0
def hook(mod, args):
nonlocal counter
counter += 1
return args
x = torch.rand(18, 18)
mod = Mod()
compiled_mod = torch.compile(mod, backend="eager")
try:
hook_handle = torch.nn.modules.module.register_module_forward_pre_hook(hook)
ref = mod(x)
self.assertEqual(counter, 1)
with self.assertWarnsRegex(
UserWarning,
r"Using `torch.compile\(module\)` when there are global hooks.*",
):
res = compiled_mod(x)
self.assertEqual(counter, 3)
self.assertEqual(ref, res)
finally:
hook_handle.remove()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()