mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
do not overguard when comparing lists (#165091)
if we are comparing two lists l1, l2 of different lengths for equality. we should early exist if len(l1) != len(l2) and avoid guarding/comparing inner elements. This avoids recompilations as in the unit test. address https://github.com/pytorch/pytorch/issues/137515 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165091 Approved by: https://github.com/aorenste, https://github.com/mlazos ghstack dependencies: #164884, #164885, #164886, #164887, #164888, #164889
This commit is contained in:
committed by
PyTorch MergeBot
parent
f0325d0787
commit
2d4654d208
@ -3204,6 +3204,37 @@ class TestGuardsExpressions(TestCase):
|
|||||||
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
||||||
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Attempt to trace generator")
|
||||||
|
@torch.fx.experimental._config.patch("use_duck_shape", False)
|
||||||
|
def test_size_comparison_no_recompile(self):
|
||||||
|
"""
|
||||||
|
Test that size comparisons don't cause recompilation.
|
||||||
|
When comparing x.size() == b.size() with different sizes,
|
||||||
|
the compiled function should only compile once.
|
||||||
|
We should not guard in sizes of the inner elements.
|
||||||
|
"""
|
||||||
|
cnt = CompileCounter()
|
||||||
|
|
||||||
|
@torch.compile(fullgraph=True, dynamic=True, backend=cnt)
|
||||||
|
def f(x, b):
|
||||||
|
if x.size() == b.size():
|
||||||
|
return x
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
# First call: shapes differ (1, 2) vs (2, 4, 9), so if branch is False
|
||||||
|
f(torch.rand(10, 2), torch.rand(20, 4, 9))
|
||||||
|
|
||||||
|
# Second call: shapes differ again (1, 2) vs (1, 4, 9), so if branch is False
|
||||||
|
f(torch.rand(10, 2), torch.rand(10, 4, 9))
|
||||||
|
|
||||||
|
# Should only compile once despite different input shapes
|
||||||
|
self.assertEqual(
|
||||||
|
cnt.frame_count,
|
||||||
|
1,
|
||||||
|
f"Expected 1 compilation, got {cnt.frame_count}. "
|
||||||
|
f"Size comparison should not cause recompilation.",
|
||||||
|
)
|
||||||
|
|
||||||
def test_remove_symbols_without_guarding(self):
|
def test_remove_symbols_without_guarding(self):
|
||||||
from torch._functorch.partitioners import _remove_symbols_without_guarding
|
from torch._functorch.partitioners import _remove_symbols_without_guarding
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import types
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Hashable, Iterable, MutableMapping, Sequence
|
from collections.abc import Hashable, Iterable, MutableMapping, Sequence
|
||||||
from itertools import repeat as _repeat
|
from itertools import repeat as _repeat
|
||||||
|
from operator import eq, ne
|
||||||
from typing import Any, Callable, TYPE_CHECKING
|
from typing import Any, Callable, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -106,13 +107,24 @@ def accumulate_grad(x, new_grad):
|
|||||||
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413
|
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413
|
||||||
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
|
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
|
||||||
"""emulate `(1,2,3) > (1,2)` etc"""
|
"""emulate `(1,2,3) > (1,2)` etc"""
|
||||||
|
|
||||||
|
# Optimization: For equality, short-circuit if lengths differ
|
||||||
|
# This avoids iterating through elements and triggering guards on SymInts
|
||||||
|
left_len = len(left)
|
||||||
|
right_len = len(right)
|
||||||
|
|
||||||
|
if op is eq and left_len != right_len:
|
||||||
|
return False
|
||||||
|
if op is ne and left_len != right_len:
|
||||||
|
return True
|
||||||
|
|
||||||
# Apply `op` to the first pair that differ
|
# Apply `op` to the first pair that differ
|
||||||
for a, b in zip(left, right):
|
for a, b in zip(left, right):
|
||||||
if a != b:
|
if a != b:
|
||||||
return op(a, b)
|
return op(a, b)
|
||||||
|
|
||||||
# No more pairs to compare, so compare sizes.
|
# No more pairs to compare, so compare sizes.
|
||||||
return op(len(left), len(right))
|
return op(left_len, right_len)
|
||||||
|
|
||||||
|
|
||||||
def dict___eq__(d, other):
|
def dict___eq__(d, other):
|
||||||
|
|||||||
Reference in New Issue
Block a user