[BE] add noqa for flake8 rule B036: found except BaseException without re-raising (#159043)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159043
Approved by: https://github.com/Skylion007
This commit is contained in:
Xuehai Pan
2025-07-25 01:07:16 +08:00
committed by PyTorch MergeBot
parent 4261e26a8b
commit f903bc475c
16 changed files with 29 additions and 29 deletions

View File

@ -170,7 +170,7 @@ def main(communication_file: str) -> None:
# Runner process sent SIGINT. # Runner process sent SIGINT.
sys.exit() sys.exit()
except BaseException: except BaseException: # noqa: B036
trace_f = io.StringIO() trace_f = io.StringIO()
traceback.print_exc(file=trace_f) traceback.print_exc(file=trace_f)
result = WorkerFailure(failure_trace=trace_f.getvalue()) result = WorkerFailure(failure_trace=trace_f.getvalue())

View File

@ -66,7 +66,7 @@ def main():
filenames, total_length filenames, total_length
) )
print(f"{modelname}, {utilization}, {mm_conv_utilization}") print(f"{modelname}, {utilization}, {mm_conv_utilization}")
except BaseException: except BaseException: # noqa: B036
log.exception("%s, ERROR", filename) log.exception("%s, ERROR", filename)
print(f"{filename}, ERROR") print(f"{filename}, ERROR")

View File

@ -492,7 +492,7 @@ class TestWithNCCL(MultiProcessTestCase):
try: try:
func(arg) func(arg)
compiled(arg) compiled(arg)
except BaseException as exc: except BaseException as exc: # noqa: B036
self.exc = exc self.exc = exc
def join(self): def join(self):

View File

@ -172,7 +172,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
def cm(): def cm():
try: try:
yield yield
except BaseException: except BaseException: # noqa: B036
raise ValueError # noqa: B904 raise ValueError # noqa: B904
@contextlib.contextmanager @contextlib.contextmanager
@ -250,7 +250,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
for x, y in args: for x, y in args:
try: try:
fn(x, y) fn(x, y)
except BaseException: except BaseException: # noqa: B036
new_exc = sys.exc_info() new_exc = sys.exc_info()
fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1]) fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
prev_exc = new_exc prev_exc = new_exc
@ -258,7 +258,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
try: try:
fixed_ctx = prev_exc[1].__context__ fixed_ctx = prev_exc[1].__context__
raise prev_exc[1] raise prev_exc[1]
except BaseException: except BaseException: # noqa: B036
prev_exc[1].__context__ = fixed_ctx prev_exc[1].__context__ = fixed_ctx
raise raise
@ -749,7 +749,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
raise GeneratorExit raise GeneratorExit
except Exception: except Exception:
return t.sin() return t.sin()
except BaseException: except BaseException: # noqa: B036
return t.cos() return t.cos()
t = torch.randn(2) t = torch.randn(2)

View File

@ -3538,14 +3538,14 @@ exit(2)
try: try:
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
mem = torch.cuda.caching_allocator_alloc(1024) mem = torch.cuda.caching_allocator_alloc(1024)
except BaseException: except BaseException: # noqa: B036
if mem is None: if mem is None:
return return
try: try:
torch.cuda.caching_allocator_delete(mem) torch.cuda.caching_allocator_delete(mem)
mem = None mem = None
return None return None
except BaseException: except BaseException: # noqa: B036
pass pass
def throws_on_cuda_event(capture_error_mode): def throws_on_cuda_event(capture_error_mode):

View File

@ -686,7 +686,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
logging_record_exception(e) logging_record_exception(e)
print(f"log file: {log_file}") print(f"log file: {log_file}")
sys.exit(1) sys.exit(1)
except BaseException as e: except BaseException as e: # noqa: B036
# You could logging.debug here to suppress the backtrace # You could logging.debug here to suppress the backtrace
# entirely, but there is no reason to hide it from technically # entirely, but there is no reason to hide it from technically
# savvy users. # savvy users.

View File

