More ruff SIM fixes (#164695)

This PR applies ruff `SIM` rules to more files. Most changes are about simplifying `dict.get` because `None` is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164695
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanyuan Chen
2025-10-09 03:24:46 +00:00
committed by PyTorch MergeBot
parent 54ae61c573
commit a029675f6f
49 changed files with 153 additions and 117 deletions

View File

@ -321,7 +321,7 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase):
def _test_serialization(self, guard_type, fn, *args, **kwargs):
# kwargs might contain a callable that generates kwargs
torch._dynamo.reset()
kwarg_gen_fn = kwargs.get("_gen_fn", None)
kwarg_gen_fn = kwargs.get("_gen_fn")
if kwarg_gen_fn is not None:
kwargs = kwarg_gen_fn()

View File

@ -242,11 +242,12 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
opt_f(torch.randn(8 + i))
failure_str = "\n".join(failure_reasons)
for line in """\
tensor 'x' size mismatch at index 0. expected 11, actual 12
tensor 'x' size mismatch at index 0. expected 10, actual 12
tensor 'x' size mismatch at index 0. expected 9, actual 12
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"):
for line in [
"tensor 'x' size mismatch at index 0. expected 11, actual 12",
"tensor 'x' size mismatch at index 0. expected 10, actual 12",
"tensor 'x' size mismatch at index 0. expected 9, actual 12",
"tensor 'x' size mismatch at index 0. expected 8, actual 12",
]:
self.assertIn(
line,
failure_str,
@ -281,16 +282,13 @@ tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"):
failure_reasons.clear()
opt_f([7, 8])
for line in """\
len(x) == 3""".split("\n"):
for line in ["len(x) == 3"]:
self.assertIn(line, filter_reasons())
failure_reasons.clear()
opt_f([9])
for line in """\
len(x) == 2
len(x) == 3""".split("\n"):
for line in ["len(x) == 2", "len(x) == 3"]:
self.assertIn(line, filter_reasons())
@torch._dynamo.config.patch(recompile_limit=1)

View File

@ -686,7 +686,7 @@ class TestFindIndependentSubsetGreedy(TestCase):
unsatisfied += 1
assert unsatisfied <= len(desc) # cycle or bad input?
name, v = desc.popleft()
args = tuple(lookup.get(n, None) for n in v)
args = tuple(lookup.get(n) for n in v)
if None in args:
desc.append((name, v))
continue

View File

@ -901,7 +901,7 @@ class TestMemoryProfilerE2E(TestCase):
ptr_pair_to_key[(t.impl_ptr, t.storage_data_ptr)] = key
def format_categories(ptr_pair: int):
target_key = ptr_pair_to_key.get(ptr_pair, None)
target_key = ptr_pair_to_key.get(ptr_pair)
if target_key is None:
return "???"

View File

@ -127,7 +127,7 @@ class TestVisibleDeviceParses(TestCase):
_transform_uuid_to_ordinals(["GPU-e4", "GPU-9e8d35e3"], uuids), [2, 1]
)
self.assertEqual(
_transform_uuid_to_ordinals("GPU-9e8d35e3,GPU-1,GPU-47".split(","), uuids),
_transform_uuid_to_ordinals(["GPU-9e8d35e3", "GPU-1", "GPU-47"], uuids),
[1, 7, 5],
)
# First invalid UUID aborts parsing

View File

@ -241,17 +241,17 @@ class TestScript(JitTestCase):
def test_split() -> tuple[list[str], list[str], list[str], list[str], list[str],
list[str], list[str], list[str], list[str], list[str], list[str]]:
return (
"a a a a a".split(),
"a a a a a".split(),
" a a\ta \v a \v\f\n a \t ".split(),
" a a a a a ".split(" "),
"a a a a a ".split(" ", 10),
"a a a a a ".split(" ", -1),
"a a a a a ".split(" ", 3),
" a a a a a ".split("*"),
" a*a a*a a".split("*"),
" a*a a*a a ".split("*", -1),
" a*a a*a a ".split("a*", 10),
["a", "a", "a", "a", "a"],
["a", "a", "a", "a", "a"],
["a", "a", "a", "a", "a"],
["", "a", "a", "a", "a", "a", ""],
["a", "a", "a", "a", "a", ""],
["a", "a", "a", "a", "a", ""],
["a", "a", "a", "a a "],
[" a a a a a "],
[" a", "a a", "a a"],
[" a", "a a", "a a "],
[" ", "a ", "a a "],
)
self.checkScript(test_split, ())
@ -266,15 +266,15 @@ class TestScript(JitTestCase):
def test_rsplit() -> tuple[list[str], list[str], list[str], list[str], list[str],
list[str], list[str], list[str], list[str]]:
return (
"a a a a a".rsplit(),
" a a a a a ".rsplit(" "),
"a a a a a ".rsplit(" ", 10),
"a a a a a ".rsplit(" ", -1),
"a a a a a ".rsplit(" ", 3),
" a a a a a ".rsplit("*"),
" a*a a*a a ".rsplit("*"),
" a*a a*a a ".rsplit("*", -1),
" a*a a*a a".rsplit("a*", 10),
["a", "a", "a", "a", "a"],
["", "a", "a", "a", "a", "a", ""],
["a", "a", "a", "a", "a", ""],
["a", "a", "a", "a", "a", ""],
["a a a", "a", "a", ""],
[" a a a a a "],
[" a", "a a", "a a "],
[" a", "a a", "a a "],
[" ", "a ", "a a"],
)
self.checkScript(test_rsplit, ())

View File

@ -190,12 +190,12 @@ def map_job_data(jobNames: Any, shaGrid: Any) -> dict[str, Any]:
def is_job_failed(job: Any) -> bool:
conclusion = job["conclusion"] if "conclusion" in job else None
conclusion = job.get("conclusion", None)
return conclusion is not None and conclusion != SUCCESS and conclusion != PENDING
def is_job_skipped(job: Any) -> bool:
conclusion = job["conclusion"] if "conclusion" in job else None
conclusion = job.get("conclusion", None)
return conclusion in (NEUTRAL, SKIPPED) or conclusion is None

View File

@ -340,7 +340,7 @@ def get_base_name(f: NativeFunction) -> str:
def get_view_info(f: NativeFunction) -> str | None:
base_name = get_base_name(f)
view_info = VIEW_FUNCTIONS.get(base_name, None)
view_info = VIEW_FUNCTIONS.get(base_name)
if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
view_info = "self"
return view_info

View File

@ -30,7 +30,7 @@ cov_data = CoverageData(basename=f".coverage.jit.{time()}")
def is_not_builtin_class(obj: Any) -> bool:
return isclass(obj) and not type(obj).__module__ == "builtins"
return isclass(obj) and type(obj).__module__ != "builtins"
class JitPlugin(CoveragePlugin): # type: ignore[misc, no-any-unimported]

View File

@ -554,26 +554,30 @@ class Op:
MatchState.SIZE_OR_SYNTAX_MISMATCH,
f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'",
)
if self.type in [
"all_gather",
"all_gather_base",
"all_gather_into_tensor_coalesced",
] and not (
math.prod(other.output_sizes[0])
== math.prod(self.input_sizes[0]) * self.pg_size
if (
self.type
in [
"all_gather",
"all_gather_base",
"all_gather_into_tensor_coalesced",
]
and math.prod(other.output_sizes[0])
!= math.prod(self.input_sizes[0]) * self.pg_size
):
return MatchInfo(
MatchState.SIZE_OR_SYNTAX_MISMATCH,
f"Found input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' "
f"does not match output numel '{math.prod(other.output_sizes[0])}'",
)
if self.type in [
"reduce_scatter",
"_reduce_scatter_base",
"reduce_scatter_tensor_coalesced",
] and not (
math.prod(other.input_sizes[0])
== math.prod(self.output_sizes[0]) * self.pg_size
if (
self.type
in [
"reduce_scatter",
"_reduce_scatter_base",
"reduce_scatter_tensor_coalesced",
]
and math.prod(other.input_sizes[0])
!= math.prod(self.output_sizes[0]) * self.pg_size
):
return MatchInfo(
MatchState.SIZE_OR_SYNTAX_MISMATCH,

View File

@ -326,7 +326,7 @@ class CMake:
# The default value cannot be easily obtained in CMakeLists.txt. We set it here.
py_lib_path = sysconfig.get_path("purelib")
cmake_prefix_path = build_options.get("CMAKE_PREFIX_PATH", None)
cmake_prefix_path = build_options.get("CMAKE_PREFIX_PATH")
if cmake_prefix_path:
build_options["CMAKE_PREFIX_PATH"] = (
py_lib_path + ";" + cast(str, cmake_prefix_path)

View File

@ -43,7 +43,7 @@ def add_job(
if workflow_name not in workflows:
workflows[workflow_name] = {"when": "always", "jobs": []}
requires = job.get("requires", None)
requires = job.get("requires")
if requires is not None:
for requirement in requires:
dependency = past_jobs[requirement]

View File

@ -43,7 +43,7 @@ class ConstantVariable(VariableTracker):
NOTE: the caller must install the proper guards if needed; most often
the guard will be `CONSTANT_MATCH`.
"""
source = kwargs.get("source", None)
source = kwargs.get("source")
# Routing for supported collection literals.
if isinstance(value, set):

