Files
pytorch/test/dynamo/test_recompile_ux.py
Yuanyuan Chen a029675f6f More ruff SIM fixes (#164695)
This PR applies ruff `SIM` rules to more files. Most changes are about simplifying `dict.get` because `None` is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164695
Approved by: https://github.com/ezyang
2025-10-09 03:24:50 +00:00

327 lines
11 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
import weakref
import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._logging
from torch._dynamo.exc import FailOnRecompileLimitHit
from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
class RecompileUxTests(torch._dynamo.test_case.TestCase):
# TODO(whc) dynamo actually recompiles one more time than the cache limit
cache_limit = 1
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
torch._dynamo.config.patch("recompile_limit", cls.cache_limit)
)
def test_drop_cache_on_skip(self):
def model(x, i):
return x + i
attached = False
triggered = False
def trigger():
nonlocal triggered
triggered = True
def compiler(gm, input):
nonlocal attached
f = gm.forward
assert not attached
# NB: making this a weakref.ref causes the cycle to no
# longer be promptly GC'ed
weakref.finalize(f, trigger)
attached = True
return f
x = torch.randn(2)
for i in range(2):
opt_model = torch.compile(model, backend=compiler)
opt_model(x, i)
self.assertTrue(triggered)
def test_loop_torture(self):
def loop_torture(input, iters):
out = input
# randint itself causes one graph break
for _ in range(iters):
out += input
return out
compile_counter = torch._dynamo.testing.CompileCounter()
for _ in range(10):
x = torch.randn(3)
iters = torch.randint(low=0, high=1000, size=())
opt_loop_torture = torch.compile(loop_torture, backend=compile_counter)
opt_loop_torture(x, iters)
# Currently, we recompile each time,
# We'd probably like to bail out quickly and warn
# TODO(whc) these checks fail on py37. Why?
# self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit)
# self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit)
# compile_counter only sees frames that were fed to the backend compiler,
# which is a subset of counters["frames"]["ok"] -- probably because
# counters["frames"]["ok"] includes frames not containing torch ops?
self.assertEqual(compile_counter.frame_count, self.cache_limit)
@torch._dynamo.config.patch("automatic_dynamic_shapes", False)
def test_dynamic_input(self):
def model(input):
return input + input
expected_recompiles = 2
compile_counter = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch("recompile_limit", expected_recompiles):
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
for _ in range(10):
bsz = torch.randint(low=0, high=1000, size=())
x = torch.randn((bsz, 3, 4))
opt_model = torch.compile(model, backend=compile_counter)
opt_model(x)
self.assertEqual(compile_counter.frame_count, expected_recompiles)
self.assertEqual(len(logs.records), 1)
print(logs.records[0])
self.assertTrue(
logs.records[0]
.getMessage()
.startswith("torch._dynamo hit config.recompile_limit")
)
@unittest.skipIf(
not torch.cuda.is_available() and not torch.xpu.is_available(),
"requires cuda or xpu",
)
def test_nvfuser_guards(self):
# we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards
# such that we ensure dynamo is in charge of all the recompilations at the top level,
# and we could thus simplify the underlying torchscript executor
def func(a, b, c):
return a + b * c
a = torch.rand(3, 4, 5, device=device_type)
b = torch.rand(3, 4, 5, device=device_type)
b_v = torch.rand(3, 5, 4, device=device_type).view(3, 4, 5)
b_p = torch.rand(3, 5, 4, device=device_type).permute(0, 2, 1)
c = torch.rand(3, 4, 5, device=device_type)
compile_counter = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch("recompile_limit", 2):
opt_func = torch.compile(func, backend=compile_counter)
opt_func(a, b, c) # warmup
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b, c) # no guard fail or recompile
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b_v, c) # a view should not cause nvfuser recompile
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b_p, c) # a permutation should cause recompile
self.assertEqual(compile_counter.frame_count, 2)
def assert_single_log_contains(self, logs, contains_str):
self.assertEqual(len(logs.records), 1)
self.assertTrue(
logs.records[0].getMessage().find(contains_str) > 0,
msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"',
)
def test_verbose_tensor_check(self):
def func(a):
# Warning: choose a function here whose meta implementation lives
# entirely in C++. If you do a Python one, Dynamo will dive into
# torch._refs which is OK but it will muddy up the warnings
return torch.add(a, 4)
def cache_fail_test(cached_input, missed_input, expected_failure):
# TODO(whc) maybe its hacky to have a 'test within a test' but this seemed convenient
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
opt_func = torch.compile(func, backend="eager")
# warmup
opt_func(cached_input)
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
opt_func = torch.compile(func, backend="eager")
opt_func(missed_input)
self.assert_single_log_contains(logs, expected_failure)
a = torch.rand(3, 4, 5)
cache_fail_test(
a,
a[0:2, :, :],
"tensor 'a' size mismatch at index 0. expected 3, actual 2",
)
cache_fail_test(
a,
a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
"tensor 'a' stride mismatch at index 0. expected 20, actual 1",
)
cache_fail_test(a, a[0, :, :], "tensor 'a' rank mismatch. expected 3, actual 2")
cache_fail_test(a, a.to("meta"), "tensor 'a' dispatch key set mismatch.")
cache_fail_test(
a,
a.to(torch.float16),
"tensor 'a' dtype mismatch. expected Float, actual Half",
)
a_grad = a.clone()
a_grad.requires_grad = True
cache_fail_test(
a,
a_grad,
"tensor 'a' requires_grad mismatch. expected requires_grad=0",
)
def test_mismatched_type(self):
a = torch.rand(3, 4, 5)
b = torch.rand(3, 4, 5)
def func(a, b):
return a + b
opt_func = torch.compile(func, backend="eager")
# warmup
opt_func(a, b)
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
opt_func = torch.compile(func, backend="eager")
opt_func(a, 1)
self.assert_single_log_contains(
logs,
"expected type of 'b' to be a tensor type, ' but found <class 'int'>",
)
@torch._dynamo.config.patch(recompile_limit=1, fail_on_recompile_limit_hit=True)
def test_fail_on_recompile_limit_hit(self):
@torch.compile(backend="eager")
def func(b, a):
if a:
return b * 2
else:
return b + 1
func(torch.randn(5), True)
with self.assertRaises(FailOnRecompileLimitHit):
func(torch.randn(5), False)
@torch._dynamo.config.patch("recompile_limit", 32)
def test_multiple_guard_fails(self):
failure_reasons = []
def guard_fail_fn(failure):
failure_reasons.append(failure[0])
def f(x):
return torch.relu(x)
opt_f = torch._dynamo.optimize(
backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
)(f)
for i in range(5):
failure_reasons.clear()
opt_f(torch.randn(8 + i))
failure_str = "\n".join(failure_reasons)
for line in [
"tensor 'x' size mismatch at index 0. expected 11, actual 12",
"tensor 'x' size mismatch at index 0. expected 10, actual 12",
"tensor 'x' size mismatch at index 0. expected 9, actual 12",
"tensor 'x' size mismatch at index 0. expected 8, actual 12",
]:
self.assertIn(
line,
failure_str,
)
@torch._dynamo.config.patch("recompile_limit", 32)
def test_multiple_guard_fails_report_all(self):
with log_settings(kwargs_to_settings(recompiles_verbose=True)):
failure_reasons = []
def guard_fail_fn(failure):
failure_reasons.append(failure[0])
def f(x):
return torch.ones(len(x), x[-1])
opt_f = torch._dynamo.optimize(
backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
)(f)
opt_f([4, 5, 6])
def filter_reasons():
return "\n".join(
[
line
for line in "\n".join(failure_reasons).splitlines()
if not line.startswith("___check_type_id")
]
)
failure_reasons.clear()
opt_f([7, 8])
for line in ["len(x) == 3"]:
self.assertIn(line, filter_reasons())
failure_reasons.clear()
opt_f([9])
for line in ["len(x) == 2", "len(x) == 3"]:
self.assertIn(line, filter_reasons())
@torch._dynamo.config.patch(recompile_limit=1)
def test_recompile_child_run_only(self):
def f(x, n):
if torch.compiler.is_compiling():
x = x + 1
x = g(x)
return h(x) + n
def g(x):
if torch.compiler.is_compiling():
return x + 2
return x
def h(x):
if torch.compiler.is_compiling():
return x + 4
return x
torch.compile(g, backend="eager")(torch.randn(3))
inp = torch.randn(3)
opt_f = torch.compile(f, backend="eager")
opt_f(inp, 0)
# expect f to run eager, g compiled (from previous invocatino), h eager
res = opt_f(inp, 1)
self.assertEqual(res, inp + 3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()