Fix recompile reason logging (#148200)

for the following test case

```
        @torch.compile(dynamic=False, backend=cnts)
        def fn(x, y, z):
            return x * y * z[0]

        fn(1, torch.randn(1), {0: torch.randn(1)})
        fn(2, torch.randn(2), {0: torch.randn(2)})
        fn(3, torch.randn(3), {0: torch.randn(3)})
        fn(4, torch.randn(4), {0: torch.randn(4)})
        fn(5, torch.randn(5), {0: torch.randn(5)})
```

previously we would log

```
0/0: L['x'] == 1
0/0: L['x'] == 1
0/0: L['x'] == 1
0/0: L['x'] == 1
```

but after this change we now log

```
0/0: L['x'] == 1
0/1: L['x'] == 2
0/2: L['x'] == 3
0/3: L['x'] == 4
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148200
Approved by: https://github.com/xmfan
This commit is contained in:
bobrenjc93
2025-02-28 09:42:27 -08:00
committed by PyTorch MergeBot
parent 40b3e4a358
commit 83ec7cdcd4

View File

@ -945,7 +945,7 @@ def _compile(
if is_recompilation(cache_size) and frame:
reasons = get_and_maybe_log_recompilation_reasons(cache_entry, frame)
recompile_reason = (
"Unable to find recompilation reasons" if not reasons else reasons[-1]
"Unable to find recompilation reasons" if not reasons else reasons[0]
)
metrics_context.update_outer({"recompile_reason": recompile_reason})