From 5471621497ea0068bd453d251bf5ec2621e8119f Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Mon, 20 Mar 2023 00:56:57 +0000 Subject: [PATCH] [BE] Remove unnecessary dict comprehensions (#97116) Removes unnecessary dict comprehensions that optimize creation of dicts from iterables Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116 Approved by: https://github.com/kit1980 --- .../maml_omniglot/maml-omniglot-transforms.py | 12 ++++++------ functorch/op_analysis/gen_data.py | 2 +- test/benchmark_utils/test_benchmark_utils.py | 2 +- test/cpp_api_parity/utils.py | 4 +--- test/distributed/test_c10d_common.py | 2 +- test/dynamo/test_repros.py | 2 +- test/jit/test_typing.py | 4 ++-- test/profiler/test_memory_profiler.py | 2 +- torch/_dynamo/convert_frame.py | 4 +--- torch/_functorch/aot_autograd.py | 2 +- torch/_functorch/top_operators_github_usage.py | 2 +- torch/cuda/__init__.py | 2 +- torch/distributed/elastic/multiprocessing/api.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 6 +++--- torch/distributed/utils.py | 2 +- torch/fx/_symbolic_trace.py | 2 +- torch/fx/experimental/accelerator_partitioner.py | 7 +------ torch/fx/immutable_collections.py | 2 +- torch/nn/utils/_expanded_weights/conv_utils.py | 2 +- .../_expanded_weights/expanded_weights_utils.py | 2 +- torch/nn/utils/rnn.py | 2 +- torch/optim/optimizer.py | 5 ++--- torch/profiler/_memory_profiler.py | 2 +- torch/testing/_internal/common_fsdp.py | 4 +--- torch/testing/_internal/common_utils.py | 4 +--- .../_internal/distributed/distributed_test.py | 2 +- torch/utils/_pytree.py | 2 +- torchgen/operator_versions/gen_mobile_upgraders.py | 2 +- 28 files changed, 37 insertions(+), 51 deletions(-) diff --git a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py index 890fcf38f9db..7883d77aaff7 100755 --- a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py +++ b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py @@ -113,8 +113,8 @@ def main(): # Trains a model for n_inner_iter using the support and returns a loss # using the query. def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry): - params = {k: v for k, v in net.named_parameters()} - buffers = {k: v for k, v in net.named_buffers()} + params = dict(net.named_parameters()) + buffers = dict(net.named_buffers()) querysz = x_qry.size(0) def compute_loss(new_params, buffers, x, y): @@ -139,8 +139,8 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry): def train(db, net, device, meta_opt, epoch, log): - params = {k: v for k, v in net.named_parameters()} - buffers = {k: v for k, v in net.named_buffers()} + params = dict(net.named_parameters()) + buffers = dict(net.named_buffers()) n_train_iter = db.x_train.shape[0] // db.batchsz for batch_idx in range(n_train_iter): @@ -186,8 +186,8 @@ def test(db, net, device, epoch, log): # Most research papers using MAML for this task do an extra # stage of fine-tuning here that should be added if you are # adapting this code for research. - params = {k: v for k, v in net.named_parameters()} - buffers = {k: v for k, v in net.named_buffers()} + params = dict(net.named_parameters()) + buffers = dict(net.named_buffers()) n_test_iter = db.x_test.shape[0] // db.batchsz qry_losses = [] diff --git a/functorch/op_analysis/gen_data.py b/functorch/op_analysis/gen_data.py index 71502ae84ff9..ca49fe5f4f20 100644 --- a/functorch/op_analysis/gen_data.py +++ b/functorch/op_analysis/gen_data.py @@ -136,7 +136,7 @@ if True: opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] with open('count_ops.txt', 'r') as f: opinfo_counts = [i.strip() for i in f.readlines()] - opinfo_counts = defaultdict(int, {k: v for k, v in zip(opinfo_ops, opinfo_counts)}) + opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts))) def count_fn(x): return opinfo_counts[x['full_name']] diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py index a1e2adaacfa9..8e8bb1514ed8 100644 --- a/test/benchmark_utils/test_benchmark_utils.py +++ b/test/benchmark_utils/test_benchmark_utils.py @@ -216,7 +216,7 @@ class TestBenchmarkUtils(TestCase): def __init__(self, stmt, setup, timer, globals): self._random_state = np.random.RandomState(seed=self._seed) - self._mean_cost = {k: v for k, v in self._function_costs}[stmt] + self._mean_cost = dict(self._function_costs)[stmt] def sample(self, mean, noise_level): return max(self._random_state.normal(mean, mean * noise_level), 5e-9) diff --git a/test/cpp_api_parity/utils.py b/test/cpp_api_parity/utils.py index 226f750a091b..62a85b0f6959 100644 --- a/test/cpp_api_parity/utils.py +++ b/test/cpp_api_parity/utils.py @@ -232,9 +232,7 @@ def compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params): return cpp_args_construction_stmts, cpp_forward_args_symbols def serialize_arg_dict_as_script_module(arg_dict): - arg_dict_flat = {arg_name: arg_value - for arg_name, arg_value in - arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args'] + arg_dict['other']} + arg_dict_flat = dict(arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args'] + arg_dict['other']) arg_dict_module = torch.nn.Module() for arg_name, arg_value in arg_dict_flat.items(): assert isinstance(arg_value, torch.Tensor) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index a91288320a64..973bcc92b24e 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1098,7 +1098,7 @@ class AbstractCommTest: rank = dist.get_rank(process_group) obj_list = [None for _ in range(dist.get_world_size(verify_pg))] dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg) - rank_to_seq_num = {rank: num for (rank, num) in obj_list} + rank_to_seq_num = dict(obj_list) self.assertEqual(len(set(rank_to_seq_num.values())), 2) self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2]) expected_same = { diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index c8c71dbd854e..cc85a65fa83f 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -529,7 +529,7 @@ class ModelOutput(collections.OrderedDict): def __getitem__(self, k): if isinstance(k, str): - inner_dict = {k: v for (k, v) in self.items()} + inner_dict = dict(self.items()) return inner_dict[k] else: return self.to_tuple()[k] diff --git a/test/jit/test_typing.py b/test/jit/test_typing.py index a461199b5d94..30120f70639b 100644 --- a/test/jit/test_typing.py +++ b/test/jit/test_typing.py @@ -89,7 +89,7 @@ class TestTyping(JitTestCase): def fn(): l1 = [1, 2, "foo", 3] l2 = ["foo", "bar", "baz", "qux"] - d: Dict[int, str] = {k : v for k, v in zip(l1, l2)} + d: Dict[int, str] = dict(zip(l1, l2)) return d with self.assertRaisesRegex(RuntimeError, "Dicts may only " @@ -102,7 +102,7 @@ class TestTyping(JitTestCase): def fn(): l1 = ["foo", "bar", "baz", "qux"] l2 = [1, 2, "foo", 3] - d: Dict[str, int] = {k : v for k, v in zip(l1, l2)} + d: Dict[str, int] = dict(zip(l1, l2)) return d with self.assertRaisesRegex(RuntimeError, "Dict type annotation" diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index 70b21b6b610f..e41a28348324 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -177,7 +177,7 @@ class TestIdentifyGradients(TestCase): def test_extract_gradients_from_module(self) -> None: model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer()) - named_parameters = {name: p for name, p in model.named_parameters()} + named_parameters = dict(model.named_parameters()) self.assertEqual(len(named_parameters), 3) def assert_only_gradients(prof: torch.profiler.profile): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index d9e652c48cf5..a75737e3aaee 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -432,9 +432,7 @@ def replay(filename): config.replay_record_enabled = False with open(filename, "rb") as in_file: record = ExecutionRecord.load(in_file) - record.globals = { - k: v for k, v in itertools.chain(record.globals.items(), globals().items()) - } + record.globals = dict(itertools.chain(record.globals.items(), globals().items())) try: _compile( diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index c2ab300bfd02..c1b7d2ab27d4 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -71,7 +71,7 @@ pytree._register_pytree_node( immutable_collections.immutable_dict, lambda x: (list(x.values()), list(x.keys())), lambda x, c: immutable_collections.immutable_dict( - {key: value for key, value in zip(c, x)} + dict(zip(c, x)) ), ) diff --git a/torch/_functorch/top_operators_github_usage.py b/torch/_functorch/top_operators_github_usage.py index 9161f98d66fa..0e361ad3a1cb 100644 --- a/torch/_functorch/top_operators_github_usage.py +++ b/torch/_functorch/top_operators_github_usage.py @@ -600,7 +600,7 @@ method_only_ops = [ def get_nn_functional_top_list(): - top_nn_functional_ = {k: v for k, v in top_nn_functional} + top_nn_functional_ = dict(top_nn_functional) for _, count, functional_name in top_nn_module: if functional_name is None: continue diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index c20036338891..b9c37afa0611 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -660,7 +660,7 @@ def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int: if uuids is None: raise RuntimeError("Can't get device UUIDs") visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids) - idx_map = {idx: real_idx for idx, real_idx in enumerate(cast(List[int], visible_devices))} + idx_map = dict(enumerate(cast(List[int], visible_devices))) if idx not in idx_map: raise RuntimeError(f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})") return idx_map[idx] diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index fde50a686964..877b7f95c74e 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -513,7 +513,7 @@ class MultiprocessContext(PContext): def pids(self) -> Dict[int, int]: assert self._pc is not None # assertion for mypy type checking - return {local_rank: pid for local_rank, pid in enumerate(self._pc.pids())} + return dict(enumerate(self._pc.pids())) def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: if not self._pc: diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 6f2502734a5d..129b4c410b6f 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1056,7 +1056,7 @@ def _get_param_id_to_param_from_optim_input( # Assume the standard case of passing `model.parameters()` to the optimizer # if `optim_input` is not specified if optim_input is None: - return {pid: param for pid, param in enumerate(model.parameters())} + return dict(enumerate(model.parameters())) try: params = cast(List[nn.Parameter], list(optim_input)) except TypeError as e: @@ -1076,7 +1076,7 @@ def _get_param_id_to_param_from_optim_input( if not all_tensors and not all_dicts: raise TypeError("Optimizer input should be an iterable of Tensors or dicts") if all_tensors: - return {pid: param for pid, param in enumerate(params)} + return dict(enumerate(params)) assert all_dicts param_id_to_param: List[nn.Parameter] = [] for param_group in params: @@ -1089,7 +1089,7 @@ def _get_param_id_to_param_from_optim_input( # Implicitly map `flat_param_id` (current length of the list) to # `param` param_id_to_param.append(param) - return {pid: param for pid, param in enumerate(param_id_to_param)} + return dict(enumerate(param_id_to_param)) def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[nn.Parameter, str]: diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index 5848c0ecab0e..6c3548f0f786 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -46,7 +46,7 @@ def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> T if len(kwarg_keys) == 0: return flat_args, {} args = flat_args[: -len(kwarg_keys)] - kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :])) return args, kwargs def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies): diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index a88dc3e90adc..b580866de5c4 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -627,7 +627,7 @@ class Tracer(TracerBase): raise RuntimeError( f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" ) - concrete_args = {name: val for name, val in zip(arg_names, concrete_args)} + concrete_args = dict(zip(arg_names, concrete_args)) args.extend(proxy_placeholder(names) for names in arg_names) if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index 3b5d5afe0f20..26cb19ff5914 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -259,12 +259,7 @@ def get_device_to_partitions_mapping( # Find devices for all the partitions without a device found_device = True for partition in no_device_partitions: - device_to_left_mem_bytes = { - d: left_mem_bytes - for d, left_mem_bytes in sorted( - device_to_left_mem_bytes.items(), key=lambda item: item[1] - ) - } + device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=lambda item: item[1])) found_device = find_device_for(partition) if not found_device: break diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py index 063ce8bafef9..73478701cbea 100644 --- a/torch/fx/immutable_collections.py +++ b/torch/fx/immutable_collections.py @@ -39,7 +39,7 @@ def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: return list(d.values()), list(d.keys()) def _immutable_dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return immutable_dict({key: value for key, value in zip(context, values)}) + return immutable_dict(dict(zip(context, values))) def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: return d, None diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index 5cea94b44933..0af61f4ba66e 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -23,7 +23,7 @@ def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt): def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs): args = expanded_args_and_kwargs[:len(expanded_args_and_kwargs) - len(kwarg_names)] kwargs = expanded_args_and_kwargs[len(expanded_args_and_kwargs) - len(kwarg_names):] - kwargs = {name: arg for (name, arg) in zip(kwarg_names, kwargs)} + kwargs = dict(zip(kwarg_names, kwargs)) return conv_normalizer(*args, **kwargs) diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index b3c91481c18c..376bd9507bb6 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -21,7 +21,7 @@ def standard_kwargs(kwarg_names, expanded_args): ''' kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names):] expanded_args_without_kwargs = expanded_args[:len(expanded_args) - len(kwarg_names)] - expanded_kwargs = {name: value for (name, value) in zip(kwarg_names, kwarg_values)} + expanded_kwargs = dict(zip(kwarg_names, kwarg_values)) return expanded_args_without_kwargs, expanded_kwargs def forward_helper(func, expanded_args, expanded_kwargs): diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index e9441a50d42d..bf2eedb86ccc 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -136,7 +136,7 @@ class PackedSequence(PackedSequence_): return self else: # Does not forward device or dtype arg/kwargs, device is set from data.device - kwargs = {k : v for k, v in filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items())} + kwargs = dict(filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items())) sorted_indices = bind(self.sorted_indices, lambda t: t.to(data.device, **kwargs)) unsorted_indices = bind(self.unsorted_indices, lambda t: t.to(data.device, **kwargs)) return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 99954c2363bd..cf7cbd8fc2a5 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -391,9 +391,8 @@ class Optimizer: "that doesn't match the size of optimizer's group") # Update the state - id_map = {old_id: p for old_id, p in - zip(chain.from_iterable((g['params'] for g in saved_groups)), - chain.from_iterable((g['params'] for g in groups)))} + id_map = dict(zip(chain.from_iterable((g['params'] for g in saved_groups)), + chain.from_iterable((g['params'] for g in groups)))) def cast(param, value, key=None): r"""Make a deep copy of value, casting all tensors to device of param.""" diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index f80a38091bac..b1cf8bb73c54 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -670,7 +670,7 @@ class MemoryProfile: allocation_times[(key, is_allocation)] = event.start_time_ns snapshot = self._category_snapshot() - last_version = {key: version for key, version in sorted(snapshot.keys())} + last_version = dict(sorted(snapshot.keys())) events: List[Tuple[int, Action, TensorAndID]] = [ (-1, Action.PREEXISTING, (key, version)) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 02725f2eede4..75e648ff25b3 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -767,9 +767,7 @@ class FSDPTest(MultiProcessTestCase): ] for values in itertools.product(*subtest_config_values): # Map keyword to chosen value - subtest_kwargs = { - kwarg: value for kwarg, value in zip(subtest_config_keys, values) - } + subtest_kwargs = dict(zip(subtest_config_keys, values)) with self.subTest(**subtest_kwargs): test_fn(*test_args, **test_kwargs, **subtest_kwargs) dist.barrier() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index f6aa895d56ca..a0cfb2f06de4 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -410,9 +410,7 @@ class parametrize(_TestParametrizer): 'values and {} names for test "{}"'.format( len(values), len(self.arg_names), test.__name__)) - param_kwargs = { - name: value for name, value in zip(self.arg_names, values) - } + param_kwargs = dict(zip(self.arg_names, values)) test_name = self._get_subtest_name(values, explicit_name=maybe_name) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index cd9fc12088d0..8bec13962630 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -628,7 +628,7 @@ class DistributedTest: def _verify_buffers_equal(self, m1, m2): # verify buffers across models - m1_buf_dict = {k: v for k, v in m1.module.named_buffers()} + m1_buf_dict = dict(m1.module.named_buffers()) for name, buf in m2.module.named_buffers(): self.assertEqual(buf, m1_buf_dict[name]) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 66dfb9a8a7e4..0b47e9f0adba 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -51,7 +51,7 @@ def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: return list(d.values()), list(d.keys()) def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} + return dict(zip(context, values)) def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: return d, None diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index 5e4ff6234daf..13910db85c98 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -264,7 +264,7 @@ def construct_version_maps( ) -> str: version_map = torch._C._get_operator_version_map() sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0]) # type: ignore[no-any-return] - sorted_version_map = {name: lst for name, lst in sorted_version_map_} + sorted_version_map = dict(sorted_version_map_) operator_list_in_version_map_part = [] for op_name in sorted_version_map: