do not overguard when comparing lists (#165091)

if we are comparing two lists l1, l2 of different lengths for equality.
we should early exist if len(l1) != len(l2)
and avoid guarding/comparing inner elements.

This avoids recompilations as in the unit test.
address https://github.com/pytorch/pytorch/issues/137515

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165091
Approved by: https://github.com/aorenste, https://github.com/mlazos
ghstack dependencies: #164884, #164885, #164886, #164887, #164888, #164889
This commit is contained in:
Laith Sakka
2025-10-10 08:16:50 -07:00
committed by PyTorch MergeBot
parent f0325d0787
commit 2d4654d208
2 changed files with 44 additions and 1 deletions

View File

@ -3204,6 +3204,37 @@ class TestGuardsExpressions(TestCase):
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
@skipIfTorchDynamo("Attempt to trace generator")
@torch.fx.experimental._config.patch("use_duck_shape", False)
def test_size_comparison_no_recompile(self):
"""
Test that size comparisons don't cause recompilation.
When comparing x.size() == b.size() with different sizes,
the compiled function should only compile once.
We should not guard in sizes of the inner elements.
"""
cnt = CompileCounter()
@torch.compile(fullgraph=True, dynamic=True, backend=cnt)
def f(x, b):
if x.size() == b.size():
return x
return x * 2
# First call: shapes differ (1, 2) vs (2, 4, 9), so if branch is False
f(torch.rand(10, 2), torch.rand(20, 4, 9))
# Second call: shapes differ again (1, 2) vs (1, 4, 9), so if branch is False
f(torch.rand(10, 2), torch.rand(10, 4, 9))
# Should only compile once despite different input shapes
self.assertEqual(
cnt.frame_count,
1,
f"Expected 1 compilation, got {cnt.frame_count}. "
f"Size comparison should not cause recompilation.",
)
def test_remove_symbols_without_guarding(self):
from torch._functorch.partitioners import _remove_symbols_without_guarding

View File

@ -12,6 +12,7 @@ import types
from collections import OrderedDict
from collections.abc import Hashable, Iterable, MutableMapping, Sequence
from itertools import repeat as _repeat
from operator import eq, ne
from typing import Any, Callable, TYPE_CHECKING
import torch
@ -106,13 +107,24 @@ def accumulate_grad(x, new_grad):
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
"""emulate `(1,2,3) > (1,2)` etc"""
# Optimization: For equality, short-circuit if lengths differ
# This avoids iterating through elements and triggering guards on SymInts
left_len = len(left)
right_len = len(right)
if op is eq and left_len != right_len:
return False
if op is ne and left_len != right_len:
return True
# Apply `op` to the first pair that differ
for a, b in zip(left, right):
if a != b:
return op(a, b)
# No more pairs to compare, so compare sizes.
return op(len(left), len(right))
return op(left_len, right_len)
def dict___eq__(d, other):