@ -2398,7 +2398,7 @@ def choose_saved_values_set(
# if idx in all_recomputable_banned_nodes: # if idx in all_recomputable_banned_nodes:
try: try:
dont_ban.add(all_recomputable_banned_nodes[idx]) dont_ban.add(all_recomputable_banned_nodes[idx])
except BaseException: except BaseException: # noqa: B036
pass pass
assert dont_ban.issubset(all_recomputable_banned_nodes) assert dont_ban.issubset(all_recomputable_banned_nodes)

View File

@ -190,7 +190,7 @@ class _DistWrapper:
local_data: Union[WRAPPED_EXCEPTION, T] local_data: Union[WRAPPED_EXCEPTION, T]
try: try:
local_data = map_fun() local_data = map_fun()
except BaseException as e: except BaseException as e: # noqa: B036
local_data = _wrap_exception(e) local_data = _wrap_exception(e)
all_data = self.gather_object(local_data) all_data = self.gather_object(local_data)
@ -206,7 +206,7 @@ class _DistWrapper:
list[Union[R, CheckpointException]], list[Union[R, CheckpointException]],
reduce_fun(cast(list[T], all_data)), reduce_fun(cast(list[T], all_data)),
) )
except BaseException as e: except BaseException as e: # noqa: B036
node_failures[self.rank] = _wrap_exception(e) node_failures[self.rank] = _wrap_exception(e)
if len(node_failures) > 0: if len(node_failures) > 0:
@ -237,7 +237,7 @@ class _DistWrapper:
local_data: Union[T, WRAPPED_EXCEPTION] local_data: Union[T, WRAPPED_EXCEPTION]
try: try:
local_data = map_fun() local_data = map_fun()
except BaseException as e: except BaseException as e: # noqa: B036
local_data = _wrap_exception(e) local_data = _wrap_exception(e)
all_data = self.gather_object(local_data) all_data = self.gather_object(local_data)
@ -248,7 +248,7 @@ class _DistWrapper:
if len(node_failures) == 0: if len(node_failures) == 0:
try: try:
result = reduce_fun(cast(list[T], all_data)) result = reduce_fun(cast(list[T], all_data))
except BaseException as e: except BaseException as e: # noqa: B036
node_failures[self.rank] = _wrap_exception(e) node_failures[self.rank] = _wrap_exception(e)
if len(node_failures) > 0: if len(node_failures) > 0:
@ -274,7 +274,7 @@ class _DistWrapper:
result: Union[T, WRAPPED_EXCEPTION] result: Union[T, WRAPPED_EXCEPTION]
try: try:
result = map_fun() result = map_fun()
except BaseException as e: except BaseException as e: # noqa: B036
result = _wrap_exception(e) result = _wrap_exception(e)
all_results = self.all_gather_object(result) all_results = self.all_gather_object(result)
@ -300,7 +300,7 @@ class _DistWrapper:
if self.is_coordinator: if self.is_coordinator:
try: try:
result = map_fun() result = map_fun()
except BaseException as e: except BaseException as e: # noqa: B036
result = CheckpointException(step, {self.rank: _wrap_exception(e)}) result = CheckpointException(step, {self.rank: _wrap_exception(e)})
final_result = self.broadcast_object(result) final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException): if isinstance(final_result, CheckpointException):

View File

@ -208,7 +208,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
try: try:
self.set_closed() self.set_closed()
return True return True
except BaseException as e: except BaseException as e: # noqa: B036
logger.warning("Shutdown failed. Error occurred: %s", str(e)) logger.warning("Shutdown failed. Error occurred: %s", str(e))
return False return False

View File

