Files
pytorch/test/inductor/test_deterministic.py
Shunting Zhang a3c700656f [inductor] verify determinism with inductor benchmark script (#164904)
Verify the deterministic mode with torch.compile benchmark scripts.

Here is what my testing script does (pasted in the end):
- run a model in default mode, save it's result
- run the model again in default mode, but distort the benchmarking results. Compare it with the saved result.
- Do the above again in deterministic mode.

I tried to test a few modes
- BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode
- DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change.

```
model=GoogleFnet

export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0
export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1  # disable autotune cache
export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/
export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1

# Non deterministic mode
# --float32 rather than --amp to make it easier to repro non-deterministic
echo "Save results for non-deterministic mode"
python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl

echo "Compare results with distorted benchmarking in non-deterministic mode"
TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl

echo "Save results for deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl

echo "Compare results with distorted benchmarking in deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904
Approved by: https://github.com/jansel, https://github.com/v0i0
ghstack dependencies: #164801, #164532
2025-10-10 00:00:58 +00:00

116 lines
3.6 KiB
Python

# Owner(s): ["module: inductor"]
import contextlib
import unittest
import torch
import torch._inductor.config as inductor_config
from torch._dynamo.utils import counters
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_cache
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_CUDA_AND_TRITON,
IS_BIG_GPU,
)
@instantiate_parametrized_tests
class DeterministicTest(TestCase):
def setUp(self) -> None:
super().setUp()
self._exit_stack = contextlib.ExitStack()
self._exit_stack.enter_context(fresh_cache())
self._exit_stack.enter_context(
getattr(torch.backends, "__allow_nonbracketed_mutation")() # noqa: B009
)
self.old_flags = [
torch.backends.cudnn.deterministic,
torch.backends.cudnn.benchmark,
torch.backends.mkldnn.deterministic,
]
def tearDown(self) -> None:
(
torch.backends.cudnn.deterministic,
torch.backends.cudnn.benchmark,
torch.backends.mkldnn.deterministic,
) = self.old_flags
self._exit_stack.close()
super().tearDown()
@parametrize("deterministic", [False, True])
def test_mm_padding(self, deterministic):
with inductor_config.patch(deterministic=deterministic):
@torch.compile()
def foo(x, y):
return x @ y
inps = [torch.rand([2049, 2049], device=GPU_TYPE) for _ in range(2)]
out = foo(*inps)
self.assertEqual(out, inps[0] @ inps[1])
if deterministic:
self.assertTrue(counters["inductor"]["pad_mm_bench"] == 0)
else:
self.assertTrue(counters["inductor"]["pad_mm_bench"] > 0)
@parametrize("deterministic", [False, True])
@inductor_config.patch(max_autotune=True)
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
def test_max_autotune(self, deterministic):
with inductor_config.patch(deterministic=deterministic):
@torch.compile()
def foo(x, y):
return x @ y
inps = [torch.rand([2048, 2048], device=GPU_TYPE) for _ in range(2)]
out = foo(*inps)
self.assertEqual(out, inps[0] @ inps[1])
if deterministic:
self.assertTrue(counters["inductor"]["select_algorithm_autotune"] == 0)
else:
self.assertTrue(counters["inductor"]["select_algorithm_autotune"] > 0)
def test_pointwise_coordesc_tuning(self):
@torch.compile(mode="max-autotune")
def f(x):
return x + 1
x = torch.randn(2048, device=GPU_TYPE)
self.assertEqual(f(x), x + 1)
self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] > 0)
@parametrize("deterministic", [False, True])
def test_reduction_coordesc_tuning(self, deterministic):
with inductor_config.patch(
deterministic=deterministic, coordinate_descent_tuning=True
):
@torch.compile()
def foo(x):
return x.sum(dim=-1)
inp = torch.rand([2048, 2048], device=GPU_TYPE)
out = foo(inp)
self.assertEqual(out, inp.sum(dim=-1))
if deterministic:
self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] == 0)
else:
self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] > 0)
if __name__ == "__main__":
if HAS_CUDA_AND_TRITON:
run_tests()