Files
pytorch/test/inductor/test_remote_cache.py
Colin L Reliability Rice 8b2a650572 pt2_remote_cache: Log sample for failures, and log the explicit reason we're faling. (#156874)
Summary: This allows us to start alerting on cache failures, based on scuba data

Test Plan:
Added new tests explicitly for the Remote Cache API.

Note that we have existing tests for memcache, but not for manifold AFAICT.

There are two potential wrinkles. One we're adding a new field (and everything uses ScubaData AFAICT, so this should just work).

The other one is the implicit api contract that if the sample is None, then it will be ignored (and not crash). I believe the second one is implemented correctly (and tested). The first one is a little more nebulous, but I think won't cause any breakages.

Also manually ran a compile and made sure it didn't break - P1851504490 as well as forcing it to break and checking we didn't screw up the exception handling - P1851504243

Rollback Plan:

Differential Revision: D77054339

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156874
Approved by: https://github.com/oulgen, https://github.com/masnesral
2025-07-18 20:28:27 +00:00

77 lines
1.8 KiB
Python

# Owner(s): ["module: inductor"]
from dataclasses import dataclass
from torch._inductor.remote_cache import (
RemoteCache,
RemoteCacheBackend,
RemoteCachePassthroughSerde,
)
from torch.testing._internal.common_utils import TestCase
class FailingBackend(RemoteCacheBackend):
def _get(self, key):
raise AssertionError("testget")
def _put(self, key, data):
raise AssertionError("testput")
class NoopBackend(RemoteCacheBackend):
def _get(self, key):
return None
def _put(self, key, data):
return None
@dataclass
class TestSample:
fail: str = None
class FakeCache(RemoteCache):
def __init__(self):
super().__init__(FailingBackend(), RemoteCachePassthroughSerde())
def _create_sample(self):
return TestSample()
def _log_sample(self, sample):
self.sample = sample
class TestRemoteCache(TestCase):
def test_normal_logging(
self,
) -> None:
c = RemoteCache(NoopBackend(), RemoteCachePassthroughSerde())
c.put("test", "value")
c.get("test")
def test_failure_no_sample(
self,
) -> None:
c = RemoteCache(FailingBackend(), RemoteCachePassthroughSerde())
with self.assertRaises(AssertionError):
c.put("test", "value")
with self.assertRaises(AssertionError):
c.get("test")
def test_failure_logging(
self,
) -> None:
c = FakeCache()
with self.assertRaises(AssertionError):
c.put("test", "value")
self.assertEqual(c.sample.fail_reason, "testput")
with self.assertRaises(AssertionError):
c.get("test")
self.assertEqual(c.sample.fail_reason, "testget")
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()