From a029675f6f0b9cf48eb7943d4be8169c67960a8e Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 9 Oct 2025 03:24:46 +0000 Subject: [PATCH] More ruff SIM fixes (#164695) This PR applies ruff `SIM` rules to more files. Most changes are about simplifying `dict.get` because `None` is already the default value. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164695 Approved by: https://github.com/ezyang --- test/dynamo/test_guard_serialization.py | 2 +- test/dynamo/test_recompile_ux.py | 18 ++++----- test/inductor/test_group_batch_fusion.py | 2 +- test/profiler/test_memory_profiler.py | 2 +- test/test_cuda_nvml_based_avail.py | 2 +- test/test_jit_string.py | 40 +++++++++---------- tools/alerts/create_alerts.py | 4 +- tools/autograd/gen_inplace_or_view_type.py | 2 +- .../src/coverage_plugins/jit_plugin.py | 2 +- tools/flight_recorder/components/types.py | 32 ++++++++------- tools/setup_helpers/cmake.py | 2 +- tools/testing/explicit_ci_jobs.py | 2 +- torch/_dynamo/variables/constant.py | 2 +- torch/_dynamo/variables/lists.py | 2 +- torch/_dynamo/variables/misc.py | 4 +- torch/_dynamo/variables/nn_module.py | 2 +- torch/_dynamo/variables/sdpa.py | 10 ++++- torch/_higher_order_ops/auto_functionalize.py | 8 ++-- torch/_inductor/analysis/device_info.py | 2 +- torch/_inductor/codegen/cpp_micro_gemm.py | 6 +-- .../codegen/rocm/ck_conv_template.py | 2 +- .../rocm/ck_tile_universal_gemm_template.py | 4 +- .../rocm/ck_universal_gemm_template.py | 6 +-- torch/_inductor/codegen/simd.py | 2 +- torch/_inductor/compiler_bisector.py | 2 +- torch/_inductor/fx_passes/pad_mm.py | 2 +- torch/_inductor/fx_passes/pre_grad.py | 2 +- torch/_inductor/fx_passes/reinplace.py | 2 +- torch/_inductor/lowering.py | 4 +- torch/_inductor/runtime/triton_heuristics.py | 6 +-- torch/_inductor/select_algorithm.py | 4 +- torch/_inductor/sizevars.py | 4 +- torch/_inductor/tiling_utils.py | 2 +- torch/_logging/structured.py | 2 +- torch/_numpy/_funcs_impl.py | 2 +- torch/_subclasses/meta_utils.py | 2 +- torch/export/_leakage_detection_utils.py | 2 +- torch/fx/experimental/unification/utils.py | 2 +- torch/fx/experimental/unification/variable.py | 2 +- torch/fx/graph_module.py | 2 +- torch/fx/passes/infra/partitioner.py | 6 +-- torch/jit/_state.py | 4 +- .../_internal/fx/passes/type_promotion.py | 6 +-- torch/testing/_internal/common_nn.py | 6 +-- torch/testing/_internal/common_utils.py | 4 +- .../distributed/common_state_dict.py | 2 +- torch/utils/cpp_extension.py | 2 +- torch/utils/data/datapipes/utils/decoder.py | 20 +++++++--- torchgen/model.py | 18 ++++++++- 49 files changed, 153 insertions(+), 117 deletions(-) diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index c0b3ff226c48..54520bba448a 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -321,7 +321,7 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase): def _test_serialization(self, guard_type, fn, *args, **kwargs): # kwargs might contain a callable that generates kwargs torch._dynamo.reset() - kwarg_gen_fn = kwargs.get("_gen_fn", None) + kwarg_gen_fn = kwargs.get("_gen_fn") if kwarg_gen_fn is not None: kwargs = kwarg_gen_fn() diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index f945039b55d1..880d37434b31 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -242,11 +242,12 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase): opt_f(torch.randn(8 + i)) failure_str = "\n".join(failure_reasons) - for line in """\ -tensor 'x' size mismatch at index 0. expected 11, actual 12 -tensor 'x' size mismatch at index 0. expected 10, actual 12 -tensor 'x' size mismatch at index 0. expected 9, actual 12 -tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"): + for line in [ + "tensor 'x' size mismatch at index 0. expected 11, actual 12", + "tensor 'x' size mismatch at index 0. expected 10, actual 12", + "tensor 'x' size mismatch at index 0. expected 9, actual 12", + "tensor 'x' size mismatch at index 0. expected 8, actual 12", + ]: self.assertIn( line, failure_str, @@ -281,16 +282,13 @@ tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"): failure_reasons.clear() opt_f([7, 8]) - for line in """\ -len(x) == 3""".split("\n"): + for line in ["len(x) == 3"]: self.assertIn(line, filter_reasons()) failure_reasons.clear() opt_f([9]) - for line in """\ -len(x) == 2 -len(x) == 3""".split("\n"): + for line in ["len(x) == 2", "len(x) == 3"]: self.assertIn(line, filter_reasons()) @torch._dynamo.config.patch(recompile_limit=1) diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 01c9962e0087..7111e10a69fc 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -686,7 +686,7 @@ class TestFindIndependentSubsetGreedy(TestCase): unsatisfied += 1 assert unsatisfied <= len(desc) # cycle or bad input? name, v = desc.popleft() - args = tuple(lookup.get(n, None) for n in v) + args = tuple(lookup.get(n) for n in v) if None in args: desc.append((name, v)) continue diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index f9821d1bf3a2..c0966afa8059 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -901,7 +901,7 @@ class TestMemoryProfilerE2E(TestCase): ptr_pair_to_key[(t.impl_ptr, t.storage_data_ptr)] = key def format_categories(ptr_pair: int): - target_key = ptr_pair_to_key.get(ptr_pair, None) + target_key = ptr_pair_to_key.get(ptr_pair) if target_key is None: return "???" diff --git a/test/test_cuda_nvml_based_avail.py b/test/test_cuda_nvml_based_avail.py index c47607f4c7ac..3da49da57ad4 100644 --- a/test/test_cuda_nvml_based_avail.py +++ b/test/test_cuda_nvml_based_avail.py @@ -127,7 +127,7 @@ class TestVisibleDeviceParses(TestCase): _transform_uuid_to_ordinals(["GPU-e4", "GPU-9e8d35e3"], uuids), [2, 1] ) self.assertEqual( - _transform_uuid_to_ordinals("GPU-9e8d35e3,GPU-1,GPU-47".split(","), uuids), + _transform_uuid_to_ordinals(["GPU-9e8d35e3", "GPU-1", "GPU-47"], uuids), [1, 7, 5], ) # First invalid UUID aborts parsing diff --git a/test/test_jit_string.py b/test/test_jit_string.py index b4344229f1ae..55bd003cd9e3 100644 --- a/test/test_jit_string.py +++ b/test/test_jit_string.py @@ -241,17 +241,17 @@ class TestScript(JitTestCase): def test_split() -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: return ( - "a a a a a".split(), - "a a a a a".split(), - " a a\ta \v a \v\f\n a \t ".split(), - " a a a a a ".split(" "), - "a a a a a ".split(" ", 10), - "a a a a a ".split(" ", -1), - "a a a a a ".split(" ", 3), - " a a a a a ".split("*"), - " a*a a*a a".split("*"), - " a*a a*a a ".split("*", -1), - " a*a a*a a ".split("a*", 10), + ["a", "a", "a", "a", "a"], + ["a", "a", "a", "a", "a"], + ["a", "a", "a", "a", "a"], + ["", "a", "a", "a", "a", "a", ""], + ["a", "a", "a", "a", "a", ""], + ["a", "a", "a", "a", "a", ""], + ["a", "a", "a", "a a "], + [" a a a a a "], + [" a", "a a", "a a"], + [" a", "a a", "a a "], + [" ", "a ", "a a "], ) self.checkScript(test_split, ()) @@ -266,15 +266,15 @@ class TestScript(JitTestCase): def test_rsplit() -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: return ( - "a a a a a".rsplit(), - " a a a a a ".rsplit(" "), - "a a a a a ".rsplit(" ", 10), - "a a a a a ".rsplit(" ", -1), - "a a a a a ".rsplit(" ", 3), - " a a a a a ".rsplit("*"), - " a*a a*a a ".rsplit("*"), - " a*a a*a a ".rsplit("*", -1), - " a*a a*a a".rsplit("a*", 10), + ["a", "a", "a", "a", "a"], + ["", "a", "a", "a", "a", "a", ""], + ["a", "a", "a", "a", "a", ""], + ["a", "a", "a", "a", "a", ""], + ["a a a", "a", "a", ""], + [" a a a a a "], + [" a", "a a", "a a "], + [" a", "a a", "a a "], + [" ", "a ", "a a"], ) self.checkScript(test_rsplit, ()) diff --git a/tools/alerts/create_alerts.py b/tools/alerts/create_alerts.py index 6b679a030682..b86e2368d440 100644 --- a/tools/alerts/create_alerts.py +++ b/tools/alerts/create_alerts.py @@ -190,12 +190,12 @@ def map_job_data(jobNames: Any, shaGrid: Any) -> dict[str, Any]: def is_job_failed(job: Any) -> bool: - conclusion = job["conclusion"] if "conclusion" in job else None + conclusion = job.get("conclusion", None) return conclusion is not None and conclusion != SUCCESS and conclusion != PENDING def is_job_skipped(job: Any) -> bool: - conclusion = job["conclusion"] if "conclusion" in job else None + conclusion = job.get("conclusion", None) return conclusion in (NEUTRAL, SKIPPED) or conclusion is None diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index 684290da0a72..4cb3429c3927 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -340,7 +340,7 @@ def get_base_name(f: NativeFunction) -> str: def get_view_info(f: NativeFunction) -> str | None: base_name = get_base_name(f) - view_info = VIEW_FUNCTIONS.get(base_name, None) + view_info = VIEW_FUNCTIONS.get(base_name) if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT: view_info = "self" return view_info diff --git a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py index 72594abefd0a..5cea32d00dec 100644 --- a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py +++ b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py @@ -30,7 +30,7 @@ cov_data = CoverageData(basename=f".coverage.jit.{time()}") def is_not_builtin_class(obj: Any) -> bool: - return isclass(obj) and not type(obj).__module__ == "builtins" + return isclass(obj) and type(obj).__module__ != "builtins" class JitPlugin(CoveragePlugin): # type: ignore[misc, no-any-unimported] diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index f28c78b596b9..2c8fea5fb334 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -554,26 +554,30 @@ class Op: MatchState.SIZE_OR_SYNTAX_MISMATCH, f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'", ) - if self.type in [ - "all_gather", - "all_gather_base", - "all_gather_into_tensor_coalesced", - ] and not ( - math.prod(other.output_sizes[0]) - == math.prod(self.input_sizes[0]) * self.pg_size + if ( + self.type + in [ + "all_gather", + "all_gather_base", + "all_gather_into_tensor_coalesced", + ] + and math.prod(other.output_sizes[0]) + != math.prod(self.input_sizes[0]) * self.pg_size ): return MatchInfo( MatchState.SIZE_OR_SYNTAX_MISMATCH, f"Found input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' " f"does not match output numel '{math.prod(other.output_sizes[0])}'", ) - if self.type in [ - "reduce_scatter", - "_reduce_scatter_base", - "reduce_scatter_tensor_coalesced", - ] and not ( - math.prod(other.input_sizes[0]) - == math.prod(self.output_sizes[0]) * self.pg_size + if ( + self.type + in [ + "reduce_scatter", + "_reduce_scatter_base", + "reduce_scatter_tensor_coalesced", + ] + and math.prod(other.input_sizes[0]) + != math.prod(self.output_sizes[0]) * self.pg_size ): return MatchInfo( MatchState.SIZE_OR_SYNTAX_MISMATCH, diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 02ab011dd482..0fd6de50a56b 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -326,7 +326,7 @@ class CMake: # The default value cannot be easily obtained in CMakeLists.txt. We set it here. py_lib_path = sysconfig.get_path("purelib") - cmake_prefix_path = build_options.get("CMAKE_PREFIX_PATH", None) + cmake_prefix_path = build_options.get("CMAKE_PREFIX_PATH") if cmake_prefix_path: build_options["CMAKE_PREFIX_PATH"] = ( py_lib_path + ";" + cast(str, cmake_prefix_path) diff --git a/tools/testing/explicit_ci_jobs.py b/tools/testing/explicit_ci_jobs.py index dcf406472353..0d25bc642678 100755 --- a/tools/testing/explicit_ci_jobs.py +++ b/tools/testing/explicit_ci_jobs.py @@ -43,7 +43,7 @@ def add_job( if workflow_name not in workflows: workflows[workflow_name] = {"when": "always", "jobs": []} - requires = job.get("requires", None) + requires = job.get("requires") if requires is not None: for requirement in requires: dependency = past_jobs[requirement] diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 82b804e8ce39..9b733340ec22 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -43,7 +43,7 @@ class ConstantVariable(VariableTracker): NOTE: the caller must install the proper guards if needed; most often the guard will be `CONSTANT_MATCH`. """ - source = kwargs.get("source", None) + source = kwargs.get("source") # Routing for supported collection literals. if isinstance(value, set): diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index c2ba3ba56049..f51ba102342c 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1154,7 +1154,7 @@ class NamedTupleVariable(TupleVariable): def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None: super().__init__(items, **kwargs) self.tuple_cls = tuple_cls - self.dynamic_attributes = {} if not dynamic_attributes else dynamic_attributes + self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {} def is_namedtuple(self): return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index b0a24c395c86..690357e55ab3 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1425,7 +1425,7 @@ class NumpyVariable(VariableTracker): def get_constant_collection_for_func(cls, fn): mod = fn.__module__.split(".") assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] - return np_constant_collections_map.get(fn, None) + return np_constant_collections_map.get(fn) def call_function( self, @@ -1930,7 +1930,7 @@ class RandomVariable(VariableTracker): class WeakRefVariable(VariableTracker): @staticmethod def build(tx, weakref_value, **options): - source = options.get("source", None) + source = options.get("source") callback = weakref_value.__callback__ callback_source = source and AttrSource(source, "__callback__") callback_vt = VariableTracker.build(tx, callback, callback_source) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 431b26fab494..22329aeeb199 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -1219,7 +1219,7 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable): """ def __init__(self, value, **kwargs) -> None: - source = kwargs.get("source", None) + source = kwargs.get("source") assert source is not None, ( "FSDPManagedNNModule depends on having an accurate source to control guarding." ) diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 6edd4a7c8ea4..e63edf8e2b03 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -13,7 +13,15 @@ if TYPE_CHECKING: from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator -PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split() +PARAM_NAMES = [ + "query", + "key", + "value", + "attn_mask", + "dropout", + "is_causal", + "enable_gqa", +] class SDPAParamsVariable(VariableTracker): diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 89f8eeffcce2..cca12066bc3e 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -239,7 +239,7 @@ def write_view_information_to_args( write_single_view( f"_{arg_name}", kwargs[arg_name], - arg_to_base_index.get(arg_name, None), # type: ignore[arg-type] + arg_to_base_index.get(arg_name), # type: ignore[arg-type] ) else: raise RuntimeError(f"Unsupported type {arg_type}") @@ -390,7 +390,7 @@ class AutoFunctionalizedV2(HigherOrderOperator): if isinstance(_mutable_op, HigherOrderOperator): _op_to_check = HopInstance( _mutable_op, - SchemaHolder.from_tree_spec(kwargs.get("_op_schema", None)).schema, # type: ignore[arg-type] + SchemaHolder.from_tree_spec(kwargs.get("_op_schema")).schema, # type: ignore[arg-type] ) else: _op_to_check = _mutable_op @@ -958,11 +958,11 @@ def auto_functionalized_v2_proxy( # hop node in the traced graph and graph module inputs to the hop. Finally, we replace the # original kwarg's callable with the graph module. all_bases = kwargs.get("_all_bases", []) - _only_clone_these_bases = kwargs.get("_only_clone_these_bases", None) + _only_clone_these_bases = kwargs.get("_only_clone_these_bases") if _only_clone_these_bases is None: _only_clone_these_bases = tuple(range(len(all_bases))) - schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema # type: ignore[arg-type] + schema = pytree.tree_unflatten([], kwargs.get("_op_schema")).schema # type: ignore[arg-type] new_kwargs, _ = _generate_new_op_kwargs_from_bases( schema, {k: v for k, v in kwargs.items() if k not in ("_all_bases", "_op_schema")}, diff --git a/torch/_inductor/analysis/device_info.py b/torch/_inductor/analysis/device_info.py index 39d62392ebb7..6fc271458c77 100644 --- a/torch/_inductor/analysis/device_info.py +++ b/torch/_inductor/analysis/device_info.py @@ -163,7 +163,7 @@ def lookup_device_info(name: str) -> Optional[DeviceInfo]: If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping. name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name(). """ - return _device_mapping.get(name, None) + return _device_mapping.get(name) def datasheet_tops(dtype: torch.dtype, is_tf32: bool = False) -> Optional[float]: diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 2f002e93d99e..e6060ff16f1d 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -974,7 +974,7 @@ def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): # We need avx512_bf16 to dequant int8 to bf16 - vec_isa = kwargs.get("vec_isa", None) + vec_isa = kwargs.get("vec_isa") assert vec_isa is not None return vec_isa.is_avx512_bf16_supported() and check_amx_extra( config, m, n, k, alpha, num_threads, **kwargs @@ -984,7 +984,7 @@ def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): # amx_fp16 need to be checked separately since it is not always supported when amx is supported def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs): assert config.input_dtype == torch.float16 and config.output_dtype == torch.float - vec_isa = kwargs.get("vec_isa", None) + vec_isa = kwargs.get("vec_isa") assert vec_isa is not None vnni_size = 2 return vec_isa.is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1 @@ -1419,7 +1419,7 @@ class CppMicroBrgemm(CppMicroGemm): def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs): if alpha != 1: return False - q_group_size = kwargs.get("q_group_size", None) + q_group_size = kwargs.get("q_group_size") assert q_group_size is not None if ( q_group_size not in [32, 64, 128] diff --git a/torch/_inductor/codegen/rocm/ck_conv_template.py b/torch/_inductor/codegen/rocm/ck_conv_template.py index 032b0491a34f..37d9898f6be3 100644 --- a/torch/_inductor/codegen/rocm/ck_conv_template.py +++ b/torch/_inductor/codegen/rocm/ck_conv_template.py @@ -528,7 +528,7 @@ class CKGroupedConvFwdTemplate(CKTemplate): op: "CKGroupedConvFwdOp", # type: ignore[name-defined] **kwargs, ) -> str: - template_buffer_node = kwargs.get("template_buffer_node", None) + template_buffer_node = kwargs.get("template_buffer_node") if template_buffer_node is not None: self.output_node = template_buffer_node X, W = self.input_nodes[0], self.input_nodes[1] diff --git a/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py index b18010bda908..94a79297ef5e 100644 --- a/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py @@ -750,9 +750,9 @@ class CKTileGemmTemplate(CKTileTemplate): """ The primary entry point for the code rendering process used in this template. """ - epilogue_nodes = kwargs.get("epilogue_nodes", None) + epilogue_nodes = kwargs.get("epilogue_nodes") assert epilogue_nodes is None or 0 == len(epilogue_nodes) - template_buffer_node = kwargs.get("template_buffer_node", None) + template_buffer_node = kwargs.get("template_buffer_node") if template_buffer_node is not None: self.output_node = template_buffer_node assert 2 == len(self.input_nodes) diff --git a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py index bc0f75b919bb..b6add1e8dbdd 100644 --- a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -547,7 +547,7 @@ class CKGemmTemplate(CKTemplate): # Define the mapping of versions to stages version_to_stages = {1: 1, 3: 2, 4: 4, 5: 3} # Get the stages for the given version - stages = version_to_stages.get(version, None) + stages = version_to_stages.get(version) if stages is None: # This means we're at stage 2, and this requires computation # See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950 @@ -612,9 +612,9 @@ class CKGemmTemplate(CKTemplate): """ The primary entry point for the code rendering process used in this template. """ - epilogue_nodes = kwargs.get("epilogue_nodes", None) + epilogue_nodes = kwargs.get("epilogue_nodes") assert epilogue_nodes is None or 0 == len(epilogue_nodes) - template_buffer_node = kwargs.get("template_buffer_node", None) + template_buffer_node = kwargs.get("template_buffer_node") if template_buffer_node is not None: self.output_node = template_buffer_node # input nodes: diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index a225e7b6c9c9..2cedd993b38c 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -2296,7 +2296,7 @@ class SIMDScheduling(BaseScheduling): return ([], []) key = (repr(vars_to_use), use_split_var, is_pointwise) - if out := scored_sub_split.get(key, None): + if out := scored_sub_split.get(key): return out splitting_vars = all_iter_vars if is_pointwise else all_red_vars diff --git a/torch/_inductor/compiler_bisector.py b/torch/_inductor/compiler_bisector.py index 2c3d3bb5bd74..bdfd02f76682 100644 --- a/torch/_inductor/compiler_bisector.py +++ b/torch/_inductor/compiler_bisector.py @@ -550,7 +550,7 @@ class CompilerBisector: curr_backend, curr_subsystem.name, low, - call_counter_debug_info.get(low, None), + call_counter_debug_info.get(low), ) next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem) diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 4b833746bd20..42ee33a367f0 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -715,7 +715,7 @@ def run_autoheuristic( ) choice = autoheuristic.get_choice() choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None} - ah_should_pad = choice2should_pad.get(choice, None) + ah_should_pad = choice2should_pad.get(choice) if torch._inductor.config.collect_autoheuristic(name): ah_ori_time = autoheuristic.get_collected_feedback(orig_choice) diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 8d3b7d0a52fd..3b851b0e27ae 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -634,7 +634,7 @@ class NormalizedLinearNode: if len(self.node.args) > 2: return self.node.args[2] # type: ignore[return-value] else: - return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value] + return self.node.kwargs.get("bias", None) # type: ignore[return-value] class NormalizedMatmulNode: diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index aa1ce1f04343..242bb98d4584 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -503,7 +503,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: if mutated_arg.op in ("placeholder", "get_attr"): # Get the first copy_ node that mutates the mutated_arg. - copy_node = copy_nodes.get(mutated_arg, None) + copy_node = copy_nodes.get(mutated_arg) if copy_node is None: # There is no copy_ back to the candidate mutated_arg (which is a graph input). # Therefore the semantics of the program are that it does not mutate diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 77f0f32d54e7..1ad7976e21c6 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2365,7 +2365,7 @@ make_fallback(aten.randint) @register_lowering(aten.rand) def rand(*args, **kwargs): - if kwargs.get("generator", None) is not None: + if kwargs.get("generator") is not None: return fallback_rand_generator(*args, **kwargs) elif config.fallback_random: kwargs.pop("generator", None) @@ -2375,7 +2375,7 @@ def rand(*args, **kwargs): @register_lowering(aten.randn) def randn(*args, **kwargs): - if kwargs.get("generator", None) is not None: + if kwargs.get("generator") is not None: return fallback_randn_generator(*args, **kwargs) elif config.fallback_random: kwargs.pop("generator", None) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 57510e941564..b5278b17f3b9 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1583,7 +1583,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]): return None def check_can_launch() -> StaticallyLaunchedCudaKernel: - if triton_meta.get("device_type", None) != "cuda": + if triton_meta.get("device_type") != "cuda": # Only cuda kernels raise CannotStaticallyLaunchKernel("Non-cuda device") @@ -1600,7 +1600,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]): # Don't support user defined triton kernels yet raise CannotStaticallyLaunchKernel("User defined triton kernel") - if inductor_meta.get("store_cubin", None): + if inductor_meta.get("store_cubin"): # Requires storing the entire binary raise CannotStaticallyLaunchKernel("store_cubin is enabled") @@ -2640,7 +2640,7 @@ def pointwise( def _reduction_configs( *, size_hints: dict[str, int], inductor_meta: dict[str, Any], num_dynamic=0 ) -> list[Config]: - reduction_hint = inductor_meta.get("reduction_hint", None) + reduction_hint = inductor_meta.get("reduction_hint") # Convert reductions to 1D, to simplify heuristics. rnumel = get_total_reduction_numel(size_hints) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 31f74bfc14ee..c71adb3c9121 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2089,8 +2089,8 @@ class TritonTemplate(KernelTemplate): "num_stages": num_stages, "num_warps": num_warps, "GROUP_M": kwargs.get("GROUP_M", -1), - "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), - "acc_type": str(kwargs.get("ACC_TYPE", None)), + "allow_tf32": str(kwargs.get("ALLOW_TF32")), + "acc_type": str(kwargs.get("ACC_TYPE")), "matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0), "waves_per_eu": kwargs.get("waves_per_eu", 0), "kpack": kwargs.get("kpack", 2), diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index d02e0d6df920..209b6f831e5b 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -112,7 +112,7 @@ class SizeVarAllocator: cache.clear() replacement_count = len(self.replacements) key = (expr, *var_ranges.items()) - result = cache.get(key, None) + result = cache.get(key) if result is None: result = self._simplify_with_ranges(expr, var_ranges) cache[key] = result @@ -136,7 +136,7 @@ class SizeVarAllocator: cache.clear() replacement_count = len(self.replacements) key = (*index_vars, *sizes, *index_formulas) - result = cache.get(key, None) + result = cache.get(key) if result is None: result = self._simplify_loops_impl(index_vars, sizes, index_formulas) cache[key] = result diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index cd12b6043455..d7b64ba1b867 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -575,7 +575,7 @@ def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int: # TODO - deduplicate with candidate_tilings var_sizes = [] for v in addr.free_symbols: - v_size = var_ranges.get(v, None) + v_size = var_ranges.get(v) # TODO - reason about indirect vars if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None: var_sizes.append(v_size) diff --git a/torch/_logging/structured.py b/torch/_logging/structured.py index 4eae33227e61..e6dd36a6c696 100644 --- a/torch/_logging/structured.py +++ b/torch/_logging/structured.py @@ -21,7 +21,7 @@ def intern_string(s: Optional[str]) -> int: if s is None: return -1 - r = INTERN_TABLE.get(s, None) + r = INTERN_TABLE.get(s) if r is None: r = len(INTERN_TABLE) INTERN_TABLE[s] = r diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py index 19748a08b9de..4ab3b29d34b8 100644 --- a/torch/_numpy/_funcs_impl.py +++ b/torch/_numpy/_funcs_impl.py @@ -1867,7 +1867,7 @@ def common_type(*tensors: ArrayLike): if not (t.is_floating_point or t.is_complex): p = 2 # array_precision[_nx.double] else: - p = array_precision.get(t, None) + p = array_precision.get(t) if p is None: raise TypeError("can't get common type for non-numeric array") precision = builtins.max(precision, p) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 1dd0adf42ffd..da3eed2b0c71 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1284,7 +1284,7 @@ class MetaConverter(Generic[_TensorT]): # Fake inner tensors of view subclasses will come from the mapping built above. visited_id = self.describer.get_tensor_id(visited_t) - fake_visited_t = real_to_fake_mapping.get(visited_id, None) + fake_visited_t = real_to_fake_mapping.get(visited_id) if fake_visited_t is not None: return fake_visited_t diff --git a/torch/export/_leakage_detection_utils.py b/torch/export/_leakage_detection_utils.py index c72152759d23..fe211e1dc079 100644 --- a/torch/export/_leakage_detection_utils.py +++ b/torch/export/_leakage_detection_utils.py @@ -43,7 +43,7 @@ def _is_tracked_fake(obj: typing.Any) -> bool: def _is_gm_meta_like_dict(d: dict, o: typing.Any) -> bool: # Hope gm.meta was a custom dict we can assert on - return d.get("val", None) is o + return d.get("val") is o def _dict_is_attr_of_tracked_fake(d: dict) -> bool: diff --git a/torch/fx/experimental/unification/utils.py b/torch/fx/experimental/unification/utils.py index a8035f75d302..ab99ad1b4f0d 100644 --- a/torch/fx/experimental/unification/utils.py +++ b/torch/fx/experimental/unification/utils.py @@ -60,7 +60,7 @@ def _toposort(edges): incoming_edges[m].remove(n) if not incoming_edges[m]: S.add(m) - if any(incoming_edges.get(v, None) for v in edges): + if any(incoming_edges.get(v) for v in edges): raise ValueError("Input has cycles") return L diff --git a/torch/fx/experimental/unification/variable.py b/torch/fx/experimental/unification/variable.py index 46e59851fdfa..8921dc77d923 100644 --- a/torch/fx/experimental/unification/variable.py +++ b/torch/fx/experimental/unification/variable.py @@ -55,7 +55,7 @@ isvar @dispatch(object) # type: ignore[no-redef] def isvar(o): - return not not _glv and hashable(o) and o in _glv + return _glv and hashable(o) and o in _glv @contextmanager diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 315a4ba75c0c..338190c7a5e9 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -193,7 +193,7 @@ def _deserialize_graph_module( graph = KeepModules().trace(com, **tracer_extras) # Recover node.meta["stack_trace"] after re-tracing - node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace", None) + node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace") if node_meta_stack_trace is not None: del body["_graphmodule_graph_node_meta_stack_trace"] for node in graph.nodes: diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 6fc17b959424..7bb536dbba93 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -190,7 +190,7 @@ class CapabilityBasedPartitioner: # Iterate through all the users of this node and update the partition map to indicate # that there is a path from the partition id of this node to the target partition id. for user_node in node.users: - target_id = assignment.get(user_node, None) + target_id = assignment.get(user_node) if target_id is not None: partition_map[id].add(target_id) partition_map[id].update(partition_map[target_id]) @@ -267,9 +267,9 @@ class CapabilityBasedPartitioner: # node has tuple outputs, re-assign all following getitem node into node's partition if is_tuple_output: - id = assignment.get(node, None) # type: ignore[arg-type] + id = assignment.get(node) # type: ignore[arg-type] for user in node.users: - if assignment.get(user, None) != id: # type: ignore[arg-type] + if assignment.get(user) != id: # type: ignore[arg-type] nodes_reassignment[user] = id # type: ignore[assignment] for node, id in nodes_reassignment.items(): merge_single_node(node, None, id) diff --git a/torch/jit/_state.py b/torch/jit/_state.py index 2c0c58b8c98a..f48dd80a0b36 100644 --- a/torch/jit/_state.py +++ b/torch/jit/_state.py @@ -76,11 +76,11 @@ def _get_script_class(python_class): override = getattr(python_class, "_jit_override_qualname", None) if override is not None: python_class = _get_python_class(override) - return _script_classes.get(python_class, None) + return _script_classes.get(python_class) def _get_python_class(qualified_name): - return _name_to_pyclass.get(qualified_name, None) + return _name_to_pyclass.get(qualified_name) def _clear_class_state(): diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index d388e44fd8c4..87220a453124 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -228,7 +228,7 @@ class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule): def preview_type_promotion( self, args: tuple, kwargs: dict ) -> TypePromotionSnapshot: - rounding_mode = kwargs.get("rounding_mode", None) + rounding_mode = kwargs.get("rounding_mode") if rounding_mode is None: # true_divide self.promotion_kind = ( @@ -287,7 +287,7 @@ class ReductionTypePromotionRule(TypePromotionRule): ) arg = args[0] assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" - dtype: torch.dtype | None = kwargs.get("dtype", None) + dtype: torch.dtype | None = kwargs.get("dtype") computation_dtype, result_dtype = _prims_common.reduction_dtypes( arg, self.promotion_kind, dtype @@ -351,7 +351,7 @@ class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule): ) arg = args[0] assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" - dtype: torch.dtype | None = kwargs.get("dtype", None) + dtype: torch.dtype | None = kwargs.get("dtype") # The below logic is copied from `torch/_refs/__init__.py` reduction ops impl. if dtype is None: if _prims_common.is_boolean_dtype( diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 574e039326fe..aaca0efe1eb4 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -3421,7 +3421,7 @@ class ModuleTest(TestBase): kwargs.get('FIXME_no_cuda_gradgrad_comparison', False) self.precision = kwargs.get('precision', 2e-4) self.check_forward_only = kwargs.get('check_forward_only', False) - self.default_dtype = kwargs.get('default_dtype', None) + self.default_dtype = kwargs.get('default_dtype') if self.default_dtype is None: self.default_dtype = torch.get_default_dtype() @@ -3632,7 +3632,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc] self.test_cpu = kwargs.get('test_cpu', True) self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False) self.check_batched_grad = kwargs.get('check_batched_grad', True) - self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None) + self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode') self.supports_forward_ad = kwargs.get('supports_forward_ad', False) self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False) @@ -3836,7 +3836,7 @@ class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc] self.with_tf32 = kwargs.get('with_tf32', True) self.tf32_precision = kwargs.get('tf32_precision', 0.001) self.check_batched_grad = kwargs.get('check_batched_grad', True) - self.default_dtype = kwargs.get('default_dtype', None) + self.default_dtype = kwargs.get('default_dtype') if self.default_dtype is None: self.default_dtype = torch.get_default_dtype() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 80c81507751b..ce1d42144aed 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -5124,7 +5124,7 @@ def gradcheck(fn, inputs, **kwargs): for key, value in default_values.items(): # default value override values explicitly set to None - k = kwargs.get(key, None) + k = kwargs.get(key) kwargs[key] = k if k is not None else value return torch.autograd.gradcheck(fn, inputs, **kwargs) @@ -5144,7 +5144,7 @@ def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs): for key, value in default_values.items(): # default value override values explicitly set to None - k = kwargs.get(key, None) + k = kwargs.get(key) kwargs[key] = k if k is not None else value return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs) diff --git a/torch/testing/_internal/distributed/common_state_dict.py b/torch/testing/_internal/distributed/common_state_dict.py index f7d79907bdbe..76b7800a8d2a 100644 --- a/torch/testing/_internal/distributed/common_state_dict.py +++ b/torch/testing/_internal/distributed/common_state_dict.py @@ -40,7 +40,7 @@ class VerifyStateDictMixin: if not options.ignore_frozen_params: self.assertEqual(len(msd), len(dist_msd)) for fqn, param in msd.items(): - dist_param = dist_msd.get(fqn, None) + dist_param = dist_msd.get(fqn) if not options.ignore_frozen_params: self.assertIsNotNone(dist_param, f"{fqn=}") try: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 47eb183f4ee6..764ce87d7692 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -2309,7 +2309,7 @@ def _write_ninja_file_and_build_library( def is_ninja_available(): """Return ``True`` if the `ninja `_ build system is available on the system, ``False`` otherwise.""" try: - subprocess.check_output('ninja --version'.split()) + subprocess.check_output(['ninja', '--version']) except Exception: return False else: diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py index 9db7309bdc52..000de3e70f72 100644 --- a/torch/utils/data/datapipes/utils/decoder.py +++ b/torch/utils/data/datapipes/utils/decoder.py @@ -61,7 +61,7 @@ def basichandlers(extension: str, data): if extension in "txt text transcript": return data.decode("utf-8") - if extension in "cls cls2 class count index inx id".split(): + if extension in ["cls", "cls2", "class", "count", "index", "inx", "id"]: try: return int(data) except ValueError: @@ -70,10 +70,10 @@ def basichandlers(extension: str, data): if extension in "json jsn": return json.loads(data) - if extension in "pyd pickle".split(): + if extension in ["pyd", "pickle"]: return pickle.loads(data) - if extension in "pt".split(): + if extension in ["pt"]: stream = io.BytesIO(data) return torch.load(stream) @@ -175,7 +175,7 @@ class ImageHandler: self.imagespec = imagespec.lower() def __call__(self, extension, data): - if extension.lower() not in "jpg jpeg png ppm pgm pbm pnm".split(): + if extension.lower() not in ["jpg", "jpeg", "png", "ppm", "pgm", "pbm", "pnm"]: return None try: @@ -235,7 +235,17 @@ def imagehandler(imagespec): # torch video ################################################################ def videohandler(extension, data): - if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + if extension not in [ + "mp4", + "ogv", + "mjpeg", + "avi", + "mov", + "h264", + "mpg", + "webm", + "wmv", + ]: return None try: diff --git a/torchgen/model.py b/torchgen/model.py index 1712332128df..906b61e2f19c 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -60,7 +60,23 @@ class Variant(Enum): DEFAULT_KERNEL_NAMESPACE = "at::native" # NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h -BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split() +BACKEND_COMPONENTS = [ + "CPU", + "CUDA", + "HIP", + "XLA", + "MTIA", + "MPS", + "IPU", + "XPU", + "HPU", + "VE", + "Lazy", + "Meta", + "PrivateUse1", + "PrivateUse2", + "PrivateUse3", +] FUNCTIONALITY_KEYS = [ "", "Quantized",