[BE] add noqa for flake8 rule B036: found except BaseException without re-raising (#159043)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159043
Approved by: https://github.com/Skylion007
This commit is contained in:
Xuehai Pan
2025-07-25 01:07:16 +08:00
committed by PyTorch MergeBot
parent 4261e26a8b
commit f903bc475c
16 changed files with 29 additions and 29 deletions

View File

@ -170,7 +170,7 @@ def main(communication_file: str) -> None:
# Runner process sent SIGINT.
sys.exit()
except BaseException:
except BaseException: # noqa: B036
trace_f = io.StringIO()
traceback.print_exc(file=trace_f)
result = WorkerFailure(failure_trace=trace_f.getvalue())

View File

@ -66,7 +66,7 @@ def main():
filenames, total_length
)
print(f"{modelname}, {utilization}, {mm_conv_utilization}")
except BaseException:
except BaseException: # noqa: B036
log.exception("%s, ERROR", filename)
print(f"{filename}, ERROR")

View File

@ -492,7 +492,7 @@ class TestWithNCCL(MultiProcessTestCase):
try:
func(arg)
compiled(arg)
except BaseException as exc:
except BaseException as exc: # noqa: B036
self.exc = exc
def join(self):

View File

@ -172,7 +172,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
def cm():
try:
yield
except BaseException:
except BaseException: # noqa: B036
raise ValueError # noqa: B904
@contextlib.contextmanager
@ -250,7 +250,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
for x, y in args:
try:
fn(x, y)
except BaseException:
except BaseException: # noqa: B036
new_exc = sys.exc_info()
fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
prev_exc = new_exc
@ -258,7 +258,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
try:
fixed_ctx = prev_exc[1].__context__
raise prev_exc[1]
except BaseException:
except BaseException: # noqa: B036
prev_exc[1].__context__ = fixed_ctx
raise
@ -749,7 +749,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
raise GeneratorExit
except Exception:
return t.sin()
except BaseException:
except BaseException: # noqa: B036
return t.cos()
t = torch.randn(2)

View File

@ -3538,14 +3538,14 @@ exit(2)
try:
with torch.cuda.stream(stream):
mem = torch.cuda.caching_allocator_alloc(1024)
except BaseException:
except BaseException: # noqa: B036
if mem is None:
return
try:
torch.cuda.caching_allocator_delete(mem)
mem = None
return None
except BaseException:
except BaseException: # noqa: B036
pass
def throws_on_cuda_event(capture_error_mode):

View File

@ -686,7 +686,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
logging_record_exception(e)
print(f"log file: {log_file}")
sys.exit(1)
except BaseException as e:
except BaseException as e: # noqa: B036
# You could logging.debug here to suppress the backtrace
# entirely, but there is no reason to hide it from technically
# savvy users.

View File

@ -2398,7 +2398,7 @@ def choose_saved_values_set(
# if idx in all_recomputable_banned_nodes:
try:
dont_ban.add(all_recomputable_banned_nodes[idx])
except BaseException:
except BaseException: # noqa: B036
pass
assert dont_ban.issubset(all_recomputable_banned_nodes)

View File

@ -190,7 +190,7 @@ class _DistWrapper:
local_data: Union[WRAPPED_EXCEPTION, T]
try:
local_data = map_fun()
except BaseException as e:
except BaseException as e: # noqa: B036
local_data = _wrap_exception(e)
all_data = self.gather_object(local_data)
@ -206,7 +206,7 @@ class _DistWrapper:
list[Union[R, CheckpointException]],
reduce_fun(cast(list[T], all_data)),
)
except BaseException as e:
except BaseException as e: # noqa: B036
node_failures[self.rank] = _wrap_exception(e)
if len(node_failures) > 0:
@ -237,7 +237,7 @@ class _DistWrapper:
local_data: Union[T, WRAPPED_EXCEPTION]
try:
local_data = map_fun()
except BaseException as e:
except BaseException as e: # noqa: B036
local_data = _wrap_exception(e)
all_data = self.gather_object(local_data)
@ -248,7 +248,7 @@ class _DistWrapper:
if len(node_failures) == 0:
try:
result = reduce_fun(cast(list[T], all_data))
except BaseException as e:
except BaseException as e: # noqa: B036
node_failures[self.rank] = _wrap_exception(e)
if len(node_failures) > 0:
@ -274,7 +274,7 @@ class _DistWrapper:
result: Union[T, WRAPPED_EXCEPTION]
try:
result = map_fun()
except BaseException as e:
except BaseException as e: # noqa: B036
result = _wrap_exception(e)
all_results = self.all_gather_object(result)
@ -300,7 +300,7 @@ class _DistWrapper:
if self.is_coordinator:
try:
result = map_fun()
except BaseException as e:
except BaseException as e: # noqa: B036
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException):

