[BE]: Apply RUF025 dict.fromkeys preview rule (#118637)

Simplifies and optimizes dict construction using the `fromkeys` classmethod ctor. This also makes it really obvious when all the keys will have the same static value, which could be a bug if unintentional. It is also significantly faster than using a dict comprehension. The rule is in preview, but I am adding a forward fix for when it becomes stable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118637
Approved by: https://github.com/albanD
This commit is contained in:
Aaron Gokaslan
2024-01-30 20:46:48 +00:00
committed by PyTorch MergeBot
parent e33e88e5bc
commit 1562dae62c
26 changed files with 48 additions and 74 deletions

View File

@ -222,7 +222,7 @@ def print_results(results: List[Experiment]):
{
"Type": "Average",
"Speedup": np.mean(speedups),
**{key: None for key in max_config_dict},
**dict.fromkeys(max_config_dict),
},
{"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
{"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},

View File

@ -24,7 +24,7 @@ class DimensionBindError(Exception):
from . import op_properties
# use dict to avoid writing C++ bindings for set
pointwise = {t: True for t in op_properties.pointwise}
pointwise = dict.fromkeys(op_properties.pointwise, True)
use_c = True
if not use_c:

View File

@ -1468,7 +1468,7 @@ class TestDict(JitTestCase):
def test_dictcomprehension_is_typed_from_annotation():
metasyntactics = ["foo", "bar", "baz"]
x: Dict[str, Optional[int]] = {word: None for word in metasyntactics}
x: Dict[str, Optional[int]] = {word: None for word in metasyntactics} # noqa: RUF025
return x
self.checkScript(test_dictcomprehension_is_typed_from_annotation, ())

View File

@ -8718,7 +8718,7 @@ class TestQuantizeFxOps(QuantizationTestCase):
continue
# fp16 dynamic quant is not supported for qnnpack
eager_qconfig_dict = {x : qconfig for x in module_types}
eager_qconfig_dict = dict.fromkeys(module_types, qconfig)
model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict)
graph_qconfig_dict = {

View File

@ -1083,8 +1083,8 @@ class TestDataLoader(TestCase):
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
def _test_shuffle(self, loader):
found_data = {i: 0 for i in range(self.data.size(0))}
found_labels = {i: 0 for i in range(self.labels.size(0))}
found_data = dict.fromkeys(range(self.data.size(0)), 0)
found_labels = dict.fromkeys(range(self.labels.size(0)), 0)
batch_size = loader.batch_size
if batch_size is None:
for i, (batch_samples, batch_targets) in enumerate(loader):

View File

@ -1066,7 +1066,7 @@ TreeSpec(tuple, None, [*,
all_zeros = py_pytree.tree_map_with_path(
lambda kp, val: val - kp[1].key + kp[0].idx, tree
)
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
def test_tree_map_with_path_multiple_trees(self):
@dataclass

View File

@ -1485,15 +1485,7 @@ class TestMakeTensor(TestCase):
low_inclusive, high_exclusive = {
torch.bool: (0, 2),
torch.uint8: (0, 10),
**{
signed_integral_dtype: (-9, 10)
for signed_integral_dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]
},
**dict.fromkeys([torch.int8, torch.int16, torch.int32, torch.int64], (-9, 10)),
}.get(dtype, (-9, 9))
t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)

View File

@ -155,9 +155,8 @@ manual_torch_name_rule_map = {
# In graph functions (including constant folding) that are C bindings
torch_c_binding_in_graph_functions = {
k: TorchInGraphFunctionVariable
for k in [
torch_c_binding_in_graph_functions = dict.fromkeys(
[
"math.acos",
"math.acosh",
"math.asin",
@ -2034,8 +2033,9 @@ torch_c_binding_in_graph_functions = {
"torch.xlogy",
"torch.zero_",
"torch.zeros",
]
}
],
TorchInGraphFunctionVariable,
)
if sys.version_info >= (3, 9):
@ -2046,9 +2046,8 @@ if sys.version_info >= (3, 11):
# In graph functions (including constant folding) that are not C bindings
torch_non_c_binding_in_graph_functions = {
k: TorchInGraphFunctionVariable
for k in [
torch_non_c_binding_in_graph_functions = dict.fromkeys(
[
"torch.__future__.get_overwrite_module_params_on_conversion",
"torch.__future__.set_overwrite_module_params_on_conversion",
"torch.__getattr__",
@ -2717,8 +2716,9 @@ torch_non_c_binding_in_graph_functions = {
"torch.typename",
"torch.unique_consecutive",
"torch.use_deterministic_algorithms",
]
}
],
TorchInGraphFunctionVariable,
)
torch_name_rule_map = [

View File

@ -147,7 +147,7 @@ def register_dataclass_as_pytree_node(
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
typ, flat_names, none_names = context
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names})
return typ(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
def default_to_dumpable_context(context: Context) -> DumpableContext:
return (serialized_type, context[1], context[2])

View File

@ -291,8 +291,8 @@ def default_partition(
saved_sym_nodes.extend(backward_usages)
else:
saved_values.append(node)
saved_values = list({k: None for k in saved_values}.keys())
saved_sym_nodes = list({k: None for k in saved_sym_nodes}.keys())
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
return _extract_fwd_bwd_modules(joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)

View File

@ -24,7 +24,7 @@ def get_all_kernel_argdefs(kernels):
Any, None
] = {} # use a dict rather than set to maintain insertion order
for argdefs in argdefs_list:
all_argdefs.update({arg: None for arg in argdefs})
all_argdefs.update(dict.fromkeys(argdefs))
return list(all_argdefs.keys())

View File

@ -150,7 +150,7 @@ def reorder_compute_for_overlap(
comm_ancestors = {node: get_ancestors(node) for node in comm_nodes}
comm_descendants = {node: get_descendants(node) for node in comm_nodes}
indeg = {k: 0 for k in snodes}
indeg = dict.fromkeys(snodes, 0)
for snode in snodes:
for user in snode.node_users:
if user in indeg:

View File

@ -249,7 +249,7 @@ def _register_foreach_lowering(aten_fn, decomp_fn):
aten_fns = get_overloads(aten_fn)
foreach_ops.update(aten_fns)
lowerings.update({fn: wrapped for fn in aten_fns})
lowerings.update(dict.fromkeys(aten_fns, wrapped))
return wrapped
@ -299,7 +299,7 @@ def _register_lowering(
aten_fn = get_overloads(aten_fn)
lowerings.update({fn: wrapped for fn in aten_fn})
lowerings.update(dict.fromkeys(aten_fn, wrapped))
return wrapped

View File

@ -107,7 +107,7 @@ class UnsupportedOperatorException(RuntimeError):
def ordered_set(*items):
return {k: True for k in items}
return dict.fromkeys(items, True)
_device_not_kwarg_ops = ordered_set(

View File

@ -296,9 +296,7 @@ class BackendConfig:
# e.g. "nccl", "gloo", "ucc", "mpi"
supported_devices = Backend.backend_capability[backend.lower()]
backend_val = Backend(backend)
self.device_backend_map = {
device : backend_val for device in supported_devices
}
self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
elif ":" in backend.lower():
# Backend specified in "device:backend" format
# make sure the backend string is in the correct format

View File

@ -223,8 +223,8 @@ def start_processes(
redirect_std = redirs[local_rank]
redirs[local_rank] = redirect_std | tee_std
stdouts = {local_rank: "" for local_rank in range(nprocs)}
stderrs = {local_rank: "" for local_rank in range(nprocs)}
stdouts = dict.fromkeys(range(nprocs), "")
stderrs = dict.fromkeys(range(nprocs), "")
tee_stdouts: Dict[int, str] = {}
tee_stderrs: Dict[int, str] = {}
error_files = {}

View File

@ -148,7 +148,7 @@ def to_map(
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
"""
if isinstance(val_or_map, Std):
return {i: val_or_map for i in range(local_world_size)}
return dict.fromkeys(range(local_world_size), val_or_map)
else:
map = {}
for i in range(local_world_size):
@ -674,9 +674,7 @@ class SubprocessContext(PContext):
)
else:
# Populate return with dummy values. This provides consistency with MultiprocessingHandler
result.return_values = {
local_rank: None for local_rank in range(self.nprocs)
}
result.return_values = dict.fromkeys(range(self.nprocs))
return result
else: # there are no failures and procs still running

View File

@ -75,7 +75,7 @@ def _topological_sort_passes(
# Contruct a graph mapping nodes to a list of their users
graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
indegree_map: Dict[Callable, int] = {p : 0 for p in passes}
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
candidates: Queue = Queue()
for a in passes:
for b in passes:
@ -90,7 +90,7 @@ def _topological_sort_passes(
if indegree_map[a] == 0:
candidates.put(a)
visited: Dict[Callable, bool] = {p : False for p in passes}
visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
sorted_passes: List[Callable] = []
while not candidates.empty():

View File

@ -243,7 +243,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
Returns:
The graph module in-place sorted
"""
indeg = {node: 0 for node in gm.graph.nodes}
indeg = dict.fromkeys(gm.graph.nodes, 0)
new_graph = torch.fx.Graph()
# Track how many unfulfilled dependencies each node has
for node in gm.graph.nodes:

View File

@ -13,7 +13,7 @@ from torch.fx._compatibility import compatibility
@compatibility(is_backward_compatible=False)
def topo_sort(nodes: NodeList) -> NodeList:
# sort nodes according to the topological order
indegree_map = {node : 0 for node in nodes}
indegree_map = dict.fromkeys(nodes, 0)
candidates: SimpleQueue = SimpleQueue()
for node in nodes:

View File

@ -216,8 +216,8 @@ class ElementwiseTypePromotionRule(TypePromotionRule):
)
return TypePromotionSnapshot(
{i: consolidated_input_dtype for i in candidate_args.keys()},
{name: consolidated_input_dtype for name in candidate_kwargs.keys()},
dict.fromkeys(candidate_args.keys(), consolidated_input_dtype),
dict.fromkeys(candidate_kwargs.keys(), consolidated_input_dtype),
result_dtype,
)

View File

@ -71,16 +71,10 @@ _DTYPE_PRECISIONS = {
# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in
# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values`
_DTYPE_PRECISIONS.update(
{
dtype: _DTYPE_PRECISIONS[torch.float32]
for dtype in (
torch.quint8,
torch.quint2x4,
torch.quint4x2,
torch.qint8,
torch.qint32,
dict.fromkeys(
(torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32),
_DTYPE_PRECISIONS[torch.float32],
)
}
)

View File

@ -7273,10 +7273,7 @@ class DistributedTest:
for num_early_join_ranks in num_uneven_ranks:
for baseline_iter in baseline_num_iters:
for offset in iteration_offsets:
mapping = {
rank: baseline_iter
for rank in range(0, num_early_join_ranks)
}
mapping = dict.fromkeys(range(0, num_early_join_ranks), baseline_iter)
# if num_early_join_ranks > 1, ranks > 0 that will join early
# iterate offset//2 more times than rank 0, to test nodes
# depleting inputs at different times.
@ -7285,12 +7282,7 @@ class DistributedTest:
if rank > 0:
mapping[rank] += offset // 2
mapping.update(
{
rank: baseline_iter + offset
for rank in range(
num_early_join_ranks, dist.get_world_size()
)
}
dict.fromkeys(range(num_early_join_ranks, dist.get_world_size()), baseline_iter + offset)
)
iteration_mappings.append(mapping)

View File

@ -282,7 +282,7 @@ def trim_sigfig(x: float, n: int) -> float:
def ordered_unique(elements: Iterable[Any]) -> List[Any]:
return list(collections.OrderedDict({i: None for i in elements}).keys())
return list(collections.OrderedDict(dict.fromkeys(elements)).keys())
@contextlib.contextmanager

View File

@ -273,9 +273,9 @@ class LazyIrProperties:
)
def __init__(self, *default_properties: str):
properties: Dict[Tuple[str, ...], Optional[str]] = {
p: None for p in LazyIrProperties.Properties
}
properties: Dict[Tuple[str, ...], Optional[str]] = dict.fromkeys(
LazyIrProperties.Properties
)
self.__dict__["properties"] = properties
for p in default_properties:
setattr(self, p, True)

View File

@ -459,7 +459,7 @@ class OrderedSet(Generic[T]):
if iterable is None:
self.storage = {}
else:
self.storage = {k: None for k in iterable}
self.storage = dict.fromkeys(iterable)
def __contains__(self, item: T) -> bool:
return item in self.storage