View File

@ -1154,7 +1154,7 @@ class NamedTupleVariable(TupleVariable):
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
super().__init__(items, **kwargs)
self.tuple_cls = tuple_cls
self.dynamic_attributes = {} if not dynamic_attributes else dynamic_attributes
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
def is_namedtuple(self):
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(

View File

@ -1425,7 +1425,7 @@ class NumpyVariable(VariableTracker):
def get_constant_collection_for_func(cls, fn):
mod = fn.__module__.split(".")
assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
return np_constant_collections_map.get(fn, None)
return np_constant_collections_map.get(fn)
def call_function(
self,
@ -1930,7 +1930,7 @@ class RandomVariable(VariableTracker):
class WeakRefVariable(VariableTracker):
@staticmethod
def build(tx, weakref_value, **options):
source = options.get("source", None)
source = options.get("source")
callback = weakref_value.__callback__
callback_source = source and AttrSource(source, "__callback__")
callback_vt = VariableTracker.build(tx, callback, callback_source)

View File

@ -1219,7 +1219,7 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
"""
def __init__(self, value, **kwargs) -> None:
source = kwargs.get("source", None)
source = kwargs.get("source")
assert source is not None, (
"FSDPManagedNNModule depends on having an accurate source to control guarding."
)

View File

@ -13,7 +13,15 @@ if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split()
PARAM_NAMES = [
"query",
"key",
"value",
"attn_mask",
"dropout",
"is_causal",
"enable_gqa",
]
class SDPAParamsVariable(VariableTracker):

View File

@ -239,7 +239,7 @@ def write_view_information_to_args(
write_single_view(
f"_{arg_name}",
kwargs[arg_name],
arg_to_base_index.get(arg_name, None), # type: ignore[arg-type]
arg_to_base_index.get(arg_name), # type: ignore[arg-type]
)
else:
raise RuntimeError(f"Unsupported type {arg_type}")
@ -390,7 +390,7 @@ class AutoFunctionalizedV2(HigherOrderOperator):
if isinstance(_mutable_op, HigherOrderOperator):
_op_to_check = HopInstance(
_mutable_op,
SchemaHolder.from_tree_spec(kwargs.get("_op_schema", None)).schema, # type: ignore[arg-type]
SchemaHolder.from_tree_spec(kwargs.get("_op_schema")).schema, # type: ignore[arg-type]
)
else:
_op_to_check = _mutable_op
@ -958,11 +958,11 @@ def auto_functionalized_v2_proxy(
# hop node in the traced graph and graph module inputs to the hop. Finally, we replace the
# original kwarg's callable with the graph module.
all_bases = kwargs.get("_all_bases", [])
_only_clone_these_bases = kwargs.get("_only_clone_these_bases", None)
_only_clone_these_bases = kwargs.get("_only_clone_these_bases")
if _only_clone_these_bases is None:
_only_clone_these_bases = tuple(range(len(all_bases)))
schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema # type: ignore[arg-type]
schema = pytree.tree_unflatten([], kwargs.get("_op_schema")).schema # type: ignore[arg-type]
new_kwargs, _ = _generate_new_op_kwargs_from_bases(
schema,
{k: v for k, v in kwargs.items() if k not in ("_all_bases", "_op_schema")},

View File

@ -163,7 +163,7 @@ def lookup_device_info(name: str) -> Optional[DeviceInfo]:
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
"""
return _device_mapping.get(name, None)
return _device_mapping.get(name)
def datasheet_tops(dtype: torch.dtype, is_tf32: bool = False) -> Optional[float]:

View File

@ -974,7 +974,7 @@ def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
# We need avx512_bf16 to dequant int8 to bf16
vec_isa = kwargs.get("vec_isa", None)
vec_isa = kwargs.get("vec_isa")
assert vec_isa is not None
return vec_isa.is_avx512_bf16_supported() and check_amx_extra(
config, m, n, k, alpha, num_threads, **kwargs
@ -984,7 +984,7 @@ def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
# amx_fp16 need to be checked separately since it is not always supported when amx is supported
def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs):
assert config.input_dtype == torch.float16 and config.output_dtype == torch.float
vec_isa = kwargs.get("vec_isa", None)
vec_isa = kwargs.get("vec_isa")
assert vec_isa is not None
vnni_size = 2
return vec_isa.is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1
@ -1419,7 +1419,7 @@ class CppMicroBrgemm(CppMicroGemm):
def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs):
if alpha != 1:
return False
q_group_size = kwargs.get("q_group_size", None)
q_group_size = kwargs.get("q_group_size")
assert q_group_size is not None
if (
q_group_size not in [32, 64, 128]

View File

@ -528,7 +528,7 @@ class CKGroupedConvFwdTemplate(CKTemplate):
op: "CKGroupedConvFwdOp", # type: ignore[name-defined]
**kwargs,
) -> str:
template_buffer_node = kwargs.get("template_buffer_node", None)
template_buffer_node = kwargs.get("template_buffer_node")
if template_buffer_node is not None:
self.output_node = template_buffer_node
X, W = self.input_nodes[0], self.input_nodes[1]

View File

@ -750,9 +750,9 @@ class CKTileGemmTemplate(CKTileTemplate):
"""
The primary entry point for the code rendering process used in this template.
"""
epilogue_nodes = kwargs.get("epilogue_nodes", None)
epilogue_nodes = kwargs.get("epilogue_nodes")
assert epilogue_nodes is None or 0 == len(epilogue_nodes)
template_buffer_node = kwargs.get("template_buffer_node", None)
template_buffer_node = kwargs.get("template_buffer_node")
if template_buffer_node is not None:
self.output_node = template_buffer_node
assert 2 == len(self.input_nodes)

View File

@ -547,7 +547,7 @@ class CKGemmTemplate(CKTemplate):
# Define the mapping of versions to stages
version_to_stages = {1: 1, 3: 2, 4: 4, 5: 3}
# Get the stages for the given version
stages = version_to_stages.get(version, None)
stages = version_to_stages.get(version)
if stages is None:
# This means we're at stage 2, and this requires computation
# See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950
@ -612,9 +612,9 @@ class CKGemmTemplate(CKTemplate):
"""
The primary entry point for the code rendering process used in this template.
"""
epilogue_nodes = kwargs.get("epilogue_nodes", None)
epilogue_nodes = kwargs.get("epilogue_nodes")
assert epilogue_nodes is None or 0 == len(epilogue_nodes)
template_buffer_node = kwargs.get("template_buffer_node", None)
template_buffer_node = kwargs.get("template_buffer_node")
if template_buffer_node is not None:
self.output_node = template_buffer_node
# input nodes:

View File

@ -2296,7 +2296,7 @@ class SIMDScheduling(BaseScheduling):
return ([], [])
key = (repr(vars_to_use), use_split_var, is_pointwise)
if out := scored_sub_split.get(key, None):
if out := scored_sub_split.get(key):
return out
splitting_vars = all_iter_vars if is_pointwise else all_red_vars

View File

@ -550,7 +550,7 @@ class CompilerBisector:
curr_backend,
curr_subsystem.name,
low,
call_counter_debug_info.get(low, None),
call_counter_debug_info.get(low),
)
next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem)

