mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[BE]: Merge startswith calls - rule PIE810 (#96754)
Merges startswith, endswith calls to into a single call that feeds in a tuple. Not only are these calls more readable, but it will be more efficient as it iterates through each string only once. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96754 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
be4eaa69c2
commit
dd5e6e8553
@ -41,8 +41,7 @@ def share_grad_blobs(
|
||||
name = str(b)
|
||||
# Note: need to look at _{namescope} pattern as it matches
|
||||
# to handle the auto-split gradients
|
||||
return name.endswith("_grad") and (name.startswith(namescope) or
|
||||
name.startswith("_" + namescope)) and name not in param_grads
|
||||
return name.endswith("_grad") and (name.startswith((namescope, "_" + namescope))) and name not in param_grads
|
||||
|
||||
def is_grad_op(op):
|
||||
# TODO: something smarter
|
||||
|
@ -199,7 +199,7 @@ class CommitList:
|
||||
else:
|
||||
# Below are some extra quick checks that aren't necessarily file-path related,
|
||||
# but I found that to catch a decent number of extra commits.
|
||||
if len(files_changed) > 0 and all([f_name.endswith('.cu') or f_name.endswith('.cuh') for f_name in files_changed]):
|
||||
if len(files_changed) > 0 and all([f_name.endswith(('.cu', '.cuh')) for f_name in files_changed]):
|
||||
category = 'cuda'
|
||||
elif '[PyTorch Edge]' in title:
|
||||
category = 'mobile'
|
||||
|
@ -172,7 +172,7 @@ def test_forward_backward(unit_test_class, test_params):
|
||||
param_name = key[:-len(suffix)]
|
||||
break
|
||||
assert param_name is not None
|
||||
sparsity_str = 'sparse' if key.endswith('_grad_indices') or key.endswith('_grad_values') else 'dense'
|
||||
sparsity_str = 'sparse' if key.endswith(('_grad_indices', '_grad_values')) else 'dense'
|
||||
|
||||
unit_test_class.assertTrue(
|
||||
key in cpp_grad_dict,
|
||||
|
@ -161,7 +161,7 @@ class DTensorAPITest(DTensorTestBase):
|
||||
dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn)
|
||||
for name, param in dist_module.named_parameters():
|
||||
self.assertIsInstance(param, DTensor)
|
||||
if name.startswith("seq.0") or name.startswith("seq.8"):
|
||||
if name.startswith(("seq.0", "seq.8")):
|
||||
self.assertEqual(param.placements, shard_spec)
|
||||
else:
|
||||
self.assertEqual(param.placements, replica_spec)
|
||||
|
@ -80,11 +80,7 @@ def check_labels(
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
label.startswith("module:")
|
||||
or label.startswith("oncall:")
|
||||
or label in ACCEPTABLE_OWNER_LABELS
|
||||
):
|
||||
if label.startswith(("module:", "oncall:")) or label in ACCEPTABLE_OWNER_LABELS:
|
||||
continue
|
||||
|
||||
lint_messages.append(
|
||||
|
@ -53,11 +53,7 @@ class UnsupportedOperatorError(OnnxExporterError):
|
||||
msg = diagnostic_rule.format_message(name, version, supported_version)
|
||||
diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
|
||||
else:
|
||||
if (
|
||||
name.startswith("aten::")
|
||||
or name.startswith("prim::")
|
||||
or name.startswith("quantized::")
|
||||
):
|
||||
if name.startswith(("aten::", "prim::", "quantized::")):
|
||||
diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function
|
||||
msg = diagnostic_rule.format_message(
|
||||
name, version, _constants.PYTORCH_GITHUB_ISSUES_URL
|
||||
|
@ -1352,7 +1352,7 @@ def unconvertible_ops(
|
||||
unsupported_ops = []
|
||||
for node in graph.nodes():
|
||||
domain_op = node.kind()
|
||||
if domain_op.startswith("onnx::") or domain_op.startswith("prim::"):
|
||||
if domain_op.startswith(("onnx::", "prim::")):
|
||||
# We consider onnx and prim ops as supported ops, even though some "prim"
|
||||
# ops are not implemented as symbolic functions, because they may be
|
||||
# eliminated in the conversion passes. Users may still see errors caused
|
||||
@ -1455,7 +1455,6 @@ def _reset_trace_module_map():
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_module_attributes(module):
|
||||
|
||||
annotations = typing.get_type_hints(type(module))
|
||||
base_m_annotations = typing.get_type_hints(torch.nn.Module)
|
||||
[annotations.pop(k, None) for k in base_m_annotations]
|
||||
@ -1800,7 +1799,6 @@ def _need_symbolic_context(symbolic_fn: Callable) -> bool:
|
||||
def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
|
||||
"""Decorator that provides the symbolic context to the symbolic function if needed."""
|
||||
if _need_symbolic_context(symbolic_fn):
|
||||
|
||||
# TODO(justinchuby): Update the module name of GraphContext when it is public
|
||||
warnings.warn(
|
||||
"The first argument to symbolic functions is deprecated in 1.13 and will be "
|
||||
@ -1824,7 +1822,6 @@ def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_aten_op_overload_name(n: _C.Node) -> str:
|
||||
|
||||
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
|
||||
schema = n.schema()
|
||||
if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
|
||||
|
@ -107,7 +107,7 @@ def check_file(filename):
|
||||
Returns:
|
||||
The number of unsafe kernel launches in the file
|
||||
"""
|
||||
if not (filename.endswith(".cu") or filename.endswith(".cuh")):
|
||||
if not (filename.endswith((".cu", ".cuh"))):
|
||||
return 0
|
||||
if should_exclude_file(filename):
|
||||
return 0
|
||||
|
@ -957,7 +957,7 @@ def largeTensorTest(size, device=None):
|
||||
In other tests, the `device=` argument needs to be specified.
|
||||
"""
|
||||
if isinstance(size, str):
|
||||
assert size.endswith("GB") or size.endswith("gb"), "only bytes or GB supported"
|
||||
assert size.endswith(('GB', 'gb')), "only bytes or GB supported"
|
||||
size = 1024 ** 3 * int(size[:-2])
|
||||
|
||||
def inner(fn):
|
||||
|
@ -552,7 +552,7 @@ class BuildExtension(build_ext):
|
||||
_ccbin = os.getenv("CC")
|
||||
if (
|
||||
_ccbin is not None
|
||||
and not any([flag.startswith('-ccbin') or flag.startswith('--compiler-bindir') for flag in cflags])
|
||||
and not any([flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags])
|
||||
):
|
||||
cflags.extend(['-ccbin', _ccbin])
|
||||
|
||||
|
@ -799,14 +799,14 @@ def preprocessor(
|
||||
f = m.group(1)
|
||||
dirpath, filename = os.path.split(f)
|
||||
if (
|
||||
f.startswith("ATen/cuda")
|
||||
or f.startswith("ATen/native/cuda")
|
||||
or f.startswith("ATen/native/nested/cuda")
|
||||
or f.startswith("ATen/native/quantized/cuda")
|
||||
or f.startswith("ATen/native/sparse/cuda")
|
||||
or f.startswith("ATen/native/transformers/cuda")
|
||||
or f.startswith("THC/")
|
||||
or (f.startswith("THC") and not f.startswith("THCP"))
|
||||
f.startswith(("ATen/cuda",
|
||||
"ATen/native/cuda",
|
||||
"ATen/native/nested/cuda",
|
||||
"ATen/native/quantized/cuda",
|
||||
"ATen/native/sparse/cuda",
|
||||
"ATen/native/transformers/cuda",
|
||||
"THC/")) or
|
||||
(f.startswith("THC") and not f.startswith("THCP"))
|
||||
):
|
||||
return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
|
||||
# if filename is one of the files being hipified for this extension
|
||||
@ -858,7 +858,7 @@ def preprocessor(
|
||||
output_source = processKernelLaunches(output_source, stats)
|
||||
|
||||
# Replace std:: with non-std:: versions
|
||||
if (filepath.endswith(".cu") or filepath.endswith(".cuh")) and "PowKernel" not in filepath:
|
||||
if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath:
|
||||
output_source = replace_math_functions(output_source)
|
||||
|
||||
# Include header if device code is contained.
|
||||
|
@ -123,9 +123,7 @@ def hierarchical_pickle(data):
|
||||
if isinstance(data, torch.utils.show_pickle.FakeObject):
|
||||
typename = f"{data.module}.{data.name}"
|
||||
if (
|
||||
typename.startswith("__torch__.") or
|
||||
typename.startswith("torch.jit.LoweredWrapper.") or
|
||||
typename.startswith("torch.jit.LoweredModule.")
|
||||
typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.'))
|
||||
):
|
||||
assert data.args == ()
|
||||
return {
|
||||
|
Reference in New Issue
Block a user