mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This PR completes two TODO items by using features of `dataclass`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164221 Approved by: https://github.com/Skylion007, https://github.com/mlazos Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
148 lines
5.6 KiB
Python
148 lines
5.6 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import unittest
|
|
from unittest.mock import Mock
|
|
|
|
import torch
|
|
from torch._dynamo.callback import callback_handler, CallbackArgs, CallbackTrigger
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._guards import CompileId
|
|
from torch.testing._internal.common_utils import TEST_WITH_ROCM
|
|
from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON, requires_gpu
|
|
|
|
|
|
device_type = (
|
|
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
|
)
|
|
|
|
|
|
class CallbackTests(TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._on_compile_start = Mock()
|
|
self._on_compile_end = Mock()
|
|
callback_handler.register_start_callback(self._on_compile_start)
|
|
callback_handler.register_end_callback(self._on_compile_end)
|
|
|
|
def tearDown(self) -> None:
|
|
callback_handler.clear()
|
|
return super().tearDown()
|
|
|
|
def test_callbacks_with_duplicate_prevention(self) -> None:
|
|
trigger = CallbackTrigger.DYNAMO
|
|
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
|
with (
|
|
callback_handler.install_callbacks(trigger, compile_id),
|
|
callback_handler.install_callbacks(trigger, compile_id),
|
|
):
|
|
self._on_compile_start.assert_called_once()
|
|
self._on_compile_end.assert_called_once()
|
|
|
|
def test_counter(self) -> None:
|
|
trigger = CallbackTrigger.DYNAMO
|
|
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
|
with callback_handler.install_callbacks(trigger, compile_id):
|
|
self.assertEqual(
|
|
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
|
|
1,
|
|
)
|
|
self.assertEqual(
|
|
callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0
|
|
)
|
|
|
|
def test_counter_assertion(self) -> None:
|
|
callback_handler._CompilationCallbackHandler__pending_callbacks_counter -= 1
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "Pending callbacks counter cannot become negative."
|
|
):
|
|
trigger = CallbackTrigger.DYNAMO
|
|
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
|
with callback_handler.install_callbacks(trigger, str(compile_id)):
|
|
pass
|
|
self.assertEqual(
|
|
callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs"
|
|
)
|
|
@requires_gpu
|
|
@torch._inductor.config.patch(force_disable_caches=True)
|
|
def test_triggers(self) -> None:
|
|
torch._dynamo.reset()
|
|
order = []
|
|
|
|
def on_start(args: CallbackArgs):
|
|
nonlocal order
|
|
order.append(f"start={args}")
|
|
|
|
def on_end(args: CallbackArgs):
|
|
nonlocal order
|
|
order.append(f"end={args}")
|
|
|
|
torch._dynamo.callback.on_compile_start(on_start)
|
|
torch._dynamo.callback.on_compile_start(on_end)
|
|
|
|
class TinyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1 = torch.nn.Linear(10, 10)
|
|
self.relu = torch.nn.ReLU()
|
|
self.fc2 = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
temp = self.fc1(x)
|
|
temp = self.relu(temp)
|
|
torch._dynamo.graph_break()
|
|
return self.fc2(temp)
|
|
|
|
model = TinyModel().to(device_type)
|
|
compiled_model = torch.compile(model, mode="max-autotune")
|
|
x = torch.randn(10, 10, device=device_type)
|
|
|
|
loss = compiled_model(x).sum()
|
|
loss.backward()
|
|
self.assertExpectedInline(
|
|
"\n".join(order),
|
|
"""\
|
|
start=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='0/0')
|
|
end=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='0/0')
|
|
start=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='1/0')
|
|
end=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='1/0')
|
|
start=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='1/0')
|
|
end=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='1/0')
|
|
start=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='0/0')
|
|
end=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='0/0')""", # noqa: B950
|
|
)
|
|
order.clear()
|
|
|
|
if not HAS_CUDA_AND_TRITON:
|
|
return
|
|
|
|
compiled_model.zero_grad()
|
|
loss = compiled_model(x).sum()
|
|
loss.backward()
|
|
|
|
self.assertExpectedInline(
|
|
"\n".join(order),
|
|
"""\
|
|
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
|
|
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
|
|
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
|
|
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
|
|
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
|
|
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
|
|
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
|
|
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')""", # noqa: B950
|
|
)
|
|
order.clear()
|
|
|
|
compiled_model.zero_grad()
|
|
loss = compiled_model(x).sum()
|
|
loss.backward()
|
|
self.assertEqual(len(order), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|