From 86ced144534c2c78460a2d0f1d9ed2ba048f2514 Mon Sep 17 00:00:00 2001 From: Burak Turk Date: Mon, 30 Jun 2025 01:23:59 +0000 Subject: [PATCH] increment pending_callbacks_counter before initation the pt2 compile callbacks (#157185) Summary: Since we increment the counter after performing the callback, it leads to the assertion error when callback raises an error and increment never happens. Let's increment first to avoid it. Test Plan: tba Rollback Plan: Differential Revision: D77475650 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157185 Approved by: https://github.com/xmfan --- torch/_dynamo/callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/callback.py b/torch/_dynamo/callback.py index 500378e8c250..58cfe66baee7 100644 --- a/torch/_dynamo/callback.py +++ b/torch/_dynamo/callback.py @@ -126,9 +126,9 @@ class CompilationCallbackHandler: args = CallbackArgs(trigger, compile_id) try: with self.__pending_callbacks_counter_lock: - if self.__pending_callbacks_counter == 0: - self.run_start_callbacks(args) self.__pending_callbacks_counter += 1 + if self.__pending_callbacks_counter == 1: + self.run_start_callbacks(args) yield finally: with self.__pending_callbacks_counter_lock: