Handle f([]) vs. f() in fake tensor caching (#162284)

Fixes https://github.com/pytorch/pytorch/issues/162279
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162284
Approved by: https://github.com/manuelcandales, https://github.com/aorenste
This commit is contained in:
angelayi
2025-09-08 18:28:01 +00:00
committed by PyTorch MergeBot
parent 314d47a210
commit fbcabb4fbd
2 changed files with 14 additions and 0 deletions

View File

@ -1943,6 +1943,16 @@ class FakeTensorDispatchCache(TestCase):
self._test_cache_key(fm, 1.0, 1.0, 1)
self._test_cache_key(fm, 0.0, 0.0, 0)
def test_empty_list(self):
with FakeTensorMode() as fm:
func = aten.any.dims
state = _CacheKeyState()
x = torch.ones((2, 3))
key_x = fm._cache_key(state, func, [x, []], {})
key_y = fm._cache_key(state, func, [x], {})
self.assertNotEqual(key_x, key_y)
def assertHitsMisses(self, hits, misses):
"""
Helper to assert on the number of recorded hits and misses.

View File

@ -1677,6 +1677,10 @@ class FakeTensorMode(TorchDispatchMode):
)
from torch._higher_order_ops.utils import FunctionalizeCtxWrapper
if isinstance(args, (list, tuple, dict)):
result.append(type(args))
result.append(f"length_{len(args)}")
if isinstance(args, dict):
self._prep_args_for_hash(result, args.keys(), state, id_hashed_objects)
self._prep_args_for_hash(result, args.values(), state, id_hashed_objects)