mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE]: ruff PLC0207 - use maxsplit kwarg (#160107)
Automatically replaces split with rsplit when relevant and only performs the split up to the first ( or last value). This allows early return of the split function and improve efficiency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160107 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
3fcd79e023
commit
beb4d7816d
@ -438,9 +438,7 @@ def build_torchvision(
|
|||||||
)
|
)
|
||||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||||
elif build_version is not None:
|
elif build_version is not None:
|
||||||
build_vars += (
|
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||||
f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
|
|
||||||
)
|
|
||||||
if host.using_docker():
|
if host.using_docker():
|
||||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||||
|
|
||||||
@ -495,9 +493,7 @@ def build_torchdata(
|
|||||||
)
|
)
|
||||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||||
elif build_version is not None:
|
elif build_version is not None:
|
||||||
build_vars += (
|
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||||
f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
|
|
||||||
)
|
|
||||||
if host.using_docker():
|
if host.using_docker():
|
||||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||||
|
|
||||||
@ -553,9 +549,7 @@ def build_torchtext(
|
|||||||
)
|
)
|
||||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||||
elif build_version is not None:
|
elif build_version is not None:
|
||||||
build_vars += (
|
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||||
f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
|
|
||||||
)
|
|
||||||
if host.using_docker():
|
if host.using_docker():
|
||||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||||
|
|
||||||
@ -613,9 +607,7 @@ def build_torchaudio(
|
|||||||
)
|
)
|
||||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||||
elif build_version is not None:
|
elif build_version is not None:
|
||||||
build_vars += (
|
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||||
f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
|
|
||||||
)
|
|
||||||
if host.using_docker():
|
if host.using_docker():
|
||||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||||
|
|
||||||
|
7
.github/scripts/runner_determinator.py
vendored
7
.github/scripts/runner_determinator.py
vendored
@ -262,7 +262,12 @@ def is_exception_branch(branch: str) -> bool:
|
|||||||
"""
|
"""
|
||||||
Branches that get opted out of experiments by default, until they're explicitly enabled.
|
Branches that get opted out of experiments by default, until they're explicitly enabled.
|
||||||
"""
|
"""
|
||||||
return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"}
|
return branch.split("/", maxsplit=1)[0] in {
|
||||||
|
"main",
|
||||||
|
"nightly",
|
||||||
|
"release",
|
||||||
|
"landchecks",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(yaml_text: str) -> Any:
|
def load_yaml(yaml_text: str) -> Any:
|
||||||
|
@ -205,7 +205,7 @@ onnxruntime=={ort.__version__}
|
|||||||
onnxscript=={onnxscript.__version__}
|
onnxscript=={onnxscript.__version__}
|
||||||
numpy=={np.__version__}
|
numpy=={np.__version__}
|
||||||
torch=={torch.__version__}"""
|
torch=={torch.__version__}"""
|
||||||
short_test_name = test_name.split(".")[-1]
|
short_test_name = test_name.rsplit(".", maxsplit=1)[-1]
|
||||||
reproduction_code = _REPRODUCTION_TEMPLATE.format(
|
reproduction_code = _REPRODUCTION_TEMPLATE.format(
|
||||||
onnx_model_text=onnx_model_text,
|
onnx_model_text=onnx_model_text,
|
||||||
ort_inputs=input_text,
|
ort_inputs=input_text,
|
||||||
@ -245,7 +245,7 @@ def create_mismatch_report(
|
|||||||
|
|
||||||
error_text = str(error)
|
error_text = str(error)
|
||||||
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
|
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
|
||||||
short_test_name = test_name.split(".")[-1]
|
short_test_name = test_name.rsplit(".", maxsplit=1)[-1]
|
||||||
diff = difflib.unified_diff(
|
diff = difflib.unified_diff(
|
||||||
str(actual).splitlines(),
|
str(actual).splitlines(),
|
||||||
str(expected).splitlines(),
|
str(expected).splitlines(),
|
||||||
|
@ -13,7 +13,7 @@ REPO_ROOT = Path(__file__).resolve().parents[2]
|
|||||||
|
|
||||||
|
|
||||||
def parse_test_module(test: str) -> str:
|
def parse_test_module(test: str) -> str:
|
||||||
return test.split(".")[0]
|
return test.split(".", maxsplit=1)[0]
|
||||||
|
|
||||||
|
|
||||||
def discover_tests(
|
def discover_tests(
|
||||||
|
@ -186,7 +186,7 @@ def get_dep_modules(test: str) -> set[str]:
|
|||||||
|
|
||||||
|
|
||||||
def parse_test_module(test: str) -> str:
|
def parse_test_module(test: str) -> str:
|
||||||
return test.split(".")[0]
|
return test.split(".", maxsplit=1)[0]
|
||||||
|
|
||||||
|
|
||||||
def print_to_stderr(message: str) -> None:
|
def print_to_stderr(message: str) -> None:
|
||||||
|
@ -648,7 +648,7 @@ def custom_op_from_existing(op):
|
|||||||
name = op.name().split("::")[-1]
|
name = op.name().split("::")[-1]
|
||||||
schema_str = str(op._schema)
|
schema_str = str(op._schema)
|
||||||
# CustomOp expects the schema string without the namespace
|
# CustomOp expects the schema string without the namespace
|
||||||
schema_str = schema_str.split("::")[-1]
|
schema_str = schema_str.rsplit("::", maxsplit=1)[-1]
|
||||||
schema = FunctionSchema.parse(schema_str)
|
schema = FunctionSchema.parse(schema_str)
|
||||||
return CustomOp(lib, ns, schema, name, op, _private_access=True)
|
return CustomOp(lib, ns, schema, name, op, _private_access=True)
|
||||||
|
|
||||||
|
@ -2552,7 +2552,7 @@ def _get_cpp_prefix_header(device: str) -> Optional[str]:
|
|||||||
def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str:
|
def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str:
|
||||||
"""Given a device type (and optionally whether we're in AOT Inductor mode), returns
|
"""Given a device type (and optionally whether we're in AOT Inductor mode), returns
|
||||||
the path to the cpp_wrapper header file to be precompiled."""
|
the path to the cpp_wrapper header file to be precompiled."""
|
||||||
base_device = device.split(":")[0]
|
base_device = device.split(":", maxsplit=1)[0]
|
||||||
is_array_ref = config.aot_inductor.allow_stack_allocation and base_device == "cpu"
|
is_array_ref = config.aot_inductor.allow_stack_allocation and base_device == "cpu"
|
||||||
return (
|
return (
|
||||||
"torch/csrc/inductor/"
|
"torch/csrc/inductor/"
|
||||||
|
@ -605,7 +605,7 @@ class BaseSchedulerNode:
|
|||||||
out_lines.append(op_info_str)
|
out_lines.append(op_info_str)
|
||||||
if "stack_trace" in o.meta:
|
if "stack_trace" in o.meta:
|
||||||
stack_trace = f"{o.meta['stack_trace']}"
|
stack_trace = f"{o.meta['stack_trace']}"
|
||||||
stack_trace_last_line = stack_trace.split("|")[-1]
|
stack_trace_last_line = stack_trace.rsplit("|", maxsplit=1)[-1]
|
||||||
out_lines.append(
|
out_lines.append(
|
||||||
"#pragma CMT "
|
"#pragma CMT "
|
||||||
+ stack_trace_last_line.replace("{", "{{")
|
+ stack_trace_last_line.replace("{", "{{")
|
||||||
|
@ -302,7 +302,7 @@ def _make_prim(
|
|||||||
else:
|
else:
|
||||||
return _prim_impl(*args, **kwargs)
|
return _prim_impl(*args, **kwargs)
|
||||||
|
|
||||||
name = schema.split("(")[0]
|
name = schema.split("(", maxsplit=1)[0]
|
||||||
schema = schema[len(name) :]
|
schema = schema[len(name) :]
|
||||||
|
|
||||||
# register non-functional ops with old custom ops API
|
# register non-functional ops with old custom ops API
|
||||||
|
@ -98,7 +98,7 @@ def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> dict[str,
|
|||||||
# string manip to split tensor_fqn into module_fqn and tensor_name
|
# string manip to split tensor_fqn into module_fqn and tensor_name
|
||||||
# if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
|
# if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
|
||||||
# if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
|
# if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
|
||||||
tensor_name = tensor_fqn.split(".")[-1]
|
tensor_name = tensor_fqn.rsplit(".", maxsplit=1)[-1]
|
||||||
module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
|
module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
|
||||||
|
|
||||||
module = fqn_to_module(model, module_fqn)
|
module = fqn_to_module(model, module_fqn)
|
||||||
|
@ -719,7 +719,7 @@ class _SplitterBase:
|
|||||||
"""
|
"""
|
||||||
# Dict that maps node to its users and ignore users that
|
# Dict that maps node to its users and ignore users that
|
||||||
# are in the subgraph that has greater tag
|
# are in the subgraph that has greater tag
|
||||||
deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1]))
|
deps = self.find_reverse_deps(tag_id=int(tag.rsplit("_", maxsplit=1)[-1]))
|
||||||
self.update_reverse_deps_for_fusions(deps)
|
self.update_reverse_deps_for_fusions(deps)
|
||||||
|
|
||||||
# Parent nodes of the subgraph
|
# Parent nodes of the subgraph
|
||||||
|
@ -291,7 +291,7 @@ def _get_torch_rocm_version():
|
|||||||
if not TEST_WITH_ROCM or torch.version.hip is None:
|
if not TEST_WITH_ROCM or torch.version.hip is None:
|
||||||
return (0, 0)
|
return (0, 0)
|
||||||
rocm_version = str(torch.version.hip)
|
rocm_version = str(torch.version.hip)
|
||||||
rocm_version = rocm_version.split("-")[0] # ignore git sha
|
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
|
||||||
return tuple(int(x) for x in rocm_version.split("."))
|
return tuple(int(x) for x in rocm_version.split("."))
|
||||||
|
|
||||||
def _check_cusparse_generic_available():
|
def _check_cusparse_generic_available():
|
||||||
@ -304,7 +304,7 @@ def _check_hipsparse_generic_available():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
rocm_version = str(torch.version.hip)
|
rocm_version = str(torch.version.hip)
|
||||||
rocm_version = rocm_version.split("-")[0] # ignore git sha
|
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
|
||||||
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
|
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
|
||||||
return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
|
return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
|
||||||
|
|
||||||
|
@ -1605,7 +1605,7 @@ class MultiProcContinousTest(TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
|
def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
|
||||||
# self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
|
# self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
|
||||||
test_name = test_id.split(".")[-1]
|
test_name = test_id.rsplit(".", maxsplit=1)[-1]
|
||||||
# Get the test function from the test class
|
# Get the test function from the test class
|
||||||
self = cls(test_name)
|
self = cls(test_name)
|
||||||
self.rank = cls.rank
|
self.rank = cls.rank
|
||||||
|
@ -2017,7 +2017,7 @@ def skipIfRocmVersionLessThan(version=None):
|
|||||||
def wrap_fn(self, *args, **kwargs):
|
def wrap_fn(self, *args, **kwargs):
|
||||||
if TEST_WITH_ROCM:
|
if TEST_WITH_ROCM:
|
||||||
rocm_version = str(torch.version.hip)
|
rocm_version = str(torch.version.hip)
|
||||||
rocm_version = rocm_version.split("-")[0] # ignore git sha
|
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
|
||||||
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
|
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
|
||||||
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
|
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
|
||||||
reason = f"ROCm {rocm_version_tuple} is available but {version} required"
|
reason = f"ROCm {rocm_version_tuple} is available but {version} required"
|
||||||
|
@ -168,4 +168,4 @@ def merge_operator_dicts(
|
|||||||
|
|
||||||
|
|
||||||
def strip_operator_overload_name(op_name: str) -> str:
|
def strip_operator_overload_name(op_name: str) -> str:
|
||||||
return op_name.split(".")[0]
|
return op_name.split(".", maxsplit=1)[0]
|
||||||
|
Reference in New Issue
Block a user