[inductor] verify determinism with inductor benchmark script (#164904)

Verify the deterministic mode with torch.compile benchmark scripts.

Here is what my testing script does (pasted in the end):
- run a model in default mode, save it's result
- run the model again in default mode, but distort the benchmarking results. Compare it with the saved result.
- Do the above again in deterministic mode.

I tried to test a few modes
- BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode
- DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change.

```
model=GoogleFnet

export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0
export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1  # disable autotune cache
export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/
export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1

# Non deterministic mode
# --float32 rather than --amp to make it easier to repro non-deterministic
echo "Save results for non-deterministic mode"
python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl

echo "Compare results with distorted benchmarking in non-deterministic mode"
TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl

echo "Save results for deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl

echo "Compare results with distorted benchmarking in deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904
Approved by: https://github.com/jansel, https://github.com/v0i0
ghstack dependencies: #164801, #164532
This commit is contained in:
Shunting Zhang
2025-10-09 11:22:55 -07:00
committed by PyTorch MergeBot
parent 600db525bd
commit a3c700656f
9 changed files with 172 additions and 35 deletions

View File

@ -50,6 +50,7 @@ from torch._dynamo.testing import (
reset_rng_state,
same,
)
from torch._dynamo.utils import bitwise_same
from torch._logging.scribe import open_source_signpost
@ -2321,6 +2322,40 @@ class BenchmarkRunner:
new_result = process_fn(new_result)
fp64_outputs = process_fn(fp64_outputs)
if (
self.args.save_model_outputs_to
and self.args.compare_model_outputs_with
and self.args.save_model_outputs_to
== self.args.compare_model_outputs_with
):
log.warning(
"args.save_model_outputs_to and args.compare_model_outputs_with points to the same path."
"Result will be undefined."
)
if self.args.save_model_outputs_to:
print(f"Save model outputs to: {self.args.save_model_outputs_to}")
torch.save(new_result, self.args.save_model_outputs_to)
if self.args.compare_model_outputs_with:
print(
f"Load model outputs from {self.args.compare_model_outputs_with} to compare"
)
saved_result = torch.load(self.args.compare_model_outputs_with)
is_bitwise_same = bitwise_same(saved_result, new_result)
if not is_bitwise_same:
print(
"The result is not bitwise equivalent to the previously saved result"
)
return record_status(
"not_bitwise_equivalent", dynamo_start_stats=start_stats
)
print(
"The result is bitwise equivalent to the previously saved result"
)
del saved_result
if not same(
correct_result,
new_result,
@ -3361,6 +3396,17 @@ def parse_args(args=None):
help="Enables caching precompile, serializing artifacts to DynamoCache between runs",
)
parser.add_argument(
"--save-model-outputs-to",
default="",
help="Specify the path to save model output to so we can load later for comparison",
)
parser.add_argument(
"--compare-model-outputs-with",
default="",
help="Specify the path for the saved model outputs to compare against",
)
group_latency = parser.add_mutually_exclusive_group()
group_latency.add_argument(
"--cold-start-latency",
@ -3640,6 +3686,43 @@ def write_csv_when_exception(args, name: str, status: str, device=None):
write_outputs(output_filename, headers, row)
def setup_determinism_for_accuracy_test(args):
if args.only is not None and args.only not in {
"alexnet",
"Background_Matting",
"pytorch_CycleGAN_and_pix2pix",
"pytorch_unet",
"Super_SloMo",
"vgg16",
# https://github.com/pytorch/pytorch/issues/96724
"Wav2Vec2ForCTC",
"Wav2Vec2ForPreTraining",
"sam",
"sam_fast",
"resnet50_quantized_qat",
"mobilenet_v2_quantized_qat",
"detectron2_maskrcnn",
"detectron2_maskrcnn_r_101_c4",
"detectron2_maskrcnn_r_101_fpn",
"detectron2_maskrcnn_r_50_c4",
"detectron2_maskrcnn_r_50_fpn",
"detectron2_fasterrcnn_r_101_c4",
"detectron2_fasterrcnn_r_101_dc5",
"detectron2_fasterrcnn_r_101_fpn",
"detectron2_fasterrcnn_r_50_c4",
"detectron2_fasterrcnn_r_50_dc5",
"detectron2_fasterrcnn_r_50_fpn",
}:
# some of the models do not support use_deterministic_algorithms
torch.use_deterministic_algorithms(True)
if args.devices == ["xpu"]:
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.mkldnn.deterministic = True
def run(runner, args, original_dir=None):
# Pass the parsed args object to benchmark runner object
torch._dynamo.reset()
@ -3705,36 +3788,9 @@ def run(runner, args, original_dir=None):
# TODO - Using train mode for timm_models and HF models. Move to train mode for Torchbench as well.
args.use_eval_mode = True
inductor_config.fallback_random = True
if args.only is not None and args.only not in {
"alexnet",
"Background_Matting",
"pytorch_CycleGAN_and_pix2pix",
"pytorch_unet",
"Super_SloMo",
"vgg16",
# https://github.com/pytorch/pytorch/issues/96724
"Wav2Vec2ForCTC",
"Wav2Vec2ForPreTraining",
"sam",
"sam_fast",
"resnet50_quantized_qat",
"mobilenet_v2_quantized_qat",
"detectron2_maskrcnn",
"detectron2_maskrcnn_r_101_c4",
"detectron2_maskrcnn_r_101_fpn",
"detectron2_maskrcnn_r_50_c4",
"detectron2_maskrcnn_r_50_fpn",
"detectron2_fasterrcnn_r_101_c4",
"detectron2_fasterrcnn_r_101_dc5",
"detectron2_fasterrcnn_r_101_fpn",
"detectron2_fasterrcnn_r_50_c4",
"detectron2_fasterrcnn_r_50_dc5",
"detectron2_fasterrcnn_r_50_fpn",
}:
# some of the models do not support use_deterministic_algorithms
torch.use_deterministic_algorithms(True)
if args.devices == ["xpu"]:
torch.use_deterministic_algorithms(True, warn_only=True)
setup_determinism_for_accuracy_test(args)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
if args.only is not None and args.only in {
"nvidia_deeprecommender",
@ -3743,14 +3799,10 @@ def run(runner, args, original_dir=None):
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False)
torch.backends.mkldnn.deterministic = True
# Remove randomness when torch manual seed is called
patch_torch_manual_seed()