increment pending_callbacks_counter before initation the pt2 compile callbacks (#157185)

Summary: Since we increment the counter after performing the callback, it leads to the assertion error when callback raises an error and increment never happens. Let's increment first to avoid it.

Test Plan:
tba

Rollback Plan:

Differential Revision: D77475650

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157185
Approved by: https://github.com/xmfan
This commit is contained in:
Burak Turk
2025-06-30 01:23:59 +00:00
committed by PyTorch MergeBot
parent 12cb06e574
commit 86ced14453

View File

@ -126,9 +126,9 @@ class CompilationCallbackHandler:
args = CallbackArgs(trigger, compile_id)
try:
with self.__pending_callbacks_counter_lock:
if self.__pending_callbacks_counter == 0:
self.run_start_callbacks(args)
self.__pending_callbacks_counter += 1
if self.__pending_callbacks_counter == 1:
self.run_start_callbacks(args)
yield
finally:
with self.__pending_callbacks_counter_lock: