[inductor] verify determinism with inductor benchmark script (#164904)

Verify the deterministic mode with torch.compile benchmark scripts.

Here is what my testing script does (pasted in the end):
- run a model in default mode, save it's result
- run the model again in default mode, but distort the benchmarking results. Compare it with the saved result.
- Do the above again in deterministic mode.

I tried to test a few modes
- BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode
- DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change.

```
model=GoogleFnet

export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0
export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1  # disable autotune cache
export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/
export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1

# Non deterministic mode
# --float32 rather than --amp to make it easier to repro non-deterministic
echo "Save results for non-deterministic mode"
python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl

echo "Compare results with distorted benchmarking in non-deterministic mode"
TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl

echo "Save results for deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl

echo "Compare results with distorted benchmarking in deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904
Approved by: https://github.com/jansel, https://github.com/v0i0
This commit is contained in:
Shunting Zhang
2025-10-10 15:21:37 -07:00
committed by PyTorch MergeBot
parent df26c51478
commit 5171f14064
9 changed files with 194 additions and 35 deletions

View File

@ -50,6 +50,7 @@ from torch._dynamo.testing import (
reset_rng_state,
same,
)
from torch._dynamo.utils import bitwise_same
from torch._logging.scribe import open_source_signpost
@ -2321,6 +2322,40 @@ class BenchmarkRunner:
new_result = process_fn(new_result)
fp64_outputs = process_fn(fp64_outputs)
if (
self.args.save_model_outputs_to
and self.args.compare_model_outputs_with
and self.args.save_model_outputs_to
== self.args.compare_model_outputs_with
):
log.warning(
"args.save_model_outputs_to and args.compare_model_outputs_with points to the same path."
"Result will be undefined."
)
if self.args.save_model_outputs_to:
print(f"Save model outputs to: {self.args.save_model_outputs_to}")
torch.save(new_result, self.args.save_model_outputs_to)
if self.args.compare_model_outputs_with:
print(
f"Load model outputs from {self.args.compare_model_outputs_with} to compare"
)
saved_result = torch.load(self.args.compare_model_outputs_with)
is_bitwise_same = bitwise_same(saved_result, new_result)
if not is_bitwise_same:
print(
"The result is not bitwise equivalent to the previously saved result"
)
return record_status(
"not_bitwise_equivalent", dynamo_start_stats=start_stats
)
print(
"The result is bitwise equivalent to the previously saved result"
)
del saved_result
if not same(
correct_result,
new_result,
@ -3361,6 +3396,17 @@ def parse_args(args=None):
help="Enables caching precompile, serializing artifacts to DynamoCache between runs",
)
parser.add_argument(
"--save-model-outputs-to",
default="",
help="Specify the path to save model output to so we can load later for comparison",
)
parser.add_argument(
"--compare-model-outputs-with",
default="",
help="Specify the path for the saved model outputs to compare against",
)
group_latency = parser.add_mutually_exclusive_group()
group_latency.add_argument(
"--cold-start-latency",
@ -3640,6 +3686,43 @@ def write_csv_when_exception(args, name: str, status: str, device=None):
write_outputs(output_filename, headers, row)
def setup_determinism_for_accuracy_test(args):
if args.only is not None and args.only not in {
"alexnet",
"Background_Matting",
"pytorch_CycleGAN_and_pix2pix",
"pytorch_unet",
"Super_SloMo",
"vgg16",
# https://github.com/pytorch/pytorch/issues/96724
"Wav2Vec2ForCTC",
"Wav2Vec2ForPreTraining",
"sam",
"sam_fast",
"resnet50_quantized_qat",
"mobilenet_v2_quantized_qat",
"detectron2_maskrcnn",
"detectron2_maskrcnn_r_101_c4",
"detectron2_maskrcnn_r_101_fpn",
"detectron2_maskrcnn_r_50_c4",
"detectron2_maskrcnn_r_50_fpn",
"detectron2_fasterrcnn_r_101_c4",
"detectron2_fasterrcnn_r_101_dc5",
"detectron2_fasterrcnn_r_101_fpn",
"detectron2_fasterrcnn_r_50_c4",
"detectron2_fasterrcnn_r_50_dc5",
"detectron2_fasterrcnn_r_50_fpn",
}:
# some of the models do not support use_deterministic_algorithms
torch.use_deterministic_algorithms(True)
if args.devices == ["xpu"]:
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.mkldnn.deterministic = True
def run(runner, args, original_dir=None):
# Pass the parsed args object to benchmark runner object
torch._dynamo.reset()
@ -3705,36 +3788,9 @@ def run(runner, args, original_dir=None):
# TODO - Using train mode for timm_models and HF models. Move to train mode for Torchbench as well.
args.use_eval_mode = True
inductor_config.fallback_random = True
if args.only is not None and args.only not in {
"alexnet",
"Background_Matting",
"pytorch_CycleGAN_and_pix2pix",
"pytorch_unet",
"Super_SloMo",
"vgg16",
# https://github.com/pytorch/pytorch/issues/96724
"Wav2Vec2ForCTC",
"Wav2Vec2ForPreTraining",
"sam",
"sam_fast",
"resnet50_quantized_qat",
"mobilenet_v2_quantized_qat",
"detectron2_maskrcnn",
"detectron2_maskrcnn_r_101_c4",
"detectron2_maskrcnn_r_101_fpn",
"detectron2_maskrcnn_r_50_c4",
"detectron2_maskrcnn_r_50_fpn",
"detectron2_fasterrcnn_r_101_c4",
"detectron2_fasterrcnn_r_101_dc5",
"detectron2_fasterrcnn_r_101_fpn",
"detectron2_fasterrcnn_r_50_c4",
"detectron2_fasterrcnn_r_50_dc5",
"detectron2_fasterrcnn_r_50_fpn",
}:
# some of the models do not support use_deterministic_algorithms
torch.use_deterministic_algorithms(True)
if args.devices == ["xpu"]:
torch.use_deterministic_algorithms(True, warn_only=True)
setup_determinism_for_accuracy_test(args)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
if args.only is not None and args.only in {
"nvidia_deeprecommender",
@ -3743,14 +3799,10 @@ def run(runner, args, original_dir=None):
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False)
torch.backends.mkldnn.deterministic = True
# Remove randomness when torch manual seed is called
patch_torch_manual_seed()

View File

@ -24,8 +24,22 @@ class DeterministicTest(TestCase):
super().setUp()
self._exit_stack = contextlib.ExitStack()
self._exit_stack.enter_context(fresh_cache())
self._exit_stack.enter_context(
getattr(torch.backends, "__allow_nonbracketed_mutation")() # noqa: B009
)
self.old_flags = [
torch.backends.cudnn.deterministic,
torch.backends.cudnn.benchmark,
torch.backends.mkldnn.deterministic,
]
def tearDown(self) -> None:
(
torch.backends.cudnn.deterministic,
torch.backends.cudnn.benchmark,
torch.backends.mkldnn.deterministic,
) = self.old_flags
self._exit_stack.close()
super().tearDown()

View File

@ -2914,6 +2914,15 @@ def rmse(ref: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
return torch.sqrt(torch.mean(torch.square(ref - res)))
def bitwise_same(ref: Any, res: Any, equal_nan: bool = False) -> bool:
return same(
ref,
res,
tol=0.0,
equal_nan=equal_nan,
)
def same(
ref: Any,
res: Any,

View File

@ -4294,7 +4294,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
def inductor_meta_common():
inductor_meta = {
"backend_hash": torch.utils._triton.triton_hash_with_backend(),
"are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(),
"assert_indirect_indexing": config.assert_indirect_indexing,
"autotune_local_cache": config.autotune_local_cache,
"autotune_pointwise": config.triton.autotune_pointwise,
@ -4308,6 +4307,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
"store_cubin": config.triton.store_cubin,
"deterministic": config.deterministic,
}
if config.write_are_deterministic_algorithms_enabled:
inductor_meta["are_deterministic_algorithms_enabled"] = (
torch.are_deterministic_algorithms_enabled()
)
if torch.version.hip is not None:
inductor_meta["is_hip"] = True
if config.is_fbcode():

View File

@ -2447,6 +2447,11 @@ def compile_fx(
ignore_shape_env=ignore_shape_env,
)
if config.deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.mkldnn.deterministic = True # type: ignore[assignment]
# Wake up the AsyncCompile subproc pool as early as possible (if there's cuda).
if any(
isinstance(e, torch.Tensor) and e.device.type in ("cuda", "xpu")

View File

@ -2018,6 +2018,10 @@ _cache_config_ignore_prefix: list[str] = [
# External callable for matmul tuning candidates
external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
write_are_deterministic_algorithms_enabled = (
os.getenv("TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED", "1") == "1"
)
class test_configs:
force_extern_kernel_in_multi_template: bool = False
@ -2063,6 +2067,14 @@ class test_configs:
os.getenv("TORCHINDUCTOR_FORCE_FILTER_REDUCTION_CONFIGS") == "1"
)
# a testing config to distort benchmarking result
# - empty string to disable
# - "inverse" to inverse the numbers
# - "random" return a random value
distort_benchmarking_result = os.getenv(
"TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", ""
)
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -1,3 +1,4 @@
import functools
import inspect
import time
from functools import cached_property, wraps
@ -23,6 +24,40 @@ P = ParamSpec("P")
T = TypeVar("T")
def may_distort_benchmarking_result(fn: Callable[..., Any]) -> Callable[..., Any]:
from torch._inductor import config
if config.test_configs.distort_benchmarking_result == "":
return fn
def distort(
ms: Union[list[float], tuple[float], float],
) -> Union[list[float], tuple[float], float]:
if isinstance(ms, (list, tuple)):
return type(ms)(distort(val) for val in ms) # type: ignore[misc]
distort_method = config.test_configs.distort_benchmarking_result
assert isinstance(ms, float)
if distort_method == "inverse":
return 1.0 / ms if ms else 0.0
elif distort_method == "random":
import random
return random.random()
else:
raise RuntimeError(f"Unrecognized distort method {distort_method}")
@functools.wraps(fn)
def wrapper(
*args: list[Any], **kwargs: dict[str, Any]
) -> Union[list[float], tuple[float], float]:
ms = fn(*args, **kwargs)
return distort(ms)
return wrapper
def may_ban_benchmarking() -> None:
if torch._inductor.config.deterministic:
raise RuntimeError("""In the deterministic mode of Inductor, we will avoid those
@ -159,6 +194,7 @@ class TritonBenchmarker(Benchmarker):
raise NotImplementedError("requires Triton") from e
return do_bench
@may_distort_benchmarking_result
@time_and_count
def benchmark_gpu(
self: Self,
@ -227,6 +263,7 @@ class InductorBenchmarker(TritonBenchmarker): # noqa: docstring_linter
]
)
@may_distort_benchmarking_result
@time_and_count
def benchmark_gpu( # type: ignore[override]
self: Self,

View File

@ -3020,6 +3020,7 @@ def reduction(
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
return cached_autotune(
size_hints,
configs=configs,

View File

@ -276,6 +276,30 @@ def do_bench_using_profiling(
warmup: int = 25,
rep: int = 100,
is_vetted_benchmarking: bool = False,
) -> float:
# We did't use decorator may_distort_benchmarking_result directly since that
# requires us to import torch._inductor.runtime.benchmarking into global scope.
# Importing torch._inductor.runtime.benchmarking will cause cuda initialization
# (because of calling torch.cuda.available in global scope)
# which cause failure in vllm when it create child processes. Check log:
# https://gist.github.com/shunting314/c194e147bf981e58df095c14874dd65a
#
# Another way to solve the issue is to just move do_bench_using_profiling
# to torch._inductor.runtime.benchmarking and change all the call site.
# But that's not trivial due to so many call sites in and out of pytorch.
from torch._inductor.runtime.benchmarking import may_distort_benchmarking_result
return may_distort_benchmarking_result(_do_bench_using_profiling)(
fn, warmup, rep, is_vetted_benchmarking
)
def _do_bench_using_profiling(
fn: Callable[[], Any],
warmup: int = 25,
rep: int = 100,
is_vetted_benchmarking: bool = False,
) -> float:
"""
Returns benchmark results by examining torch profiler events.