diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 1d2a7b5f9d2d..94f2b3fcb0a5 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3204,6 +3204,37 @@ class TestGuardsExpressions(TestCase): self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)])) + @skipIfTorchDynamo("Attempt to trace generator") + @torch.fx.experimental._config.patch("use_duck_shape", False) + def test_size_comparison_no_recompile(self): + """ + Test that size comparisons don't cause recompilation. + When comparing x.size() == b.size() with different sizes, + the compiled function should only compile once. + We should not guard in sizes of the inner elements. + """ + cnt = CompileCounter() + + @torch.compile(fullgraph=True, dynamic=True, backend=cnt) + def f(x, b): + if x.size() == b.size(): + return x + return x * 2 + + # First call: shapes differ (1, 2) vs (2, 4, 9), so if branch is False + f(torch.rand(10, 2), torch.rand(20, 4, 9)) + + # Second call: shapes differ again (1, 2) vs (1, 4, 9), so if branch is False + f(torch.rand(10, 2), torch.rand(10, 4, 9)) + + # Should only compile once despite different input shapes + self.assertEqual( + cnt.frame_count, + 1, + f"Expected 1 compilation, got {cnt.frame_count}. " + f"Size comparison should not cause recompilation.", + ) + def test_remove_symbols_without_guarding(self): from torch._functorch.partitioners import _remove_symbols_without_guarding diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 4fc777ffe7ef..6f071e818356 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -12,6 +12,7 @@ import types from collections import OrderedDict from collections.abc import Hashable, Iterable, MutableMapping, Sequence from itertools import repeat as _repeat +from operator import eq, ne from typing import Any, Callable, TYPE_CHECKING import torch @@ -106,13 +107,24 @@ def accumulate_grad(x, new_grad): # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413 def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]): """emulate `(1,2,3) > (1,2)` etc""" + + # Optimization: For equality, short-circuit if lengths differ + # This avoids iterating through elements and triggering guards on SymInts + left_len = len(left) + right_len = len(right) + + if op is eq and left_len != right_len: + return False + if op is ne and left_len != right_len: + return True + # Apply `op` to the first pair that differ for a, b in zip(left, right): if a != b: return op(a, b) # No more pairs to compare, so compare sizes. - return op(len(left), len(right)) + return op(left_len, right_len) def dict___eq__(d, other):