mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix exception chaining in torch/ (#43836)
Summary: ## Motivation Fixes https://github.com/pytorch/pytorch/issues/43770. ## Description of the change This PR fixes exception chaining only in files under `torch/` where appropriate. To fix exception chaining, I used either: 1. `raise new_exception from old_exception` where `new_exception` itself seems not descriptive enough to debug or `old_exception` delivers valuable information. 2. `raise new_exception from None` where raising both of `new_exception` and `old_exception` seems a bit noisy and redundant. I subjectively chose which one to use from the above options. ## List of lines containing raise in except clause: I wrote [this simple script](https://gist.github.com/akihironitta/4223c1b32404b36c1b349d70c4c93b4d) using [ast](https://docs.python.org/3.8/library/ast.html#module-ast) to list lines where `raise`ing in `except` clause. - [x]000739c31a/torch/jit/annotations.py (L35)- [x]000739c31a/torch/jit/annotations.py (L150)- [x]000739c31a/torch/jit/annotations.py (L158)- [x]000739c31a/torch/jit/annotations.py (L231)- [x]000739c31a/torch/jit/_trace.py (L432)- [x]000739c31a/torch/nn/utils/prune.py (L192)- [x]000739c31a/torch/cuda/nvtx.py (L7)- [x]000739c31a/torch/utils/cpp_extension.py (L1537)- [x]000739c31a/torch/utils/tensorboard/_pytorch_graph.py (L292)- [x]000739c31a/torch/utils/data/dataloader.py (L835)- [x]000739c31a/torch/utils/data/dataloader.py (L849)- [x]000739c31a/torch/utils/data/dataloader.py (L856)- [x]000739c31a/torch/testing/_internal/common_utils.py (L186)- [x]000739c31a/torch/testing/_internal/common_utils.py (L189)- [x]000739c31a/torch/testing/_internal/common_utils.py (L424)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1279)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1283)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1356)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1388)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1391)- [ ]000739c31a/torch/testing/_internal/common_utils.py (L1412)- [x]000739c31a/torch/testing/_internal/codegen/random_topo_test.py (L310)- [x]000739c31a/torch/testing/_internal/codegen/random_topo_test.py (L329)- [x]000739c31a/torch/testing/_internal/codegen/random_topo_test.py (L332)- [x]000739c31a/torch/testing/_internal/jit_utils.py (L183)- [x]000739c31a/torch/testing/_internal/common_nn.py (L4789)- [x]000739c31a/torch/onnx/utils.py (L367)- [x]000739c31a/torch/onnx/utils.py (L659)- [x]000739c31a/torch/onnx/utils.py (L892)- [x]000739c31a/torch/onnx/utils.py (L897)- [x]000739c31a/torch/serialization.py (L108)- [x]000739c31a/torch/serialization.py (L754)- [x]000739c31a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py (L76)- [x]000739c31a/torch/distributed/rpc/backend_registry.py (L260)- [x]000739c31a/torch/distributed/distributed_c10d.py (L184)- [x]000739c31a/torch/_utils_internal.py (L57)- [x]000739c31a/torch/hub.py (L494)- [x]000739c31a/torch/contrib/_tensorboard_vis.py (L16)- [x]000739c31a/torch/distributions/lowrank_multivariate_normal.py (L100)- [x]000739c31a/torch/distributions/constraint_registry.py (L142)Pull Request resolved: https://github.com/pytorch/pytorch/pull/43836 Reviewed By: ailzhang Differential Revision: D23431212 Pulled By: malfet fbshipit-source-id: 5f7f41b391164a5ad0efc06e55cd58c23408a921
This commit is contained in:
committed by
Facebook GitHub Bot
parent
da32bf4cc6
commit
f17d7a5556
@ -49,12 +49,12 @@ def get_source_lines_and_file(obj, error_msg=None):
|
||||
filename = inspect.getsourcefile(obj)
|
||||
sourcelines, file_lineno = inspect.getsourcelines(obj)
|
||||
except OSError as e:
|
||||
msg = ("Can't get source for {}. TorchScript requires source access in "
|
||||
msg = (f"Can't get source for {obj}. TorchScript requires source access in "
|
||||
"order to carry out compilation, make sure original .py files are "
|
||||
"available. Original error: {}".format(obj, e))
|
||||
"available.")
|
||||
if error_msg:
|
||||
msg += '\n' + error_msg
|
||||
raise OSError(msg)
|
||||
raise OSError(msg) from e
|
||||
|
||||
return sourcelines, file_lineno, filename
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ try:
|
||||
from tensorflow.python.summary.writer.writer import FileWriter
|
||||
except ImportError:
|
||||
raise ImportError("TensorBoard visualization of GraphExecutors requires having "
|
||||
"TensorFlow installed")
|
||||
"TensorFlow installed") from None
|
||||
|
||||
|
||||
def dump_tensorboard_summary(graph_executor, logdir):
|
||||
|
||||
@ -181,7 +181,7 @@ def _get_group_rank(group, rank):
|
||||
try:
|
||||
group_rank = _pg_group_ranks[group][rank]
|
||||
except KeyError:
|
||||
raise RuntimeError("The global rank is not part of the group")
|
||||
raise RuntimeError(f"The global rank {rank} is not part of the group {group}") from None
|
||||
return group_rank
|
||||
|
||||
|
||||
|
||||
@ -140,7 +140,7 @@ class ConstraintRegistry(object):
|
||||
factory = self._registry[type(constraint)]
|
||||
except KeyError:
|
||||
raise NotImplementedError(
|
||||
'Cannot transform {} constraints'.format(type(constraint).__name__))
|
||||
f'Cannot transform {type(constraint).__name__} constraints') from None
|
||||
return factory(constraint)
|
||||
|
||||
|
||||
|
||||
@ -96,9 +96,9 @@ class LowRankMultivariateNormal(Distribution):
|
||||
cov_diag_ = cov_diag.unsqueeze(-1)
|
||||
try:
|
||||
loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_)
|
||||
except RuntimeError:
|
||||
except RuntimeError as e:
|
||||
raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
|
||||
.format(loc.shape, cov_factor.shape, cov_diag.shape))
|
||||
.format(loc.shape, cov_factor.shape, cov_diag.shape)) from e
|
||||
self.loc = loc_[..., 0]
|
||||
self.cov_diag = cov_diag_[..., 0]
|
||||
batch_shape = self.loc.shape[:-1]
|
||||
|
||||
@ -433,9 +433,8 @@ def _check_trace(
|
||||
*graph_diagnostic_info(),
|
||||
extra_msg="Encountered an exception while running the "
|
||||
+ running_what
|
||||
+ " with test inputs.\nException:\n"
|
||||
+ indent(str(e))
|
||||
)
|
||||
+ " with test inputs."
|
||||
) from e
|
||||
|
||||
has_warned = [False]
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ class Module(object):
|
||||
try:
|
||||
return self.members[name]
|
||||
except KeyError:
|
||||
raise RuntimeError("Module {} has no member called {}".format(self.name, name))
|
||||
raise RuntimeError("Module {} has no member called {}".format(self.name, name)) from None
|
||||
|
||||
|
||||
class EvalEnv(object):
|
||||
@ -147,7 +147,7 @@ def parse_type_line(type_line, rcb, loc):
|
||||
try:
|
||||
arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # noqa: P204
|
||||
except (NameError, SyntaxError) as e:
|
||||
raise RuntimeError("Failed to parse the argument list of a type annotation: {}".format(str(e)))
|
||||
raise RuntimeError("Failed to parse the argument list of a type annotation") from e
|
||||
|
||||
if not isinstance(arg_ann, tuple):
|
||||
arg_ann = (arg_ann,)
|
||||
@ -155,7 +155,7 @@ def parse_type_line(type_line, rcb, loc):
|
||||
try:
|
||||
ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # noqa: P204
|
||||
except (NameError, SyntaxError) as e:
|
||||
raise RuntimeError("Failed to parse the return type of a type annotation: {}".format(str(e)))
|
||||
raise RuntimeError("Failed to parse the return type of a type annotation") from e
|
||||
|
||||
arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
|
||||
return arg_types, ann_to_type(ret_ann, loc)
|
||||
@ -228,7 +228,7 @@ def split_type_line(type_line):
|
||||
try:
|
||||
arrow_pos = type_line.index('->')
|
||||
except ValueError:
|
||||
raise RuntimeError("Syntax error in type annotation (cound't find `->`)")
|
||||
raise RuntimeError("Syntax error in type annotation (cound't find `->`)") from None
|
||||
return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()
|
||||
|
||||
|
||||
|
||||
@ -363,8 +363,8 @@ def _model_to_graph(model, args, verbose=False,
|
||||
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
|
||||
graph = _propagate_and_assign_input_shapes(
|
||||
method_graph, tuple(in_vars), False, propagate)
|
||||
except AttributeError:
|
||||
raise RuntimeError('\'forward\' method must be a script method')
|
||||
except AttributeError as e:
|
||||
raise RuntimeError('\'forward\' method must be a script method') from e
|
||||
elif isinstance(model, torch.jit.ScriptFunction):
|
||||
assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript ScriptFunction"
|
||||
method = model
|
||||
|
||||
@ -105,7 +105,7 @@ def check_module_version_greater_or_equal(module, req_version_tuple, error_if_ma
|
||||
module.__name__, module.__version__, str(req_version_tuple)
|
||||
)
|
||||
if error_if_malformed:
|
||||
raise RuntimeError(message)
|
||||
raise RuntimeError(message) from e
|
||||
else:
|
||||
warnings.warn(message + ', but continuing assuming that requirement is met')
|
||||
requirement_is_met = True
|
||||
@ -752,7 +752,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
if _is_zipfile(f):
|
||||
# .zip is used for torch.jit.save and will throw an un-pickling error here
|
||||
raise RuntimeError(
|
||||
"{filename} is a zip archive (did you mean to use torch.jit.load()?)".format(filename=f.name))
|
||||
f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
|
||||
# if not a tarfile, reset file offset and proceed
|
||||
f.seek(0)
|
||||
|
||||
|
||||
@ -307,8 +307,8 @@ def runTest(seed, args):
|
||||
for out in o:
|
||||
print("val size: ", out.size())
|
||||
except Exception as err:
|
||||
raise Exception("Testing script failure with error message {0}\n\trepro by running:\n\t{1}".format(
|
||||
str(err), reproString(seed, args)))
|
||||
raise Exception("Testing script failure with error message, repro by running:\n"
|
||||
f"\t{reproString(seed, args)}") from err
|
||||
try:
|
||||
traced_model = torch.jit.trace(random_topology_test, (seed_tensor, *tensor_list))
|
||||
if DEBUG_PRINT:
|
||||
@ -325,12 +325,12 @@ def runTest(seed, args):
|
||||
print("jit output: ", jit_oo)
|
||||
print("diff ", jit_oo - oo)
|
||||
raise WrongResultException()
|
||||
except WrongResultException:
|
||||
except WrongResultException as err:
|
||||
raise Exception("cuda fuser gives wrong results, repro by running:\n"
|
||||
"\t{0}".format(reproString(seed, args)))
|
||||
f"\t{reproString(seed, args)}") from err
|
||||
except Exception as err:
|
||||
raise Exception("something in cuda fuser went wrong {0}\n\trepro by running:\n\t{1}".format(
|
||||
str(err), reproString(seed, args)))
|
||||
raise Exception("something in cuda fuser went wrong, repro by running:\n"
|
||||
f"\t{reproString(seed, args)}") from err
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
@ -1283,7 +1283,7 @@ class TestCase(expecttest.TestCase):
|
||||
raise RuntimeError(
|
||||
("I got this output for {}{}:\n\n{}\n\n"
|
||||
"No expect file exists; to accept the current output, run:\n"
|
||||
"python {} {} --accept").format(munged_id, subname_output, s, __main__.__file__, munged_id))
|
||||
"python {} {} --accept").format(munged_id, subname_output, s, __main__.__file__, munged_id)) from None
|
||||
|
||||
# a hack for JIT tests
|
||||
if IS_WINDOWS:
|
||||
@ -1350,10 +1350,10 @@ def download_file(url, binary=True):
|
||||
with open(path, 'wb' if binary else 'w') as f:
|
||||
f.write(data)
|
||||
return path
|
||||
except error.URLError:
|
||||
except error.URLError as e:
|
||||
msg = "could not download test file '{}'".format(url)
|
||||
warnings.warn(msg, RuntimeWarning)
|
||||
raise unittest.SkipTest(msg)
|
||||
raise unittest.SkipTest(msg) from e
|
||||
|
||||
|
||||
def find_free_port():
|
||||
|
||||
@ -1525,7 +1525,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=build_directory,
|
||||
env=env)
|
||||
except subprocess.CalledProcessError:
|
||||
except subprocess.CalledProcessError as e:
|
||||
# Python 2 and 3 compatible way of getting the error object.
|
||||
_, error, _ = sys.exc_info()
|
||||
# error.output contains the stdout and stderr of the build attempt.
|
||||
@ -1534,7 +1534,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
|
||||
# mypy thinks it's Optional[BaseException] and doesn't narrow
|
||||
if hasattr(error, 'output') and error.output: # type: ignore
|
||||
message += ": {}".format(error.output.decode()) # type: ignore
|
||||
raise RuntimeError(message)
|
||||
raise RuntimeError(message) from e
|
||||
|
||||
|
||||
def _import_module_from_library(module_name, path, is_python_module):
|
||||
|
||||
@ -832,7 +832,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
self._shutdown_worker(worker_id)
|
||||
if len(failed_workers) > 0:
|
||||
pids_str = ', '.join(str(w.pid) for w in failed_workers)
|
||||
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
|
||||
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
|
||||
if isinstance(e, queue.Empty):
|
||||
return (False, None)
|
||||
import tempfile
|
||||
@ -852,7 +852,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
" limit using `ulimit -n` in the shell or change the"
|
||||
" sharing strategy by calling"
|
||||
" `torch.multiprocessing.set_sharing_strategy('file_system')`"
|
||||
" at the beginning of your code")
|
||||
" at the beginning of your code") from None
|
||||
raise
|
||||
|
||||
# NOTE [ DataLoader on Linux and open files limit ]
|
||||
|
||||
Reference in New Issue
Block a user