View File

@ -715,7 +715,7 @@ def run_autoheuristic(
)
choice = autoheuristic.get_choice()
choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None}
ah_should_pad = choice2should_pad.get(choice, None)
ah_should_pad = choice2should_pad.get(choice)
if torch._inductor.config.collect_autoheuristic(name):
ah_ori_time = autoheuristic.get_collected_feedback(orig_choice)

View File

@ -634,7 +634,7 @@ class NormalizedLinearNode:
if len(self.node.args) > 2:
return self.node.args[2] # type: ignore[return-value]
else:
return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value]
return self.node.kwargs.get("bias", None) # type: ignore[return-value]
class NormalizedMatmulNode:

View File

@ -503,7 +503,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
if mutated_arg.op in ("placeholder", "get_attr"):
# Get the first copy_ node that mutates the mutated_arg.
copy_node = copy_nodes.get(mutated_arg, None)
copy_node = copy_nodes.get(mutated_arg)
if copy_node is None:
# There is no copy_ back to the candidate mutated_arg (which is a graph input).
# Therefore the semantics of the program are that it does not mutate

View File

@ -2365,7 +2365,7 @@ make_fallback(aten.randint)
@register_lowering(aten.rand)
def rand(*args, **kwargs):
if kwargs.get("generator", None) is not None:
if kwargs.get("generator") is not None:
return fallback_rand_generator(*args, **kwargs)
elif config.fallback_random:
kwargs.pop("generator", None)
@ -2375,7 +2375,7 @@ def rand(*args, **kwargs):
@register_lowering(aten.randn)
def randn(*args, **kwargs):
if kwargs.get("generator", None) is not None:
if kwargs.get("generator") is not None:
return fallback_randn_generator(*args, **kwargs)
elif config.fallback_random:
kwargs.pop("generator", None)

