mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix recompile reason logging (#148200)
for the following test case ``` @torch.compile(dynamic=False, backend=cnts) def fn(x, y, z): return x * y * z[0] fn(1, torch.randn(1), {0: torch.randn(1)}) fn(2, torch.randn(2), {0: torch.randn(2)}) fn(3, torch.randn(3), {0: torch.randn(3)}) fn(4, torch.randn(4), {0: torch.randn(4)}) fn(5, torch.randn(5), {0: torch.randn(5)}) ``` previously we would log ``` 0/0: L['x'] == 1 0/0: L['x'] == 1 0/0: L['x'] == 1 0/0: L['x'] == 1 ``` but after this change we now log ``` 0/0: L['x'] == 1 0/1: L['x'] == 2 0/2: L['x'] == 3 0/3: L['x'] == 4 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148200 Approved by: https://github.com/xmfan
This commit is contained in:
committed by
PyTorch MergeBot
parent
40b3e4a358
commit
83ec7cdcd4
@ -945,7 +945,7 @@ def _compile(
|
||||
if is_recompilation(cache_size) and frame:
|
||||
reasons = get_and_maybe_log_recompilation_reasons(cache_entry, frame)
|
||||
recompile_reason = (
|
||||
"Unable to find recompilation reasons" if not reasons else reasons[-1]
|
||||
"Unable to find recompilation reasons" if not reasons else reasons[0]
|
||||
)
|
||||
metrics_context.update_outer({"recompile_reason": recompile_reason})
|
||||
|
||||
|
Reference in New Issue
Block a user