mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is a requirement from Meta internal cases, where ppl wants to register a callback function to detect if a job is stuck during compilation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120764 Approved by: https://github.com/jansel
161 lines
4.8 KiB
Python
161 lines
4.8 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import inspect
|
|
import io
|
|
import os
|
|
import tempfile
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._dynamo.testing import CompileCounter
|
|
|
|
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.linear(x))
|
|
|
|
|
|
class InPlaceCompilationTests(TestCase):
|
|
def test_compilation(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
cnt = CompileCounter()
|
|
model.compile(backend=cnt)
|
|
x = torch.randn(10, 10)
|
|
model(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_overwrite_call_impl(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
self.assertTrue(model._compiled_call_impl is None)
|
|
model.compile()
|
|
self.assertTrue(model._compiled_call_impl is not None)
|
|
|
|
def test_save(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
model.compile()
|
|
model(torch.randn(1, 10))
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
torch.save(model, os.path.join(tmpdirname, "model.pt"))
|
|
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
|
loaded_model(torch.randn(1, 10))
|
|
|
|
def test_state_dict_save(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
model.compile()
|
|
model(torch.randn(1, 10))
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt"))
|
|
loaded_model = ToyModel()
|
|
loaded_model.load_state_dict(
|
|
torch.load(os.path.join(tmpdirname, "model.pt"))
|
|
)
|
|
loaded_model(torch.randn(1, 10))
|
|
|
|
def test_jit_save(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
model.compile()
|
|
model(torch.randn(1, 10))
|
|
scripted_model = torch.jit.script(model)
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
|
|
loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
|
|
loaded_model(torch.randn(1, 10))
|
|
|
|
def test_compilation_callback(self):
|
|
torch._dynamo.reset()
|
|
|
|
@torch._dynamo.on_compile_start
|
|
def start_callback():
|
|
print("Compilation started.")
|
|
|
|
@torch._dynamo.on_compile_end
|
|
def end_callback():
|
|
print("Compilation ended.")
|
|
|
|
mod = ToyModel()
|
|
x = torch.randn(10, 10)
|
|
|
|
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
|
opt_mod = torch.compile(backend="eager", fullgraph=True)(mod)
|
|
opt_mod(x)
|
|
printed_output = mock_stdout.getvalue().strip()
|
|
|
|
self.assertEqual(printed_output, "Compilation started.\nCompilation ended.")
|
|
|
|
def test_compilation_callback_with_graph_break(self):
|
|
torch._dynamo.reset()
|
|
counter = 0
|
|
|
|
@torch._dynamo.on_compile_start
|
|
def start_callback():
|
|
nonlocal counter
|
|
counter += 1
|
|
print(f"Counter = {counter}")
|
|
|
|
@torch._dynamo.on_compile_end
|
|
def end_callback():
|
|
nonlocal counter
|
|
counter += 1
|
|
print(f"Counter = {counter}")
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
return torch.sin(x)
|
|
|
|
x = torch.randn(10, 10)
|
|
|
|
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
|
fn(x)
|
|
printed_output = mock_stdout.getvalue().strip()
|
|
|
|
self.assertEqual(
|
|
printed_output, "Counter = 1\nCounter = 2\nCounter = 3\nCounter = 4"
|
|
)
|
|
|
|
|
|
# The private variants of the below functions are extensively tested
|
|
# So as long as the signatures match we're good
|
|
class PublicTorchCompilerTests(TestCase):
|
|
def check_signature(self, public_fn_name, private_fn_name, private_namespace):
|
|
public_fn = getattr(torch.compiler, public_fn_name)
|
|
private_fn = getattr(private_namespace, private_fn_name)
|
|
|
|
public_sig = inspect.signature(public_fn)
|
|
private_sig = inspect.signature(private_fn)
|
|
|
|
self.assertEqual(
|
|
public_sig,
|
|
private_sig,
|
|
f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}",
|
|
)
|
|
|
|
def test_dynamo_signatures(self):
|
|
function_names = [
|
|
"reset",
|
|
"allow_in_graph",
|
|
"list_backends",
|
|
"assume_constant_result",
|
|
"disable",
|
|
]
|
|
|
|
for fn_name in function_names:
|
|
self.check_signature(fn_name, fn_name, torch._dynamo)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|