mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] verify determinism with inductor benchmark script (#164904)
Verify the deterministic mode with torch.compile benchmark scripts. Here is what my testing script does (pasted in the end): - run a model in default mode, save it's result - run the model again in default mode, but distort the benchmarking results. Compare it with the saved result. - Do the above again in deterministic mode. I tried to test a few modes - BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode - DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change. ``` model=GoogleFnet export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0 export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 # disable autotune cache export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0 export TORCHINDUCTOR_FX_GRAPH_CACHE=0 export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/ export TORCHINDUCTOR_BENCHMARK_KERNEL=1 export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 # Non deterministic mode # --float32 rather than --amp to make it easier to repro non-deterministic echo "Save results for non-deterministic mode" python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl echo "Compare results with distorted benchmarking in non-deterministic mode" TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl echo "Save results for deterministic mode" TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl echo "Compare results with distorted benchmarking in deterministic mode" TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904 Approved by: https://github.com/jansel, https://github.com/v0i0 ghstack dependencies: #164801, #164532
This commit is contained in:
committed by
PyTorch MergeBot
parent
600db525bd
commit
a3c700656f
@ -50,6 +50,7 @@ from torch._dynamo.testing import (
|
||||
reset_rng_state,
|
||||
same,
|
||||
)
|
||||
from torch._dynamo.utils import bitwise_same
|
||||
from torch._logging.scribe import open_source_signpost
|
||||
|
||||
|
||||
@ -2321,6 +2322,40 @@ class BenchmarkRunner:
|
||||
new_result = process_fn(new_result)
|
||||
fp64_outputs = process_fn(fp64_outputs)
|
||||
|
||||
if (
|
||||
self.args.save_model_outputs_to
|
||||
and self.args.compare_model_outputs_with
|
||||
and self.args.save_model_outputs_to
|
||||
== self.args.compare_model_outputs_with
|
||||
):
|
||||
log.warning(
|
||||
"args.save_model_outputs_to and args.compare_model_outputs_with points to the same path."
|
||||
"Result will be undefined."
|
||||
)
|
||||
|
||||
if self.args.save_model_outputs_to:
|
||||
print(f"Save model outputs to: {self.args.save_model_outputs_to}")
|
||||
torch.save(new_result, self.args.save_model_outputs_to)
|
||||
|
||||
if self.args.compare_model_outputs_with:
|
||||
print(
|
||||
f"Load model outputs from {self.args.compare_model_outputs_with} to compare"
|
||||
)
|
||||
saved_result = torch.load(self.args.compare_model_outputs_with)
|
||||
is_bitwise_same = bitwise_same(saved_result, new_result)
|
||||
if not is_bitwise_same:
|
||||
print(
|
||||
"The result is not bitwise equivalent to the previously saved result"
|
||||
)
|
||||
return record_status(
|
||||
"not_bitwise_equivalent", dynamo_start_stats=start_stats
|
||||
)
|
||||
|
||||
print(
|
||||
"The result is bitwise equivalent to the previously saved result"
|
||||
)
|
||||
del saved_result
|
||||
|
||||
if not same(
|
||||
correct_result,
|
||||
new_result,
|
||||
@ -3361,6 +3396,17 @@ def parse_args(args=None):
|
||||
help="Enables caching precompile, serializing artifacts to DynamoCache between runs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-model-outputs-to",
|
||||
default="",
|
||||
help="Specify the path to save model output to so we can load later for comparison",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare-model-outputs-with",
|
||||
default="",
|
||||
help="Specify the path for the saved model outputs to compare against",
|
||||
)
|
||||
|
||||
group_latency = parser.add_mutually_exclusive_group()
|
||||
group_latency.add_argument(
|
||||
"--cold-start-latency",
|
||||
@ -3640,6 +3686,43 @@ def write_csv_when_exception(args, name: str, status: str, device=None):
|
||||
write_outputs(output_filename, headers, row)
|
||||
|
||||
|
||||
def setup_determinism_for_accuracy_test(args):
|
||||
if args.only is not None and args.only not in {
|
||||
"alexnet",
|
||||
"Background_Matting",
|
||||
"pytorch_CycleGAN_and_pix2pix",
|
||||
"pytorch_unet",
|
||||
"Super_SloMo",
|
||||
"vgg16",
|
||||
# https://github.com/pytorch/pytorch/issues/96724
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForPreTraining",
|
||||
"sam",
|
||||
"sam_fast",
|
||||
"resnet50_quantized_qat",
|
||||
"mobilenet_v2_quantized_qat",
|
||||
"detectron2_maskrcnn",
|
||||
"detectron2_maskrcnn_r_101_c4",
|
||||
"detectron2_maskrcnn_r_101_fpn",
|
||||
"detectron2_maskrcnn_r_50_c4",
|
||||
"detectron2_maskrcnn_r_50_fpn",
|
||||
"detectron2_fasterrcnn_r_101_c4",
|
||||
"detectron2_fasterrcnn_r_101_dc5",
|
||||
"detectron2_fasterrcnn_r_101_fpn",
|
||||
"detectron2_fasterrcnn_r_50_c4",
|
||||
"detectron2_fasterrcnn_r_50_dc5",
|
||||
"detectron2_fasterrcnn_r_50_fpn",
|
||||
}:
|
||||
# some of the models do not support use_deterministic_algorithms
|
||||
torch.use_deterministic_algorithms(True)
|
||||
if args.devices == ["xpu"]:
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.mkldnn.deterministic = True
|
||||
|
||||
|
||||
def run(runner, args, original_dir=None):
|
||||
# Pass the parsed args object to benchmark runner object
|
||||
torch._dynamo.reset()
|
||||
@ -3705,36 +3788,9 @@ def run(runner, args, original_dir=None):
|
||||
# TODO - Using train mode for timm_models and HF models. Move to train mode for Torchbench as well.
|
||||
args.use_eval_mode = True
|
||||
inductor_config.fallback_random = True
|
||||
if args.only is not None and args.only not in {
|
||||
"alexnet",
|
||||
"Background_Matting",
|
||||
"pytorch_CycleGAN_and_pix2pix",
|
||||
"pytorch_unet",
|
||||
"Super_SloMo",
|
||||
"vgg16",
|
||||
# https://github.com/pytorch/pytorch/issues/96724
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForPreTraining",
|
||||
"sam",
|
||||
"sam_fast",
|
||||
"resnet50_quantized_qat",
|
||||
"mobilenet_v2_quantized_qat",
|
||||
"detectron2_maskrcnn",
|
||||
"detectron2_maskrcnn_r_101_c4",
|
||||
"detectron2_maskrcnn_r_101_fpn",
|
||||
"detectron2_maskrcnn_r_50_c4",
|
||||
"detectron2_maskrcnn_r_50_fpn",
|
||||
"detectron2_fasterrcnn_r_101_c4",
|
||||
"detectron2_fasterrcnn_r_101_dc5",
|
||||
"detectron2_fasterrcnn_r_101_fpn",
|
||||
"detectron2_fasterrcnn_r_50_c4",
|
||||
"detectron2_fasterrcnn_r_50_dc5",
|
||||
"detectron2_fasterrcnn_r_50_fpn",
|
||||
}:
|
||||
# some of the models do not support use_deterministic_algorithms
|
||||
torch.use_deterministic_algorithms(True)
|
||||
if args.devices == ["xpu"]:
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
setup_determinism_for_accuracy_test(args)
|
||||
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
if args.only is not None and args.only in {
|
||||
"nvidia_deeprecommender",
|
||||
@ -3743,14 +3799,10 @@ def run(runner, args, original_dir=None):
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False)
|
||||
|
||||
torch.backends.mkldnn.deterministic = True
|
||||
|
||||
# Remove randomness when torch manual seed is called
|
||||
patch_torch_manual_seed()
|
||||
|
||||
|
Reference in New Issue
Block a user