Compare commits

...

2 Commits

Author SHA1 Message Date
d64384ab77 [inductor][determinism] deterministically cache padmm (#166648)
Summary:

luckily we can just reuse the prior key calculation for matmul padding with minor tweaks, otherwise it seems like it does just work as expected

Test Plan: `buck test fbcode//mode/opt caffe2/test/inductor:caching -- test_matmul_padding`

Differential Revision: D85684098
2025-11-09 22:21:27 -08:00
5017a24fdb [inductor][determinism] type errors + use odc to dump imc on exit (#167136)
Summary:

fix some type errors + instead of manually creating a filelock when dumping dcache's imc to file we simply use an odc (since this is the intended behavior of odc, anyways)

Test Plan:
```
buck test fbcode//mode/opt caffe2/test/inductor:caching
```

Reviewed By: aorenste

Differential Revision: D86345594
2025-11-09 22:21:27 -08:00
3 changed files with 116 additions and 18 deletions

View File

@ -34,6 +34,7 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
if TYPE_CHECKING:
@ -693,6 +694,80 @@ class ImplementationsTest(TestMixin, TestCase):
self.assert_key_has_value_in(key, value, impl)
class IntegrationTest(TestMixin, TestCase):
@classmethod
def sub_dir(cls) -> str:
return f"testing-integration-instance-{cls.cls_id}"
@classmethod
def setUpClass(cls) -> None:
rmtree(impls._OnDiskCacheImpl()._base_dir / cls.sub_dir(), ignore_errors=True)
@classmethod
def tearDownClass(cls) -> None:
rmtree(impls._OnDiskCacheImpl()._base_dir / cls.sub_dir(), ignore_errors=True)
@requires_gpu()
@set_caching_module_enabled(True)
@set_deterministic_caching_enabled(True)
@set_strictly_pre_populated_determinism(False)
@set_strictly_cached_determinism(False)
@patch_on_disk_cache_base_dir
@patch_remote_cache_with_on_disk_cache
@patch_deterministic_cache_intf_no_dump_on_exit
def test_matmul_padding(self) -> None:
from torch._inductor.fx_passes import pad_mm
def foo() -> torch.Tensor:
mat1: torch.Tensor = torch.randn(
2048, 1737, dtype=torch.float16, device=GPU_TYPE
)
mat2: torch.Tensor = torch.randn(
1737, 2048, dtype=torch.float16, device=GPU_TYPE
)
return torch.matmul(mat1, mat2)
# these are in pairs, with non-padded benchmarked first
# meaning on the first execution, pad mm should see that
# the orig time is 0.1 and the pad time is 0.2 (i.e. no pad)
# and on the second run the orig time is 0.4 and the pad time
# is 0.3 (i.e. pad)
do_bench_iter = iter([0.1, 0.2, 0.4, 0.3])
def fake_do_bench(*args, **kwargs) -> float:
nonlocal do_bench_iter
return next(do_bench_iter)
class FakePadCache:
def lookup(self, *args, **kwargs) -> None:
return None
def set_value(self, *args, **kwargs) -> None:
return None
was_padded = None
def fake_set_cached_should_pad(_, padded: bool) -> None:
nonlocal was_padded
was_padded = padded
return None
with (
patch.object(pad_mm, "get_do_bench", lambda: fake_do_bench),
patch.object(pad_mm, "get_pad_cache", lambda: FakePadCache()),
patch.object(pad_mm, "set_cached_should_pad", fake_set_cached_should_pad),
):
cfoo = torch.compile(foo)
cfoo()
# based on the fake do bench values, the cold run should not pad
self.assertFalse(was_padded)
torch._dynamo.reset()
cfoo()
# based on the fake do bench values, the subsequent run would usually pad
# but with the deterministic cache it should reuse the prior result
self.assertFalse(was_padded)
@instantiate_parametrized_tests
class InterfacesTest(TestMixin, TestCase):
intf_typenames: list[str] = [

View File

@ -21,6 +21,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
pad_mm_operations,
pad_mm_precondition,
)
from torch._inductor.runtime.caching import icache
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._mode_utils import no_dispatch
@ -405,6 +406,21 @@ def get_do_bench() -> Callable[[Callable[[], Any]], float]:
)
def _should_pad_bench_params_encoder(
match: Match,
mat1: Tensor,
mat2: Tensor,
op: torch._ops.OpOverloadPacket,
input: Tensor | None = None,
) -> dict[str, Any]:
return {
"key": should_pad_bench_key(
match, mat1, mat2, op, input, is_base_time_key=False
)
}
@icache.record(custom_params_encoder=_should_pad_bench_params_encoder)
def _should_pad_bench(
match: Match,
mat1: Tensor,

View File

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from ast import literal_eval
from enum import Enum
from functools import partial, wraps
from logging import DEBUG, getLogger, Logger
from logging import DEBUG, getLogger, INFO, Logger
from os import PathLike
from pathlib import Path
from threading import Lock
@ -15,8 +15,6 @@ from time import time
from typing import Any, TYPE_CHECKING, TypeAlias
from typing_extensions import override
from filelock import FileLock
from . import config, context, exceptions, implementations as impls, locks
@ -329,10 +327,10 @@ class _CacheIntf(ABC):
def record(
self,
ischema: context.IsolationSchema | None = None,
custom_params_encoder: Callable[P, Any] | None = None,
custom_result_encoder: Callable[[R], Any] | None = None,
custom_result_decoder: Callable[[Any], R] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
custom_params_encoder: Callable[..., Any] | None = None,
custom_result_encoder: Callable[..., Any] | None = None,
custom_result_decoder: Callable[..., ...] | None = None,
) -> Callable[[Callable[..., ...]], Callable[..., ...]]:
if custom_result_encoder and not custom_result_decoder:
raise exceptions.CustomResultDecoderRequiredError(
"Custom result encoder provided without custom result decoder."
@ -506,16 +504,22 @@ class _DeterministicCacheIntf(_CacheIntf):
super().__init__()
self._imc: impls._InMemoryCacheImpl = impls._InMemoryCacheImpl()
if fpath := os.environ.get("TORCHINDUCTOR_PRE_POPULATE_DETERMINISTIC_CACHE"):
# pyrefly: ignore [bad-assignment]
flock: FileLock = FileLock(str(fpath) + ".lock")
with locks._acquire_flock_with_timeout(flock):
with open(fpath) as fp:
dump_for_pre_population: dict[str, str] = json.load(fp)
for key_r, value_r in dump_for_pre_population.items():
key: bytes = literal_eval(key_r)
value: bytes = literal_eval(value_r)
self._imc._memory[key] = value
if fpath_str := os.environ.get(
"TORCHINDUCTOR_PRE_POPULATE_DETERMINISTIC_CACHE"
):
fpath: Path = Path(fpath_str)
fpath_parent: PathLike[str] = fpath.parent
if fpath.is_file():
odc: impls._OnDiskCacheImpl = impls._OnDiskCacheImpl(
sub_dir=fpath_parent
)
with odc.lock():
with open(fpath) as fp:
dump_for_pre_population: dict[str, str] = json.load(fp)
for key_r, value_r in dump_for_pre_population.items():
key: bytes = literal_eval(key_r)
value: bytes = literal_eval(value_r)
self._imc._memory[key] = value
if config.STRICTLY_PRE_POPULATED_DETERMINISM:
# we'll never need a synchronization cache if we're in strictly pre-populated mode,
@ -578,7 +582,7 @@ class _DeterministicCacheIntf(_CacheIntf):
for key, value in existing_dump.items():
if key not in to_dump:
to_dump[key] = value
else:
elif to_dump[key] != value:
raise exceptions.DeterministicCachingIMCDumpConflictError from None
w_fp = open(fpath, "w")
@ -586,6 +590,9 @@ class _DeterministicCacheIntf(_CacheIntf):
assert w_fp is not None
try:
json.dump(to_dump, w_fp, indent=4)
logger.log(
INFO, "Dumped deterministic cache memoization to %s", fpath
)
finally:
w_fp.close()