mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] add noqa for flake8 rule B036: found except BaseException
without re-raising (#159043)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159043 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
4261e26a8b
commit
f903bc475c
@ -170,7 +170,7 @@ def main(communication_file: str) -> None:
|
|||||||
# Runner process sent SIGINT.
|
# Runner process sent SIGINT.
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
trace_f = io.StringIO()
|
trace_f = io.StringIO()
|
||||||
traceback.print_exc(file=trace_f)
|
traceback.print_exc(file=trace_f)
|
||||||
result = WorkerFailure(failure_trace=trace_f.getvalue())
|
result = WorkerFailure(failure_trace=trace_f.getvalue())
|
||||||
|
@ -66,7 +66,7 @@ def main():
|
|||||||
filenames, total_length
|
filenames, total_length
|
||||||
)
|
)
|
||||||
print(f"{modelname}, {utilization}, {mm_conv_utilization}")
|
print(f"{modelname}, {utilization}, {mm_conv_utilization}")
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
log.exception("%s, ERROR", filename)
|
log.exception("%s, ERROR", filename)
|
||||||
print(f"{filename}, ERROR")
|
print(f"{filename}, ERROR")
|
||||||
|
|
||||||
|
@ -492,7 +492,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
|||||||
try:
|
try:
|
||||||
func(arg)
|
func(arg)
|
||||||
compiled(arg)
|
compiled(arg)
|
||||||
except BaseException as exc:
|
except BaseException as exc: # noqa: B036
|
||||||
self.exc = exc
|
self.exc = exc
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
|
@ -172,7 +172,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||||||
def cm():
|
def cm():
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
raise ValueError # noqa: B904
|
raise ValueError # noqa: B904
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@ -250,7 +250,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||||||
for x, y in args:
|
for x, y in args:
|
||||||
try:
|
try:
|
||||||
fn(x, y)
|
fn(x, y)
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
new_exc = sys.exc_info()
|
new_exc = sys.exc_info()
|
||||||
fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
|
fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
|
||||||
prev_exc = new_exc
|
prev_exc = new_exc
|
||||||
@ -258,7 +258,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||||||
try:
|
try:
|
||||||
fixed_ctx = prev_exc[1].__context__
|
fixed_ctx = prev_exc[1].__context__
|
||||||
raise prev_exc[1]
|
raise prev_exc[1]
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
prev_exc[1].__context__ = fixed_ctx
|
prev_exc[1].__context__ = fixed_ctx
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -749,7 +749,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||||||
raise GeneratorExit
|
raise GeneratorExit
|
||||||
except Exception:
|
except Exception:
|
||||||
return t.sin()
|
return t.sin()
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
return t.cos()
|
return t.cos()
|
||||||
|
|
||||||
t = torch.randn(2)
|
t = torch.randn(2)
|
||||||
|
@ -3538,14 +3538,14 @@ exit(2)
|
|||||||
try:
|
try:
|
||||||
with torch.cuda.stream(stream):
|
with torch.cuda.stream(stream):
|
||||||
mem = torch.cuda.caching_allocator_alloc(1024)
|
mem = torch.cuda.caching_allocator_alloc(1024)
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
if mem is None:
|
if mem is None:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
torch.cuda.caching_allocator_delete(mem)
|
torch.cuda.caching_allocator_delete(mem)
|
||||||
mem = None
|
mem = None
|
||||||
return None
|
return None
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def throws_on_cuda_event(capture_error_mode):
|
def throws_on_cuda_event(capture_error_mode):
|
||||||
|
@ -686,7 +686,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
|
|||||||
logging_record_exception(e)
|
logging_record_exception(e)
|
||||||
print(f"log file: {log_file}")
|
print(f"log file: {log_file}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
# You could logging.debug here to suppress the backtrace
|
# You could logging.debug here to suppress the backtrace
|
||||||
# entirely, but there is no reason to hide it from technically
|
# entirely, but there is no reason to hide it from technically
|
||||||
# savvy users.
|
# savvy users.
|
||||||
|
@ -2398,7 +2398,7 @@ def choose_saved_values_set(
|
|||||||
# if idx in all_recomputable_banned_nodes:
|
# if idx in all_recomputable_banned_nodes:
|
||||||
try:
|
try:
|
||||||
dont_ban.add(all_recomputable_banned_nodes[idx])
|
dont_ban.add(all_recomputable_banned_nodes[idx])
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert dont_ban.issubset(all_recomputable_banned_nodes)
|
assert dont_ban.issubset(all_recomputable_banned_nodes)
|
||||||
|
@ -190,7 +190,7 @@ class _DistWrapper:
|
|||||||
local_data: Union[WRAPPED_EXCEPTION, T]
|
local_data: Union[WRAPPED_EXCEPTION, T]
|
||||||
try:
|
try:
|
||||||
local_data = map_fun()
|
local_data = map_fun()
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
local_data = _wrap_exception(e)
|
local_data = _wrap_exception(e)
|
||||||
|
|
||||||
all_data = self.gather_object(local_data)
|
all_data = self.gather_object(local_data)
|
||||||
@ -206,7 +206,7 @@ class _DistWrapper:
|
|||||||
list[Union[R, CheckpointException]],
|
list[Union[R, CheckpointException]],
|
||||||
reduce_fun(cast(list[T], all_data)),
|
reduce_fun(cast(list[T], all_data)),
|
||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
node_failures[self.rank] = _wrap_exception(e)
|
node_failures[self.rank] = _wrap_exception(e)
|
||||||
|
|
||||||
if len(node_failures) > 0:
|
if len(node_failures) > 0:
|
||||||
@ -237,7 +237,7 @@ class _DistWrapper:
|
|||||||
local_data: Union[T, WRAPPED_EXCEPTION]
|
local_data: Union[T, WRAPPED_EXCEPTION]
|
||||||
try:
|
try:
|
||||||
local_data = map_fun()
|
local_data = map_fun()
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
local_data = _wrap_exception(e)
|
local_data = _wrap_exception(e)
|
||||||
|
|
||||||
all_data = self.gather_object(local_data)
|
all_data = self.gather_object(local_data)
|
||||||
@ -248,7 +248,7 @@ class _DistWrapper:
|
|||||||
if len(node_failures) == 0:
|
if len(node_failures) == 0:
|
||||||
try:
|
try:
|
||||||
result = reduce_fun(cast(list[T], all_data))
|
result = reduce_fun(cast(list[T], all_data))
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
node_failures[self.rank] = _wrap_exception(e)
|
node_failures[self.rank] = _wrap_exception(e)
|
||||||
|
|
||||||
if len(node_failures) > 0:
|
if len(node_failures) > 0:
|
||||||
@ -274,7 +274,7 @@ class _DistWrapper:
|
|||||||
result: Union[T, WRAPPED_EXCEPTION]
|
result: Union[T, WRAPPED_EXCEPTION]
|
||||||
try:
|
try:
|
||||||
result = map_fun()
|
result = map_fun()
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
result = _wrap_exception(e)
|
result = _wrap_exception(e)
|
||||||
|
|
||||||
all_results = self.all_gather_object(result)
|
all_results = self.all_gather_object(result)
|
||||||
@ -300,7 +300,7 @@ class _DistWrapper:
|
|||||||
if self.is_coordinator:
|
if self.is_coordinator:
|
||||||
try:
|
try:
|
||||||
result = map_fun()
|
result = map_fun()
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
|
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
|
||||||
final_result = self.broadcast_object(result)
|
final_result = self.broadcast_object(result)
|
||||||
if isinstance(final_result, CheckpointException):
|
if isinstance(final_result, CheckpointException):
|
||||||
|
@ -208,7 +208,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
|
|||||||
try:
|
try:
|
||||||
self.set_closed()
|
self.set_closed()
|
||||||
return True
|
return True
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
logger.warning("Shutdown failed. Error occurred: %s", str(e))
|
logger.warning("Shutdown failed. Error occurred: %s", str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -330,7 +330,7 @@ def _full_post_state_dict_hook(
|
|||||||
try:
|
try:
|
||||||
state_dict[fqn] = state_dict[fqn].detach().clone()
|
state_dict[fqn] = state_dict[fqn].detach().clone()
|
||||||
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
|
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
|
f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
|
||||||
"This may mean that this state_dict entry could point to invalid "
|
"This may mean that this state_dict entry could point to invalid "
|
||||||
|
@ -226,7 +226,7 @@ def _handle_exception(result):
|
|||||||
exc = None
|
exc = None
|
||||||
try:
|
try:
|
||||||
exc = result.exception_type(exception_msg)
|
exc = result.exception_type(exception_msg)
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
raise RuntimeError( # noqa: B904
|
raise RuntimeError( # noqa: B904
|
||||||
f"Failed to create original exception type. Error msg was {str(e)}"
|
f"Failed to create original exception type. Error msg was {str(e)}"
|
||||||
f" Original exception on remote side was {exception_msg}"
|
f" Original exception on remote side was {exception_msg}"
|
||||||
|
@ -53,13 +53,13 @@ def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
|
|||||||
def _wrap_rref_type_cont(fut):
|
def _wrap_rref_type_cont(fut):
|
||||||
try:
|
try:
|
||||||
_rref_type_cont(fut).then(_complete_op)
|
_rref_type_cont(fut).then(_complete_op)
|
||||||
except BaseException as ex:
|
except BaseException as ex: # noqa: B036
|
||||||
result.set_exception(ex)
|
result.set_exception(ex)
|
||||||
|
|
||||||
def _complete_op(fut):
|
def _complete_op(fut):
|
||||||
try:
|
try:
|
||||||
result.set_result(fut.value())
|
result.set_result(fut.value())
|
||||||
except BaseException as ex:
|
except BaseException as ex: # noqa: B036
|
||||||
result.set_exception(ex)
|
result.set_exception(ex)
|
||||||
|
|
||||||
rref_fut.then(_wrap_rref_type_cont)
|
rref_fut.then(_wrap_rref_type_cont)
|
||||||
|
@ -1149,7 +1149,7 @@ def spawn_threads_and_init_comms(
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
callback()
|
callback()
|
||||||
except BaseException as ex:
|
except BaseException as ex: # noqa: B036
|
||||||
# Exceptions are handled in MultiThreadedTestCase
|
# Exceptions are handled in MultiThreadedTestCase
|
||||||
MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
|
MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
|
||||||
ProcessLocalGroup.exception_handle(
|
ProcessLocalGroup.exception_handle(
|
||||||
@ -1310,7 +1310,7 @@ class MultiThreadedTestCase(TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
getattr(self, test_name)()
|
getattr(self, test_name)()
|
||||||
except BaseException as ex:
|
except BaseException as ex: # noqa: B036
|
||||||
self.exception_queue.put((rank, sys.exc_info()))
|
self.exception_queue.put((rank, sys.exc_info()))
|
||||||
ProcessLocalGroup.exception_handle(
|
ProcessLocalGroup.exception_handle(
|
||||||
ex
|
ex
|
||||||
@ -1641,7 +1641,7 @@ class MultiProcContinousTest(TestCase):
|
|||||||
try:
|
try:
|
||||||
cls._run_test_given_id(test_id)
|
cls._run_test_given_id(test_id)
|
||||||
completion_queue.put(test_id)
|
completion_queue.put(test_id)
|
||||||
except BaseException as ex:
|
except BaseException as ex: # noqa: B036
|
||||||
raised_exception = True
|
raised_exception = True
|
||||||
# Send the exception and stack trace back to the dispatcher
|
# Send the exception and stack trace back to the dispatcher
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
|
@ -3324,7 +3324,7 @@ class TestCase(expecttest.TestCase):
|
|||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
f(*args, **kwargs)
|
f(*args, **kwargs)
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
self.skipTest(e)
|
self.skipTest(e)
|
||||||
raise RuntimeError(f"Unexpected success, please remove `{file_name}`")
|
raise RuntimeError(f"Unexpected success, please remove `{file_name}`")
|
||||||
return wrapper
|
return wrapper
|
||||||
@ -3346,7 +3346,7 @@ class TestCase(expecttest.TestCase):
|
|||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
f(*args, **kwargs)
|
f(*args, **kwargs)
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
self.skipTest(e)
|
self.skipTest(e)
|
||||||
method = getattr(self, self._testMethodName)
|
method = getattr(self, self._testMethodName)
|
||||||
if getattr(method, "__unittest_expecting_failure__", False):
|
if getattr(method, "__unittest_expecting_failure__", False):
|
||||||
|
@ -3560,7 +3560,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
|
|||||||
print(f"Got msg {msg}")
|
print(f"Got msg {msg}")
|
||||||
self.assertTrue("Original exception on remote side was" in msg)
|
self.assertTrue("Original exception on remote side was" in msg)
|
||||||
self.assertTrue("CustomException" in msg)
|
self.assertTrue("CustomException" in msg)
|
||||||
except BaseException as e:
|
except BaseException as e: # noqa: B036
|
||||||
raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e
|
raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e
|
||||||
finally:
|
finally:
|
||||||
self.assertTrue(exc_caught)
|
self.assertTrue(exc_caught)
|
||||||
|
@ -48,7 +48,7 @@ def _wrap_generator(ctx_factory, func):
|
|||||||
gen.close()
|
gen.close()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
except BaseException:
|
except BaseException: # noqa: B036
|
||||||
# Propagate the exception thrown at us by the caller
|
# Propagate the exception thrown at us by the caller
|
||||||
with ctx_factory():
|
with ctx_factory():
|
||||||
response = gen.throw(*sys.exc_info())
|
response = gen.throw(*sys.exc_info())
|
||||||
|
Reference in New Issue
Block a user