mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Handle f([]) vs. f() in fake tensor caching (#162284)
Fixes https://github.com/pytorch/pytorch/issues/162279 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162284 Approved by: https://github.com/manuelcandales, https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
314d47a210
commit
fbcabb4fbd
@ -1943,6 +1943,16 @@ class FakeTensorDispatchCache(TestCase):
|
|||||||
self._test_cache_key(fm, 1.0, 1.0, 1)
|
self._test_cache_key(fm, 1.0, 1.0, 1)
|
||||||
self._test_cache_key(fm, 0.0, 0.0, 0)
|
self._test_cache_key(fm, 0.0, 0.0, 0)
|
||||||
|
|
||||||
|
def test_empty_list(self):
|
||||||
|
with FakeTensorMode() as fm:
|
||||||
|
func = aten.any.dims
|
||||||
|
state = _CacheKeyState()
|
||||||
|
x = torch.ones((2, 3))
|
||||||
|
key_x = fm._cache_key(state, func, [x, []], {})
|
||||||
|
key_y = fm._cache_key(state, func, [x], {})
|
||||||
|
|
||||||
|
self.assertNotEqual(key_x, key_y)
|
||||||
|
|
||||||
def assertHitsMisses(self, hits, misses):
|
def assertHitsMisses(self, hits, misses):
|
||||||
"""
|
"""
|
||||||
Helper to assert on the number of recorded hits and misses.
|
Helper to assert on the number of recorded hits and misses.
|
||||||
|
@ -1677,6 +1677,10 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
)
|
)
|
||||||
from torch._higher_order_ops.utils import FunctionalizeCtxWrapper
|
from torch._higher_order_ops.utils import FunctionalizeCtxWrapper
|
||||||
|
|
||||||
|
if isinstance(args, (list, tuple, dict)):
|
||||||
|
result.append(type(args))
|
||||||
|
result.append(f"length_{len(args)}")
|
||||||
|
|
||||||
if isinstance(args, dict):
|
if isinstance(args, dict):
|
||||||
self._prep_args_for_hash(result, args.keys(), state, id_hashed_objects)
|
self._prep_args_for_hash(result, args.keys(), state, id_hashed_objects)
|
||||||
self._prep_args_for_hash(result, args.values(), state, id_hashed_objects)
|
self._prep_args_for_hash(result, args.values(), state, id_hashed_objects)
|
||||||
|
Reference in New Issue
Block a user