diff --git a/torch/_dynamo/callback.py b/torch/_dynamo/callback.py index 500378e8c250..58cfe66baee7 100644 --- a/torch/_dynamo/callback.py +++ b/torch/_dynamo/callback.py @@ -126,9 +126,9 @@ class CompilationCallbackHandler: args = CallbackArgs(trigger, compile_id) try: with self.__pending_callbacks_counter_lock: - if self.__pending_callbacks_counter == 0: - self.run_start_callbacks(args) self.__pending_callbacks_counter += 1 + if self.__pending_callbacks_counter == 1: + self.run_start_callbacks(args) yield finally: with self.__pending_callbacks_counter_lock: