[aotinductor] Fix benchmarks with self.autocast for run_performance_test (#123699)

## Pitch
Similar to https://github.com/pytorch/pytorch/pull/110490 which fixes the `self.autocast` in the `check_accuracy` function, this PR fixes the `self.autocast` context in the `run_performance_test` function.

## Description
The code inside `check_accuracy` after the fix on https://github.com/pytorch/pytorch/pull/110490:
a4a49f77b8/benchmarks/dynamo/common.py (L2490-L2500)

The current code on main branch before this PR in `run_performance_test` does not have the `self.autocast` context:
a4a49f77b8/benchmarks/dynamo/common.py (L2685-L2692)

For eager mode, the `model_iter_fn`  (which is actually [forward_pass](e8ad5460c0/benchmarks/dynamo/huggingface.py (L556-L558))) is used in [warmup](e8ad5460c0/benchmarks/dynamo/common.py (L2690))  and    [speedup_experiment](e8ad5460c0/benchmarks/dynamo/common.py (L648)). The `forward_pass` has the `self.autocast` context thus it could run into BF16 when AMP is on. While for AOTInductor, we will call `export_aot_inductor` in both [warmup](e8ad5460c0/benchmarks/dynamo/common.py (L2695)) and [speedup_experiment](e8ad5460c0/benchmarks/dynamo/common.py (L644-L646)), which doesn't have the `autocast` context thus will always run into FP32.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123699
Approved by: https://github.com/jgong5, https://github.com/desertfire
This commit is contained in:
chunyuan
2024-04-10 02:38:50 +00:00
committed by PyTorch MergeBot
parent 902cb2c842
commit ec00daf4f1

View File

@ -2680,7 +2680,16 @@ class BenchmarkRunner:
model = self.deepcopy_and_maybe_parallelize(model)
self.init_optimizer(name, current_device, model.parameters())
with self.pick_grad(name, self.args.training):
# The self.autocast context is needed for the model we export with aot_compile,
# similar to what we do in the check_accuracy function
ctx = (
self.autocast(**self.autocast_arg)
if self.args.export_aot_inductor
else contextlib.nullcontext()
)
with self.pick_grad(name, self.args.training), ctx:
ok, total = Stats.reset_counters()
experiment_kwargs = {}
if tag is not None: