mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Rename cache limit to recompile limit in configs (#143709)
This PR renames every cache_limit to recompile_limit via sed. Old config options are maintained via Config(alias='xyz') Pull Request resolved: https://github.com/pytorch/pytorch/pull/143709 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
9bf4b1c2e9
commit
dc55704b48
@ -18,7 +18,7 @@ def foo({", ".join(args)}):
|
||||
"""
|
||||
|
||||
exec(fn_str, globals())
|
||||
torch._dynamo.config.cache_size_limit = 16
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
|
||||
|
||||
def bench(name, fn):
|
||||
|
@ -8,7 +8,7 @@ def setup_baseline():
|
||||
|
||||
recommended_inductor_config_setter()
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
torch._dynamo.config.recompile_limit = 10000
|
||||
|
||||
|
||||
def torchao_optimize_ctx(quantization: str):
|
||||
|
@ -26,7 +26,7 @@ from torch.nn.attention.flex_attention import (
|
||||
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
# Needed since changing args to function causes recompiles
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
torch._dynamo.config.recompile_limit = 1000
|
||||
|
||||
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
|
@ -126,7 +126,7 @@ Why is compilation slow?
|
||||
optimizations, and expresses these assumptions as guards that check
|
||||
particular values at runtime. If any of these guards fail, Dynamo will
|
||||
recompile that function (or part) up to
|
||||
``torch._dynamo.config.cache_size_limit`` times. If your program is
|
||||
``torch._dynamo.config.recompile_limit`` times. If your program is
|
||||
hitting the cache limit, you will first need to determine which guard is
|
||||
failing and what part of your program is triggering it. The
|
||||
`recompilation profiler <#recompilation-profiler>`__ automates the
|
||||
|
@ -618,8 +618,8 @@ For more information on dynamic shapes, see `The dynamic shapes manual <https://
|
||||
Changing the cache size limit
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
There is a limit to how many times a function can be recompiled, determined by ``torch._dynamo.config.cache_size_limit``
|
||||
and ``torch._dynamo.config.accumulated_cache_size_limit``.
|
||||
There is a limit to how many times a function can be recompiled, determined by ``torch._dynamo.config.recompile_limit``
|
||||
and ``torch._dynamo.config.accumulated_recompile_limit``.
|
||||
If either limit is exceeded, then we will not attempt to compile the function again and instead will run the function eagerly.
|
||||
``torch.compile`` will also issue a warning containing the affected function and which limit was hit.
|
||||
In the example below, each function call results in a recompile attempt.
|
||||
@ -639,7 +639,7 @@ When we hit the cache size limit (8), we stop attempting to recompile.
|
||||
::
|
||||
|
||||
$ python playground.py
|
||||
torch._dynamo hit config.cache_size_limit (8)
|
||||
torch._dynamo hit config.recompile_limit (8)
|
||||
function: 'fn' (/data/users/williamwen/pytorch/playground.py:5)
|
||||
last reason: 0/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 9
|
||||
|
||||
@ -676,7 +676,7 @@ In the below example, we have a recompilation for each function call.
|
||||
- 0/2: L['c'] == 3.5
|
||||
- 0/1: L['c'] == 2.5
|
||||
- 0/0: L['c'] == 1.5
|
||||
torch._dynamo hit config.cache_size_limit (8)
|
||||
torch._dynamo hit config.recompile_limit (8)
|
||||
function: 'fn' (/data/users/williamwen/pytorch/playground.py:3)
|
||||
last reason: 0/0: L['c'] == 1.5
|
||||
|
||||
@ -714,7 +714,7 @@ In particular, for LR schedulers, initializing with a constant can lead to recom
|
||||
- 3/2: L['self'].param_groups[0]['lr'] == 0.008100000000000001
|
||||
- 3/1: L['self'].param_groups[0]['lr'] == 0.009000000000000001
|
||||
- 3/0: L['self'].param_groups[0]['lr'] == 0.01
|
||||
torch._dynamo hit config.cache_size_limit (8)
|
||||
torch._dynamo hit config.recompile_limit (8)
|
||||
function: 'step' (/data/users/williamwen/pytorch/torch/optim/adam.py:189)
|
||||
last reason: 3/0: L['self'].param_groups[0]['lr'] == 0.01
|
||||
|
||||
|
@ -668,7 +668,7 @@ assumptions about locals and globals in order to allow compiler
|
||||
optimizations, and expresses these assumptions as guards that check
|
||||
particular values at runtime. If any of these guards fail, Dynamo will
|
||||
recompile that function (or part) up to
|
||||
``torch._dynamo.config.cache_size_limit`` times. If your program is
|
||||
``torch._dynamo.config.recompile_limit`` times. If your program is
|
||||
hitting the cache limit, you will first need to determine which guard is
|
||||
failing and what part of your program is triggering it.
|
||||
|
||||
@ -679,7 +679,7 @@ cost of recompilation outweighs any optimization benefits.
|
||||
|
||||
::
|
||||
|
||||
torch._dynamo.config.cache_size_limit = <your desired cache limit>
|
||||
torch._dynamo.config.recompile_limit = <your desired cache limit>
|
||||
|
||||
TorchDynamo plans to support many common cases of dynamic tensor shapes,
|
||||
such as varying batch size or sequence length. It does not plan to
|
||||
|
@ -66,8 +66,8 @@ class ConfigTests(torch._dynamo.test_case.TestCase):
|
||||
"verbose",
|
||||
"verify_correctness", # will not affect model, will raise RuntimeError
|
||||
# (no silent change to compilation behaviour)
|
||||
"cache_size_limit",
|
||||
"accumulated_cache_size_limit",
|
||||
"recompile_limit",
|
||||
"accumulated_recompile_limit",
|
||||
"replay_record_enabled",
|
||||
"cprofile", # only wraps _compile, not graph
|
||||
"repro_after",
|
||||
|
@ -217,7 +217,7 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
self.assertTrue(same(val4, correct1))
|
||||
self.assertEqual(counter.frame_count, 3)
|
||||
|
||||
@torch._dynamo.config.patch(accumulated_cache_size_limit=1)
|
||||
@torch._dynamo.config.patch(accumulated_recompile_limit=1)
|
||||
def test_dynamo_disabled_in_custom_op_kernels(self):
|
||||
counters.clear()
|
||||
|
||||
@ -2564,7 +2564,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(mandelbrot_numpy, backend=cnts, fullgraph=True)
|
||||
n_iter = torch._dynamo.config.cache_size_limit - 2
|
||||
n_iter = torch._dynamo.config.recompile_limit - 2
|
||||
for i in range(n_iter):
|
||||
x = i + 3
|
||||
ref = mandelbrot_numpy(x)
|
||||
@ -2757,7 +2757,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# cache size limit needs to be larger than the `dtypes` list size
|
||||
@torch._dynamo.config.patch(cache_size_limit=12)
|
||||
@torch._dynamo.config.patch(recompile_limit=12)
|
||||
def test_dtypes_no_graphbreaks(self):
|
||||
dtypes = [
|
||||
# floats
|
||||
|
@ -489,7 +489,7 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# Needs larger cache size since we recompile for each op
|
||||
@patch.object(torch._dynamo.config, "cache_size_limit", 48)
|
||||
@patch.object(torch._dynamo.config, "recompile_limit", 48)
|
||||
def test_builtin_equivalent_funcs(self):
|
||||
from torch._dynamo.variables.torch_function import (
|
||||
bin_int_ops,
|
||||
|
@ -2086,7 +2086,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test_no_recompile_on_nn_guarded_modules(self):
|
||||
size = (10, 10)
|
||||
cache_size_limit = 1
|
||||
recompile_limit = 1
|
||||
num_submodules = 4
|
||||
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
|
||||
|
||||
@ -2116,8 +2116,8 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.config.error_on_recompile", True
|
||||
), unittest.mock.patch(
|
||||
"torch._dynamo.config.cache_size_limit",
|
||||
cache_size_limit,
|
||||
"torch._dynamo.config.recompile_limit",
|
||||
recompile_limit,
|
||||
):
|
||||
x = torch.randn(*size, requires_grad=True)
|
||||
mod(x)
|
||||
@ -2126,7 +2126,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
else:
|
||||
self.assertEqual(cnts.frame_count, num_submodules)
|
||||
|
||||
@patch.object(torch._dynamo.config, "accumulated_cache_size_limit", 2)
|
||||
@patch.object(torch._dynamo.config, "accumulated_recompile_limit", 2)
|
||||
@patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False)
|
||||
def test_recompile_limit_on_freed_module(self):
|
||||
class Mod(torch.nn.Module):
|
||||
@ -2152,7 +2152,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
@patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True)
|
||||
def test_inline_inbuilt_nn_modules(self):
|
||||
size = (10, 10)
|
||||
cache_size_limit = 1
|
||||
recompile_limit = 1
|
||||
num_submodules = 4
|
||||
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
|
||||
|
||||
@ -2182,15 +2182,15 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.config.error_on_recompile", True
|
||||
), unittest.mock.patch(
|
||||
"torch._dynamo.config.cache_size_limit",
|
||||
cache_size_limit,
|
||||
"torch._dynamo.config.recompile_limit",
|
||||
recompile_limit,
|
||||
):
|
||||
x = torch.randn(*size, requires_grad=True)
|
||||
mod(x)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_cache_size_limit_on_guarded_nn_modules(self):
|
||||
cache_size_limit = 2
|
||||
def test_recompile_limit_on_guarded_nn_modules(self):
|
||||
recompile_limit = 2
|
||||
num_submodules = 4
|
||||
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
|
||||
|
||||
@ -2219,8 +2219,8 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
# therefore the total number of expected frame count is 2 *
|
||||
# num_submodules.
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.config.cache_size_limit",
|
||||
cache_size_limit,
|
||||
"torch._dynamo.config.recompile_limit",
|
||||
recompile_limit,
|
||||
):
|
||||
for size in [
|
||||
(4,),
|
||||
|
@ -20,7 +20,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._exit_stack.enter_context(
|
||||
torch._dynamo.config.patch("cache_size_limit", cls.cache_limit)
|
||||
torch._dynamo.config.patch("recompile_limit", cls.cache_limit)
|
||||
)
|
||||
|
||||
def test_drop_cache_on_skip(self):
|
||||
@ -84,7 +84,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
expected_recompiles = 2
|
||||
compile_counter = torch._dynamo.testing.CompileCounter()
|
||||
with torch._dynamo.config.patch("cache_size_limit", expected_recompiles):
|
||||
with torch._dynamo.config.patch("recompile_limit", expected_recompiles):
|
||||
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
|
||||
for _ in range(10):
|
||||
bsz = torch.randint(low=0, high=1000, size=())
|
||||
@ -98,7 +98,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(
|
||||
logs.records[0]
|
||||
.getMessage()
|
||||
.startswith("torch._dynamo hit config.cache_size_limit")
|
||||
.startswith("torch._dynamo hit config.recompile_limit")
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
@ -116,7 +116,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
c = torch.rand(3, 4, 5, device="cuda")
|
||||
compile_counter = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
with torch._dynamo.config.patch("cache_size_limit", 2):
|
||||
with torch._dynamo.config.patch("recompile_limit", 2):
|
||||
opt_func = torch.compile(func, backend=compile_counter)
|
||||
opt_func(a, b, c) # warmup
|
||||
self.assertEqual(compile_counter.frame_count, 1)
|
||||
@ -204,8 +204,8 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
"expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(cache_size_limit=1, fail_on_cache_limit_hit=True)
|
||||
def test_fail_on_cache_limit_hit(self):
|
||||
@torch._dynamo.config.patch(recompile_limit=1, fail_on_recompile_limit_hit=True)
|
||||
def test_fail_on_recompile_limit_hit(self):
|
||||
@torch.compile(backend="eager")
|
||||
def func(b, a):
|
||||
if a:
|
||||
@ -217,7 +217,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
with self.assertRaises(FailOnRecompileLimitHit):
|
||||
func(torch.randn(5), False)
|
||||
|
||||
@torch._dynamo.config.patch("cache_size_limit", 32)
|
||||
@torch._dynamo.config.patch("recompile_limit", 32)
|
||||
def test_multiple_guard_fails(self):
|
||||
failure_reasons = []
|
||||
|
||||
@ -248,7 +248,7 @@ tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
|
||||
failure_str,
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch("cache_size_limit", 32)
|
||||
@torch._dynamo.config.patch("recompile_limit", 32)
|
||||
def test_multiple_guard_fails_report_all(self):
|
||||
with log_settings(kwargs_to_settings(recompiles_verbose=True)):
|
||||
failure_reasons = []
|
||||
|
@ -315,7 +315,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
|
||||
model(x)
|
||||
self.assertEqual(counter.frame_count, 2)
|
||||
|
||||
@patch.object(torch._dynamo.config, "cache_size_limit", 2)
|
||||
@patch.object(torch._dynamo.config, "recompile_limit", 2)
|
||||
def test_no_recursive_compile_after_cache_limit_hit(self):
|
||||
def f(x, n):
|
||||
x = x + n
|
||||
@ -351,7 +351,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
|
||||
h(torch.randn(5), f(i))
|
||||
self.assertEqual(counter.frame_count, 2)
|
||||
|
||||
@patch.object(torch._dynamo.config, "cache_size_limit", 2)
|
||||
@patch.object(torch._dynamo.config, "recompile_limit", 2)
|
||||
def test_run_mode_after_cache_limit_hit(self):
|
||||
def f(x, n):
|
||||
x = x + n
|
||||
|
@ -5861,7 +5861,7 @@ def forward(self, arg0_1, arg1_1):
|
||||
|
||||
inp = torch.ones(3, 4)
|
||||
exp_out = inp.sin()
|
||||
iter_n = torch._dynamo.config.cache_size_limit + 1
|
||||
iter_n = torch._dynamo.config.recompile_limit + 1
|
||||
|
||||
# Need functions that cause recompilations
|
||||
def get_dummy_fns(str):
|
||||
|
@ -14,7 +14,7 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
class B2BGEMMTest(TestCase):
|
||||
device = GPU_TYPE
|
||||
|
||||
@torch._dynamo.config.patch(cache_size_limit=32)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
@torch._inductor.config.patch(b2b_gemm_pass=True)
|
||||
def test_b2b_gemm_left_assoc_good_shape(self):
|
||||
"""
|
||||
@ -48,7 +48,7 @@ class B2BGEMMTest(TestCase):
|
||||
self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01))
|
||||
self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" in code)
|
||||
|
||||
@torch._dynamo.config.patch(cache_size_limit=32)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
@torch._inductor.config.patch(b2b_gemm_pass=True)
|
||||
def test_b2b_gemm_right_assoc_good_shape(self):
|
||||
"""
|
||||
@ -74,7 +74,7 @@ class B2BGEMMTest(TestCase):
|
||||
self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01))
|
||||
self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" in code)
|
||||
|
||||
@torch._dynamo.config.patch(cache_size_limit=32)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
@torch._inductor.config.patch(b2b_gemm_pass=True)
|
||||
def test_b2b_gemm_trivial_left_assoc_good_shape(self):
|
||||
"""
|
||||
@ -99,7 +99,7 @@ class B2BGEMMTest(TestCase):
|
||||
self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01))
|
||||
self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" in code)
|
||||
|
||||
@torch._dynamo.config.patch(cache_size_limit=32)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
@torch._inductor.config.patch(b2b_gemm_pass=True)
|
||||
def test_b2b_gemm_trivial_right_assoc_good_shape(self):
|
||||
"""
|
||||
@ -124,7 +124,7 @@ class B2BGEMMTest(TestCase):
|
||||
self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01))
|
||||
self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" in code)
|
||||
|
||||
@torch._dynamo.config.patch(cache_size_limit=32)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
@torch._inductor.config.patch(b2b_gemm_pass=True)
|
||||
def test_b2b_gemm_bad_pattern_good_shape(self):
|
||||
"""
|
||||
@ -145,7 +145,7 @@ class B2BGEMMTest(TestCase):
|
||||
self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code)
|
||||
self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code)
|
||||
|
||||
@torch._dynamo.config.patch(cache_size_limit=32)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
@torch._inductor.config.patch(b2b_gemm_pass=True)
|
||||
def test_b2b_gemm_good_pattern_bad_shape(self):
|
||||
"""
|
||||
@ -167,7 +167,7 @@ class B2BGEMMTest(TestCase):
|
||||
@unittest.skipIf(
|
||||
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
|
||||
)
|
||||
@torch._dynamo.config.patch(cache_size_limit=32)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
def test_plain_b2b_gemm_performance(self):
|
||||
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
||||
|
||||
@ -222,7 +222,7 @@ class B2BGEMMTest(TestCase):
|
||||
@unittest.skipIf(
|
||||
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
|
||||
)
|
||||
@torch._dynamo.config.patch(cache_size_limit=32)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
def test_gelu_b2b_gemm_performance(self):
|
||||
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
||||
|
||||
@ -279,7 +279,7 @@ class B2BGEMMTest(TestCase):
|
||||
@unittest.skipIf(
|
||||
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
|
||||
)
|
||||
@torch._dynamo.config.patch(cache_size_limit=32)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
def test_gelu_mlp_b2b_gemm_performance(self):
|
||||
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
||||
|
||||
|
@ -7120,7 +7120,7 @@ class CommonTemplate:
|
||||
for ref, test in zip(refs, tests):
|
||||
torch.testing.assert_close(ref, test)
|
||||
|
||||
@torch._dynamo.config.patch(cache_size_limit=10)
|
||||
@torch._dynamo.config.patch(recompile_limit=10)
|
||||
def test_tensor_index_put_slice(self):
|
||||
def fn(a, version):
|
||||
x = torch.tensor([1, 2], device=self.device, dtype=torch.int32)
|
||||
|
@ -3132,7 +3132,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
||||
self.assertNotIn(opname, code)
|
||||
|
||||
@requires_gpu
|
||||
@patch.object(torch._dynamo.config, "cache_size_limit", 1)
|
||||
@patch.object(torch._dynamo.config, "recompile_limit", 1)
|
||||
def test_triton_dynamic_grid_no_recompile(self):
|
||||
libname = "my_cool_namespace"
|
||||
opname = "my_triton_operator"
|
||||
|
@ -67,7 +67,7 @@ class TestContentStore(TestCase):
|
||||
# Should not raise an error
|
||||
hash_storage(torch.tensor(2, device=device).untyped_storage())
|
||||
|
||||
@torch._dynamo.config.patch(cache_size_limit=1)
|
||||
@torch._dynamo.config.patch(recompile_limit=1)
|
||||
def test_repeated_hash(self, device):
|
||||
# Test that repeated hashing doesn't trigger a recompile in dynamo
|
||||
# If it does, we will execute prims.xor_sum in eager which fails
|
||||
|
@ -2417,7 +2417,7 @@ def compile(
|
||||
results are not applicable for subsequent calls (this is called a "guard
|
||||
failure), you can use TORCH_LOGS=guards to debug these situations.
|
||||
Multiple compiled results can be associated with a frame up to
|
||||
``torch._dynamo.config.cache_size_limit``, which defaults to 8; at which
|
||||
``torch._dynamo.config.recompile_limit``, which defaults to 8; at which
|
||||
point we will fall back to eager. Note that compile caches are per
|
||||
*code object*, not frame; if you dynamically create multiple copies of a
|
||||
function, they will all share the same code cache.
|
||||
|
@ -22,8 +22,8 @@ of the guard_manager's returns True, we recompile and add a new entry. To ensure
|
||||
don't end up recompiling infinitely, we put limits on the cache size.
|
||||
|
||||
There are two limits
|
||||
1) cache_size_limit
|
||||
2) accumulated_cache_size_limit
|
||||
1) recompile_limit
|
||||
2) accumulated_recompile_limit
|
||||
|
||||
|
||||
Earlier we used to have only limit - maximum number of entries in 1 cache line
|
||||
@ -33,11 +33,11 @@ to understand that.
|
||||
In general, we want our cache limit value to be a small number (e.g. 8 or even
|
||||
lower). This ensures that for frames that cause too many recompilation fall to
|
||||
eager quickly. However, there is another problem that prevents us from lowering
|
||||
the value of cache_size_limit. This is due to ID_MATCH'd guards. Today, we put
|
||||
the value of recompile_limit. This is due to ID_MATCH'd guards. Today, we put
|
||||
ID_MATCH guards on nn module if there is a graph break. This means we will have
|
||||
many recompilations for the same code object because the ID_MATCH guard fails
|
||||
for different instances of the nn module. This is a common pattern in how models
|
||||
are authored. Therefore, this requires us to keep the cache_size_limit high.
|
||||
are authored. Therefore, this requires us to keep the recompile_limit high.
|
||||
|
||||
We resolve this by introducing these two limits. The first limit (1) limits the
|
||||
number of cache entries that have an ID_MATCH'd guard for an nn module instance.
|
||||
@ -58,8 +58,8 @@ compilations to burst the cache and fallback to eager. These 32 recompilations
|
||||
are too many and we want to fallback for these compilation-unfriendly functions
|
||||
sooner.
|
||||
|
||||
In the new scenario, we can have (1) cache_size_limit = 2, (2)
|
||||
accumulated_cache_size_limit = 32. This means that each ID_MATCH'd object can
|
||||
In the new scenario, we can have (1) recompile_limit = 2, (2)
|
||||
accumulated_recompile_limit = 32. This means that each ID_MATCH'd object can
|
||||
have maximum of two cache entries, and the maximum number of cache entries
|
||||
(irrespective of ID_MATCH obj) is 32. This covers the case of forward code
|
||||
object which has 32 recompilations. For the other function, the one unsuitable
|
||||
@ -94,7 +94,7 @@ class CacheSizeRelevantForFrame:
|
||||
)
|
||||
|
||||
def will_compilation_exceed_accumulated_limit(self) -> bool:
|
||||
return self.num_cache_entries >= config.accumulated_cache_size_limit
|
||||
return self.num_cache_entries >= config.accumulated_recompile_limit
|
||||
|
||||
def will_compilation_exceed_specific_limit(self, limit: int) -> bool:
|
||||
return self.num_cache_entries_with_same_id_matched_objs >= limit
|
||||
@ -142,7 +142,7 @@ def compute_cache_size(
|
||||
num_cache_entries += 1
|
||||
# Track the number of cache entries having same ID_MATCH'd objects as
|
||||
# that of frame.f_locals. This will be used later to compare against the
|
||||
# cache_size_limit.
|
||||
# recompile_limit.
|
||||
if _has_same_id_matched_objs(frame, cache_entry):
|
||||
num_cache_entries_with_same_id_matched_objs += 1
|
||||
cache_entry = cache_entry.next
|
||||
@ -165,22 +165,22 @@ def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool:
|
||||
return cache_size.will_compilation_exceed(1)
|
||||
|
||||
|
||||
def exceeds_cache_size_limit(
|
||||
def exceeds_recompile_limit(
|
||||
cache_size: CacheSizeRelevantForFrame, compile_id: CompileId
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Checks if we are exceeding the cache size limit.
|
||||
"""
|
||||
if cache_size.will_compilation_exceed_accumulated_limit():
|
||||
return True, "accumulated_cache_size_limit"
|
||||
if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit):
|
||||
return True, "cache_size_limit"
|
||||
return True, "accumulated_recompile_limit"
|
||||
if cache_size.will_compilation_exceed_specific_limit(config.recompile_limit):
|
||||
return True, "recompile_limit"
|
||||
# NOTE this check is needed in the case that the frame's cache doesn't grow
|
||||
# and we keep recompiling. This can happen if the guard guard_manager becomes invalidated,
|
||||
# e.g. due to guarded objects being freed. This technically makes the
|
||||
# will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the
|
||||
# check in case we have a better fix in the future.
|
||||
assert compile_id.frame_compile_id is not None
|
||||
if compile_id.frame_compile_id >= config.accumulated_cache_size_limit:
|
||||
return True, "accumulated_cache_size_limit"
|
||||
if compile_id.frame_compile_id >= config.accumulated_recompile_limit:
|
||||
return True, "accumulated_recompile_limit"
|
||||
return False, ""
|
||||
|
@ -41,19 +41,30 @@ dead_code_elimination = True
|
||||
# object. It also controls the maximum size of cache entries if they don't have
|
||||
# any ID_MATCH'd guards.
|
||||
# [@compile_ignored: runtime_behaviour]
|
||||
cache_size_limit = 8
|
||||
recompile_limit = 8
|
||||
|
||||
# [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps
|
||||
accumulated_cache_size_limit = 256
|
||||
accumulated_recompile_limit = 256
|
||||
|
||||
# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit
|
||||
skip_code_recursive_on_cache_limit_hit = True
|
||||
skip_code_recursive_on_recompile_limit_hit = True
|
||||
|
||||
# raise a hard error if cache limit is hit. If you are on a model where you
|
||||
# know you've sized the cache correctly, this can help detect problems when
|
||||
# you regress guards/specialization. This works best when cache_size_limit = 1.
|
||||
# you regress guards/specialization. This works best when recompile_limit = 1.
|
||||
# [@compile_ignored: runtime_behaviour]
|
||||
fail_on_cache_limit_hit = False
|
||||
fail_on_recompile_limit_hit = False
|
||||
|
||||
cache_size_limit: int = Config(alias="torch._dynamo.config.recompile_limit")
|
||||
accumulated_cache_size_limit: int = Config(
|
||||
alias="torch._dynamo.config.accumulated_recompile_limit"
|
||||
)
|
||||
skip_code_recursive_on_cache_limit_hit: bool = Config(
|
||||
alias="torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit"
|
||||
)
|
||||
fail_on_cache_limit_hit: bool = Config(
|
||||
alias="torch._dynamo.config.fail_on_recompile_limit_hit"
|
||||
)
|
||||
|
||||
# whether or not to specialize on int inputs. This only has an effect with
|
||||
# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
|
||||
|
@ -66,7 +66,7 @@ from .bytecode_transformation import (
|
||||
from .cache_size import (
|
||||
CacheSizeRelevantForFrame,
|
||||
compute_cache_size,
|
||||
exceeds_cache_size_limit,
|
||||
exceeds_recompile_limit,
|
||||
is_recompilation,
|
||||
)
|
||||
from .eval_frame import (
|
||||
@ -911,7 +911,7 @@ def _compile(
|
||||
cache_entry, frame
|
||||
)
|
||||
|
||||
exceeded, limit_type = exceeds_cache_size_limit(cache_size, compile_id)
|
||||
exceeded, limit_type = exceeds_recompile_limit(cache_size, compile_id)
|
||||
if exceeded:
|
||||
|
||||
def format_func_info(code: CodeType) -> str:
|
||||
@ -934,12 +934,12 @@ def _compile(
|
||||
format_guard_failures(),
|
||||
troubleshooting_url,
|
||||
)
|
||||
if config.fail_on_cache_limit_hit:
|
||||
if config.fail_on_recompile_limit_hit:
|
||||
raise FailOnRecompileLimitHit(
|
||||
f"{limit_type} reached, because fail_on_cache_limit_hit = True this is a HARD failure"
|
||||
f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
|
||||
)
|
||||
elif config.skip_code_recursive_on_cache_limit_hit and justknobs_check(
|
||||
"pytorch/compiler:skip_code_recursive_on_cache_limit_hit"
|
||||
elif config.skip_code_recursive_on_recompile_limit_hit and justknobs_check(
|
||||
"pytorch/compiler:skip_code_recursive_on_recompile_limit_hit"
|
||||
):
|
||||
raise RecompileLimitExceeded(f"{limit_type} reached")
|
||||
else:
|
||||
|
@ -2401,16 +2401,16 @@ def format_func_info(code):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disable_cache_limit():
|
||||
prior = config.cache_size_limit
|
||||
config.cache_size_limit = sys.maxsize
|
||||
prior_acc_limit = config.accumulated_cache_size_limit
|
||||
config.accumulated_cache_size_limit = sys.maxsize
|
||||
prior = config.recompile_limit
|
||||
config.recompile_limit = sys.maxsize
|
||||
prior_acc_limit = config.accumulated_recompile_limit
|
||||
config.accumulated_recompile_limit = sys.maxsize
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
config.cache_size_limit = prior
|
||||
config.accumulated_cache_size_limit = prior_acc_limit
|
||||
config.recompile_limit = prior
|
||||
config.accumulated_recompile_limit = prior_acc_limit
|
||||
|
||||
|
||||
# map from transformed code back to original user code
|
||||
|
@ -1574,7 +1574,7 @@ TEST_WITH_TORCHDYNAMO: bool = TestEnvironment.def_flag(
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
import torch._dynamo
|
||||
# Do not spend time on helper functions that are called with different inputs
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 64
|
||||
torch._dynamo.config.accumulated_recompile_limit = 64
|
||||
# Do not log compilation metrics from unit tests
|
||||
torch._dynamo.config.log_compilation_metrics = False
|
||||
# Silence 3.13.0 guard performance warnings
|
||||
|
Reference in New Issue
Block a user