Files
pytorch/test/dynamo/test_callback.py
Yuanyuan Chen fa90090735 Use dataclass features in two classes (#164221)
This PR completes two TODO items by using features of `dataclass`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164221
Approved by: https://github.com/Skylion007, https://github.com/mlazos

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-10-01 03:20:39 +00:00

148 lines
5.6 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
from unittest.mock import Mock
import torch
from torch._dynamo.callback import callback_handler, CallbackArgs, CallbackTrigger
from torch._dynamo.test_case import run_tests, TestCase
from torch._guards import CompileId
from torch.testing._internal.common_utils import TEST_WITH_ROCM
from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON, requires_gpu
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
class CallbackTests(TestCase):
def setUp(self) -> None:
super().setUp()
self._on_compile_start = Mock()
self._on_compile_end = Mock()
callback_handler.register_start_callback(self._on_compile_start)
callback_handler.register_end_callback(self._on_compile_end)
def tearDown(self) -> None:
callback_handler.clear()
return super().tearDown()
def test_callbacks_with_duplicate_prevention(self) -> None:
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(frame_id=0, frame_compile_id=0)
with (
callback_handler.install_callbacks(trigger, compile_id),
callback_handler.install_callbacks(trigger, compile_id),
):
self._on_compile_start.assert_called_once()
self._on_compile_end.assert_called_once()
def test_counter(self) -> None:
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(frame_id=0, frame_compile_id=0)
with callback_handler.install_callbacks(trigger, compile_id):
self.assertEqual(
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
1,
)
self.assertEqual(
callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0
)
def test_counter_assertion(self) -> None:
callback_handler._CompilationCallbackHandler__pending_callbacks_counter -= 1
with self.assertRaisesRegex(
AssertionError, "Pending callbacks counter cannot become negative."
):
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(frame_id=0, frame_compile_id=0)
with callback_handler.install_callbacks(trigger, str(compile_id)):
pass
self.assertEqual(
callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0
)
@unittest.skipIf(
TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs"
)
@requires_gpu
@torch._inductor.config.patch(force_disable_caches=True)
def test_triggers(self) -> None:
torch._dynamo.reset()
order = []
def on_start(args: CallbackArgs):
nonlocal order
order.append(f"start={args}")
def on_end(args: CallbackArgs):
nonlocal order
order.append(f"end={args}")
torch._dynamo.callback.on_compile_start(on_start)
torch._dynamo.callback.on_compile_start(on_end)
class TinyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(10, 10)
def forward(self, x):
temp = self.fc1(x)
temp = self.relu(temp)
torch._dynamo.graph_break()
return self.fc2(temp)
model = TinyModel().to(device_type)
compiled_model = torch.compile(model, mode="max-autotune")
x = torch.randn(10, 10, device=device_type)
loss = compiled_model(x).sum()
loss.backward()
self.assertExpectedInline(
"\n".join(order),
"""\
start=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='0/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='0/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='1/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='1/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='1/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='1/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='0/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='0/0')""", # noqa: B950
)
order.clear()
if not HAS_CUDA_AND_TRITON:
return
compiled_model.zero_grad()
loss = compiled_model(x).sum()
loss.backward()
self.assertExpectedInline(
"\n".join(order),
"""\
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')""", # noqa: B950
)
order.clear()
compiled_model.zero_grad()
loss = compiled_model(x).sum()
loss.backward()
self.assertEqual(len(order), 0)
if __name__ == "__main__":
run_tests()