mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[inductor] verify determinism with inductor benchmark script (#164904)
Verify the deterministic mode with torch.compile benchmark scripts. Here is what my testing script does (pasted in the end): - run a model in default mode, save it's result - run the model again in default mode, but distort the benchmarking results. Compare it with the saved result. - Do the above again in deterministic mode. I tried to test a few modes - BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode - DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change. ``` model=GoogleFnet export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0 export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 # disable autotune cache export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0 export TORCHINDUCTOR_FX_GRAPH_CACHE=0 export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/ export TORCHINDUCTOR_BENCHMARK_KERNEL=1 export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 # Non deterministic mode # --float32 rather than --amp to make it easier to repro non-deterministic echo "Save results for non-deterministic mode" python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl echo "Compare results with distorted benchmarking in non-deterministic mode" TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl echo "Save results for deterministic mode" TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl echo "Compare results with distorted benchmarking in deterministic mode" TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904 Approved by: https://github.com/jansel, https://github.com/v0i0 ghstack dependencies: #164801, #164532
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							600db525bd
						
					
				
				
					commit
					a3c700656f
				
			| @ -50,6 +50,7 @@ from torch._dynamo.testing import ( | ||||
|     reset_rng_state, | ||||
|     same, | ||||
| ) | ||||
| from torch._dynamo.utils import bitwise_same | ||||
| from torch._logging.scribe import open_source_signpost | ||||
|  | ||||
|  | ||||
| @ -2321,6 +2322,40 @@ class BenchmarkRunner: | ||||
|                         new_result = process_fn(new_result) | ||||
|                         fp64_outputs = process_fn(fp64_outputs) | ||||
|  | ||||
|                 if ( | ||||
|                     self.args.save_model_outputs_to | ||||
|                     and self.args.compare_model_outputs_with | ||||
|                     and self.args.save_model_outputs_to | ||||
|                     == self.args.compare_model_outputs_with | ||||
|                 ): | ||||
|                     log.warning( | ||||
|                         "args.save_model_outputs_to and args.compare_model_outputs_with points to the same path." | ||||
|                         "Result will be undefined." | ||||
|                     ) | ||||
|  | ||||
|                 if self.args.save_model_outputs_to: | ||||
|                     print(f"Save model outputs to: {self.args.save_model_outputs_to}") | ||||
|                     torch.save(new_result, self.args.save_model_outputs_to) | ||||
|  | ||||
|                 if self.args.compare_model_outputs_with: | ||||
|                     print( | ||||
|                         f"Load model outputs from {self.args.compare_model_outputs_with} to compare" | ||||
|                     ) | ||||
|                     saved_result = torch.load(self.args.compare_model_outputs_with) | ||||
|                     is_bitwise_same = bitwise_same(saved_result, new_result) | ||||
|                     if not is_bitwise_same: | ||||
|                         print( | ||||
|                             "The result is not bitwise equivalent to the previously saved result" | ||||
|                         ) | ||||
|                         return record_status( | ||||
|                             "not_bitwise_equivalent", dynamo_start_stats=start_stats | ||||
|                         ) | ||||
|  | ||||
|                     print( | ||||
|                         "The result is bitwise equivalent to the previously saved result" | ||||
|                     ) | ||||
|                     del saved_result | ||||
|  | ||||
|                 if not same( | ||||
|                     correct_result, | ||||
|                     new_result, | ||||
| @ -3361,6 +3396,17 @@ def parse_args(args=None): | ||||
|         help="Enables caching precompile, serializing artifacts to DynamoCache between runs", | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--save-model-outputs-to", | ||||
|         default="", | ||||
|         help="Specify the path to save model output to so we can load later for comparison", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--compare-model-outputs-with", | ||||
|         default="", | ||||
|         help="Specify the path for the saved model outputs to compare against", | ||||
|     ) | ||||
|  | ||||
|     group_latency = parser.add_mutually_exclusive_group() | ||||
|     group_latency.add_argument( | ||||
|         "--cold-start-latency", | ||||
| @ -3640,6 +3686,43 @@ def write_csv_when_exception(args, name: str, status: str, device=None): | ||||
|         write_outputs(output_filename, headers, row) | ||||
|  | ||||
|  | ||||
| def setup_determinism_for_accuracy_test(args): | ||||
|     if args.only is not None and args.only not in { | ||||
|         "alexnet", | ||||
|         "Background_Matting", | ||||
|         "pytorch_CycleGAN_and_pix2pix", | ||||
|         "pytorch_unet", | ||||
|         "Super_SloMo", | ||||
|         "vgg16", | ||||
|         # https://github.com/pytorch/pytorch/issues/96724 | ||||
|         "Wav2Vec2ForCTC", | ||||
|         "Wav2Vec2ForPreTraining", | ||||
|         "sam", | ||||
|         "sam_fast", | ||||
|         "resnet50_quantized_qat", | ||||
|         "mobilenet_v2_quantized_qat", | ||||
|         "detectron2_maskrcnn", | ||||
|         "detectron2_maskrcnn_r_101_c4", | ||||
|         "detectron2_maskrcnn_r_101_fpn", | ||||
|         "detectron2_maskrcnn_r_50_c4", | ||||
|         "detectron2_maskrcnn_r_50_fpn", | ||||
|         "detectron2_fasterrcnn_r_101_c4", | ||||
|         "detectron2_fasterrcnn_r_101_dc5", | ||||
|         "detectron2_fasterrcnn_r_101_fpn", | ||||
|         "detectron2_fasterrcnn_r_50_c4", | ||||
|         "detectron2_fasterrcnn_r_50_dc5", | ||||
|         "detectron2_fasterrcnn_r_50_fpn", | ||||
|     }: | ||||
|         # some of the models do not support use_deterministic_algorithms | ||||
|         torch.use_deterministic_algorithms(True) | ||||
|     if args.devices == ["xpu"]: | ||||
|         torch.use_deterministic_algorithms(True, warn_only=True) | ||||
|  | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.mkldnn.deterministic = True | ||||
|  | ||||
|  | ||||
| def run(runner, args, original_dir=None): | ||||
|     # Pass the parsed args object to benchmark runner object | ||||
|     torch._dynamo.reset() | ||||
| @ -3705,36 +3788,9 @@ def run(runner, args, original_dir=None): | ||||
|             # TODO - Using train mode for timm_models and HF models. Move to train mode for Torchbench as well. | ||||
|             args.use_eval_mode = True | ||||
|         inductor_config.fallback_random = True | ||||
|         if args.only is not None and args.only not in { | ||||
|             "alexnet", | ||||
|             "Background_Matting", | ||||
|             "pytorch_CycleGAN_and_pix2pix", | ||||
|             "pytorch_unet", | ||||
|             "Super_SloMo", | ||||
|             "vgg16", | ||||
|             # https://github.com/pytorch/pytorch/issues/96724 | ||||
|             "Wav2Vec2ForCTC", | ||||
|             "Wav2Vec2ForPreTraining", | ||||
|             "sam", | ||||
|             "sam_fast", | ||||
|             "resnet50_quantized_qat", | ||||
|             "mobilenet_v2_quantized_qat", | ||||
|             "detectron2_maskrcnn", | ||||
|             "detectron2_maskrcnn_r_101_c4", | ||||
|             "detectron2_maskrcnn_r_101_fpn", | ||||
|             "detectron2_maskrcnn_r_50_c4", | ||||
|             "detectron2_maskrcnn_r_50_fpn", | ||||
|             "detectron2_fasterrcnn_r_101_c4", | ||||
|             "detectron2_fasterrcnn_r_101_dc5", | ||||
|             "detectron2_fasterrcnn_r_101_fpn", | ||||
|             "detectron2_fasterrcnn_r_50_c4", | ||||
|             "detectron2_fasterrcnn_r_50_dc5", | ||||
|             "detectron2_fasterrcnn_r_50_fpn", | ||||
|         }: | ||||
|             # some of the models do not support use_deterministic_algorithms | ||||
|             torch.use_deterministic_algorithms(True) | ||||
|         if args.devices == ["xpu"]: | ||||
|             torch.use_deterministic_algorithms(True, warn_only=True) | ||||
|  | ||||
|         setup_determinism_for_accuracy_test(args) | ||||
|  | ||||
|         os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | ||||
|         if args.only is not None and args.only in { | ||||
|             "nvidia_deeprecommender", | ||||
| @ -3743,14 +3799,10 @@ def run(runner, args, original_dir=None): | ||||
|             torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | ||||
|             torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | ||||
|  | ||||
|         torch.backends.cudnn.deterministic = True | ||||
|         torch.backends.cudnn.allow_tf32 = False | ||||
|         torch.backends.cudnn.benchmark = False | ||||
|         torch.backends.cuda.matmul.allow_tf32 = False | ||||
|         torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False) | ||||
|  | ||||
|         torch.backends.mkldnn.deterministic = True | ||||
|  | ||||
|         # Remove randomness when torch manual seed is called | ||||
|         patch_torch_manual_seed() | ||||
|  | ||||
|  | ||||
| @ -24,8 +24,22 @@ class DeterministicTest(TestCase): | ||||
|         super().setUp() | ||||
|         self._exit_stack = contextlib.ExitStack() | ||||
|         self._exit_stack.enter_context(fresh_cache()) | ||||
|         self._exit_stack.enter_context( | ||||
|             getattr(torch.backends, "__allow_nonbracketed_mutation")()  # noqa: B009 | ||||
|         ) | ||||
|  | ||||
|         self.old_flags = [ | ||||
|             torch.backends.cudnn.deterministic, | ||||
|             torch.backends.cudnn.benchmark, | ||||
|             torch.backends.mkldnn.deterministic, | ||||
|         ] | ||||
|  | ||||
|     def tearDown(self) -> None: | ||||
|         ( | ||||
|             torch.backends.cudnn.deterministic, | ||||
|             torch.backends.cudnn.benchmark, | ||||
|             torch.backends.mkldnn.deterministic, | ||||
|         ) = self.old_flags | ||||
|         self._exit_stack.close() | ||||
|         super().tearDown() | ||||
|  | ||||
|  | ||||
| @ -2912,6 +2912,15 @@ def rmse(ref: torch.Tensor, res: torch.Tensor) -> torch.Tensor: | ||||
|     return torch.sqrt(torch.mean(torch.square(ref - res))) | ||||
|  | ||||
|  | ||||
| def bitwise_same(ref: Any, res: Any, equal_nan: bool = False) -> bool: | ||||
|     return same( | ||||
|         ref, | ||||
|         res, | ||||
|         tol=0.0, | ||||
|         equal_nan=equal_nan, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def same( | ||||
|     ref: Any, | ||||
|     res: Any, | ||||
|  | ||||
| @ -4274,7 +4274,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): | ||||
|     def inductor_meta_common(): | ||||
|         inductor_meta = { | ||||
|             "backend_hash": torch.utils._triton.triton_hash_with_backend(), | ||||
|             "are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(), | ||||
|             "assert_indirect_indexing": config.assert_indirect_indexing, | ||||
|             "autotune_local_cache": config.autotune_local_cache, | ||||
|             "autotune_pointwise": config.triton.autotune_pointwise, | ||||
| @ -4288,6 +4287,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): | ||||
|             "store_cubin": config.triton.store_cubin, | ||||
|             "deterministic": config.deterministic, | ||||
|         } | ||||
|  | ||||
|         if config.write_are_deterministic_algorithms_enabled: | ||||
|             inductor_meta["are_deterministic_algorithms_enabled"] = ( | ||||
|                 torch.are_deterministic_algorithms_enabled() | ||||
|             ) | ||||
|  | ||||
|         if torch.version.hip is not None: | ||||
|             inductor_meta["is_hip"] = True | ||||
|         if config.is_fbcode(): | ||||
|  | ||||
| @ -2447,6 +2447,11 @@ def compile_fx( | ||||
|                 ignore_shape_env=ignore_shape_env, | ||||
|             ) | ||||
|  | ||||
|     if config.deterministic: | ||||
|         torch.backends.cudnn.deterministic = True | ||||
|         torch.backends.cudnn.benchmark = False | ||||
|         torch.backends.mkldnn.deterministic = True  # type: ignore[assignment] | ||||
|  | ||||
|     # Wake up the AsyncCompile subproc pool as early as possible (if there's cuda). | ||||
|     if any( | ||||
|         isinstance(e, torch.Tensor) and e.device.type in ("cuda", "xpu") | ||||
|  | ||||
| @ -2004,6 +2004,10 @@ _cache_config_ignore_prefix: list[str] = [ | ||||
| # External callable for matmul tuning candidates | ||||
| external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = [] | ||||
|  | ||||
| write_are_deterministic_algorithms_enabled = ( | ||||
|     os.getenv("TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED", "1") == "1" | ||||
| ) | ||||
|  | ||||
|  | ||||
| class test_configs: | ||||
|     force_extern_kernel_in_multi_template: bool = False | ||||
| @ -2049,6 +2053,14 @@ class test_configs: | ||||
|         os.getenv("TORCHINDUCTOR_FORCE_FILTER_REDUCTION_CONFIGS") == "1" | ||||
|     ) | ||||
|  | ||||
|     # a testing config to distort benchmarking result | ||||
|     # - empty string to disable | ||||
|     # - "inverse" to inverse the numbers | ||||
|     # - "random" return a random value | ||||
|     distort_benchmarking_result = os.getenv( | ||||
|         "TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", "" | ||||
|     ) | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from torch.utils._config_typing import *  # noqa: F401, F403 | ||||
|  | ||||
| @ -1,3 +1,4 @@ | ||||
| import functools | ||||
| import inspect | ||||
| import time | ||||
| from functools import cached_property, wraps | ||||
| @ -23,6 +24,40 @@ P = ParamSpec("P") | ||||
| T = TypeVar("T") | ||||
|  | ||||
|  | ||||
| def may_distort_benchmarking_result(fn: Callable[..., Any]) -> Callable[..., Any]: | ||||
|     from torch._inductor import config | ||||
|  | ||||
|     if config.test_configs.distort_benchmarking_result == "": | ||||
|         return fn | ||||
|  | ||||
|     def distort( | ||||
|         ms: Union[list[float], tuple[float], float], | ||||
|     ) -> Union[list[float], tuple[float], float]: | ||||
|         if isinstance(ms, (list, tuple)): | ||||
|             return type(ms)(distort(val) for val in ms)  # type: ignore[misc] | ||||
|  | ||||
|         distort_method = config.test_configs.distort_benchmarking_result | ||||
|         assert isinstance(ms, float) | ||||
|         if distort_method == "inverse": | ||||
|             return 1.0 / ms if ms else 0.0 | ||||
|         elif distort_method == "random": | ||||
|             import random | ||||
|  | ||||
|             return random.random() | ||||
|         else: | ||||
|             raise RuntimeError(f"Unrecognized distort method {distort_method}") | ||||
|  | ||||
|     @functools.wraps(fn) | ||||
|     def wrapper( | ||||
|         *args: list[Any], **kwargs: dict[str, Any] | ||||
|     ) -> Union[list[float], tuple[float], float]: | ||||
|         ms = fn(*args, **kwargs) | ||||
|  | ||||
|         return distort(ms) | ||||
|  | ||||
|     return wrapper | ||||
|  | ||||
|  | ||||
| def may_ban_benchmarking() -> None: | ||||
|     if torch._inductor.config.deterministic: | ||||
|         raise RuntimeError("""In the deterministic mode of Inductor, we will avoid those | ||||
| @ -159,6 +194,7 @@ class TritonBenchmarker(Benchmarker): | ||||
|             raise NotImplementedError("requires Triton") from e | ||||
|         return do_bench | ||||
|  | ||||
|     @may_distort_benchmarking_result | ||||
|     @time_and_count | ||||
|     def benchmark_gpu( | ||||
|         self: Self, | ||||
| @ -227,6 +263,7 @@ class InductorBenchmarker(TritonBenchmarker):  # noqa: docstring_linter | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|     @may_distort_benchmarking_result | ||||
|     @time_and_count | ||||
|     def benchmark_gpu(  # type: ignore[override] | ||||
|         self: Self, | ||||
|  | ||||
| @ -3020,6 +3020,7 @@ def reduction( | ||||
|  | ||||
|     configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) | ||||
|     configs = filter_reduction_configs_for_determinism(inductor_meta, configs) | ||||
|  | ||||
|     return cached_autotune( | ||||
|         size_hints, | ||||
|         configs=configs, | ||||
|  | ||||
| @ -72,6 +72,7 @@ OPTIMUS_EXCLUDE_POST_GRAD = [ | ||||
|     "inductor_autotune_lookup_table", | ||||
| ] | ||||
|  | ||||
| from torch._inductor.runtime.benchmarking import may_distort_benchmarking_result | ||||
| from torch.fx.experimental.symbolic_shapes import ( | ||||
|     free_symbols, | ||||
|     free_unbacked_symbols, | ||||
| @ -271,6 +272,7 @@ def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float: | ||||
|     return res | ||||
|  | ||||
|  | ||||
| @may_distort_benchmarking_result | ||||
| def do_bench_using_profiling( | ||||
|     fn: Callable[[], Any], | ||||
|     warmup: int = 25, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user