[BE]: Merge startswith calls - rule PIE810 (#96754)

Merges startswith, endswith calls to into a single call that feeds in a tuple. Not only are these calls more readable, but it will be more efficient as it iterates through each string only once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96754
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2023-03-14 22:05:16 +00:00
committed by PyTorch MergeBot
parent be4eaa69c2
commit dd5e6e8553
12 changed files with 20 additions and 34 deletions

View File

@ -41,8 +41,7 @@ def share_grad_blobs(
name = str(b)
# Note: need to look at _{namescope} pattern as it matches
# to handle the auto-split gradients
return name.endswith("_grad") and (name.startswith(namescope) or
name.startswith("_" + namescope)) and name not in param_grads
return name.endswith("_grad") and (name.startswith((namescope, "_" + namescope))) and name not in param_grads
def is_grad_op(op):
# TODO: something smarter

View File

@ -199,7 +199,7 @@ class CommitList:
else:
# Below are some extra quick checks that aren't necessarily file-path related,
# but I found that to catch a decent number of extra commits.
if len(files_changed) > 0 and all([f_name.endswith('.cu') or f_name.endswith('.cuh') for f_name in files_changed]):
if len(files_changed) > 0 and all([f_name.endswith(('.cu', '.cuh')) for f_name in files_changed]):
category = 'cuda'
elif '[PyTorch Edge]' in title:
category = 'mobile'

View File

@ -172,7 +172,7 @@ def test_forward_backward(unit_test_class, test_params):
param_name = key[:-len(suffix)]
break
assert param_name is not None
sparsity_str = 'sparse' if key.endswith('_grad_indices') or key.endswith('_grad_values') else 'dense'
sparsity_str = 'sparse' if key.endswith(('_grad_indices', '_grad_values')) else 'dense'
unit_test_class.assertTrue(
key in cpp_grad_dict,

View File

@ -161,7 +161,7 @@ class DTensorAPITest(DTensorTestBase):
dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn)
for name, param in dist_module.named_parameters():
self.assertIsInstance(param, DTensor)
if name.startswith("seq.0") or name.startswith("seq.8"):
if name.startswith(("seq.0", "seq.8")):
self.assertEqual(param.placements, shard_spec)
else:
self.assertEqual(param.placements, replica_spec)

View File

@ -80,11 +80,7 @@ def check_labels(
)
)
if (
label.startswith("module:")
or label.startswith("oncall:")
or label in ACCEPTABLE_OWNER_LABELS
):
if label.startswith(("module:", "oncall:")) or label in ACCEPTABLE_OWNER_LABELS:
continue
lint_messages.append(

View File

@ -53,11 +53,7 @@ class UnsupportedOperatorError(OnnxExporterError):
msg = diagnostic_rule.format_message(name, version, supported_version)
diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
else:
if (
name.startswith("aten::")
or name.startswith("prim::")
or name.startswith("quantized::")
):
if name.startswith(("aten::", "prim::", "quantized::")):
diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function
msg = diagnostic_rule.format_message(
name, version, _constants.PYTORCH_GITHUB_ISSUES_URL

View File

@ -1352,7 +1352,7 @@ def unconvertible_ops(
unsupported_ops = []
for node in graph.nodes():
domain_op = node.kind()
if domain_op.startswith("onnx::") or domain_op.startswith("prim::"):
if domain_op.startswith(("onnx::", "prim::")):
# We consider onnx and prim ops as supported ops, even though some "prim"
# ops are not implemented as symbolic functions, because they may be
# eliminated in the conversion passes. Users may still see errors caused
@ -1455,7 +1455,6 @@ def _reset_trace_module_map():
@_beartype.beartype
def _get_module_attributes(module):
annotations = typing.get_type_hints(type(module))
base_m_annotations = typing.get_type_hints(torch.nn.Module)
[annotations.pop(k, None) for k in base_m_annotations]
@ -1800,7 +1799,6 @@ def _need_symbolic_context(symbolic_fn: Callable) -> bool:
def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
"""Decorator that provides the symbolic context to the symbolic function if needed."""
if _need_symbolic_context(symbolic_fn):
# TODO(justinchuby): Update the module name of GraphContext when it is public
warnings.warn(
"The first argument to symbolic functions is deprecated in 1.13 and will be "
@ -1824,7 +1822,6 @@ def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
@_beartype.beartype
def _get_aten_op_overload_name(n: _C.Node) -> str:
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
schema = n.schema()
if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():

View File

@ -107,7 +107,7 @@ def check_file(filename):
Returns:
The number of unsafe kernel launches in the file
"""
if not (filename.endswith(".cu") or filename.endswith(".cuh")):
if not (filename.endswith((".cu", ".cuh"))):
return 0
if should_exclude_file(filename):
return 0

View File

@ -957,7 +957,7 @@ def largeTensorTest(size, device=None):
In other tests, the `device=` argument needs to be specified.
"""
if isinstance(size, str):
assert size.endswith("GB") or size.endswith("gb"), "only bytes or GB supported"
assert size.endswith(('GB', 'gb')), "only bytes or GB supported"
size = 1024 ** 3 * int(size[:-2])
def inner(fn):

View File

@ -552,7 +552,7 @@ class BuildExtension(build_ext):
_ccbin = os.getenv("CC")
if (
_ccbin is not None
and not any([flag.startswith('-ccbin') or flag.startswith('--compiler-bindir') for flag in cflags])
and not any([flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags])
):
cflags.extend(['-ccbin', _ccbin])

View File

@ -799,14 +799,14 @@ def preprocessor(
f = m.group(1)
dirpath, filename = os.path.split(f)
if (
f.startswith("ATen/cuda")
or f.startswith("ATen/native/cuda")
or f.startswith("ATen/native/nested/cuda")
or f.startswith("ATen/native/quantized/cuda")
or f.startswith("ATen/native/sparse/cuda")
or f.startswith("ATen/native/transformers/cuda")
or f.startswith("THC/")
or (f.startswith("THC") and not f.startswith("THCP"))
f.startswith(("ATen/cuda",
"ATen/native/cuda",
"ATen/native/nested/cuda",
"ATen/native/quantized/cuda",
"ATen/native/sparse/cuda",
"ATen/native/transformers/cuda",
"THC/")) or
(f.startswith("THC") and not f.startswith("THCP"))
):
return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
# if filename is one of the files being hipified for this extension
@ -858,7 +858,7 @@ def preprocessor(
output_source = processKernelLaunches(output_source, stats)
# Replace std:: with non-std:: versions
if (filepath.endswith(".cu") or filepath.endswith(".cuh")) and "PowKernel" not in filepath:
if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath:
output_source = replace_math_functions(output_source)
# Include header if device code is contained.

View File

@ -123,9 +123,7 @@ def hierarchical_pickle(data):
if isinstance(data, torch.utils.show_pickle.FakeObject):
typename = f"{data.module}.{data.name}"
if (
typename.startswith("__torch__.") or
typename.startswith("torch.jit.LoweredWrapper.") or
typename.startswith("torch.jit.LoweredModule.")
typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.'))
):
assert data.args == ()
return {