From f903bc475cc0b649155d2c1f113d4a857667e7f5 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 25 Jul 2025 01:07:16 +0800 Subject: [PATCH] [BE] add noqa for flake8 rule B036: found `except BaseException` without re-raising (#159043) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159043 Approved by: https://github.com/Skylion007 --- benchmarks/instruction_counts/worker/main.py | 2 +- functorch/benchmarks/chrome_trace_parser.py | 2 +- test/distributed/test_c10d_functional_native.py | 2 +- test/dynamo/test_exceptions.py | 8 ++++---- test/test_cuda.py | 4 ++-- tools/nightly.py | 2 +- torch/_functorch/partitioners.py | 2 +- torch/distributed/checkpoint/utils.py | 12 ++++++------ .../elastic/rendezvous/etcd_rendezvous.py | 2 +- torch/distributed/fsdp/_state_dict_utils.py | 2 +- torch/distributed/rpc/internal.py | 2 +- torch/distributed/rpc/rref_proxy.py | 4 ++-- torch/testing/_internal/common_distributed.py | 6 +++--- torch/testing/_internal/common_utils.py | 4 ++-- torch/testing/_internal/distributed/rpc/rpc_test.py | 2 +- torch/utils/_contextlib.py | 2 +- 16 files changed, 29 insertions(+), 29 deletions(-) diff --git a/benchmarks/instruction_counts/worker/main.py b/benchmarks/instruction_counts/worker/main.py index 73cbe029878f..33021ec65004 100644 --- a/benchmarks/instruction_counts/worker/main.py +++ b/benchmarks/instruction_counts/worker/main.py @@ -170,7 +170,7 @@ def main(communication_file: str) -> None: # Runner process sent SIGINT. sys.exit() - except BaseException: + except BaseException: # noqa: B036 trace_f = io.StringIO() traceback.print_exc(file=trace_f) result = WorkerFailure(failure_trace=trace_f.getvalue()) diff --git a/functorch/benchmarks/chrome_trace_parser.py b/functorch/benchmarks/chrome_trace_parser.py index 4f6b30606267..cc641c1cf81c 100755 --- a/functorch/benchmarks/chrome_trace_parser.py +++ b/functorch/benchmarks/chrome_trace_parser.py @@ -66,7 +66,7 @@ def main(): filenames, total_length ) print(f"{modelname}, {utilization}, {mm_conv_utilization}") - except BaseException: + except BaseException: # noqa: B036 log.exception("%s, ERROR", filename) print(f"{filename}, ERROR") diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 17a7966a1584..80f6705f4ef6 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -492,7 +492,7 @@ class TestWithNCCL(MultiProcessTestCase): try: func(arg) compiled(arg) - except BaseException as exc: + except BaseException as exc: # noqa: B036 self.exc = exc def join(self): diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 94ce690ed5b9..7a1913be5460 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -172,7 +172,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): def cm(): try: yield - except BaseException: + except BaseException: # noqa: B036 raise ValueError # noqa: B904 @contextlib.contextmanager @@ -250,7 +250,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): for x, y in args: try: fn(x, y) - except BaseException: + except BaseException: # noqa: B036 new_exc = sys.exc_info() fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1]) prev_exc = new_exc @@ -258,7 +258,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): try: fixed_ctx = prev_exc[1].__context__ raise prev_exc[1] - except BaseException: + except BaseException: # noqa: B036 prev_exc[1].__context__ = fixed_ctx raise @@ -749,7 +749,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): raise GeneratorExit except Exception: return t.sin() - except BaseException: + except BaseException: # noqa: B036 return t.cos() t = torch.randn(2) diff --git a/test/test_cuda.py b/test/test_cuda.py index 43d56394104f..689f4f38250c 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3538,14 +3538,14 @@ exit(2) try: with torch.cuda.stream(stream): mem = torch.cuda.caching_allocator_alloc(1024) - except BaseException: + except BaseException: # noqa: B036 if mem is None: return try: torch.cuda.caching_allocator_delete(mem) mem = None return None - except BaseException: + except BaseException: # noqa: B036 pass def throws_on_cuda_event(capture_error_mode): diff --git a/tools/nightly.py b/tools/nightly.py index 0ed8cfe165aa..c0af8bccf152 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -686,7 +686,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N logging_record_exception(e) print(f"log file: {log_file}") sys.exit(1) - except BaseException as e: + except BaseException as e: # noqa: B036 # You could logging.debug here to suppress the backtrace # entirely, but there is no reason to hide it from technically # savvy users. diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index c666a924b468..af6409926ae4 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -2398,7 +2398,7 @@ def choose_saved_values_set( # if idx in all_recomputable_banned_nodes: try: dont_ban.add(all_recomputable_banned_nodes[idx]) - except BaseException: + except BaseException: # noqa: B036 pass assert dont_ban.issubset(all_recomputable_banned_nodes) diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index e39bfd25cdd3..6d00026d9934 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -190,7 +190,7 @@ class _DistWrapper: local_data: Union[WRAPPED_EXCEPTION, T] try: local_data = map_fun() - except BaseException as e: + except BaseException as e: # noqa: B036 local_data = _wrap_exception(e) all_data = self.gather_object(local_data) @@ -206,7 +206,7 @@ class _DistWrapper: list[Union[R, CheckpointException]], reduce_fun(cast(list[T], all_data)), ) - except BaseException as e: + except BaseException as e: # noqa: B036 node_failures[self.rank] = _wrap_exception(e) if len(node_failures) > 0: @@ -237,7 +237,7 @@ class _DistWrapper: local_data: Union[T, WRAPPED_EXCEPTION] try: local_data = map_fun() - except BaseException as e: + except BaseException as e: # noqa: B036 local_data = _wrap_exception(e) all_data = self.gather_object(local_data) @@ -248,7 +248,7 @@ class _DistWrapper: if len(node_failures) == 0: try: result = reduce_fun(cast(list[T], all_data)) - except BaseException as e: + except BaseException as e: # noqa: B036 node_failures[self.rank] = _wrap_exception(e) if len(node_failures) > 0: @@ -274,7 +274,7 @@ class _DistWrapper: result: Union[T, WRAPPED_EXCEPTION] try: result = map_fun() - except BaseException as e: + except BaseException as e: # noqa: B036 result = _wrap_exception(e) all_results = self.all_gather_object(result) @@ -300,7 +300,7 @@ class _DistWrapper: if self.is_coordinator: try: result = map_fun() - except BaseException as e: + except BaseException as e: # noqa: B036 result = CheckpointException(step, {self.rank: _wrap_exception(e)}) final_result = self.broadcast_object(result) if isinstance(final_result, CheckpointException): diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index 6b049423ffc6..0e4da86d4621 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -208,7 +208,7 @@ class EtcdRendezvousHandler(RendezvousHandler): try: self.set_closed() return True - except BaseException as e: + except BaseException as e: # noqa: B036 logger.warning("Shutdown failed. Error occurred: %s", str(e)) return False diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index d59b5b4492c0..a81d48ebdba8 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -330,7 +330,7 @@ def _full_post_state_dict_hook( try: state_dict[fqn] = state_dict[fqn].detach().clone() state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] - except BaseException as e: + except BaseException as e: # noqa: B036 warnings.warn( f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " "This may mean that this state_dict entry could point to invalid " diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index 5faf7d14d0da..c830fc11d8ed 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -226,7 +226,7 @@ def _handle_exception(result): exc = None try: exc = result.exception_type(exception_msg) - except BaseException as e: + except BaseException as e: # noqa: B036 raise RuntimeError( # noqa: B904 f"Failed to create original exception type. Error msg was {str(e)}" f" Original exception on remote side was {exception_msg}" diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index 85927b68bacb..71c111b2f2e6 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -53,13 +53,13 @@ def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): def _wrap_rref_type_cont(fut): try: _rref_type_cont(fut).then(_complete_op) - except BaseException as ex: + except BaseException as ex: # noqa: B036 result.set_exception(ex) def _complete_op(fut): try: result.set_result(fut.value()) - except BaseException as ex: + except BaseException as ex: # noqa: B036 result.set_exception(ex) rref_fut.then(_wrap_rref_type_cont) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 9b311411e34a..af1aafd3871a 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -1149,7 +1149,7 @@ def spawn_threads_and_init_comms( ) try: callback() - except BaseException as ex: + except BaseException as ex: # noqa: B036 # Exceptions are handled in MultiThreadedTestCase MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info())) ProcessLocalGroup.exception_handle( @@ -1310,7 +1310,7 @@ class MultiThreadedTestCase(TestCase): try: getattr(self, test_name)() - except BaseException as ex: + except BaseException as ex: # noqa: B036 self.exception_queue.put((rank, sys.exc_info())) ProcessLocalGroup.exception_handle( ex @@ -1641,7 +1641,7 @@ class MultiProcContinousTest(TestCase): try: cls._run_test_given_id(test_id) completion_queue.put(test_id) - except BaseException as ex: + except BaseException as ex: # noqa: B036 raised_exception = True # Send the exception and stack trace back to the dispatcher exc_info = sys.exc_info() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 1b4f03da3dfc..70f6ebb99536 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -3324,7 +3324,7 @@ class TestCase(expecttest.TestCase): def wrapper(*args, **kwargs): try: f(*args, **kwargs) - except BaseException as e: + except BaseException as e: # noqa: B036 self.skipTest(e) raise RuntimeError(f"Unexpected success, please remove `{file_name}`") return wrapper @@ -3346,7 +3346,7 @@ class TestCase(expecttest.TestCase): def wrapper(*args, **kwargs): try: f(*args, **kwargs) - except BaseException as e: + except BaseException as e: # noqa: B036 self.skipTest(e) method = getattr(self, self._testMethodName) if getattr(method, "__unittest_expecting_failure__", False): diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 5149c66810f8..4ec964092b39 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -3560,7 +3560,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon): print(f"Got msg {msg}") self.assertTrue("Original exception on remote side was" in msg) self.assertTrue("CustomException" in msg) - except BaseException as e: + except BaseException as e: # noqa: B036 raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e finally: self.assertTrue(exc_caught) diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index 26217de5bb32..8db27efa270a 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -48,7 +48,7 @@ def _wrap_generator(ctx_factory, func): gen.close() raise - except BaseException: + except BaseException: # noqa: B036 # Propagate the exception thrown at us by the caller with ctx_factory(): response = gen.throw(*sys.exc_info())