mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
do not overguard when comparing lists (#165091)
if we are comparing two lists l1, l2 of different lengths for equality. we should early exist if len(l1) != len(l2) and avoid guarding/comparing inner elements. This avoids recompilations as in the unit test. address https://github.com/pytorch/pytorch/issues/137515 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165091 Approved by: https://github.com/aorenste, https://github.com/mlazos ghstack dependencies: #164884, #164885, #164886, #164887, #164888, #164889
This commit is contained in:
committed by
PyTorch MergeBot
parent
f0325d0787
commit
2d4654d208
@ -3204,6 +3204,37 @@ class TestGuardsExpressions(TestCase):
|
||||
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
||||
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
||||
|
||||
@skipIfTorchDynamo("Attempt to trace generator")
|
||||
@torch.fx.experimental._config.patch("use_duck_shape", False)
|
||||
def test_size_comparison_no_recompile(self):
|
||||
"""
|
||||
Test that size comparisons don't cause recompilation.
|
||||
When comparing x.size() == b.size() with different sizes,
|
||||
the compiled function should only compile once.
|
||||
We should not guard in sizes of the inner elements.
|
||||
"""
|
||||
cnt = CompileCounter()
|
||||
|
||||
@torch.compile(fullgraph=True, dynamic=True, backend=cnt)
|
||||
def f(x, b):
|
||||
if x.size() == b.size():
|
||||
return x
|
||||
return x * 2
|
||||
|
||||
# First call: shapes differ (1, 2) vs (2, 4, 9), so if branch is False
|
||||
f(torch.rand(10, 2), torch.rand(20, 4, 9))
|
||||
|
||||
# Second call: shapes differ again (1, 2) vs (1, 4, 9), so if branch is False
|
||||
f(torch.rand(10, 2), torch.rand(10, 4, 9))
|
||||
|
||||
# Should only compile once despite different input shapes
|
||||
self.assertEqual(
|
||||
cnt.frame_count,
|
||||
1,
|
||||
f"Expected 1 compilation, got {cnt.frame_count}. "
|
||||
f"Size comparison should not cause recompilation.",
|
||||
)
|
||||
|
||||
def test_remove_symbols_without_guarding(self):
|
||||
from torch._functorch.partitioners import _remove_symbols_without_guarding
|
||||
|
||||
|
@ -12,6 +12,7 @@ import types
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Hashable, Iterable, MutableMapping, Sequence
|
||||
from itertools import repeat as _repeat
|
||||
from operator import eq, ne
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@ -106,13 +107,24 @@ def accumulate_grad(x, new_grad):
|
||||
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413
|
||||
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
|
||||
"""emulate `(1,2,3) > (1,2)` etc"""
|
||||
|
||||
# Optimization: For equality, short-circuit if lengths differ
|
||||
# This avoids iterating through elements and triggering guards on SymInts
|
||||
left_len = len(left)
|
||||
right_len = len(right)
|
||||
|
||||
if op is eq and left_len != right_len:
|
||||
return False
|
||||
if op is ne and left_len != right_len:
|
||||
return True
|
||||
|
||||
# Apply `op` to the first pair that differ
|
||||
for a, b in zip(left, right):
|
||||
if a != b:
|
||||
return op(a, b)
|
||||
|
||||
# No more pairs to compare, so compare sizes.
|
||||
return op(len(left), len(right))
|
||||
return op(left_len, right_len)
|
||||
|
||||
|
||||
def dict___eq__(d, other):
|
||||
|
Reference in New Issue
Block a user