mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[BE]: Merge startswith calls - rule PIE810 (#96754)
Merges startswith, endswith calls to into a single call that feeds in a tuple. Not only are these calls more readable, but it will be more efficient as it iterates through each string only once. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96754 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
be4eaa69c2
commit
dd5e6e8553
@ -41,8 +41,7 @@ def share_grad_blobs(
|
|||||||
name = str(b)
|
name = str(b)
|
||||||
# Note: need to look at _{namescope} pattern as it matches
|
# Note: need to look at _{namescope} pattern as it matches
|
||||||
# to handle the auto-split gradients
|
# to handle the auto-split gradients
|
||||||
return name.endswith("_grad") and (name.startswith(namescope) or
|
return name.endswith("_grad") and (name.startswith((namescope, "_" + namescope))) and name not in param_grads
|
||||||
name.startswith("_" + namescope)) and name not in param_grads
|
|
||||||
|
|
||||||
def is_grad_op(op):
|
def is_grad_op(op):
|
||||||
# TODO: something smarter
|
# TODO: something smarter
|
||||||
|
@ -199,7 +199,7 @@ class CommitList:
|
|||||||
else:
|
else:
|
||||||
# Below are some extra quick checks that aren't necessarily file-path related,
|
# Below are some extra quick checks that aren't necessarily file-path related,
|
||||||
# but I found that to catch a decent number of extra commits.
|
# but I found that to catch a decent number of extra commits.
|
||||||
if len(files_changed) > 0 and all([f_name.endswith('.cu') or f_name.endswith('.cuh') for f_name in files_changed]):
|
if len(files_changed) > 0 and all([f_name.endswith(('.cu', '.cuh')) for f_name in files_changed]):
|
||||||
category = 'cuda'
|
category = 'cuda'
|
||||||
elif '[PyTorch Edge]' in title:
|
elif '[PyTorch Edge]' in title:
|
||||||
category = 'mobile'
|
category = 'mobile'
|
||||||
|
@ -172,7 +172,7 @@ def test_forward_backward(unit_test_class, test_params):
|
|||||||
param_name = key[:-len(suffix)]
|
param_name = key[:-len(suffix)]
|
||||||
break
|
break
|
||||||
assert param_name is not None
|
assert param_name is not None
|
||||||
sparsity_str = 'sparse' if key.endswith('_grad_indices') or key.endswith('_grad_values') else 'dense'
|
sparsity_str = 'sparse' if key.endswith(('_grad_indices', '_grad_values')) else 'dense'
|
||||||
|
|
||||||
unit_test_class.assertTrue(
|
unit_test_class.assertTrue(
|
||||||
key in cpp_grad_dict,
|
key in cpp_grad_dict,
|
||||||
|
@ -161,7 +161,7 @@ class DTensorAPITest(DTensorTestBase):
|
|||||||
dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn)
|
dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn)
|
||||||
for name, param in dist_module.named_parameters():
|
for name, param in dist_module.named_parameters():
|
||||||
self.assertIsInstance(param, DTensor)
|
self.assertIsInstance(param, DTensor)
|
||||||
if name.startswith("seq.0") or name.startswith("seq.8"):
|
if name.startswith(("seq.0", "seq.8")):
|
||||||
self.assertEqual(param.placements, shard_spec)
|
self.assertEqual(param.placements, shard_spec)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(param.placements, replica_spec)
|
self.assertEqual(param.placements, replica_spec)
|
||||||
|
@ -80,11 +80,7 @@ def check_labels(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if label.startswith(("module:", "oncall:")) or label in ACCEPTABLE_OWNER_LABELS:
|
||||||
label.startswith("module:")
|
|
||||||
or label.startswith("oncall:")
|
|
||||||
or label in ACCEPTABLE_OWNER_LABELS
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lint_messages.append(
|
lint_messages.append(
|
||||||
|
@ -53,11 +53,7 @@ class UnsupportedOperatorError(OnnxExporterError):
|
|||||||
msg = diagnostic_rule.format_message(name, version, supported_version)
|
msg = diagnostic_rule.format_message(name, version, supported_version)
|
||||||
diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
|
diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
|
||||||
else:
|
else:
|
||||||
if (
|
if name.startswith(("aten::", "prim::", "quantized::")):
|
||||||
name.startswith("aten::")
|
|
||||||
or name.startswith("prim::")
|
|
||||||
or name.startswith("quantized::")
|
|
||||||
):
|
|
||||||
diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function
|
diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function
|
||||||
msg = diagnostic_rule.format_message(
|
msg = diagnostic_rule.format_message(
|
||||||
name, version, _constants.PYTORCH_GITHUB_ISSUES_URL
|
name, version, _constants.PYTORCH_GITHUB_ISSUES_URL
|
||||||
|
@ -1352,7 +1352,7 @@ def unconvertible_ops(
|
|||||||
unsupported_ops = []
|
unsupported_ops = []
|
||||||
for node in graph.nodes():
|
for node in graph.nodes():
|
||||||
domain_op = node.kind()
|
domain_op = node.kind()
|
||||||
if domain_op.startswith("onnx::") or domain_op.startswith("prim::"):
|
if domain_op.startswith(("onnx::", "prim::")):
|
||||||
# We consider onnx and prim ops as supported ops, even though some "prim"
|
# We consider onnx and prim ops as supported ops, even though some "prim"
|
||||||
# ops are not implemented as symbolic functions, because they may be
|
# ops are not implemented as symbolic functions, because they may be
|
||||||
# eliminated in the conversion passes. Users may still see errors caused
|
# eliminated in the conversion passes. Users may still see errors caused
|
||||||
@ -1455,7 +1455,6 @@ def _reset_trace_module_map():
|
|||||||
|
|
||||||
@_beartype.beartype
|
@_beartype.beartype
|
||||||
def _get_module_attributes(module):
|
def _get_module_attributes(module):
|
||||||
|
|
||||||
annotations = typing.get_type_hints(type(module))
|
annotations = typing.get_type_hints(type(module))
|
||||||
base_m_annotations = typing.get_type_hints(torch.nn.Module)
|
base_m_annotations = typing.get_type_hints(torch.nn.Module)
|
||||||
[annotations.pop(k, None) for k in base_m_annotations]
|
[annotations.pop(k, None) for k in base_m_annotations]
|
||||||
@ -1800,7 +1799,6 @@ def _need_symbolic_context(symbolic_fn: Callable) -> bool:
|
|||||||
def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
|
def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
|
||||||
"""Decorator that provides the symbolic context to the symbolic function if needed."""
|
"""Decorator that provides the symbolic context to the symbolic function if needed."""
|
||||||
if _need_symbolic_context(symbolic_fn):
|
if _need_symbolic_context(symbolic_fn):
|
||||||
|
|
||||||
# TODO(justinchuby): Update the module name of GraphContext when it is public
|
# TODO(justinchuby): Update the module name of GraphContext when it is public
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The first argument to symbolic functions is deprecated in 1.13 and will be "
|
"The first argument to symbolic functions is deprecated in 1.13 and will be "
|
||||||
@ -1824,7 +1822,6 @@ def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
|
|||||||
|
|
||||||
@_beartype.beartype
|
@_beartype.beartype
|
||||||
def _get_aten_op_overload_name(n: _C.Node) -> str:
|
def _get_aten_op_overload_name(n: _C.Node) -> str:
|
||||||
|
|
||||||
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
|
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
|
||||||
schema = n.schema()
|
schema = n.schema()
|
||||||
if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
|
if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
|
||||||
|
@ -107,7 +107,7 @@ def check_file(filename):
|
|||||||
Returns:
|
Returns:
|
||||||
The number of unsafe kernel launches in the file
|
The number of unsafe kernel launches in the file
|
||||||
"""
|
"""
|
||||||
if not (filename.endswith(".cu") or filename.endswith(".cuh")):
|
if not (filename.endswith((".cu", ".cuh"))):
|
||||||
return 0
|
return 0
|
||||||
if should_exclude_file(filename):
|
if should_exclude_file(filename):
|
||||||
return 0
|
return 0
|
||||||
|
@ -957,7 +957,7 @@ def largeTensorTest(size, device=None):
|
|||||||
In other tests, the `device=` argument needs to be specified.
|
In other tests, the `device=` argument needs to be specified.
|
||||||
"""
|
"""
|
||||||
if isinstance(size, str):
|
if isinstance(size, str):
|
||||||
assert size.endswith("GB") or size.endswith("gb"), "only bytes or GB supported"
|
assert size.endswith(('GB', 'gb')), "only bytes or GB supported"
|
||||||
size = 1024 ** 3 * int(size[:-2])
|
size = 1024 ** 3 * int(size[:-2])
|
||||||
|
|
||||||
def inner(fn):
|
def inner(fn):
|
||||||
|
@ -552,7 +552,7 @@ class BuildExtension(build_ext):
|
|||||||
_ccbin = os.getenv("CC")
|
_ccbin = os.getenv("CC")
|
||||||
if (
|
if (
|
||||||
_ccbin is not None
|
_ccbin is not None
|
||||||
and not any([flag.startswith('-ccbin') or flag.startswith('--compiler-bindir') for flag in cflags])
|
and not any([flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags])
|
||||||
):
|
):
|
||||||
cflags.extend(['-ccbin', _ccbin])
|
cflags.extend(['-ccbin', _ccbin])
|
||||||
|
|
||||||
|
@ -799,14 +799,14 @@ def preprocessor(
|
|||||||
f = m.group(1)
|
f = m.group(1)
|
||||||
dirpath, filename = os.path.split(f)
|
dirpath, filename = os.path.split(f)
|
||||||
if (
|
if (
|
||||||
f.startswith("ATen/cuda")
|
f.startswith(("ATen/cuda",
|
||||||
or f.startswith("ATen/native/cuda")
|
"ATen/native/cuda",
|
||||||
or f.startswith("ATen/native/nested/cuda")
|
"ATen/native/nested/cuda",
|
||||||
or f.startswith("ATen/native/quantized/cuda")
|
"ATen/native/quantized/cuda",
|
||||||
or f.startswith("ATen/native/sparse/cuda")
|
"ATen/native/sparse/cuda",
|
||||||
or f.startswith("ATen/native/transformers/cuda")
|
"ATen/native/transformers/cuda",
|
||||||
or f.startswith("THC/")
|
"THC/")) or
|
||||||
or (f.startswith("THC") and not f.startswith("THCP"))
|
(f.startswith("THC") and not f.startswith("THCP"))
|
||||||
):
|
):
|
||||||
return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
|
return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
|
||||||
# if filename is one of the files being hipified for this extension
|
# if filename is one of the files being hipified for this extension
|
||||||
@ -858,7 +858,7 @@ def preprocessor(
|
|||||||
output_source = processKernelLaunches(output_source, stats)
|
output_source = processKernelLaunches(output_source, stats)
|
||||||
|
|
||||||
# Replace std:: with non-std:: versions
|
# Replace std:: with non-std:: versions
|
||||||
if (filepath.endswith(".cu") or filepath.endswith(".cuh")) and "PowKernel" not in filepath:
|
if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath:
|
||||||
output_source = replace_math_functions(output_source)
|
output_source = replace_math_functions(output_source)
|
||||||
|
|
||||||
# Include header if device code is contained.
|
# Include header if device code is contained.
|
||||||
|
@ -123,9 +123,7 @@ def hierarchical_pickle(data):
|
|||||||
if isinstance(data, torch.utils.show_pickle.FakeObject):
|
if isinstance(data, torch.utils.show_pickle.FakeObject):
|
||||||
typename = f"{data.module}.{data.name}"
|
typename = f"{data.module}.{data.name}"
|
||||||
if (
|
if (
|
||||||
typename.startswith("__torch__.") or
|
typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.'))
|
||||||
typename.startswith("torch.jit.LoweredWrapper.") or
|
|
||||||
typename.startswith("torch.jit.LoweredModule.")
|
|
||||||
):
|
):
|
||||||
assert data.args == ()
|
assert data.args == ()
|
||||||
return {
|
return {
|
||||||
|
Reference in New Issue
Block a user