View File

@ -208,7 +208,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
try:
self.set_closed()
return True
except BaseException as e:
except BaseException as e: # noqa: B036
logger.warning("Shutdown failed. Error occurred: %s", str(e))
return False

View File

@ -330,7 +330,7 @@ def _full_post_state_dict_hook(
try:
state_dict[fqn] = state_dict[fqn].detach().clone()
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
except BaseException as e:
except BaseException as e: # noqa: B036
warnings.warn(
f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
"This may mean that this state_dict entry could point to invalid "

View File

@ -226,7 +226,7 @@ def _handle_exception(result):
exc = None
try:
exc = result.exception_type(exception_msg)
except BaseException as e:
except BaseException as e: # noqa: B036
raise RuntimeError( # noqa: B904
f"Failed to create original exception type. Error msg was {str(e)}"
f" Original exception on remote side was {exception_msg}"

View File

@ -53,13 +53,13 @@ def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
def _wrap_rref_type_cont(fut):
try:
_rref_type_cont(fut).then(_complete_op)
except BaseException as ex:
except BaseException as ex: # noqa: B036
result.set_exception(ex)
def _complete_op(fut):
try:
result.set_result(fut.value())
except BaseException as ex:
except BaseException as ex: # noqa: B036
result.set_exception(ex)
rref_fut.then(_wrap_rref_type_cont)

View File

@ -1149,7 +1149,7 @@ def spawn_threads_and_init_comms(
)
try:
callback()
except BaseException as ex:
except BaseException as ex: # noqa: B036
# Exceptions are handled in MultiThreadedTestCase
MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
ProcessLocalGroup.exception_handle(
@ -1310,7 +1310,7 @@ class MultiThreadedTestCase(TestCase):
try:
getattr(self, test_name)()
except BaseException as ex:
except BaseException as ex: # noqa: B036
self.exception_queue.put((rank, sys.exc_info()))
ProcessLocalGroup.exception_handle(
ex
@ -1641,7 +1641,7 @@ class MultiProcContinousTest(TestCase):
try:
cls._run_test_given_id(test_id)
completion_queue.put(test_id)
except BaseException as ex:
except BaseException as ex: # noqa: B036
raised_exception = True
# Send the exception and stack trace back to the dispatcher
exc_info = sys.exc_info()

View File

@ -3324,7 +3324,7 @@ class TestCase(expecttest.TestCase):
def wrapper(*args, **kwargs):
try:
f(*args, **kwargs)
except BaseException as e:
except BaseException as e: # noqa: B036
self.skipTest(e)
raise RuntimeError(f"Unexpected success, please remove `{file_name}`")
return wrapper
@ -3346,7 +3346,7 @@ class TestCase(expecttest.TestCase):
def wrapper(*args, **kwargs):
try:
f(*args, **kwargs)
except BaseException as e:
except BaseException as e: # noqa: B036
self.skipTest(e)
method = getattr(self, self._testMethodName)
if getattr(method, "__unittest_expecting_failure__", False):

View File

@ -3560,7 +3560,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
print(f"Got msg {msg}")
self.assertTrue("Original exception on remote side was" in msg)
self.assertTrue("CustomException" in msg)
except BaseException as e:
except BaseException as e: # noqa: B036
raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e
finally:
self.assertTrue(exc_caught)

View File

@ -48,7 +48,7 @@ def _wrap_generator(ctx_factory, func):
gen.close()
raise
except BaseException:
except BaseException: # noqa: B036
# Propagate the exception thrown at us by the caller
with ctx_factory():
response = gen.throw(*sys.exc_info())