View File

@ -1583,7 +1583,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
return None
def check_can_launch() -> StaticallyLaunchedCudaKernel:
if triton_meta.get("device_type", None) != "cuda":
if triton_meta.get("device_type") != "cuda":
# Only cuda kernels
raise CannotStaticallyLaunchKernel("Non-cuda device")
@ -1600,7 +1600,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
# Don't support user defined triton kernels yet
raise CannotStaticallyLaunchKernel("User defined triton kernel")
if inductor_meta.get("store_cubin", None):
if inductor_meta.get("store_cubin"):
# Requires storing the entire binary
raise CannotStaticallyLaunchKernel("store_cubin is enabled")
@ -2640,7 +2640,7 @@ def pointwise(
def _reduction_configs(
*, size_hints: dict[str, int], inductor_meta: dict[str, Any], num_dynamic=0
) -> list[Config]:
reduction_hint = inductor_meta.get("reduction_hint", None)
reduction_hint = inductor_meta.get("reduction_hint")
# Convert reductions to 1D, to simplify heuristics.
rnumel = get_total_reduction_numel(size_hints)

View File

@ -2089,8 +2089,8 @@ class TritonTemplate(KernelTemplate):
"num_stages": num_stages,
"num_warps": num_warps,
"GROUP_M": kwargs.get("GROUP_M", -1),
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
"acc_type": str(kwargs.get("ACC_TYPE", None)),
"allow_tf32": str(kwargs.get("ALLOW_TF32")),
"acc_type": str(kwargs.get("ACC_TYPE")),
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
"waves_per_eu": kwargs.get("waves_per_eu", 0),
"kpack": kwargs.get("kpack", 2),

View File

@ -112,7 +112,7 @@ class SizeVarAllocator:
cache.clear()
replacement_count = len(self.replacements)
key = (expr, *var_ranges.items())
result = cache.get(key, None)
result = cache.get(key)
if result is None:
result = self._simplify_with_ranges(expr, var_ranges)
cache[key] = result
@ -136,7 +136,7 @@ class SizeVarAllocator:
cache.clear()
replacement_count = len(self.replacements)
key = (*index_vars, *sizes, *index_formulas)
result = cache.get(key, None)
result = cache.get(key)
if result is None:
result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
cache[key] = result

View File

@ -575,7 +575,7 @@ def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
# TODO - deduplicate with candidate_tilings
var_sizes = []
for v in addr.free_symbols:
v_size = var_ranges.get(v, None)
v_size = var_ranges.get(v)
# TODO - reason about indirect vars
if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None:
var_sizes.append(v_size)

View File

@ -21,7 +21,7 @@ def intern_string(s: Optional[str]) -> int:
if s is None:
return -1
r = INTERN_TABLE.get(s, None)
r = INTERN_TABLE.get(s)
if r is None:
r = len(INTERN_TABLE)
INTERN_TABLE[s] = r

View File

@ -1867,7 +1867,7 @@ def common_type(*tensors: ArrayLike):
if not (t.is_floating_point or t.is_complex):
p = 2 # array_precision[_nx.double]
else:
p = array_precision.get(t, None)
p = array_precision.get(t)
if p is None:
raise TypeError("can't get common type for non-numeric array")
precision = builtins.max(precision, p)

View File

@ -1284,7 +1284,7 @@ class MetaConverter(Generic[_TensorT]):
# Fake inner tensors of view subclasses will come from the mapping built above.
visited_id = self.describer.get_tensor_id(visited_t)
fake_visited_t = real_to_fake_mapping.get(visited_id, None)
fake_visited_t = real_to_fake_mapping.get(visited_id)
if fake_visited_t is not None:
return fake_visited_t

View File

@ -43,7 +43,7 @@ def _is_tracked_fake(obj: typing.Any) -> bool:
def _is_gm_meta_like_dict(d: dict, o: typing.Any) -> bool:
# Hope gm.meta was a custom dict we can assert on
return d.get("val", None) is o
return d.get("val") is o
def _dict_is_attr_of_tracked_fake(d: dict) -> bool:

View File

@ -60,7 +60,7 @@ def _toposort(edges):
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S.add(m)
if any(incoming_edges.get(v, None) for v in edges):
if any(incoming_edges.get(v) for v in edges):
raise ValueError("Input has cycles")
return L

View File

@ -55,7 +55,7 @@ isvar
@dispatch(object) # type: ignore[no-redef]
def isvar(o):
return not not _glv and hashable(o) and o in _glv
return _glv and hashable(o) and o in _glv
@contextmanager

View File

@ -193,7 +193,7 @@ def _deserialize_graph_module(
graph = KeepModules().trace(com, **tracer_extras)
# Recover node.meta["stack_trace"] after re-tracing
node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace", None)
node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace")
if node_meta_stack_trace is not None:
del body["_graphmodule_graph_node_meta_stack_trace"]
for node in graph.nodes:

View File

@ -190,7 +190,7 @@ class CapabilityBasedPartitioner:
# Iterate through all the users of this node and update the partition map to indicate
# that there is a path from the partition id of this node to the target partition id.
for user_node in node.users:
target_id = assignment.get(user_node, None)
target_id = assignment.get(user_node)
if target_id is not None:
partition_map[id].add(target_id)
partition_map[id].update(partition_map[target_id])
@ -267,9 +267,9 @@ class CapabilityBasedPartitioner:
# node has tuple outputs, re-assign all following getitem node into node's partition
if is_tuple_output:
id = assignment.get(node, None) # type: ignore[arg-type]
id = assignment.get(node) # type: ignore[arg-type]
for user in node.users:
if assignment.get(user, None) != id: # type: ignore[arg-type]
if assignment.get(user) != id: # type: ignore[arg-type]
nodes_reassignment[user] = id # type: ignore[assignment]
for node, id in nodes_reassignment.items():
merge_single_node(node, None, id)

View File

@ -76,11 +76,11 @@ def _get_script_class(python_class):
override = getattr(python_class, "_jit_override_qualname", None)
if override is not None:
python_class = _get_python_class(override)
return _script_classes.get(python_class, None)
return _script_classes.get(python_class)
def _get_python_class(qualified_name):
return _name_to_pyclass.get(qualified_name, None)
return _name_to_pyclass.get(qualified_name)
def _clear_class_state():

View File

@ -228,7 +228,7 @@ class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule):
def preview_type_promotion(
self, args: tuple, kwargs: dict
) -> TypePromotionSnapshot:
rounding_mode = kwargs.get("rounding_mode", None)
rounding_mode = kwargs.get("rounding_mode")
if rounding_mode is None:
# true_divide
self.promotion_kind = (
@ -287,7 +287,7 @@ class ReductionTypePromotionRule(TypePromotionRule):
)
arg = args[0]
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
dtype: torch.dtype | None = kwargs.get("dtype", None)
dtype: torch.dtype | None = kwargs.get("dtype")
computation_dtype, result_dtype = _prims_common.reduction_dtypes(
arg, self.promotion_kind, dtype
@ -351,7 +351,7 @@ class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule):
)
arg = args[0]
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
dtype: torch.dtype | None = kwargs.get("dtype", None)
dtype: torch.dtype | None = kwargs.get("dtype")
# The below logic is copied from `torch/_refs/__init__.py` reduction ops impl.
if dtype is None:
if _prims_common.is_boolean_dtype(

View File

@ -3421,7 +3421,7 @@ class ModuleTest(TestBase):
kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
self.precision = kwargs.get('precision', 2e-4)
self.check_forward_only = kwargs.get('check_forward_only', False)
self.default_dtype = kwargs.get('default_dtype', None)
self.default_dtype = kwargs.get('default_dtype')
if self.default_dtype is None:
self.default_dtype = torch.get_default_dtype()
@ -3632,7 +3632,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc]
self.test_cpu = kwargs.get('test_cpu', True)
self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
self.check_batched_grad = kwargs.get('check_batched_grad', True)
self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None)
self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode')
self.supports_forward_ad = kwargs.get('supports_forward_ad', False)
self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False)
@ -3836,7 +3836,7 @@ class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc]
self.with_tf32 = kwargs.get('with_tf32', True)
self.tf32_precision = kwargs.get('tf32_precision', 0.001)
self.check_batched_grad = kwargs.get('check_batched_grad', True)
self.default_dtype = kwargs.get('default_dtype', None)
self.default_dtype = kwargs.get('default_dtype')
if self.default_dtype is None:
self.default_dtype = torch.get_default_dtype()

