mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Verify the deterministic mode with torch.compile benchmark scripts. Here is what my testing script does (pasted in the end): - run a model in default mode, save it's result - run the model again in default mode, but distort the benchmarking results. Compare it with the saved result. - Do the above again in deterministic mode. I tried to test a few modes - BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode - DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change. ``` model=GoogleFnet export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0 export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 # disable autotune cache export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0 export TORCHINDUCTOR_FX_GRAPH_CACHE=0 export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/ export TORCHINDUCTOR_BENCHMARK_KERNEL=1 export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 # Non deterministic mode # --float32 rather than --amp to make it easier to repro non-deterministic echo "Save results for non-deterministic mode" python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl echo "Compare results with distorted benchmarking in non-deterministic mode" TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl echo "Save results for deterministic mode" TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl echo "Compare results with distorted benchmarking in deterministic mode" TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904 Approved by: https://github.com/jansel, https://github.com/v0i0
116 lines
3.6 KiB
Python
116 lines
3.6 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import contextlib
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._inductor.config as inductor_config
|
|
from torch._dynamo.utils import counters
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import fresh_cache
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
GPU_TYPE,
|
|
HAS_CUDA_AND_TRITON,
|
|
IS_BIG_GPU,
|
|
)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class DeterministicTest(TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._exit_stack = contextlib.ExitStack()
|
|
self._exit_stack.enter_context(fresh_cache())
|
|
self._exit_stack.enter_context(
|
|
getattr(torch.backends, "__allow_nonbracketed_mutation")() # noqa: B009
|
|
)
|
|
|
|
self.old_flags = [
|
|
torch.backends.cudnn.deterministic,
|
|
torch.backends.cudnn.benchmark,
|
|
torch.backends.mkldnn.deterministic,
|
|
]
|
|
|
|
def tearDown(self) -> None:
|
|
(
|
|
torch.backends.cudnn.deterministic,
|
|
torch.backends.cudnn.benchmark,
|
|
torch.backends.mkldnn.deterministic,
|
|
) = self.old_flags
|
|
self._exit_stack.close()
|
|
super().tearDown()
|
|
|
|
@parametrize("deterministic", [False, True])
|
|
def test_mm_padding(self, deterministic):
|
|
with inductor_config.patch(deterministic=deterministic):
|
|
|
|
@torch.compile()
|
|
def foo(x, y):
|
|
return x @ y
|
|
|
|
inps = [torch.rand([2049, 2049], device=GPU_TYPE) for _ in range(2)]
|
|
out = foo(*inps)
|
|
self.assertEqual(out, inps[0] @ inps[1])
|
|
|
|
if deterministic:
|
|
self.assertTrue(counters["inductor"]["pad_mm_bench"] == 0)
|
|
else:
|
|
self.assertTrue(counters["inductor"]["pad_mm_bench"] > 0)
|
|
|
|
@parametrize("deterministic", [False, True])
|
|
@inductor_config.patch(max_autotune=True)
|
|
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
|
|
def test_max_autotune(self, deterministic):
|
|
with inductor_config.patch(deterministic=deterministic):
|
|
|
|
@torch.compile()
|
|
def foo(x, y):
|
|
return x @ y
|
|
|
|
inps = [torch.rand([2048, 2048], device=GPU_TYPE) for _ in range(2)]
|
|
out = foo(*inps)
|
|
self.assertEqual(out, inps[0] @ inps[1])
|
|
|
|
if deterministic:
|
|
self.assertTrue(counters["inductor"]["select_algorithm_autotune"] == 0)
|
|
else:
|
|
self.assertTrue(counters["inductor"]["select_algorithm_autotune"] > 0)
|
|
|
|
def test_pointwise_coordesc_tuning(self):
|
|
@torch.compile(mode="max-autotune")
|
|
def f(x):
|
|
return x + 1
|
|
|
|
x = torch.randn(2048, device=GPU_TYPE)
|
|
self.assertEqual(f(x), x + 1)
|
|
|
|
self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] > 0)
|
|
|
|
@parametrize("deterministic", [False, True])
|
|
def test_reduction_coordesc_tuning(self, deterministic):
|
|
with inductor_config.patch(
|
|
deterministic=deterministic, coordinate_descent_tuning=True
|
|
):
|
|
|
|
@torch.compile()
|
|
def foo(x):
|
|
return x.sum(dim=-1)
|
|
|
|
inp = torch.rand([2048, 2048], device=GPU_TYPE)
|
|
|
|
out = foo(inp)
|
|
self.assertEqual(out, inp.sum(dim=-1))
|
|
|
|
if deterministic:
|
|
self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] == 0)
|
|
else:
|
|
self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] > 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if HAS_CUDA_AND_TRITON:
|
|
run_tests()
|