mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
More ruff SIM fixes (#164695)
This PR applies ruff `SIM` rules to more files. Most changes are about simplifying `dict.get` because `None` is already the default value. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164695 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
54ae61c573
commit
a029675f6f
@ -321,7 +321,7 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase):
|
||||
def _test_serialization(self, guard_type, fn, *args, **kwargs):
|
||||
# kwargs might contain a callable that generates kwargs
|
||||
torch._dynamo.reset()
|
||||
kwarg_gen_fn = kwargs.get("_gen_fn", None)
|
||||
kwarg_gen_fn = kwargs.get("_gen_fn")
|
||||
if kwarg_gen_fn is not None:
|
||||
kwargs = kwarg_gen_fn()
|
||||
|
||||
|
@ -242,11 +242,12 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
opt_f(torch.randn(8 + i))
|
||||
|
||||
failure_str = "\n".join(failure_reasons)
|
||||
for line in """\
|
||||
tensor 'x' size mismatch at index 0. expected 11, actual 12
|
||||
tensor 'x' size mismatch at index 0. expected 10, actual 12
|
||||
tensor 'x' size mismatch at index 0. expected 9, actual 12
|
||||
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"):
|
||||
for line in [
|
||||
"tensor 'x' size mismatch at index 0. expected 11, actual 12",
|
||||
"tensor 'x' size mismatch at index 0. expected 10, actual 12",
|
||||
"tensor 'x' size mismatch at index 0. expected 9, actual 12",
|
||||
"tensor 'x' size mismatch at index 0. expected 8, actual 12",
|
||||
]:
|
||||
self.assertIn(
|
||||
line,
|
||||
failure_str,
|
||||
@ -281,16 +282,13 @@ tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"):
|
||||
failure_reasons.clear()
|
||||
opt_f([7, 8])
|
||||
|
||||
for line in """\
|
||||
len(x) == 3""".split("\n"):
|
||||
for line in ["len(x) == 3"]:
|
||||
self.assertIn(line, filter_reasons())
|
||||
|
||||
failure_reasons.clear()
|
||||
opt_f([9])
|
||||
|
||||
for line in """\
|
||||
len(x) == 2
|
||||
len(x) == 3""".split("\n"):
|
||||
for line in ["len(x) == 2", "len(x) == 3"]:
|
||||
self.assertIn(line, filter_reasons())
|
||||
|
||||
@torch._dynamo.config.patch(recompile_limit=1)
|
||||
|
@ -686,7 +686,7 @@ class TestFindIndependentSubsetGreedy(TestCase):
|
||||
unsatisfied += 1
|
||||
assert unsatisfied <= len(desc) # cycle or bad input?
|
||||
name, v = desc.popleft()
|
||||
args = tuple(lookup.get(n, None) for n in v)
|
||||
args = tuple(lookup.get(n) for n in v)
|
||||
if None in args:
|
||||
desc.append((name, v))
|
||||
continue
|
||||
|
@ -901,7 +901,7 @@ class TestMemoryProfilerE2E(TestCase):
|
||||
ptr_pair_to_key[(t.impl_ptr, t.storage_data_ptr)] = key
|
||||
|
||||
def format_categories(ptr_pair: int):
|
||||
target_key = ptr_pair_to_key.get(ptr_pair, None)
|
||||
target_key = ptr_pair_to_key.get(ptr_pair)
|
||||
if target_key is None:
|
||||
return "???"
|
||||
|
||||
|
@ -127,7 +127,7 @@ class TestVisibleDeviceParses(TestCase):
|
||||
_transform_uuid_to_ordinals(["GPU-e4", "GPU-9e8d35e3"], uuids), [2, 1]
|
||||
)
|
||||
self.assertEqual(
|
||||
_transform_uuid_to_ordinals("GPU-9e8d35e3,GPU-1,GPU-47".split(","), uuids),
|
||||
_transform_uuid_to_ordinals(["GPU-9e8d35e3", "GPU-1", "GPU-47"], uuids),
|
||||
[1, 7, 5],
|
||||
)
|
||||
# First invalid UUID aborts parsing
|
||||
|
@ -241,17 +241,17 @@ class TestScript(JitTestCase):
|
||||
def test_split() -> tuple[list[str], list[str], list[str], list[str], list[str],
|
||||
list[str], list[str], list[str], list[str], list[str], list[str]]:
|
||||
return (
|
||||
"a a a a a".split(),
|
||||
"a a a a a".split(),
|
||||
" a a\ta \v a \v\f\n a \t ".split(),
|
||||
" a a a a a ".split(" "),
|
||||
"a a a a a ".split(" ", 10),
|
||||
"a a a a a ".split(" ", -1),
|
||||
"a a a a a ".split(" ", 3),
|
||||
" a a a a a ".split("*"),
|
||||
" a*a a*a a".split("*"),
|
||||
" a*a a*a a ".split("*", -1),
|
||||
" a*a a*a a ".split("a*", 10),
|
||||
["a", "a", "a", "a", "a"],
|
||||
["a", "a", "a", "a", "a"],
|
||||
["a", "a", "a", "a", "a"],
|
||||
["", "a", "a", "a", "a", "a", ""],
|
||||
["a", "a", "a", "a", "a", ""],
|
||||
["a", "a", "a", "a", "a", ""],
|
||||
["a", "a", "a", "a a "],
|
||||
[" a a a a a "],
|
||||
[" a", "a a", "a a"],
|
||||
[" a", "a a", "a a "],
|
||||
[" ", "a ", "a a "],
|
||||
)
|
||||
self.checkScript(test_split, ())
|
||||
|
||||
@ -266,15 +266,15 @@ class TestScript(JitTestCase):
|
||||
def test_rsplit() -> tuple[list[str], list[str], list[str], list[str], list[str],
|
||||
list[str], list[str], list[str], list[str]]:
|
||||
return (
|
||||
"a a a a a".rsplit(),
|
||||
" a a a a a ".rsplit(" "),
|
||||
"a a a a a ".rsplit(" ", 10),
|
||||
"a a a a a ".rsplit(" ", -1),
|
||||
"a a a a a ".rsplit(" ", 3),
|
||||
" a a a a a ".rsplit("*"),
|
||||
" a*a a*a a ".rsplit("*"),
|
||||
" a*a a*a a ".rsplit("*", -1),
|
||||
" a*a a*a a".rsplit("a*", 10),
|
||||
["a", "a", "a", "a", "a"],
|
||||
["", "a", "a", "a", "a", "a", ""],
|
||||
["a", "a", "a", "a", "a", ""],
|
||||
["a", "a", "a", "a", "a", ""],
|
||||
["a a a", "a", "a", ""],
|
||||
[" a a a a a "],
|
||||
[" a", "a a", "a a "],
|
||||
[" a", "a a", "a a "],
|
||||
[" ", "a ", "a a"],
|
||||
)
|
||||
self.checkScript(test_rsplit, ())
|
||||
|
||||
|
@ -190,12 +190,12 @@ def map_job_data(jobNames: Any, shaGrid: Any) -> dict[str, Any]:
|
||||
|
||||
|
||||
def is_job_failed(job: Any) -> bool:
|
||||
conclusion = job["conclusion"] if "conclusion" in job else None
|
||||
conclusion = job.get("conclusion", None)
|
||||
return conclusion is not None and conclusion != SUCCESS and conclusion != PENDING
|
||||
|
||||
|
||||
def is_job_skipped(job: Any) -> bool:
|
||||
conclusion = job["conclusion"] if "conclusion" in job else None
|
||||
conclusion = job.get("conclusion", None)
|
||||
return conclusion in (NEUTRAL, SKIPPED) or conclusion is None
|
||||
|
||||
|
||||
|
@ -340,7 +340,7 @@ def get_base_name(f: NativeFunction) -> str:
|
||||
|
||||
def get_view_info(f: NativeFunction) -> str | None:
|
||||
base_name = get_base_name(f)
|
||||
view_info = VIEW_FUNCTIONS.get(base_name, None)
|
||||
view_info = VIEW_FUNCTIONS.get(base_name)
|
||||
if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
|
||||
view_info = "self"
|
||||
return view_info
|
||||
|
@ -30,7 +30,7 @@ cov_data = CoverageData(basename=f".coverage.jit.{time()}")
|
||||
|
||||
|
||||
def is_not_builtin_class(obj: Any) -> bool:
|
||||
return isclass(obj) and not type(obj).__module__ == "builtins"
|
||||
return isclass(obj) and type(obj).__module__ != "builtins"
|
||||
|
||||
|
||||
class JitPlugin(CoveragePlugin): # type: ignore[misc, no-any-unimported]
|
||||
|
@ -554,26 +554,30 @@ class Op:
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'",
|
||||
)
|
||||
if self.type in [
|
||||
"all_gather",
|
||||
"all_gather_base",
|
||||
"all_gather_into_tensor_coalesced",
|
||||
] and not (
|
||||
math.prod(other.output_sizes[0])
|
||||
== math.prod(self.input_sizes[0]) * self.pg_size
|
||||
if (
|
||||
self.type
|
||||
in [
|
||||
"all_gather",
|
||||
"all_gather_base",
|
||||
"all_gather_into_tensor_coalesced",
|
||||
]
|
||||
and math.prod(other.output_sizes[0])
|
||||
!= math.prod(self.input_sizes[0]) * self.pg_size
|
||||
):
|
||||
return MatchInfo(
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
f"Found input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' "
|
||||
f"does not match output numel '{math.prod(other.output_sizes[0])}'",
|
||||
)
|
||||
if self.type in [
|
||||
"reduce_scatter",
|
||||
"_reduce_scatter_base",
|
||||
"reduce_scatter_tensor_coalesced",
|
||||
] and not (
|
||||
math.prod(other.input_sizes[0])
|
||||
== math.prod(self.output_sizes[0]) * self.pg_size
|
||||
if (
|
||||
self.type
|
||||
in [
|
||||
"reduce_scatter",
|
||||
"_reduce_scatter_base",
|
||||
"reduce_scatter_tensor_coalesced",
|
||||
]
|
||||
and math.prod(other.input_sizes[0])
|
||||
!= math.prod(self.output_sizes[0]) * self.pg_size
|
||||
):
|
||||
return MatchInfo(
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
|
@ -326,7 +326,7 @@ class CMake:
|
||||
|
||||
# The default value cannot be easily obtained in CMakeLists.txt. We set it here.
|
||||
py_lib_path = sysconfig.get_path("purelib")
|
||||
cmake_prefix_path = build_options.get("CMAKE_PREFIX_PATH", None)
|
||||
cmake_prefix_path = build_options.get("CMAKE_PREFIX_PATH")
|
||||
if cmake_prefix_path:
|
||||
build_options["CMAKE_PREFIX_PATH"] = (
|
||||
py_lib_path + ";" + cast(str, cmake_prefix_path)
|
||||
|
@ -43,7 +43,7 @@ def add_job(
|
||||
if workflow_name not in workflows:
|
||||
workflows[workflow_name] = {"when": "always", "jobs": []}
|
||||
|
||||
requires = job.get("requires", None)
|
||||
requires = job.get("requires")
|
||||
if requires is not None:
|
||||
for requirement in requires:
|
||||
dependency = past_jobs[requirement]
|
||||
|
@ -43,7 +43,7 @@ class ConstantVariable(VariableTracker):
|
||||
NOTE: the caller must install the proper guards if needed; most often
|
||||
the guard will be `CONSTANT_MATCH`.
|
||||
"""
|
||||
source = kwargs.get("source", None)
|
||||
source = kwargs.get("source")
|
||||
|
||||
# Routing for supported collection literals.
|
||||
if isinstance(value, set):
|
||||
|
@ -1154,7 +1154,7 @@ class NamedTupleVariable(TupleVariable):
|
||||
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
self.tuple_cls = tuple_cls
|
||||
self.dynamic_attributes = {} if not dynamic_attributes else dynamic_attributes
|
||||
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
|
||||
|
||||
def is_namedtuple(self):
|
||||
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
|
||||
|
@ -1425,7 +1425,7 @@ class NumpyVariable(VariableTracker):
|
||||
def get_constant_collection_for_func(cls, fn):
|
||||
mod = fn.__module__.split(".")
|
||||
assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
|
||||
return np_constant_collections_map.get(fn, None)
|
||||
return np_constant_collections_map.get(fn)
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
@ -1930,7 +1930,7 @@ class RandomVariable(VariableTracker):
|
||||
class WeakRefVariable(VariableTracker):
|
||||
@staticmethod
|
||||
def build(tx, weakref_value, **options):
|
||||
source = options.get("source", None)
|
||||
source = options.get("source")
|
||||
callback = weakref_value.__callback__
|
||||
callback_source = source and AttrSource(source, "__callback__")
|
||||
callback_vt = VariableTracker.build(tx, callback, callback_source)
|
||||
|
@ -1219,7 +1219,7 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
|
||||
"""
|
||||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
source = kwargs.get("source", None)
|
||||
source = kwargs.get("source")
|
||||
assert source is not None, (
|
||||
"FSDPManagedNNModule depends on having an accurate source to control guarding."
|
||||
)
|
||||
|
@ -13,7 +13,15 @@ if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split()
|
||||
PARAM_NAMES = [
|
||||
"query",
|
||||
"key",
|
||||
"value",
|
||||
"attn_mask",
|
||||
"dropout",
|
||||
"is_causal",
|
||||
"enable_gqa",
|
||||
]
|
||||
|
||||
|
||||
class SDPAParamsVariable(VariableTracker):
|
||||
|
@ -239,7 +239,7 @@ def write_view_information_to_args(
|
||||
write_single_view(
|
||||
f"_{arg_name}",
|
||||
kwargs[arg_name],
|
||||
arg_to_base_index.get(arg_name, None), # type: ignore[arg-type]
|
||||
arg_to_base_index.get(arg_name), # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported type {arg_type}")
|
||||
@ -390,7 +390,7 @@ class AutoFunctionalizedV2(HigherOrderOperator):
|
||||
if isinstance(_mutable_op, HigherOrderOperator):
|
||||
_op_to_check = HopInstance(
|
||||
_mutable_op,
|
||||
SchemaHolder.from_tree_spec(kwargs.get("_op_schema", None)).schema, # type: ignore[arg-type]
|
||||
SchemaHolder.from_tree_spec(kwargs.get("_op_schema")).schema, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
_op_to_check = _mutable_op
|
||||
@ -958,11 +958,11 @@ def auto_functionalized_v2_proxy(
|
||||
# hop node in the traced graph and graph module inputs to the hop. Finally, we replace the
|
||||
# original kwarg's callable with the graph module.
|
||||
all_bases = kwargs.get("_all_bases", [])
|
||||
_only_clone_these_bases = kwargs.get("_only_clone_these_bases", None)
|
||||
_only_clone_these_bases = kwargs.get("_only_clone_these_bases")
|
||||
if _only_clone_these_bases is None:
|
||||
_only_clone_these_bases = tuple(range(len(all_bases)))
|
||||
|
||||
schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema # type: ignore[arg-type]
|
||||
schema = pytree.tree_unflatten([], kwargs.get("_op_schema")).schema # type: ignore[arg-type]
|
||||
new_kwargs, _ = _generate_new_op_kwargs_from_bases(
|
||||
schema,
|
||||
{k: v for k, v in kwargs.items() if k not in ("_all_bases", "_op_schema")},
|
||||
|
@ -163,7 +163,7 @@ def lookup_device_info(name: str) -> Optional[DeviceInfo]:
|
||||
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
|
||||
name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
|
||||
"""
|
||||
return _device_mapping.get(name, None)
|
||||
return _device_mapping.get(name)
|
||||
|
||||
|
||||
def datasheet_tops(dtype: torch.dtype, is_tf32: bool = False) -> Optional[float]:
|
||||
|
@ -974,7 +974,7 @@ def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
|
||||
def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
# We need avx512_bf16 to dequant int8 to bf16
|
||||
vec_isa = kwargs.get("vec_isa", None)
|
||||
vec_isa = kwargs.get("vec_isa")
|
||||
assert vec_isa is not None
|
||||
return vec_isa.is_avx512_bf16_supported() and check_amx_extra(
|
||||
config, m, n, k, alpha, num_threads, **kwargs
|
||||
@ -984,7 +984,7 @@ def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
# amx_fp16 need to be checked separately since it is not always supported when amx is supported
|
||||
def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
assert config.input_dtype == torch.float16 and config.output_dtype == torch.float
|
||||
vec_isa = kwargs.get("vec_isa", None)
|
||||
vec_isa = kwargs.get("vec_isa")
|
||||
assert vec_isa is not None
|
||||
vnni_size = 2
|
||||
return vec_isa.is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1
|
||||
@ -1419,7 +1419,7 @@ class CppMicroBrgemm(CppMicroGemm):
|
||||
def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
if alpha != 1:
|
||||
return False
|
||||
q_group_size = kwargs.get("q_group_size", None)
|
||||
q_group_size = kwargs.get("q_group_size")
|
||||
assert q_group_size is not None
|
||||
if (
|
||||
q_group_size not in [32, 64, 128]
|
||||
|
@ -528,7 +528,7 @@ class CKGroupedConvFwdTemplate(CKTemplate):
|
||||
op: "CKGroupedConvFwdOp", # type: ignore[name-defined]
|
||||
**kwargs,
|
||||
) -> str:
|
||||
template_buffer_node = kwargs.get("template_buffer_node", None)
|
||||
template_buffer_node = kwargs.get("template_buffer_node")
|
||||
if template_buffer_node is not None:
|
||||
self.output_node = template_buffer_node
|
||||
X, W = self.input_nodes[0], self.input_nodes[1]
|
||||
|
@ -750,9 +750,9 @@ class CKTileGemmTemplate(CKTileTemplate):
|
||||
"""
|
||||
The primary entry point for the code rendering process used in this template.
|
||||
"""
|
||||
epilogue_nodes = kwargs.get("epilogue_nodes", None)
|
||||
epilogue_nodes = kwargs.get("epilogue_nodes")
|
||||
assert epilogue_nodes is None or 0 == len(epilogue_nodes)
|
||||
template_buffer_node = kwargs.get("template_buffer_node", None)
|
||||
template_buffer_node = kwargs.get("template_buffer_node")
|
||||
if template_buffer_node is not None:
|
||||
self.output_node = template_buffer_node
|
||||
assert 2 == len(self.input_nodes)
|
||||
|
@ -547,7 +547,7 @@ class CKGemmTemplate(CKTemplate):
|
||||
# Define the mapping of versions to stages
|
||||
version_to_stages = {1: 1, 3: 2, 4: 4, 5: 3}
|
||||
# Get the stages for the given version
|
||||
stages = version_to_stages.get(version, None)
|
||||
stages = version_to_stages.get(version)
|
||||
if stages is None:
|
||||
# This means we're at stage 2, and this requires computation
|
||||
# See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950
|
||||
@ -612,9 +612,9 @@ class CKGemmTemplate(CKTemplate):
|
||||
"""
|
||||
The primary entry point for the code rendering process used in this template.
|
||||
"""
|
||||
epilogue_nodes = kwargs.get("epilogue_nodes", None)
|
||||
epilogue_nodes = kwargs.get("epilogue_nodes")
|
||||
assert epilogue_nodes is None or 0 == len(epilogue_nodes)
|
||||
template_buffer_node = kwargs.get("template_buffer_node", None)
|
||||
template_buffer_node = kwargs.get("template_buffer_node")
|
||||
if template_buffer_node is not None:
|
||||
self.output_node = template_buffer_node
|
||||
# input nodes:
|
||||
|
@ -2296,7 +2296,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
return ([], [])
|
||||
|
||||
key = (repr(vars_to_use), use_split_var, is_pointwise)
|
||||
if out := scored_sub_split.get(key, None):
|
||||
if out := scored_sub_split.get(key):
|
||||
return out
|
||||
|
||||
splitting_vars = all_iter_vars if is_pointwise else all_red_vars
|
||||
|
@ -550,7 +550,7 @@ class CompilerBisector:
|
||||
curr_backend,
|
||||
curr_subsystem.name,
|
||||
low,
|
||||
call_counter_debug_info.get(low, None),
|
||||
call_counter_debug_info.get(low),
|
||||
)
|
||||
|
||||
next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem)
|
||||
|
@ -715,7 +715,7 @@ def run_autoheuristic(
|
||||
)
|
||||
choice = autoheuristic.get_choice()
|
||||
choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None}
|
||||
ah_should_pad = choice2should_pad.get(choice, None)
|
||||
ah_should_pad = choice2should_pad.get(choice)
|
||||
|
||||
if torch._inductor.config.collect_autoheuristic(name):
|
||||
ah_ori_time = autoheuristic.get_collected_feedback(orig_choice)
|
||||
|
@ -634,7 +634,7 @@ class NormalizedLinearNode:
|
||||
if len(self.node.args) > 2:
|
||||
return self.node.args[2] # type: ignore[return-value]
|
||||
else:
|
||||
return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value]
|
||||
return self.node.kwargs.get("bias", None) # type: ignore[return-value]
|
||||
|
||||
|
||||
class NormalizedMatmulNode:
|
||||
|
@ -503,7 +503,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
|
||||
if mutated_arg.op in ("placeholder", "get_attr"):
|
||||
# Get the first copy_ node that mutates the mutated_arg.
|
||||
copy_node = copy_nodes.get(mutated_arg, None)
|
||||
copy_node = copy_nodes.get(mutated_arg)
|
||||
if copy_node is None:
|
||||
# There is no copy_ back to the candidate mutated_arg (which is a graph input).
|
||||
# Therefore the semantics of the program are that it does not mutate
|
||||
|
@ -2365,7 +2365,7 @@ make_fallback(aten.randint)
|
||||
|
||||
@register_lowering(aten.rand)
|
||||
def rand(*args, **kwargs):
|
||||
if kwargs.get("generator", None) is not None:
|
||||
if kwargs.get("generator") is not None:
|
||||
return fallback_rand_generator(*args, **kwargs)
|
||||
elif config.fallback_random:
|
||||
kwargs.pop("generator", None)
|
||||
@ -2375,7 +2375,7 @@ def rand(*args, **kwargs):
|
||||
|
||||
@register_lowering(aten.randn)
|
||||
def randn(*args, **kwargs):
|
||||
if kwargs.get("generator", None) is not None:
|
||||
if kwargs.get("generator") is not None:
|
||||
return fallback_randn_generator(*args, **kwargs)
|
||||
elif config.fallback_random:
|
||||
kwargs.pop("generator", None)
|
||||
|
@ -1583,7 +1583,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
return None
|
||||
|
||||
def check_can_launch() -> StaticallyLaunchedCudaKernel:
|
||||
if triton_meta.get("device_type", None) != "cuda":
|
||||
if triton_meta.get("device_type") != "cuda":
|
||||
# Only cuda kernels
|
||||
raise CannotStaticallyLaunchKernel("Non-cuda device")
|
||||
|
||||
@ -1600,7 +1600,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
# Don't support user defined triton kernels yet
|
||||
raise CannotStaticallyLaunchKernel("User defined triton kernel")
|
||||
|
||||
if inductor_meta.get("store_cubin", None):
|
||||
if inductor_meta.get("store_cubin"):
|
||||
# Requires storing the entire binary
|
||||
raise CannotStaticallyLaunchKernel("store_cubin is enabled")
|
||||
|
||||
@ -2640,7 +2640,7 @@ def pointwise(
|
||||
def _reduction_configs(
|
||||
*, size_hints: dict[str, int], inductor_meta: dict[str, Any], num_dynamic=0
|
||||
) -> list[Config]:
|
||||
reduction_hint = inductor_meta.get("reduction_hint", None)
|
||||
reduction_hint = inductor_meta.get("reduction_hint")
|
||||
|
||||
# Convert reductions to 1D, to simplify heuristics.
|
||||
rnumel = get_total_reduction_numel(size_hints)
|
||||
|
@ -2089,8 +2089,8 @@ class TritonTemplate(KernelTemplate):
|
||||
"num_stages": num_stages,
|
||||
"num_warps": num_warps,
|
||||
"GROUP_M": kwargs.get("GROUP_M", -1),
|
||||
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
|
||||
"acc_type": str(kwargs.get("ACC_TYPE", None)),
|
||||
"allow_tf32": str(kwargs.get("ALLOW_TF32")),
|
||||
"acc_type": str(kwargs.get("ACC_TYPE")),
|
||||
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
|
||||
"waves_per_eu": kwargs.get("waves_per_eu", 0),
|
||||
"kpack": kwargs.get("kpack", 2),
|
||||
|
@ -112,7 +112,7 @@ class SizeVarAllocator:
|
||||
cache.clear()
|
||||
replacement_count = len(self.replacements)
|
||||
key = (expr, *var_ranges.items())
|
||||
result = cache.get(key, None)
|
||||
result = cache.get(key)
|
||||
if result is None:
|
||||
result = self._simplify_with_ranges(expr, var_ranges)
|
||||
cache[key] = result
|
||||
@ -136,7 +136,7 @@ class SizeVarAllocator:
|
||||
cache.clear()
|
||||
replacement_count = len(self.replacements)
|
||||
key = (*index_vars, *sizes, *index_formulas)
|
||||
result = cache.get(key, None)
|
||||
result = cache.get(key)
|
||||
if result is None:
|
||||
result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
|
||||
cache[key] = result
|
||||
|
@ -575,7 +575,7 @@ def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
|
||||
# TODO - deduplicate with candidate_tilings
|
||||
var_sizes = []
|
||||
for v in addr.free_symbols:
|
||||
v_size = var_ranges.get(v, None)
|
||||
v_size = var_ranges.get(v)
|
||||
# TODO - reason about indirect vars
|
||||
if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None:
|
||||
var_sizes.append(v_size)
|
||||
|
@ -21,7 +21,7 @@ def intern_string(s: Optional[str]) -> int:
|
||||
if s is None:
|
||||
return -1
|
||||
|
||||
r = INTERN_TABLE.get(s, None)
|
||||
r = INTERN_TABLE.get(s)
|
||||
if r is None:
|
||||
r = len(INTERN_TABLE)
|
||||
INTERN_TABLE[s] = r
|
||||
|
@ -1867,7 +1867,7 @@ def common_type(*tensors: ArrayLike):
|
||||
if not (t.is_floating_point or t.is_complex):
|
||||
p = 2 # array_precision[_nx.double]
|
||||
else:
|
||||
p = array_precision.get(t, None)
|
||||
p = array_precision.get(t)
|
||||
if p is None:
|
||||
raise TypeError("can't get common type for non-numeric array")
|
||||
precision = builtins.max(precision, p)
|
||||
|
@ -1284,7 +1284,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
|
||||
# Fake inner tensors of view subclasses will come from the mapping built above.
|
||||
visited_id = self.describer.get_tensor_id(visited_t)
|
||||
fake_visited_t = real_to_fake_mapping.get(visited_id, None)
|
||||
fake_visited_t = real_to_fake_mapping.get(visited_id)
|
||||
if fake_visited_t is not None:
|
||||
return fake_visited_t
|
||||
|
||||
|
@ -43,7 +43,7 @@ def _is_tracked_fake(obj: typing.Any) -> bool:
|
||||
|
||||
def _is_gm_meta_like_dict(d: dict, o: typing.Any) -> bool:
|
||||
# Hope gm.meta was a custom dict we can assert on
|
||||
return d.get("val", None) is o
|
||||
return d.get("val") is o
|
||||
|
||||
|
||||
def _dict_is_attr_of_tracked_fake(d: dict) -> bool:
|
||||
|
@ -60,7 +60,7 @@ def _toposort(edges):
|
||||
incoming_edges[m].remove(n)
|
||||
if not incoming_edges[m]:
|
||||
S.add(m)
|
||||
if any(incoming_edges.get(v, None) for v in edges):
|
||||
if any(incoming_edges.get(v) for v in edges):
|
||||
raise ValueError("Input has cycles")
|
||||
return L
|
||||
|
||||
|
@ -55,7 +55,7 @@ isvar
|
||||
|
||||
@dispatch(object) # type: ignore[no-redef]
|
||||
def isvar(o):
|
||||
return not not _glv and hashable(o) and o in _glv
|
||||
return _glv and hashable(o) and o in _glv
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -193,7 +193,7 @@ def _deserialize_graph_module(
|
||||
graph = KeepModules().trace(com, **tracer_extras)
|
||||
|
||||
# Recover node.meta["stack_trace"] after re-tracing
|
||||
node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace", None)
|
||||
node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace")
|
||||
if node_meta_stack_trace is not None:
|
||||
del body["_graphmodule_graph_node_meta_stack_trace"]
|
||||
for node in graph.nodes:
|
||||
|
@ -190,7 +190,7 @@ class CapabilityBasedPartitioner:
|
||||
# Iterate through all the users of this node and update the partition map to indicate
|
||||
# that there is a path from the partition id of this node to the target partition id.
|
||||
for user_node in node.users:
|
||||
target_id = assignment.get(user_node, None)
|
||||
target_id = assignment.get(user_node)
|
||||
if target_id is not None:
|
||||
partition_map[id].add(target_id)
|
||||
partition_map[id].update(partition_map[target_id])
|
||||
@ -267,9 +267,9 @@ class CapabilityBasedPartitioner:
|
||||
|
||||
# node has tuple outputs, re-assign all following getitem node into node's partition
|
||||
if is_tuple_output:
|
||||
id = assignment.get(node, None) # type: ignore[arg-type]
|
||||
id = assignment.get(node) # type: ignore[arg-type]
|
||||
for user in node.users:
|
||||
if assignment.get(user, None) != id: # type: ignore[arg-type]
|
||||
if assignment.get(user) != id: # type: ignore[arg-type]
|
||||
nodes_reassignment[user] = id # type: ignore[assignment]
|
||||
for node, id in nodes_reassignment.items():
|
||||
merge_single_node(node, None, id)
|
||||
|
@ -76,11 +76,11 @@ def _get_script_class(python_class):
|
||||
override = getattr(python_class, "_jit_override_qualname", None)
|
||||
if override is not None:
|
||||
python_class = _get_python_class(override)
|
||||
return _script_classes.get(python_class, None)
|
||||
return _script_classes.get(python_class)
|
||||
|
||||
|
||||
def _get_python_class(qualified_name):
|
||||
return _name_to_pyclass.get(qualified_name, None)
|
||||
return _name_to_pyclass.get(qualified_name)
|
||||
|
||||
|
||||
def _clear_class_state():
|
||||
|
@ -228,7 +228,7 @@ class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule):
|
||||
def preview_type_promotion(
|
||||
self, args: tuple, kwargs: dict
|
||||
) -> TypePromotionSnapshot:
|
||||
rounding_mode = kwargs.get("rounding_mode", None)
|
||||
rounding_mode = kwargs.get("rounding_mode")
|
||||
if rounding_mode is None:
|
||||
# true_divide
|
||||
self.promotion_kind = (
|
||||
@ -287,7 +287,7 @@ class ReductionTypePromotionRule(TypePromotionRule):
|
||||
)
|
||||
arg = args[0]
|
||||
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
|
||||
dtype: torch.dtype | None = kwargs.get("dtype", None)
|
||||
dtype: torch.dtype | None = kwargs.get("dtype")
|
||||
|
||||
computation_dtype, result_dtype = _prims_common.reduction_dtypes(
|
||||
arg, self.promotion_kind, dtype
|
||||
@ -351,7 +351,7 @@ class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule):
|
||||
)
|
||||
arg = args[0]
|
||||
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
|
||||
dtype: torch.dtype | None = kwargs.get("dtype", None)
|
||||
dtype: torch.dtype | None = kwargs.get("dtype")
|
||||
# The below logic is copied from `torch/_refs/__init__.py` reduction ops impl.
|
||||
if dtype is None:
|
||||
if _prims_common.is_boolean_dtype(
|
||||
|
@ -3421,7 +3421,7 @@ class ModuleTest(TestBase):
|
||||
kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
|
||||
self.precision = kwargs.get('precision', 2e-4)
|
||||
self.check_forward_only = kwargs.get('check_forward_only', False)
|
||||
self.default_dtype = kwargs.get('default_dtype', None)
|
||||
self.default_dtype = kwargs.get('default_dtype')
|
||||
if self.default_dtype is None:
|
||||
self.default_dtype = torch.get_default_dtype()
|
||||
|
||||
@ -3632,7 +3632,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc]
|
||||
self.test_cpu = kwargs.get('test_cpu', True)
|
||||
self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
|
||||
self.check_batched_grad = kwargs.get('check_batched_grad', True)
|
||||
self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None)
|
||||
self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode')
|
||||
self.supports_forward_ad = kwargs.get('supports_forward_ad', False)
|
||||
self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False)
|
||||
|
||||
@ -3836,7 +3836,7 @@ class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc]
|
||||
self.with_tf32 = kwargs.get('with_tf32', True)
|
||||
self.tf32_precision = kwargs.get('tf32_precision', 0.001)
|
||||
self.check_batched_grad = kwargs.get('check_batched_grad', True)
|
||||
self.default_dtype = kwargs.get('default_dtype', None)
|
||||
self.default_dtype = kwargs.get('default_dtype')
|
||||
if self.default_dtype is None:
|
||||
self.default_dtype = torch.get_default_dtype()
|
||||
|
||||
|
@ -5124,7 +5124,7 @@ def gradcheck(fn, inputs, **kwargs):
|
||||
|
||||
for key, value in default_values.items():
|
||||
# default value override values explicitly set to None
|
||||
k = kwargs.get(key, None)
|
||||
k = kwargs.get(key)
|
||||
kwargs[key] = k if k is not None else value
|
||||
|
||||
return torch.autograd.gradcheck(fn, inputs, **kwargs)
|
||||
@ -5144,7 +5144,7 @@ def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs):
|
||||
|
||||
for key, value in default_values.items():
|
||||
# default value override values explicitly set to None
|
||||
k = kwargs.get(key, None)
|
||||
k = kwargs.get(key)
|
||||
kwargs[key] = k if k is not None else value
|
||||
|
||||
return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)
|
||||
|
@ -40,7 +40,7 @@ class VerifyStateDictMixin:
|
||||
if not options.ignore_frozen_params:
|
||||
self.assertEqual(len(msd), len(dist_msd))
|
||||
for fqn, param in msd.items():
|
||||
dist_param = dist_msd.get(fqn, None)
|
||||
dist_param = dist_msd.get(fqn)
|
||||
if not options.ignore_frozen_params:
|
||||
self.assertIsNotNone(dist_param, f"{fqn=}")
|
||||
try:
|
||||
|
@ -2309,7 +2309,7 @@ def _write_ninja_file_and_build_library(
|
||||
def is_ninja_available():
|
||||
"""Return ``True`` if the `ninja <https://ninja-build.org/>`_ build system is available on the system, ``False`` otherwise."""
|
||||
try:
|
||||
subprocess.check_output('ninja --version'.split())
|
||||
subprocess.check_output(['ninja', '--version'])
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
|
@ -61,7 +61,7 @@ def basichandlers(extension: str, data):
|
||||
if extension in "txt text transcript":
|
||||
return data.decode("utf-8")
|
||||
|
||||
if extension in "cls cls2 class count index inx id".split():
|
||||
if extension in ["cls", "cls2", "class", "count", "index", "inx", "id"]:
|
||||
try:
|
||||
return int(data)
|
||||
except ValueError:
|
||||
@ -70,10 +70,10 @@ def basichandlers(extension: str, data):
|
||||
if extension in "json jsn":
|
||||
return json.loads(data)
|
||||
|
||||
if extension in "pyd pickle".split():
|
||||
if extension in ["pyd", "pickle"]:
|
||||
return pickle.loads(data)
|
||||
|
||||
if extension in "pt".split():
|
||||
if extension in ["pt"]:
|
||||
stream = io.BytesIO(data)
|
||||
return torch.load(stream)
|
||||
|
||||
@ -175,7 +175,7 @@ class ImageHandler:
|
||||
self.imagespec = imagespec.lower()
|
||||
|
||||
def __call__(self, extension, data):
|
||||
if extension.lower() not in "jpg jpeg png ppm pgm pbm pnm".split():
|
||||
if extension.lower() not in ["jpg", "jpeg", "png", "ppm", "pgm", "pbm", "pnm"]:
|
||||
return None
|
||||
|
||||
try:
|
||||
@ -235,7 +235,17 @@ def imagehandler(imagespec):
|
||||
# torch video
|
||||
################################################################
|
||||
def videohandler(extension, data):
|
||||
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
|
||||
if extension not in [
|
||||
"mp4",
|
||||
"ogv",
|
||||
"mjpeg",
|
||||
"avi",
|
||||
"mov",
|
||||
"h264",
|
||||
"mpg",
|
||||
"webm",
|
||||
"wmv",
|
||||
]:
|
||||
return None
|
||||
|
||||
try:
|
||||
|
@ -60,7 +60,23 @@ class Variant(Enum):
|
||||
DEFAULT_KERNEL_NAMESPACE = "at::native"
|
||||
|
||||
# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
|
||||
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
|
||||
BACKEND_COMPONENTS = [
|
||||
"CPU",
|
||||
"CUDA",
|
||||
"HIP",
|
||||
"XLA",
|
||||
"MTIA",
|
||||
"MPS",
|
||||
"IPU",
|
||||
"XPU",
|
||||
"HPU",
|
||||
"VE",
|
||||
"Lazy",
|
||||
"Meta",
|
||||
"PrivateUse1",
|
||||
"PrivateUse2",
|
||||
"PrivateUse3",
|
||||
]
|
||||
FUNCTIONALITY_KEYS = [
|
||||
"",
|
||||
"Quantized",
|
||||
|
Reference in New Issue
Block a user