@ -330,7 +330,7 @@ def _full_post_state_dict_hook(
try: try:
state_dict[fqn] = state_dict[fqn].detach().clone() state_dict[fqn] = state_dict[fqn].detach().clone()
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
except BaseException as e: except BaseException as e: # noqa: B036
warnings.warn( warnings.warn(
f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
"This may mean that this state_dict entry could point to invalid " "This may mean that this state_dict entry could point to invalid "

View File

@ -226,7 +226,7 @@ def _handle_exception(result):
exc = None exc = None
try: try:
exc = result.exception_type(exception_msg) exc = result.exception_type(exception_msg)
except BaseException as e: except BaseException as e: # noqa: B036
raise RuntimeError( # noqa: B904 raise RuntimeError( # noqa: B904
f"Failed to create original exception type. Error msg was {str(e)}" f"Failed to create original exception type. Error msg was {str(e)}"
f" Original exception on remote side was {exception_msg}" f" Original exception on remote side was {exception_msg}"

View File

@ -53,13 +53,13 @@ def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
def _wrap_rref_type_cont(fut): def _wrap_rref_type_cont(fut):
try: try:
_rref_type_cont(fut).then(_complete_op) _rref_type_cont(fut).then(_complete_op)
except BaseException as ex: except BaseException as ex: # noqa: B036
result.set_exception(ex) result.set_exception(ex)
def _complete_op(fut): def _complete_op(fut):
try: try:
result.set_result(fut.value()) result.set_result(fut.value())
except BaseException as ex: except BaseException as ex: # noqa: B036
result.set_exception(ex) result.set_exception(ex)
rref_fut.then(_wrap_rref_type_cont) rref_fut.then(_wrap_rref_type_cont)

View File

@ -1149,7 +1149,7 @@ def spawn_threads_and_init_comms(
) )
try: try:
callback() callback()
except BaseException as ex: except BaseException as ex: # noqa: B036
# Exceptions are handled in MultiThreadedTestCase # Exceptions are handled in MultiThreadedTestCase
MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info())) MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
ProcessLocalGroup.exception_handle( ProcessLocalGroup.exception_handle(
@ -1310,7 +1310,7 @@ class MultiThreadedTestCase(TestCase):
try: try:
getattr(self, test_name)() getattr(self, test_name)()
except BaseException as ex: except BaseException as ex: # noqa: B036
self.exception_queue.put((rank, sys.exc_info())) self.exception_queue.put((rank, sys.exc_info()))
ProcessLocalGroup.exception_handle( ProcessLocalGroup.exception_handle(
ex ex
@ -1641,7 +1641,7 @@ class MultiProcContinousTest(TestCase):
try: try:
cls._run_test_given_id(test_id) cls._run_test_given_id(test_id)
completion_queue.put(test_id) completion_queue.put(test_id)
except BaseException as ex: except BaseException as ex: # noqa: B036
raised_exception = True raised_exception = True
# Send the exception and stack trace back to the dispatcher # Send the exception and stack trace back to the dispatcher
exc_info = sys.exc_info() exc_info = sys.exc_info()

View File

@ -3324,7 +3324,7 @@ class TestCase(expecttest.TestCase):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
f(*args, **kwargs) f(*args, **kwargs)
except BaseException as e: except BaseException as e: # noqa: B036
self.skipTest(e) self.skipTest(e)
raise RuntimeError(f"Unexpected success, please remove `{file_name}`") raise RuntimeError(f"Unexpected success, please remove `{file_name}`")
return wrapper return wrapper
@ -3346,7 +3346,7 @@ class TestCase(expecttest.TestCase):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
f(*args, **kwargs) f(*args, **kwargs)
except BaseException as e: except BaseException as e: # noqa: B036
self.skipTest(e) self.skipTest(e)
method = getattr(self, self._testMethodName) method = getattr(self, self._testMethodName)
if getattr(method, "__unittest_expecting_failure__", False): if getattr(method, "__unittest_expecting_failure__", False):

View File

@ -3560,7 +3560,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
print(f"Got msg {msg}") print(f"Got msg {msg}")
self.assertTrue("Original exception on remote side was" in msg) self.assertTrue("Original exception on remote side was" in msg)
self.assertTrue("CustomException" in msg) self.assertTrue("CustomException" in msg)
except BaseException as e: except BaseException as e: # noqa: B036
raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e
finally: finally:
self.assertTrue(exc_caught) self.assertTrue(exc_caught)

View File

@ -48,7 +48,7 @@ def _wrap_generator(ctx_factory, func):
gen.close() gen.close()
raise raise
except BaseException: except BaseException: # noqa: B036
# Propagate the exception thrown at us by the caller # Propagate the exception thrown at us by the caller
with ctx_factory(): with ctx_factory():
response = gen.throw(*sys.exc_info()) response = gen.throw(*sys.exc_info())