GPT2ForSequenceClassification Hugging Face (HF) model fails on ROCm for bfloat16. The failure is numerically small. This PRs adds this model to an exception list for small tensors. The exception list already includes two models. This increases the multiplier factor to 10.0 instead of 3 (default) for this model used in `torch/_dynamo/utils.py`.
In the PR comment below, I include a short analysis of the numerics.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160001
Approved by: https://github.com/anijain2305, https://github.com/jataylo, https://github.com/jeffdaily
PT2 benchmark scripts has a pattern like:
```
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
self.optimizer_zero_grad(mod)
with self.autocast(**self.autocast_arg):
pred = mod(**cloned_inputs)
loss = self.compute_loss(pred)
self.grad_scaler.scale(loss).backward()
self.optimizer_step()
if collect_outputs:
return collect_results(mod, pred, loss, cloned_inputs)
return None
```
for training.
The collect_outputs argument is True only for accuracy testing and it's false for performance testing.
For HF benchmark suite, a model usually returns tuple (loss, logits). For performance testing, even though the logits is never used anywhere, dynamo has to keep it due to the control flow.
A few bad things if we keep logits here
1. the peak memory will be higher since the logits is large and we can not release its memory earlier.
2. we can not do optimization like chunking for the logits because the tensor needs to be returned from the pre-grad graph
Actually I think it's fine to not return logits at all.
- For training cases, checking loss and gradients for accuracy is good enough. It's hard to see two runs have mismatch logits but matching loss/gradients.
- Also, discarding logits as soon as possible for perf benchmarking makes it more fair for us.
On the other hand, it may be interesting to let dynamo support something like dynamo.constexpr (similar to tl.constexpr). A variable annotated as dynamo.constexpr will be specialized at compile time and we can do more optimization (DCE e.g.) at compile time. (A small [repro](https://gist.github.com/shunting314/0912a8947028a904c34f361021b8024d))
Benchmark results here [link](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Fri%2C%2004%20Apr%202025%2018%3A03%3A26%20GMT&stopTime=Fri%2C%2011%20Apr%202025%2018%3A03%3A26%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/shunting314/204/head&lCommit=fe25dab3f65e1b0e9db0af03f7664af70fcc9c66&rBranch=main&rCommit=55e62ff74ad5614faf80b060c7bfc551e3b7af5a)
- HF 15% (1.51 -> 1.66 compression ratio) peak memory improvement
- I also see 5% (2.74 -> 2.79x) perf win for HF. It could be true. We may generate more efficient kernels since we don't need keep logits and return it from the pre-grad graph. But I'll double check
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151075
Approved by: https://github.com/eellison, https://github.com/jansel
Softmax need do some preparation work that access the input tensor in two passes
- compute amax of each row
- compute (x - amax).exp.sum for each row
When the row size is large, cache can not hold all the active data and accessing the input multiple passes increases execution time since the kernel is membw bounded.
Online softmax uses a customized reduction to compute max and sum at the same time by accessing the data in one pass. Check this paper for more details ( https://arxiv.org/abs/1805.02867 ).
Also here is an online softmax kernel generated by inductor as a reference: https://gist.github.com/shunting314/67ae4fffd45d4f2753c781780332fa54
## Microbenchmark
- `TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_ONLINE_SOFTMAX=0 DO_PERF_TEST=1 python test/inductor/test_online_softmax.py -k test_softmax` : without online softmax
- eager_ms=6.671296119689941
- opt_ms=8.06931209564209
- `TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_ONLINE_SOFTMAX=1 DO_PERF_TEST=1 python test/inductor/test_online_softmax.py -k test_softmax`: with online softmax
- eager_ms=6.634047985076904
- opt_ms=6.230591773986816
Ideally, online softmax should save about 2ms here. We saves about 1.84ms in practice.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127011
Approved by: https://github.com/jansel
This turns on AOTAutogradCache for all inductor tests. It clears AOTAutogradCache on each test as well, by virtue of the local cache using the same directory to store cache entries.
I've also tested with INDUCTOR_TEST_DISABLE_FRESH_CACHE=1, running all the tests. AOTAutogradCache successfully caches 99% of these. There are a few tests that use view_replay and therefore save functional tensors, which cause AOTAutogradCache to fail to pickle its result. Will look into next steps there, but for now, it seems okay if the cache just misses on those cases where it can't serialize the result. It would be better to check before pickling, though.
I've made the following small bugfixes to get this working:
- Inductor is sometimes used in a standalone mode without dynamo, which leads to attribute errors in check_can_cache. In general, we should *never* crash in cache checking, only bypass. So I change a try catch to check Exception instead of just a specific exception.
- Add extra structured logging for metadata on cache hits
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140890
Approved by: https://github.com/bdhirsh
This PR batch the fix for a few accuracy failures issues during training by raising tolerance. I do that only for models that I think it fails not due to real issue.
## sebotnet33ts_256
The accuracy test for this model start to fail around June 05 [link](https://hud.pytorch.org/benchmark/timm_models/inductor_with_cudagraphs?dashboard=torchinductor&startTime=Sun%2C%2002%20Jun%202024%2007%3A19%3A38%20GMT&stopTime=Tue%2C%2002%20Jul%202024%2007%3A19%3A38%20GMT&granularity=day&mode=training&dtype=amp&lBranch=main&lCommit=04a0d856207d83c2031e4b9cb6825ba3e0092850&rBranch=main&rCommit=e62925930f6a62f6aeeb1fe1a661a9bd3352b53d&model=sebotnet33ts_256).
I can not repro locally, but from the log from the dashboard:
```
RMSE (res-fp64): 0.09441, (ref-fp64): 0.02971 and shape=torch.Size([1536]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.040000
```
raising the tolerance should fix it.
## DebertaForQuestionAnswering
This model fails accuracy test on the dashboard only in max-autotune mode. I can not repro locally by command:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/huggingface.py --accuracy --no-translation-validation --training --amp --backend inductor --device cuda --only DebertaForQuestionAnswering
```
From error message on the dashboard:
```
RMSE (res-fp64): 0.01803, (ref-fp64): 0.00537 and shape=torch.Size([2]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.010000
```
0.02 tolerance should suppress this error.
## gluon_inception_v3
This model fail on the dashboard in max-autotune mode. I can not repro locally by command
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/timm_models.py --accuracy --training --amp --backend inductor --disable-cudagraphs --device cuda --only gluon_inception_v3
```
From error message on the dashboard
```
RMSE (res-fp64): 0.02798, (ref-fp64): 0.00730 and shape=torch.Size([384]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.010000
Accuracy failed for key name Mixed_7c.branch3x3dbl_3a.bn.running_var
```
raising tolerance should suppress this error.
# mobilenetv3_large_100
Fail in MA model. I can not repro locally by command
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/timm_models.py --accuracy --training --amp --backend inductor --disable-cudagraphs --device cuda --only
```
The error message on the dashboard is
```
RMSE (res-fp64): 0.29754, (ref-fp64): 0.05205 and shape=torch.Size([]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.040000
```
The tensor is so small that the noise can be high. I use larger multiplier for smaller tensor in torch._dynamo.utils.same.
# yolov3
Fail on dashboard with error
```
Error on the dashboard: RMSE (res-fp64): 0.01278, (ref-fp64): 0.00246 and shape=torch.Size([256]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.001000
```
Fix it by using a larger multiplier for smaller tensors and raising the tolereance.
# timm_efficientdet
Fail on the dashboard with error
```
E0623 18:37:43.638000 139924418725056 torch/_dynamo/utils.py:1468] RMSE (res-fp64): 0.00096, (ref-fp64): 0.00009 and shape=torch.Size([2]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.001000
```
But I can not repro locally with command
```
time python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --only timm_efficientdet --training
```
Raise the tolerance should fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129941
Approved by: https://github.com/jansel
ghstack dependencies: #129996
This PR batch the fix for a few accuracy failures issues during training by raising tolerance. I do that only for models that I think it fails not due to real issue.
## sebotnet33ts_256
The accuracy test for this model start to fail around June 05 [link](https://hud.pytorch.org/benchmark/timm_models/inductor_with_cudagraphs?dashboard=torchinductor&startTime=Sun%2C%2002%20Jun%202024%2007%3A19%3A38%20GMT&stopTime=Tue%2C%2002%20Jul%202024%2007%3A19%3A38%20GMT&granularity=day&mode=training&dtype=amp&lBranch=main&lCommit=04a0d856207d83c2031e4b9cb6825ba3e0092850&rBranch=main&rCommit=e62925930f6a62f6aeeb1fe1a661a9bd3352b53d&model=sebotnet33ts_256).
I can not repro locally, but from the log from the dashboard:
```
RMSE (res-fp64): 0.09441, (ref-fp64): 0.02971 and shape=torch.Size([1536]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.040000
```
raising the tolerance should fix it.
## DebertaForQuestionAnswering
This model fails accuracy test on the dashboard only in max-autotune mode. I can not repro locally by command:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/huggingface.py --accuracy --no-translation-validation --training --amp --backend inductor --device cuda --only DebertaForQuestionAnswering
```
From error message on the dashboard:
```
RMSE (res-fp64): 0.01803, (ref-fp64): 0.00537 and shape=torch.Size([2]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.010000
```
0.02 tolerance should suppress this error.
## gluon_inception_v3
This model fail on the dashboard in max-autotune mode. I can not repro locally by command
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/timm_models.py --accuracy --training --amp --backend inductor --disable-cudagraphs --device cuda --only gluon_inception_v3
```
From error message on the dashboard
```
RMSE (res-fp64): 0.02798, (ref-fp64): 0.00730 and shape=torch.Size([384]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.010000
Accuracy failed for key name Mixed_7c.branch3x3dbl_3a.bn.running_var
```
raising tolerance should suppress this error.
# mobilenetv3_large_100
Fail in MA model. I can not repro locally by command
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/timm_models.py --accuracy --training --amp --backend inductor --disable-cudagraphs --device cuda --only
```
The error message on the dashboard is
```
RMSE (res-fp64): 0.29754, (ref-fp64): 0.05205 and shape=torch.Size([]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.040000
```
The tensor is so small that the noise can be high. I use larger multiplier for smaller tensor in torch._dynamo.utils.same.
# yolov3
Fail on dashboard with error
```
Error on the dashboard: RMSE (res-fp64): 0.01278, (ref-fp64): 0.00246 and shape=torch.Size([256]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.001000
```
Fix it by using a larger multiplier for smaller tensors and raising the tolereance.
# timm_efficientdet
Fail on the dashboard with error
```
E0623 18:37:43.638000 139924418725056 torch/_dynamo/utils.py:1468] RMSE (res-fp64): 0.00096, (ref-fp64): 0.00009 and shape=torch.Size([2]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.001000
```
But I can not repro locally with command
```
time python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --only timm_efficientdet --training
```
Raise the tolerance should fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129941
Approved by: https://github.com/jansel
ghstack dependencies: #129996
Fixes#128510.
https://github.com/pytorch/pytorch/pull/124451 makes LayoutLMForSequenceClassification hit the SDPA pattern 1 and then encounter the accuracy issue. The issue only happens with BF16 inference single thread. This PR tends to increase the model tolerance and make the check pass. Note that even the math-version SDPA could have the issue because of some small implementation diff.
The test log:
Single thread
```
correct_result: SequenceClassifierOutput(loss=tensor(0.5998), logits=tensor([[0.3301, 0.1338]], dtype=torch.bfloat16), hidden_states=None, attentions=None)
new_result: SequenceClassifierOutput(loss=tensor(0.6016), logits=tensor([[0.3281, 0.1357]], dtype=torch.bfloat16), hidden_states=None, attentions=None)
E0627 01:09:16.762789 140281313759104 torch/_dynamo/utils.py:1476] RMSE (res-fp64): 0.00151, (ref-fp64): 0.00046 and shape=torch.Size([1, 2]). res.dtype: torch.bfloat16, multiplier: 3.000000, tol: 0.001000
E0627 01:09:16.762972 140281313759104 torch/_dynamo/utils.py:1390] Accuracy failed for key name logits
fail_accuracy
```
Multiple threads
```
correct_result: SequenceClassifierOutput(loss=tensor(0.6007), logits=tensor([[0.3301, 0.1357]], dtype=torch.bfloat16), hidden_states=None, attentions=None)
new_result: SequenceClassifierOutput(loss=tensor(0.6016), logits=tensor([[0.3281, 0.1357]], dtype=torch.bfloat16), hidden_states=None, attentions=None)
pass
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129728
Approved by: https://github.com/jgong5, https://github.com/jansel
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
Biggest movement is 4% HF inference, 9% TIMM inference. Note, this is max-autotune mode so we are more tolerant of compilation increases. We could improve compilation time by limiting:
```
# Take how many of the top triton kernels to benchmark epilogue
max_epilogue_benchmarked_choices = 3
```
There is a hf_Whisper failure which you can repro on main without this stack with `TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --accuracy --training --only hf_Whisper`. When you turn off epilogue fusion, it fixes the accuracy. I bisected the failure to an epilogue, however when you compare the results of that epilogue with the corresponding separate kernels the results of the output are equivalent.
Inference:
<img width="1686" alt="image" src="https://github.com/pytorch/pytorch/assets/11477974/0b240080-cd33-4c08-89d3-583103b1fb0c">
Training:
<img width="1329" alt="Screenshot 2024-04-16 at 6 16 30 PM" src="https://github.com/pytorch/pytorch/assets/11477974/db0afcc9-7288-4c27-84ce-4fc1a5690788">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124031
Approved by: https://github.com/Chillee, https://github.com/shunting314
ghstack dependencies: #124030, #122642, #123229, #122825
We need a higher tolerance for GPT2ForSequenceClassification since if I change --bfloat16 in
```
time python benchmarks/dynamo/huggingface.py --accuracy --inference --bfloat16 --backend inductor --disable-cudagraphs --only GPT2ForSequenceClassification
```
to --float16 or --float32 it will pass the accuracy check.
Adding --freezing can also make the test pass for this model. I think that's may be due to different fusion output being generated (depending on if constant propagation is happening controlled by freezing) and cause some small numerical difference.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120537
Approved by: https://github.com/jansel
The memory compression for these models is at parity, but because we interleave timings between torch.compile and eager run memory is duplicated between between eager and cudagraphs pool and causes OOM.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101837
Approved by: https://github.com/anijain2305
This pr accomplishes
1) Enables retries for downloading torchbenchmark and huggingface models in a similar method to how we do it for timm models right now.
2) creates a `_download_model` function for the hugging face and TIMM runners whose output I plan to use to preload the models somewhere if possible (please double check I'll be saving the right thing). Instead of retries, we plan to just add torchbench to a docker image as it is relatively small.
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 3361a4c</samp>
> _We're the brave and bold coders of the `common.py` module_
> _We've made a handy function for downloading models_
> _We've shared it with our mates in the other runners_
> _So pull and push and try again, we'll get them all in time_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101019
Approved by: https://github.com/huydhn, https://github.com/desertfire
Since the CI exclusions are hard-coded in our script, we might as well require them to match exactly. This solved some head scratching where I was like, "this model is not obviously excluded, why is it not showing up in CI."
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92761
Approved by: https://github.com/jansel