Fix exception chaining in torch/ (#43836)

Summary:
## Motivation
Fixes https://github.com/pytorch/pytorch/issues/43770.

## Description of the change
This PR fixes exception chaining only in files under `torch/` where appropriate.
To fix exception chaining, I used either:
1. `raise new_exception from old_exception` where `new_exception` itself seems not descriptive enough to debug or `old_exception` delivers valuable information.
2. `raise new_exception from None` where raising both of `new_exception` and `old_exception` seems a bit noisy and redundant.
I subjectively chose which one to use from the above options.

## List of lines containing raise in except clause:
I wrote [this simple script](https://gist.github.com/akihironitta/4223c1b32404b36c1b349d70c4c93b4d) using [ast](https://docs.python.org/3.8/library/ast.html#module-ast) to list lines where `raise`ing in `except` clause.

- [x] 000739c31a/torch/jit/annotations.py (L35)
- [x] 000739c31a/torch/jit/annotations.py (L150)
- [x] 000739c31a/torch/jit/annotations.py (L158)
- [x] 000739c31a/torch/jit/annotations.py (L231)
- [x] 000739c31a/torch/jit/_trace.py (L432)
- [x] 000739c31a/torch/nn/utils/prune.py (L192)
- [x] 000739c31a/torch/cuda/nvtx.py (L7)
- [x] 000739c31a/torch/utils/cpp_extension.py (L1537)
- [x] 000739c31a/torch/utils/tensorboard/_pytorch_graph.py (L292)
- [x] 000739c31a/torch/utils/data/dataloader.py (L835)
- [x] 000739c31a/torch/utils/data/dataloader.py (L849)
- [x] 000739c31a/torch/utils/data/dataloader.py (L856)
- [x] 000739c31a/torch/testing/_internal/common_utils.py (L186)
- [x] 000739c31a/torch/testing/_internal/common_utils.py (L189)
- [x] 000739c31a/torch/testing/_internal/common_utils.py (L424)
- [x] 000739c31a/torch/testing/_internal/common_utils.py (L1279)
- [x] 000739c31a/torch/testing/_internal/common_utils.py (L1283)
- [x] 000739c31a/torch/testing/_internal/common_utils.py (L1356)
- [x] 000739c31a/torch/testing/_internal/common_utils.py (L1388)
- [x] 000739c31a/torch/testing/_internal/common_utils.py (L1391)
- [ ] 000739c31a/torch/testing/_internal/common_utils.py (L1412)
- [x] 000739c31a/torch/testing/_internal/codegen/random_topo_test.py (L310)
- [x] 000739c31a/torch/testing/_internal/codegen/random_topo_test.py (L329)
- [x] 000739c31a/torch/testing/_internal/codegen/random_topo_test.py (L332)
- [x] 000739c31a/torch/testing/_internal/jit_utils.py (L183)
- [x] 000739c31a/torch/testing/_internal/common_nn.py (L4789)
- [x] 000739c31a/torch/onnx/utils.py (L367)
- [x] 000739c31a/torch/onnx/utils.py (L659)
- [x] 000739c31a/torch/onnx/utils.py (L892)
- [x] 000739c31a/torch/onnx/utils.py (L897)
- [x] 000739c31a/torch/serialization.py (L108)
- [x] 000739c31a/torch/serialization.py (L754)
- [x] 000739c31a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py (L76)
- [x] 000739c31a/torch/distributed/rpc/backend_registry.py (L260)
- [x] 000739c31a/torch/distributed/distributed_c10d.py (L184)
- [x] 000739c31a/torch/_utils_internal.py (L57)
- [x] 000739c31a/torch/hub.py (L494)
- [x] 000739c31a/torch/contrib/_tensorboard_vis.py (L16)
- [x] 000739c31a/torch/distributions/lowrank_multivariate_normal.py (L100)
- [x] 000739c31a/torch/distributions/constraint_registry.py (L142)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43836

Reviewed By: ailzhang

Differential Revision: D23431212

Pulled By: malfet

fbshipit-source-id: 5f7f41b391164a5ad0efc06e55cd58c23408a921
This commit is contained in:
Akihiro Nitta
2020-08-31 19:28:48 -07:00
committed by Facebook GitHub Bot
parent da32bf4cc6
commit f17d7a5556
13 changed files with 31 additions and 32 deletions

View File

@ -49,12 +49,12 @@ def get_source_lines_and_file(obj, error_msg=None):
filename = inspect.getsourcefile(obj)
sourcelines, file_lineno = inspect.getsourcelines(obj)
except OSError as e:
msg = ("Can't get source for {}. TorchScript requires source access in "
msg = (f"Can't get source for {obj}. TorchScript requires source access in "
"order to carry out compilation, make sure original .py files are "
"available. Original error: {}".format(obj, e))
"available.")
if error_msg:
msg += '\n' + error_msg
raise OSError(msg)
raise OSError(msg) from e
return sourcelines, file_lineno, filename

View File

@ -14,7 +14,7 @@ try:
from tensorflow.python.summary.writer.writer import FileWriter
except ImportError:
raise ImportError("TensorBoard visualization of GraphExecutors requires having "
"TensorFlow installed")
"TensorFlow installed") from None
def dump_tensorboard_summary(graph_executor, logdir):

View File

@ -181,7 +181,7 @@ def _get_group_rank(group, rank):
try:
group_rank = _pg_group_ranks[group][rank]
except KeyError:
raise RuntimeError("The global rank is not part of the group")
raise RuntimeError(f"The global rank {rank} is not part of the group {group}") from None
return group_rank

View File

@ -140,7 +140,7 @@ class ConstraintRegistry(object):
factory = self._registry[type(constraint)]
except KeyError:
raise NotImplementedError(
'Cannot transform {} constraints'.format(type(constraint).__name__))
f'Cannot transform {type(constraint).__name__} constraints') from None
return factory(constraint)