View File

@ -5124,7 +5124,7 @@ def gradcheck(fn, inputs, **kwargs):
for key, value in default_values.items():
# default value override values explicitly set to None
k = kwargs.get(key, None)
k = kwargs.get(key)
kwargs[key] = k if k is not None else value
return torch.autograd.gradcheck(fn, inputs, **kwargs)
@ -5144,7 +5144,7 @@ def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs):
for key, value in default_values.items():
# default value override values explicitly set to None
k = kwargs.get(key, None)
k = kwargs.get(key)
kwargs[key] = k if k is not None else value
return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)

View File

@ -40,7 +40,7 @@ class VerifyStateDictMixin:
if not options.ignore_frozen_params:
self.assertEqual(len(msd), len(dist_msd))
for fqn, param in msd.items():
dist_param = dist_msd.get(fqn, None)
dist_param = dist_msd.get(fqn)
if not options.ignore_frozen_params:
self.assertIsNotNone(dist_param, f"{fqn=}")
try:

View File

@ -2309,7 +2309,7 @@ def _write_ninja_file_and_build_library(
def is_ninja_available():
"""Return ``True`` if the `ninja <https://ninja-build.org/>`_ build system is available on the system, ``False`` otherwise."""
try:
subprocess.check_output('ninja --version'.split())
subprocess.check_output(['ninja', '--version'])
except Exception:
return False
else:

View File

@ -61,7 +61,7 @@ def basichandlers(extension: str, data):
if extension in "txt text transcript":
return data.decode("utf-8")
if extension in "cls cls2 class count index inx id".split():
if extension in ["cls", "cls2", "class", "count", "index", "inx", "id"]:
try:
return int(data)
except ValueError:
@ -70,10 +70,10 @@ def basichandlers(extension: str, data):
if extension in "json jsn":
return json.loads(data)
if extension in "pyd pickle".split():
if extension in ["pyd", "pickle"]:
return pickle.loads(data)
if extension in "pt".split():
if extension in ["pt"]:
stream = io.BytesIO(data)
return torch.load(stream)
@ -175,7 +175,7 @@ class ImageHandler:
self.imagespec = imagespec.lower()
def __call__(self, extension, data):
if extension.lower() not in "jpg jpeg png ppm pgm pbm pnm".split():
if extension.lower() not in ["jpg", "jpeg", "png", "ppm", "pgm", "pbm", "pnm"]:
return None
try:
@ -235,7 +235,17 @@ def imagehandler(imagespec):
# torch video
################################################################
def videohandler(extension, data):
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
if extension not in [
"mp4",
"ogv",
"mjpeg",
"avi",
"mov",
"h264",
"mpg",
"webm",
"wmv",
]:
return None
try:

View File

@ -60,7 +60,23 @@ class Variant(Enum):
DEFAULT_KERNEL_NAMESPACE = "at::native"
# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
BACKEND_COMPONENTS = [
"CPU",
"CUDA",
"HIP",
"XLA",
"MTIA",
"MPS",
"IPU",
"XPU",
"HPU",
"VE",
"Lazy",
"Meta",
"PrivateUse1",
"PrivateUse2",
"PrivateUse3",
]
FUNCTIONALITY_KEYS = [
"",
"Quantized",