[BE]: Merge startswith calls - rule PIE810 (#96754)

Merges startswith, endswith calls to into a single call that feeds in a tuple. Not only are these calls more readable, but it will be more efficient as it iterates through each string only once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96754
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2023-03-14 22:05:16 +00:00
committed by PyTorch MergeBot
parent be4eaa69c2
commit dd5e6e8553
12 changed files with 20 additions and 34 deletions

View File

@ -41,8 +41,7 @@ def share_grad_blobs(
name = str(b) name = str(b)
# Note: need to look at _{namescope} pattern as it matches # Note: need to look at _{namescope} pattern as it matches
# to handle the auto-split gradients # to handle the auto-split gradients
return name.endswith("_grad") and (name.startswith(namescope) or return name.endswith("_grad") and (name.startswith((namescope, "_" + namescope))) and name not in param_grads
name.startswith("_" + namescope)) and name not in param_grads
def is_grad_op(op): def is_grad_op(op):
# TODO: something smarter # TODO: something smarter

View File

@ -199,7 +199,7 @@ class CommitList:
else: else:
# Below are some extra quick checks that aren't necessarily file-path related, # Below are some extra quick checks that aren't necessarily file-path related,
# but I found that to catch a decent number of extra commits. # but I found that to catch a decent number of extra commits.
if len(files_changed) > 0 and all([f_name.endswith('.cu') or f_name.endswith('.cuh') for f_name in files_changed]): if len(files_changed) > 0 and all([f_name.endswith(('.cu', '.cuh')) for f_name in files_changed]):
category = 'cuda' category = 'cuda'
elif '[PyTorch Edge]' in title: elif '[PyTorch Edge]' in title:
category = 'mobile' category = 'mobile'

View File

@ -172,7 +172,7 @@ def test_forward_backward(unit_test_class, test_params):
param_name = key[:-len(suffix)] param_name = key[:-len(suffix)]
break break
assert param_name is not None assert param_name is not None
sparsity_str = 'sparse' if key.endswith('_grad_indices') or key.endswith('_grad_values') else 'dense' sparsity_str = 'sparse' if key.endswith(('_grad_indices', '_grad_values')) else 'dense'
unit_test_class.assertTrue( unit_test_class.assertTrue(
key in cpp_grad_dict, key in cpp_grad_dict,

View File

@ -161,7 +161,7 @@ class DTensorAPITest(DTensorTestBase):
dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn) dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn)
for name, param in dist_module.named_parameters(): for name, param in dist_module.named_parameters():
self.assertIsInstance(param, DTensor) self.assertIsInstance(param, DTensor)
if name.startswith("seq.0") or name.startswith("seq.8"): if name.startswith(("seq.0", "seq.8")):
self.assertEqual(param.placements, shard_spec) self.assertEqual(param.placements, shard_spec)
else: else:
self.assertEqual(param.placements, replica_spec) self.assertEqual(param.placements, replica_spec)

View File

@ -80,11 +80,7 @@ def check_labels(
) )
) )
if ( if label.startswith(("module:", "oncall:")) or label in ACCEPTABLE_OWNER_LABELS:
label.startswith("module:")
or label.startswith("oncall:")
or label in ACCEPTABLE_OWNER_LABELS
):
continue continue
lint_messages.append( lint_messages.append(

View File

@ -53,11 +53,7 @@ class UnsupportedOperatorError(OnnxExporterError):
msg = diagnostic_rule.format_message(name, version, supported_version) msg = diagnostic_rule.format_message(name, version, supported_version)
diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
else: else:
if ( if name.startswith(("aten::", "prim::", "quantized::")):
name.startswith("aten::")
or name.startswith("prim::")
or name.startswith("quantized::")
):
diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function
msg = diagnostic_rule.format_message( msg = diagnostic_rule.format_message(
name, version, _constants.PYTORCH_GITHUB_ISSUES_URL name, version, _constants.PYTORCH_GITHUB_ISSUES_URL

View File

@ -1352,7 +1352,7 @@ def unconvertible_ops(
unsupported_ops = [] unsupported_ops = []
for node in graph.nodes(): for node in graph.nodes():
domain_op = node.kind() domain_op = node.kind()
if domain_op.startswith("onnx::") or domain_op.startswith("prim::"): if domain_op.startswith(("onnx::", "prim::")):
# We consider onnx and prim ops as supported ops, even though some "prim" # We consider onnx and prim ops as supported ops, even though some "prim"
# ops are not implemented as symbolic functions, because they may be # ops are not implemented as symbolic functions, because they may be
# eliminated in the conversion passes. Users may still see errors caused # eliminated in the conversion passes. Users may still see errors caused
@ -1455,7 +1455,6 @@ def _reset_trace_module_map():
@_beartype.beartype @_beartype.beartype
def _get_module_attributes(module): def _get_module_attributes(module):
annotations = typing.get_type_hints(type(module)) annotations = typing.get_type_hints(type(module))
base_m_annotations = typing.get_type_hints(torch.nn.Module) base_m_annotations = typing.get_type_hints(torch.nn.Module)
[annotations.pop(k, None) for k in base_m_annotations] [annotations.pop(k, None) for k in base_m_annotations]
@ -1800,7 +1799,6 @@ def _need_symbolic_context(symbolic_fn: Callable) -> bool:
def _symbolic_context_handler(symbolic_fn: Callable) -> Callable: def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
"""Decorator that provides the symbolic context to the symbolic function if needed.""" """Decorator that provides the symbolic context to the symbolic function if needed."""
if _need_symbolic_context(symbolic_fn): if _need_symbolic_context(symbolic_fn):
# TODO(justinchuby): Update the module name of GraphContext when it is public # TODO(justinchuby): Update the module name of GraphContext when it is public
warnings.warn( warnings.warn(
"The first argument to symbolic functions is deprecated in 1.13 and will be " "The first argument to symbolic functions is deprecated in 1.13 and will be "
@ -1824,7 +1822,6 @@ def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
@_beartype.beartype @_beartype.beartype
def _get_aten_op_overload_name(n: _C.Node) -> str: def _get_aten_op_overload_name(n: _C.Node) -> str:
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
schema = n.schema() schema = n.schema()
if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback(): if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():

View File

@ -107,7 +107,7 @@ def check_file(filename):
Returns: Returns:
The number of unsafe kernel launches in the file The number of unsafe kernel launches in the file
""" """
if not (filename.endswith(".cu") or filename.endswith(".cuh")): if not (filename.endswith((".cu", ".cuh"))):
return 0 return 0
if should_exclude_file(filename): if should_exclude_file(filename):
return 0 return 0

View File

@ -957,7 +957,7 @@ def largeTensorTest(size, device=None):
In other tests, the `device=` argument needs to be specified. In other tests, the `device=` argument needs to be specified.
""" """
if isinstance(size, str): if isinstance(size, str):
assert size.endswith("GB") or size.endswith("gb"), "only bytes or GB supported" assert size.endswith(('GB', 'gb')), "only bytes or GB supported"
size = 1024 ** 3 * int(size[:-2]) size = 1024 ** 3 * int(size[:-2])
def inner(fn): def inner(fn):

View File

@ -552,7 +552,7 @@ class BuildExtension(build_ext):
_ccbin = os.getenv("CC") _ccbin = os.getenv("CC")
if ( if (
_ccbin is not None _ccbin is not None
and not any([flag.startswith('-ccbin') or flag.startswith('--compiler-bindir') for flag in cflags]) and not any([flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags])
): ):
cflags.extend(['-ccbin', _ccbin]) cflags.extend(['-ccbin', _ccbin])

View File

@ -799,14 +799,14 @@ def preprocessor(
f = m.group(1) f = m.group(1)
dirpath, filename = os.path.split(f) dirpath, filename = os.path.split(f)
if ( if (
f.startswith("ATen/cuda") f.startswith(("ATen/cuda",
or f.startswith("ATen/native/cuda") "ATen/native/cuda",
or f.startswith("ATen/native/nested/cuda") "ATen/native/nested/cuda",
or f.startswith("ATen/native/quantized/cuda") "ATen/native/quantized/cuda",
or f.startswith("ATen/native/sparse/cuda") "ATen/native/sparse/cuda",
or f.startswith("ATen/native/transformers/cuda") "ATen/native/transformers/cuda",
or f.startswith("THC/") "THC/")) or
or (f.startswith("THC") and not f.startswith("THCP")) (f.startswith("THC") and not f.startswith("THCP"))
): ):
return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension)) return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
# if filename is one of the files being hipified for this extension # if filename is one of the files being hipified for this extension
@ -858,7 +858,7 @@ def preprocessor(
output_source = processKernelLaunches(output_source, stats) output_source = processKernelLaunches(output_source, stats)
# Replace std:: with non-std:: versions # Replace std:: with non-std:: versions
if (filepath.endswith(".cu") or filepath.endswith(".cuh")) and "PowKernel" not in filepath: if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath:
output_source = replace_math_functions(output_source) output_source = replace_math_functions(output_source)
# Include header if device code is contained. # Include header if device code is contained.

View File

@ -123,9 +123,7 @@ def hierarchical_pickle(data):
if isinstance(data, torch.utils.show_pickle.FakeObject): if isinstance(data, torch.utils.show_pickle.FakeObject):
typename = f"{data.module}.{data.name}" typename = f"{data.module}.{data.name}"
if ( if (
typename.startswith("__torch__.") or typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.'))
typename.startswith("torch.jit.LoweredWrapper.") or
typename.startswith("torch.jit.LoweredModule.")
): ):
assert data.args == () assert data.args == ()
return { return {