View File

@ -96,9 +96,9 @@ class LowRankMultivariateNormal(Distribution):
cov_diag_ = cov_diag.unsqueeze(-1)
try:
loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_)
except RuntimeError:
except RuntimeError as e:
raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
.format(loc.shape, cov_factor.shape, cov_diag.shape))
.format(loc.shape, cov_factor.shape, cov_diag.shape)) from e
self.loc = loc_[..., 0]
self.cov_diag = cov_diag_[..., 0]
batch_shape = self.loc.shape[:-1]

View File

@ -433,9 +433,8 @@ def _check_trace(
*graph_diagnostic_info(),
extra_msg="Encountered an exception while running the "
+ running_what
+ " with test inputs.\nException:\n"
+ indent(str(e))
)
+ " with test inputs."
) from e
has_warned = [False]

View File

@ -32,7 +32,7 @@ class Module(object):
try:
return self.members[name]
except KeyError:
raise RuntimeError("Module {} has no member called {}".format(self.name, name))
raise RuntimeError("Module {} has no member called {}".format(self.name, name)) from None
class EvalEnv(object):
@ -147,7 +147,7 @@ def parse_type_line(type_line, rcb, loc):
try:
arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # noqa: P204
except (NameError, SyntaxError) as e:
raise RuntimeError("Failed to parse the argument list of a type annotation: {}".format(str(e)))
raise RuntimeError("Failed to parse the argument list of a type annotation") from e
if not isinstance(arg_ann, tuple):
arg_ann = (arg_ann,)
@ -155,7 +155,7 @@ def parse_type_line(type_line, rcb, loc):
try:
ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # noqa: P204
except (NameError, SyntaxError) as e:
raise RuntimeError("Failed to parse the return type of a type annotation: {}".format(str(e)))
raise RuntimeError("Failed to parse the return type of a type annotation") from e
arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
return arg_types, ann_to_type(ret_ann, loc)
@ -228,7 +228,7 @@ def split_type_line(type_line):
try:
arrow_pos = type_line.index('->')
except ValueError:
raise RuntimeError("Syntax error in type annotation (cound't find `->`)")
raise RuntimeError("Syntax error in type annotation (cound't find `->`)") from None
return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()

View File

