mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[BE]: Apply RUF025 dict.fromkeys preview rule (#118637)
Simplifies and optimizes dict construction using the `fromkeys` classmethod ctor. This also makes it really obvious when all the keys will have the same static value, which could be a bug if unintentional. It is also significantly faster than using a dict comprehension. The rule is in preview, but I am adding a forward fix for when it becomes stable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118637 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e33e88e5bc
commit
1562dae62c
@ -222,7 +222,7 @@ def print_results(results: List[Experiment]):
|
||||
{
|
||||
"Type": "Average",
|
||||
"Speedup": np.mean(speedups),
|
||||
**{key: None for key in max_config_dict},
|
||||
**dict.fromkeys(max_config_dict),
|
||||
},
|
||||
{"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
|
||||
{"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
|
||||
|
@ -24,7 +24,7 @@ class DimensionBindError(Exception):
|
||||
from . import op_properties
|
||||
|
||||
# use dict to avoid writing C++ bindings for set
|
||||
pointwise = {t: True for t in op_properties.pointwise}
|
||||
pointwise = dict.fromkeys(op_properties.pointwise, True)
|
||||
|
||||
use_c = True
|
||||
if not use_c:
|
||||
|
@ -1468,7 +1468,7 @@ class TestDict(JitTestCase):
|
||||
|
||||
def test_dictcomprehension_is_typed_from_annotation():
|
||||
metasyntactics = ["foo", "bar", "baz"]
|
||||
x: Dict[str, Optional[int]] = {word: None for word in metasyntactics}
|
||||
x: Dict[str, Optional[int]] = {word: None for word in metasyntactics} # noqa: RUF025
|
||||
return x
|
||||
|
||||
self.checkScript(test_dictcomprehension_is_typed_from_annotation, ())
|
||||
|
@ -8718,7 +8718,7 @@ class TestQuantizeFxOps(QuantizationTestCase):
|
||||
continue
|
||||
# fp16 dynamic quant is not supported for qnnpack
|
||||
|
||||
eager_qconfig_dict = {x : qconfig for x in module_types}
|
||||
eager_qconfig_dict = dict.fromkeys(module_types, qconfig)
|
||||
model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict)
|
||||
|
||||
graph_qconfig_dict = {
|
||||
|
@ -1083,8 +1083,8 @@ class TestDataLoader(TestCase):
|
||||
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
||||
|
||||
def _test_shuffle(self, loader):
|
||||
found_data = {i: 0 for i in range(self.data.size(0))}
|
||||
found_labels = {i: 0 for i in range(self.labels.size(0))}
|
||||
found_data = dict.fromkeys(range(self.data.size(0)), 0)
|
||||
found_labels = dict.fromkeys(range(self.labels.size(0)), 0)
|
||||
batch_size = loader.batch_size
|
||||
if batch_size is None:
|
||||
for i, (batch_samples, batch_targets) in enumerate(loader):
|
||||
|
@ -1066,7 +1066,7 @@ TreeSpec(tuple, None, [*,
|
||||
all_zeros = py_pytree.tree_map_with_path(
|
||||
lambda kp, val: val - kp[1].key + kp[0].idx, tree
|
||||
)
|
||||
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
|
||||
self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
|
||||
|
||||
def test_tree_map_with_path_multiple_trees(self):
|
||||
@dataclass
|
||||
|
@ -1485,15 +1485,7 @@ class TestMakeTensor(TestCase):
|
||||
low_inclusive, high_exclusive = {
|
||||
torch.bool: (0, 2),
|
||||
torch.uint8: (0, 10),
|
||||
**{
|
||||
signed_integral_dtype: (-9, 10)
|
||||
for signed_integral_dtype in [
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]
|
||||
},
|
||||
**dict.fromkeys([torch.int8, torch.int16, torch.int32, torch.int64], (-9, 10)),
|
||||
}.get(dtype, (-9, 9))
|
||||
|
||||
t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
|
||||
|
@ -155,9 +155,8 @@ manual_torch_name_rule_map = {
|
||||
|
||||
|
||||
# In graph functions (including constant folding) that are C bindings
|
||||
torch_c_binding_in_graph_functions = {
|
||||
k: TorchInGraphFunctionVariable
|
||||
for k in [
|
||||
torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
[
|
||||
"math.acos",
|
||||
"math.acosh",
|
||||
"math.asin",
|
||||
@ -2034,8 +2033,9 @@ torch_c_binding_in_graph_functions = {
|
||||
"torch.xlogy",
|
||||
"torch.zero_",
|
||||
"torch.zeros",
|
||||
]
|
||||
}
|
||||
],
|
||||
TorchInGraphFunctionVariable,
|
||||
)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
@ -2046,9 +2046,8 @@ if sys.version_info >= (3, 11):
|
||||
|
||||
|
||||
# In graph functions (including constant folding) that are not C bindings
|
||||
torch_non_c_binding_in_graph_functions = {
|
||||
k: TorchInGraphFunctionVariable
|
||||
for k in [
|
||||
torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
[
|
||||
"torch.__future__.get_overwrite_module_params_on_conversion",
|
||||
"torch.__future__.set_overwrite_module_params_on_conversion",
|
||||
"torch.__getattr__",
|
||||
@ -2717,8 +2716,9 @@ torch_non_c_binding_in_graph_functions = {
|
||||
"torch.typename",
|
||||
"torch.unique_consecutive",
|
||||
"torch.use_deterministic_algorithms",
|
||||
]
|
||||
}
|
||||
],
|
||||
TorchInGraphFunctionVariable,
|
||||
)
|
||||
|
||||
|
||||
torch_name_rule_map = [
|
||||
|
@ -147,7 +147,7 @@ def register_dataclass_as_pytree_node(
|
||||
|
||||
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
|
||||
typ, flat_names, none_names = context
|
||||
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names})
|
||||
return typ(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
|
||||
|
||||
def default_to_dumpable_context(context: Context) -> DumpableContext:
|
||||
return (serialized_type, context[1], context[2])
|
||||
|
@ -291,8 +291,8 @@ def default_partition(
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
else:
|
||||
saved_values.append(node)
|
||||
saved_values = list({k: None for k in saved_values}.keys())
|
||||
saved_sym_nodes = list({k: None for k in saved_sym_nodes}.keys())
|
||||
saved_values = list(dict.fromkeys(saved_values).keys())
|
||||
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
||||
|
||||
return _extract_fwd_bwd_modules(joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)
|
||||
|
||||
|
@ -24,7 +24,7 @@ def get_all_kernel_argdefs(kernels):
|
||||
Any, None
|
||||
] = {} # use a dict rather than set to maintain insertion order
|
||||
for argdefs in argdefs_list:
|
||||
all_argdefs.update({arg: None for arg in argdefs})
|
||||
all_argdefs.update(dict.fromkeys(argdefs))
|
||||
|
||||
return list(all_argdefs.keys())
|
||||
|
||||
|
@ -150,7 +150,7 @@ def reorder_compute_for_overlap(
|
||||
comm_ancestors = {node: get_ancestors(node) for node in comm_nodes}
|
||||
comm_descendants = {node: get_descendants(node) for node in comm_nodes}
|
||||
|
||||
indeg = {k: 0 for k in snodes}
|
||||
indeg = dict.fromkeys(snodes, 0)
|
||||
for snode in snodes:
|
||||
for user in snode.node_users:
|
||||
if user in indeg:
|
||||
|
@ -249,7 +249,7 @@ def _register_foreach_lowering(aten_fn, decomp_fn):
|
||||
|
||||
aten_fns = get_overloads(aten_fn)
|
||||
foreach_ops.update(aten_fns)
|
||||
lowerings.update({fn: wrapped for fn in aten_fns})
|
||||
lowerings.update(dict.fromkeys(aten_fns, wrapped))
|
||||
return wrapped
|
||||
|
||||
|
||||
@ -299,7 +299,7 @@ def _register_lowering(
|
||||
|
||||
aten_fn = get_overloads(aten_fn)
|
||||
|
||||
lowerings.update({fn: wrapped for fn in aten_fn})
|
||||
lowerings.update(dict.fromkeys(aten_fn, wrapped))
|
||||
return wrapped
|
||||
|
||||
|
||||
|
@ -107,7 +107,7 @@ class UnsupportedOperatorException(RuntimeError):
|
||||
|
||||
|
||||
def ordered_set(*items):
|
||||
return {k: True for k in items}
|
||||
return dict.fromkeys(items, True)
|
||||
|
||||
|
||||
_device_not_kwarg_ops = ordered_set(
|
||||
|
@ -296,9 +296,7 @@ class BackendConfig:
|
||||
# e.g. "nccl", "gloo", "ucc", "mpi"
|
||||
supported_devices = Backend.backend_capability[backend.lower()]
|
||||
backend_val = Backend(backend)
|
||||
self.device_backend_map = {
|
||||
device : backend_val for device in supported_devices
|
||||
}
|
||||
self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
|
||||
elif ":" in backend.lower():
|
||||
# Backend specified in "device:backend" format
|
||||
# make sure the backend string is in the correct format
|
||||
|
@ -223,8 +223,8 @@ def start_processes(
|
||||
redirect_std = redirs[local_rank]
|
||||
redirs[local_rank] = redirect_std | tee_std
|
||||
|
||||
stdouts = {local_rank: "" for local_rank in range(nprocs)}
|
||||
stderrs = {local_rank: "" for local_rank in range(nprocs)}
|
||||
stdouts = dict.fromkeys(range(nprocs), "")
|
||||
stderrs = dict.fromkeys(range(nprocs), "")
|
||||
tee_stdouts: Dict[int, str] = {}
|
||||
tee_stderrs: Dict[int, str] = {}
|
||||
error_files = {}
|
||||
|
@ -148,7 +148,7 @@ def to_map(
|
||||
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
|
||||
"""
|
||||
if isinstance(val_or_map, Std):
|
||||
return {i: val_or_map for i in range(local_world_size)}
|
||||
return dict.fromkeys(range(local_world_size), val_or_map)
|
||||
else:
|
||||
map = {}
|
||||
for i in range(local_world_size):
|
||||
@ -674,9 +674,7 @@ class SubprocessContext(PContext):
|
||||
)
|
||||
else:
|
||||
# Populate return with dummy values. This provides consistency with MultiprocessingHandler
|
||||
result.return_values = {
|
||||
local_rank: None for local_rank in range(self.nprocs)
|
||||
}
|
||||
result.return_values = dict.fromkeys(range(self.nprocs))
|
||||
|
||||
return result
|
||||
else: # there are no failures and procs still running
|
||||
|
@ -75,7 +75,7 @@ def _topological_sort_passes(
|
||||
|
||||
# Contruct a graph mapping nodes to a list of their users
|
||||
graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
|
||||
indegree_map: Dict[Callable, int] = {p : 0 for p in passes}
|
||||
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
|
||||
candidates: Queue = Queue()
|
||||
for a in passes:
|
||||
for b in passes:
|
||||
@ -90,7 +90,7 @@ def _topological_sort_passes(
|
||||
if indegree_map[a] == 0:
|
||||
candidates.put(a)
|
||||
|
||||
visited: Dict[Callable, bool] = {p : False for p in passes}
|
||||
visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
|
||||
sorted_passes: List[Callable] = []
|
||||
|
||||
while not candidates.empty():
|
||||
|
@ -243,7 +243,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
Returns:
|
||||
The graph module in-place sorted
|
||||
"""
|
||||
indeg = {node: 0 for node in gm.graph.nodes}
|
||||
indeg = dict.fromkeys(gm.graph.nodes, 0)
|
||||
new_graph = torch.fx.Graph()
|
||||
# Track how many unfulfilled dependencies each node has
|
||||
for node in gm.graph.nodes:
|
||||
|
@ -13,7 +13,7 @@ from torch.fx._compatibility import compatibility
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def topo_sort(nodes: NodeList) -> NodeList:
|
||||
# sort nodes according to the topological order
|
||||
indegree_map = {node : 0 for node in nodes}
|
||||
indegree_map = dict.fromkeys(nodes, 0)
|
||||
candidates: SimpleQueue = SimpleQueue()
|
||||
|
||||
for node in nodes:
|
||||
|
@ -216,8 +216,8 @@ class ElementwiseTypePromotionRule(TypePromotionRule):
|
||||
)
|
||||
|
||||
return TypePromotionSnapshot(
|
||||
{i: consolidated_input_dtype for i in candidate_args.keys()},
|
||||
{name: consolidated_input_dtype for name in candidate_kwargs.keys()},
|
||||
dict.fromkeys(candidate_args.keys(), consolidated_input_dtype),
|
||||
dict.fromkeys(candidate_kwargs.keys(), consolidated_input_dtype),
|
||||
result_dtype,
|
||||
)
|
||||
|
||||
|
@ -71,16 +71,10 @@ _DTYPE_PRECISIONS = {
|
||||
# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in
|
||||
# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values`
|
||||
_DTYPE_PRECISIONS.update(
|
||||
{
|
||||
dtype: _DTYPE_PRECISIONS[torch.float32]
|
||||
for dtype in (
|
||||
torch.quint8,
|
||||
torch.quint2x4,
|
||||
torch.quint4x2,
|
||||
torch.qint8,
|
||||
torch.qint32,
|
||||
dict.fromkeys(
|
||||
(torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32),
|
||||
_DTYPE_PRECISIONS[torch.float32],
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
@ -7273,10 +7273,7 @@ class DistributedTest:
|
||||
for num_early_join_ranks in num_uneven_ranks:
|
||||
for baseline_iter in baseline_num_iters:
|
||||
for offset in iteration_offsets:
|
||||
mapping = {
|
||||
rank: baseline_iter
|
||||
for rank in range(0, num_early_join_ranks)
|
||||
}
|
||||
mapping = dict.fromkeys(range(0, num_early_join_ranks), baseline_iter)
|
||||
# if num_early_join_ranks > 1, ranks > 0 that will join early
|
||||
# iterate offset//2 more times than rank 0, to test nodes
|
||||
# depleting inputs at different times.
|
||||
@ -7285,12 +7282,7 @@ class DistributedTest:
|
||||
if rank > 0:
|
||||
mapping[rank] += offset // 2
|
||||
mapping.update(
|
||||
{
|
||||
rank: baseline_iter + offset
|
||||
for rank in range(
|
||||
num_early_join_ranks, dist.get_world_size()
|
||||
)
|
||||
}
|
||||
dict.fromkeys(range(num_early_join_ranks, dist.get_world_size()), baseline_iter + offset)
|
||||
)
|
||||
iteration_mappings.append(mapping)
|
||||
|
||||
|
@ -282,7 +282,7 @@ def trim_sigfig(x: float, n: int) -> float:
|
||||
|
||||
|
||||
def ordered_unique(elements: Iterable[Any]) -> List[Any]:
|
||||
return list(collections.OrderedDict({i: None for i in elements}).keys())
|
||||
return list(collections.OrderedDict(dict.fromkeys(elements)).keys())
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -273,9 +273,9 @@ class LazyIrProperties:
|
||||
)
|
||||
|
||||
def __init__(self, *default_properties: str):
|
||||
properties: Dict[Tuple[str, ...], Optional[str]] = {
|
||||
p: None for p in LazyIrProperties.Properties
|
||||
}
|
||||
properties: Dict[Tuple[str, ...], Optional[str]] = dict.fromkeys(
|
||||
LazyIrProperties.Properties
|
||||
)
|
||||
self.__dict__["properties"] = properties
|
||||
for p in default_properties:
|
||||
setattr(self, p, True)
|
||||
|
@ -459,7 +459,7 @@ class OrderedSet(Generic[T]):
|
||||
if iterable is None:
|
||||
self.storage = {}
|
||||
else:
|
||||
self.storage = {k: None for k in iterable}
|
||||
self.storage = dict.fromkeys(iterable)
|
||||
|
||||
def __contains__(self, item: T) -> bool:
|
||||
return item in self.storage
|
||||
|
Reference in New Issue
Block a user