mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aotinductor] Fix benchmarks with self.autocast for run_performance_test (#123699)
## Pitch Similar to https://github.com/pytorch/pytorch/pull/110490 which fixes the `self.autocast` in the `check_accuracy` function, this PR fixes the `self.autocast` context in the `run_performance_test` function. ## Description The code inside `check_accuracy` after the fix on https://github.com/pytorch/pytorch/pull/110490:a4a49f77b8/benchmarks/dynamo/common.py (L2490-L2500)
The current code on main branch before this PR in `run_performance_test` does not have the `self.autocast` context:a4a49f77b8/benchmarks/dynamo/common.py (L2685-L2692)
For eager mode, the `model_iter_fn` (which is actually [forward_pass](e8ad5460c0/benchmarks/dynamo/huggingface.py (L556-L558)
)) is used in [warmup](e8ad5460c0/benchmarks/dynamo/common.py (L2690)
) and [speedup_experiment](e8ad5460c0/benchmarks/dynamo/common.py (L648)
). The `forward_pass` has the `self.autocast` context thus it could run into BF16 when AMP is on. While for AOTInductor, we will call `export_aot_inductor` in both [warmup](e8ad5460c0/benchmarks/dynamo/common.py (L2695)
) and [speedup_experiment](e8ad5460c0/benchmarks/dynamo/common.py (L644-L646)
), which doesn't have the `autocast` context thus will always run into FP32. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123699 Approved by: https://github.com/jgong5, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
902cb2c842
commit
ec00daf4f1
@ -2680,7 +2680,16 @@ class BenchmarkRunner:
|
||||
model = self.deepcopy_and_maybe_parallelize(model)
|
||||
|
||||
self.init_optimizer(name, current_device, model.parameters())
|
||||
with self.pick_grad(name, self.args.training):
|
||||
|
||||
# The self.autocast context is needed for the model we export with aot_compile,
|
||||
# similar to what we do in the check_accuracy function
|
||||
ctx = (
|
||||
self.autocast(**self.autocast_arg)
|
||||
if self.args.export_aot_inductor
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with self.pick_grad(name, self.args.training), ctx:
|
||||
ok, total = Stats.reset_counters()
|
||||
experiment_kwargs = {}
|
||||
if tag is not None:
|
||||
|
Reference in New Issue
Block a user