@ -363,8 +363,8 @@ def _model_to_graph(model, args, verbose=False,
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
graph = _propagate_and_assign_input_shapes(
method_graph, tuple(in_vars), False, propagate)
except AttributeError:
raise RuntimeError('\'forward\' method must be a script method')
except AttributeError as e:
raise RuntimeError('\'forward\' method must be a script method') from e
elif isinstance(model, torch.jit.ScriptFunction):
assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript ScriptFunction"
method = model

View File

@ -105,7 +105,7 @@ def check_module_version_greater_or_equal(module, req_version_tuple, error_if_ma
module.__name__, module.__version__, str(req_version_tuple)
)
if error_if_malformed:
raise RuntimeError(message)
raise RuntimeError(message) from e
else:
warnings.warn(message + ', but continuing assuming that requirement is met')
requirement_is_met = True
@ -752,7 +752,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
if _is_zipfile(f):
# .zip is used for torch.jit.save and will throw an un-pickling error here
raise RuntimeError(
"{filename} is a zip archive (did you mean to use torch.jit.load()?)".format(filename=f.name))
f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
# if not a tarfile, reset file offset and proceed
f.seek(0)

View File

@ -307,8 +307,8 @@ def runTest(seed, args):
for out in o:
print("val size: ", out.size())
except Exception as err:
raise Exception("Testing script failure with error message {0}\n\trepro by running:\n\t{1}".format(
str(err), reproString(seed, args)))
raise Exception("Testing script failure with error message, repro by running:\n"
f"\t{reproString(seed, args)}") from err
try:
traced_model = torch.jit.trace(random_topology_test, (seed_tensor, *tensor_list))
if DEBUG_PRINT:
@ -325,12 +325,12 @@ def runTest(seed, args):
print("jit output: ", jit_oo)
print("diff ", jit_oo - oo)
raise WrongResultException()
except WrongResultException:
except WrongResultException as err:
raise Exception("cuda fuser gives wrong results, repro by running:\n"
"\t{0}".format(reproString(seed, args)))
f"\t{reproString(seed, args)}") from err
except Exception as err:
raise Exception("something in cuda fuser went wrong {0}\n\trepro by running:\n\t{1}".format(
str(err), reproString(seed, args)))
raise Exception("something in cuda fuser went wrong, repro by running:\n"
f"\t{reproString(seed, args)}") from err
def parse_args():

View File

@ -1283,7 +1283,7 @@ class TestCase(expecttest.TestCase):
raise RuntimeError(
("I got this output for {}{}:\n\n{}\n\n"
"No expect file exists; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, subname_output, s, __main__.__file__, munged_id))
"python {} {} --accept").format(munged_id, subname_output, s, __main__.__file__, munged_id)) from None
# a hack for JIT tests
if IS_WINDOWS:
@ -1350,10 +1350,10 @@ def download_file(url, binary=True):
with open(path, 'wb' if binary else 'w') as f:
f.write(data)
return path
except error.URLError:
except error.URLError as e:
msg = "could not download test file '{}'".format(url)
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)
raise unittest.SkipTest(msg) from e
def find_free_port():

View File

@ -1525,7 +1525,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
stderr=subprocess.STDOUT,
cwd=build_directory,
env=env)
except subprocess.CalledProcessError:
except subprocess.CalledProcessError as e:
# Python 2 and 3 compatible way of getting the error object.
_, error, _ = sys.exc_info()
# error.output contains the stdout and stderr of the build attempt.
@ -1534,7 +1534,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
# mypy thinks it's Optional[BaseException] and doesn't narrow
if hasattr(error, 'output') and error.output: # type: ignore
message += ": {}".format(error.output.decode()) # type: ignore
raise RuntimeError(message)
raise RuntimeError(message) from e
def _import_module_from_library(module_name, path, is_python_module):

View File

@ -832,7 +832,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
self._shutdown_worker(worker_id)
if len(failed_workers) > 0:
pids_str = ', '.join(str(w.pid) for w in failed_workers)
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
if isinstance(e, queue.Empty):
return (False, None)
import tempfile
@ -852,7 +852,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
" limit using `ulimit -n` in the shell or change the"
" sharing strategy by calling"
" `torch.multiprocessing.set_sharing_strategy('file_system')`"
" at the beginning of your code")
" at the beginning of your code") from None
raise
# NOTE [ DataLoader on Linux and open files limit ]