From a3c700656f9a666eb33074b60333a23eb7e99a15 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Thu, 9 Oct 2025 11:22:55 -0700 Subject: [PATCH] [inductor] verify determinism with inductor benchmark script (#164904) Verify the deterministic mode with torch.compile benchmark scripts. Here is what my testing script does (pasted in the end): - run a model in default mode, save it's result - run the model again in default mode, but distort the benchmarking results. Compare it with the saved result. - Do the above again in deterministic mode. I tried to test a few modes - BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode - DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change. ``` model=GoogleFnet export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0 export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 # disable autotune cache export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0 export TORCHINDUCTOR_FX_GRAPH_CACHE=0 export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/ export TORCHINDUCTOR_BENCHMARK_KERNEL=1 export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 # Non deterministic mode # --float32 rather than --amp to make it easier to repro non-deterministic echo "Save results for non-deterministic mode" python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl echo "Compare results with distorted benchmarking in non-deterministic mode" TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl echo "Save results for deterministic mode" TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl echo "Compare results with distorted benchmarking in deterministic mode" TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904 Approved by: https://github.com/jansel, https://github.com/v0i0 ghstack dependencies: #164801, #164532 --- benchmarks/dynamo/common.py | 120 +++++++++++++------ test/inductor/test_deterministic.py | 14 +++ torch/_dynamo/utils.py | 9 ++ torch/_inductor/codegen/triton.py | 7 +- torch/_inductor/compile_fx.py | 5 + torch/_inductor/config.py | 12 ++ torch/_inductor/runtime/benchmarking.py | 37 ++++++ torch/_inductor/runtime/triton_heuristics.py | 1 + torch/_inductor/utils.py | 2 + 9 files changed, 172 insertions(+), 35 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index feb5f97c2dc7..bc4af146967d 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -50,6 +50,7 @@ from torch._dynamo.testing import ( reset_rng_state, same, ) +from torch._dynamo.utils import bitwise_same from torch._logging.scribe import open_source_signpost @@ -2321,6 +2322,40 @@ class BenchmarkRunner: new_result = process_fn(new_result) fp64_outputs = process_fn(fp64_outputs) + if ( + self.args.save_model_outputs_to + and self.args.compare_model_outputs_with + and self.args.save_model_outputs_to + == self.args.compare_model_outputs_with + ): + log.warning( + "args.save_model_outputs_to and args.compare_model_outputs_with points to the same path." + "Result will be undefined." + ) + + if self.args.save_model_outputs_to: + print(f"Save model outputs to: {self.args.save_model_outputs_to}") + torch.save(new_result, self.args.save_model_outputs_to) + + if self.args.compare_model_outputs_with: + print( + f"Load model outputs from {self.args.compare_model_outputs_with} to compare" + ) + saved_result = torch.load(self.args.compare_model_outputs_with) + is_bitwise_same = bitwise_same(saved_result, new_result) + if not is_bitwise_same: + print( + "The result is not bitwise equivalent to the previously saved result" + ) + return record_status( + "not_bitwise_equivalent", dynamo_start_stats=start_stats + ) + + print( + "The result is bitwise equivalent to the previously saved result" + ) + del saved_result + if not same( correct_result, new_result, @@ -3361,6 +3396,17 @@ def parse_args(args=None): help="Enables caching precompile, serializing artifacts to DynamoCache between runs", ) + parser.add_argument( + "--save-model-outputs-to", + default="", + help="Specify the path to save model output to so we can load later for comparison", + ) + parser.add_argument( + "--compare-model-outputs-with", + default="", + help="Specify the path for the saved model outputs to compare against", + ) + group_latency = parser.add_mutually_exclusive_group() group_latency.add_argument( "--cold-start-latency", @@ -3640,6 +3686,43 @@ def write_csv_when_exception(args, name: str, status: str, device=None): write_outputs(output_filename, headers, row) +def setup_determinism_for_accuracy_test(args): + if args.only is not None and args.only not in { + "alexnet", + "Background_Matting", + "pytorch_CycleGAN_and_pix2pix", + "pytorch_unet", + "Super_SloMo", + "vgg16", + # https://github.com/pytorch/pytorch/issues/96724 + "Wav2Vec2ForCTC", + "Wav2Vec2ForPreTraining", + "sam", + "sam_fast", + "resnet50_quantized_qat", + "mobilenet_v2_quantized_qat", + "detectron2_maskrcnn", + "detectron2_maskrcnn_r_101_c4", + "detectron2_maskrcnn_r_101_fpn", + "detectron2_maskrcnn_r_50_c4", + "detectron2_maskrcnn_r_50_fpn", + "detectron2_fasterrcnn_r_101_c4", + "detectron2_fasterrcnn_r_101_dc5", + "detectron2_fasterrcnn_r_101_fpn", + "detectron2_fasterrcnn_r_50_c4", + "detectron2_fasterrcnn_r_50_dc5", + "detectron2_fasterrcnn_r_50_fpn", + }: + # some of the models do not support use_deterministic_algorithms + torch.use_deterministic_algorithms(True) + if args.devices == ["xpu"]: + torch.use_deterministic_algorithms(True, warn_only=True) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.mkldnn.deterministic = True + + def run(runner, args, original_dir=None): # Pass the parsed args object to benchmark runner object torch._dynamo.reset() @@ -3705,36 +3788,9 @@ def run(runner, args, original_dir=None): # TODO - Using train mode for timm_models and HF models. Move to train mode for Torchbench as well. args.use_eval_mode = True inductor_config.fallback_random = True - if args.only is not None and args.only not in { - "alexnet", - "Background_Matting", - "pytorch_CycleGAN_and_pix2pix", - "pytorch_unet", - "Super_SloMo", - "vgg16", - # https://github.com/pytorch/pytorch/issues/96724 - "Wav2Vec2ForCTC", - "Wav2Vec2ForPreTraining", - "sam", - "sam_fast", - "resnet50_quantized_qat", - "mobilenet_v2_quantized_qat", - "detectron2_maskrcnn", - "detectron2_maskrcnn_r_101_c4", - "detectron2_maskrcnn_r_101_fpn", - "detectron2_maskrcnn_r_50_c4", - "detectron2_maskrcnn_r_50_fpn", - "detectron2_fasterrcnn_r_101_c4", - "detectron2_fasterrcnn_r_101_dc5", - "detectron2_fasterrcnn_r_101_fpn", - "detectron2_fasterrcnn_r_50_c4", - "detectron2_fasterrcnn_r_50_dc5", - "detectron2_fasterrcnn_r_50_fpn", - }: - # some of the models do not support use_deterministic_algorithms - torch.use_deterministic_algorithms(True) - if args.devices == ["xpu"]: - torch.use_deterministic_algorithms(True, warn_only=True) + + setup_determinism_for_accuracy_test(args) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" if args.only is not None and args.only in { "nvidia_deeprecommender", @@ -3743,14 +3799,10 @@ def run(runner, args, original_dir=None): torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - torch.backends.cudnn.deterministic = True torch.backends.cudnn.allow_tf32 = False - torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False) - torch.backends.mkldnn.deterministic = True - # Remove randomness when torch manual seed is called patch_torch_manual_seed() diff --git a/test/inductor/test_deterministic.py b/test/inductor/test_deterministic.py index 3d512bba6eac..b139c68c577c 100644 --- a/test/inductor/test_deterministic.py +++ b/test/inductor/test_deterministic.py @@ -24,8 +24,22 @@ class DeterministicTest(TestCase): super().setUp() self._exit_stack = contextlib.ExitStack() self._exit_stack.enter_context(fresh_cache()) + self._exit_stack.enter_context( + getattr(torch.backends, "__allow_nonbracketed_mutation")() # noqa: B009 + ) + + self.old_flags = [ + torch.backends.cudnn.deterministic, + torch.backends.cudnn.benchmark, + torch.backends.mkldnn.deterministic, + ] def tearDown(self) -> None: + ( + torch.backends.cudnn.deterministic, + torch.backends.cudnn.benchmark, + torch.backends.mkldnn.deterministic, + ) = self.old_flags self._exit_stack.close() super().tearDown() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 8da851d66b98..1930aaf69a26 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2912,6 +2912,15 @@ def rmse(ref: torch.Tensor, res: torch.Tensor) -> torch.Tensor: return torch.sqrt(torch.mean(torch.square(ref - res))) +def bitwise_same(ref: Any, res: Any, equal_nan: bool = False) -> bool: + return same( + ref, + res, + tol=0.0, + equal_nan=equal_nan, + ) + + def same( ref: Any, res: Any, diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index d6b80268c8f0..a29de68e55ef 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -4274,7 +4274,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): def inductor_meta_common(): inductor_meta = { "backend_hash": torch.utils._triton.triton_hash_with_backend(), - "are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(), "assert_indirect_indexing": config.assert_indirect_indexing, "autotune_local_cache": config.autotune_local_cache, "autotune_pointwise": config.triton.autotune_pointwise, @@ -4288,6 +4287,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): "store_cubin": config.triton.store_cubin, "deterministic": config.deterministic, } + + if config.write_are_deterministic_algorithms_enabled: + inductor_meta["are_deterministic_algorithms_enabled"] = ( + torch.are_deterministic_algorithms_enabled() + ) + if torch.version.hip is not None: inductor_meta["is_hip"] = True if config.is_fbcode(): diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 679bfbaac46c..7947e9cb8445 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -2447,6 +2447,11 @@ def compile_fx( ignore_shape_env=ignore_shape_env, ) + if config.deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.mkldnn.deterministic = True # type: ignore[assignment] + # Wake up the AsyncCompile subproc pool as early as possible (if there's cuda). if any( isinstance(e, torch.Tensor) and e.device.type in ("cuda", "xpu") diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c487d259afca..34cf5a8a84fb 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -2004,6 +2004,10 @@ _cache_config_ignore_prefix: list[str] = [ # External callable for matmul tuning candidates external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = [] +write_are_deterministic_algorithms_enabled = ( + os.getenv("TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED", "1") == "1" +) + class test_configs: force_extern_kernel_in_multi_template: bool = False @@ -2049,6 +2053,14 @@ class test_configs: os.getenv("TORCHINDUCTOR_FORCE_FILTER_REDUCTION_CONFIGS") == "1" ) + # a testing config to distort benchmarking result + # - empty string to disable + # - "inverse" to inverse the numbers + # - "random" return a random value + distort_benchmarking_result = os.getenv( + "TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", "" + ) + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index 24e908a52773..21ee339b7df6 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -1,3 +1,4 @@ +import functools import inspect import time from functools import cached_property, wraps @@ -23,6 +24,40 @@ P = ParamSpec("P") T = TypeVar("T") +def may_distort_benchmarking_result(fn: Callable[..., Any]) -> Callable[..., Any]: + from torch._inductor import config + + if config.test_configs.distort_benchmarking_result == "": + return fn + + def distort( + ms: Union[list[float], tuple[float], float], + ) -> Union[list[float], tuple[float], float]: + if isinstance(ms, (list, tuple)): + return type(ms)(distort(val) for val in ms) # type: ignore[misc] + + distort_method = config.test_configs.distort_benchmarking_result + assert isinstance(ms, float) + if distort_method == "inverse": + return 1.0 / ms if ms else 0.0 + elif distort_method == "random": + import random + + return random.random() + else: + raise RuntimeError(f"Unrecognized distort method {distort_method}") + + @functools.wraps(fn) + def wrapper( + *args: list[Any], **kwargs: dict[str, Any] + ) -> Union[list[float], tuple[float], float]: + ms = fn(*args, **kwargs) + + return distort(ms) + + return wrapper + + def may_ban_benchmarking() -> None: if torch._inductor.config.deterministic: raise RuntimeError("""In the deterministic mode of Inductor, we will avoid those @@ -159,6 +194,7 @@ class TritonBenchmarker(Benchmarker): raise NotImplementedError("requires Triton") from e return do_bench + @may_distort_benchmarking_result @time_and_count def benchmark_gpu( self: Self, @@ -227,6 +263,7 @@ class InductorBenchmarker(TritonBenchmarker): # noqa: docstring_linter ] ) + @may_distort_benchmarking_result @time_and_count def benchmark_gpu( # type: ignore[override] self: Self, diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 806843e360c6..ad2597867ad4 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -3020,6 +3020,7 @@ def reduction( configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) configs = filter_reduction_configs_for_determinism(inductor_meta, configs) + return cached_autotune( size_hints, configs=configs, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5291a0f8f9ab..170f80424e81 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -72,6 +72,7 @@ OPTIMUS_EXCLUDE_POST_GRAD = [ "inductor_autotune_lookup_table", ] +from torch._inductor.runtime.benchmarking import may_distort_benchmarking_result from torch.fx.experimental.symbolic_shapes import ( free_symbols, free_unbacked_symbols, @@ -271,6 +272,7 @@ def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float: return res +@may_distort_benchmarking_result def do_bench_using_profiling( fn: Callable[[], Any], warmup: int = 25,