[BE] Remove unnecessary dict comprehensions (#97116)

Removes unnecessary dict comprehensions that optimize creation of dicts from iterables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116
Approved by: https://github.com/kit1980
This commit is contained in:
Aaron Gokaslan
2023-03-20 00:56:57 +00:00
committed by PyTorch MergeBot
parent be0b415a5a
commit 5471621497
28 changed files with 37 additions and 51 deletions

View File

@ -113,8 +113,8 @@ def main():
# Trains a model for n_inner_iter using the support and returns a loss
# using the query.
def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
params = {k: v for k, v in net.named_parameters()}
buffers = {k: v for k, v in net.named_buffers()}
params = dict(net.named_parameters())
buffers = dict(net.named_buffers())
querysz = x_qry.size(0)
def compute_loss(new_params, buffers, x, y):
@ -139,8 +139,8 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
def train(db, net, device, meta_opt, epoch, log):
params = {k: v for k, v in net.named_parameters()}
buffers = {k: v for k, v in net.named_buffers()}
params = dict(net.named_parameters())
buffers = dict(net.named_buffers())
n_train_iter = db.x_train.shape[0] // db.batchsz
for batch_idx in range(n_train_iter):
@ -186,8 +186,8 @@ def test(db, net, device, epoch, log):
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
params = {k: v for k, v in net.named_parameters()}
buffers = {k: v for k, v in net.named_buffers()}
params = dict(net.named_parameters())
buffers = dict(net.named_buffers())
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []

View File

@ -136,7 +136,7 @@ if True:
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
with open('count_ops.txt', 'r') as f:
opinfo_counts = [i.strip() for i in f.readlines()]
opinfo_counts = defaultdict(int, {k: v for k, v in zip(opinfo_ops, opinfo_counts)})
opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts)))
def count_fn(x):
return opinfo_counts[x['full_name']]

View File

@ -216,7 +216,7 @@ class TestBenchmarkUtils(TestCase):
def __init__(self, stmt, setup, timer, globals):
self._random_state = np.random.RandomState(seed=self._seed)
self._mean_cost = {k: v for k, v in self._function_costs}[stmt]
self._mean_cost = dict(self._function_costs)[stmt]
def sample(self, mean, noise_level):
return max(self._random_state.normal(mean, mean * noise_level), 5e-9)

View File

@ -232,9 +232,7 @@ def compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params):
return cpp_args_construction_stmts, cpp_forward_args_symbols
def serialize_arg_dict_as_script_module(arg_dict):
arg_dict_flat = {arg_name: arg_value
for arg_name, arg_value in
arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args'] + arg_dict['other']}
arg_dict_flat = dict(arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args'] + arg_dict['other'])
arg_dict_module = torch.nn.Module()
for arg_name, arg_value in arg_dict_flat.items():
assert isinstance(arg_value, torch.Tensor)

View File

@ -1098,7 +1098,7 @@ class AbstractCommTest:
rank = dist.get_rank(process_group)
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
rank_to_seq_num = {rank: num for (rank, num) in obj_list}
rank_to_seq_num = dict(obj_list)
self.assertEqual(len(set(rank_to_seq_num.values())), 2)
self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2])
expected_same = {

View File

@ -529,7 +529,7 @@ class ModelOutput(collections.OrderedDict):
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]

View File

@ -89,7 +89,7 @@ class TestTyping(JitTestCase):
def fn():
l1 = [1, 2, "foo", 3]
l2 = ["foo", "bar", "baz", "qux"]
d: Dict[int, str] = {k : v for k, v in zip(l1, l2)}
d: Dict[int, str] = dict(zip(l1, l2))
return d
with self.assertRaisesRegex(RuntimeError, "Dicts may only "
@ -102,7 +102,7 @@ class TestTyping(JitTestCase):
def fn():
l1 = ["foo", "bar", "baz", "qux"]
l2 = [1, 2, "foo", 3]
d: Dict[str, int] = {k : v for k, v in zip(l1, l2)}
d: Dict[str, int] = dict(zip(l1, l2))
return d
with self.assertRaisesRegex(RuntimeError, "Dict type annotation"

View File

@ -177,7 +177,7 @@ class TestIdentifyGradients(TestCase):
def test_extract_gradients_from_module(self) -> None:
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
named_parameters = {name: p for name, p in model.named_parameters()}
named_parameters = dict(model.named_parameters())
self.assertEqual(len(named_parameters), 3)
def assert_only_gradients(prof: torch.profiler.profile):

View File

@ -432,9 +432,7 @@ def replay(filename):
config.replay_record_enabled = False
with open(filename, "rb") as in_file:
record = ExecutionRecord.load(in_file)
record.globals = {
k: v for k, v in itertools.chain(record.globals.items(), globals().items())
}
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
try:
_compile(

View File

@ -71,7 +71,7 @@ pytree._register_pytree_node(
immutable_collections.immutable_dict,
lambda x: (list(x.values()), list(x.keys())),
lambda x, c: immutable_collections.immutable_dict(
{key: value for key, value in zip(c, x)}
dict(zip(c, x))
),
)

View File

@ -600,7 +600,7 @@ method_only_ops = [
def get_nn_functional_top_list():
top_nn_functional_ = {k: v for k, v in top_nn_functional}
top_nn_functional_ = dict(top_nn_functional)
for _, count, functional_name in top_nn_module:
if functional_name is None:
continue

View File

@ -660,7 +660,7 @@ def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
if uuids is None:
raise RuntimeError("Can't get device UUIDs")
visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids)
idx_map = {idx: real_idx for idx, real_idx in enumerate(cast(List[int], visible_devices))}
idx_map = dict(enumerate(cast(List[int], visible_devices)))
if idx not in idx_map:
raise RuntimeError(f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})")
return idx_map[idx]

View File

@ -513,7 +513,7 @@ class MultiprocessContext(PContext):
def pids(self) -> Dict[int, int]:
assert self._pc is not None # assertion for mypy type checking
return {local_rank: pid for local_rank, pid in enumerate(self._pc.pids())}
return dict(enumerate(self._pc.pids()))
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
if not self._pc:

View File

@ -1056,7 +1056,7 @@ def _get_param_id_to_param_from_optim_input(
# Assume the standard case of passing `model.parameters()` to the optimizer
# if `optim_input` is not specified
if optim_input is None:
return {pid: param for pid, param in enumerate(model.parameters())}
return dict(enumerate(model.parameters()))
try:
params = cast(List[nn.Parameter], list(optim_input))
except TypeError as e:
@ -1076,7 +1076,7 @@ def _get_param_id_to_param_from_optim_input(
if not all_tensors and not all_dicts:
raise TypeError("Optimizer input should be an iterable of Tensors or dicts")
if all_tensors:
return {pid: param for pid, param in enumerate(params)}
return dict(enumerate(params))
assert all_dicts
param_id_to_param: List[nn.Parameter] = []
for param_group in params:
@ -1089,7 +1089,7 @@ def _get_param_id_to_param_from_optim_input(
# Implicitly map `flat_param_id` (current length of the list) to
# `param`
param_id_to_param.append(param)
return {pid: param for pid, param in enumerate(param_id_to_param)}
return dict(enumerate(param_id_to_param))
def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[nn.Parameter, str]:

View File

@ -46,7 +46,7 @@ def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> T
if len(kwarg_keys) == 0:
return flat_args, {}
args = flat_args[: -len(kwarg_keys)]
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])}
kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
return args, kwargs
def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies):

View File

@ -627,7 +627,7 @@ class Tracer(TracerBase):
raise RuntimeError(
f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
)
concrete_args = {name: val for name, val in zip(arg_names, concrete_args)}
concrete_args = dict(zip(arg_names, concrete_args))
args.extend(proxy_placeholder(names) for names in arg_names)
if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:

View File

@ -259,12 +259,7 @@ def get_device_to_partitions_mapping(
# Find devices for all the partitions without a device
found_device = True
for partition in no_device_partitions:
device_to_left_mem_bytes = {
d: left_mem_bytes
for d, left_mem_bytes in sorted(
device_to_left_mem_bytes.items(), key=lambda item: item[1]
)
}
device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=lambda item: item[1]))
found_device = find_device_for(partition)
if not found_device:
break

View File

@ -39,7 +39,7 @@ def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
return list(d.values()), list(d.keys())
def _immutable_dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
return immutable_dict({key: value for key, value in zip(context, values)})
return immutable_dict(dict(zip(context, values)))
def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
return d, None

View File

@ -23,7 +23,7 @@ def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
args = expanded_args_and_kwargs[:len(expanded_args_and_kwargs) - len(kwarg_names)]
kwargs = expanded_args_and_kwargs[len(expanded_args_and_kwargs) - len(kwarg_names):]
kwargs = {name: arg for (name, arg) in zip(kwarg_names, kwargs)}
kwargs = dict(zip(kwarg_names, kwargs))
return conv_normalizer(*args, **kwargs)

View File

@ -21,7 +21,7 @@ def standard_kwargs(kwarg_names, expanded_args):
'''
kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names):]
expanded_args_without_kwargs = expanded_args[:len(expanded_args) - len(kwarg_names)]
expanded_kwargs = {name: value for (name, value) in zip(kwarg_names, kwarg_values)}
expanded_kwargs = dict(zip(kwarg_names, kwarg_values))
return expanded_args_without_kwargs, expanded_kwargs
def forward_helper(func, expanded_args, expanded_kwargs):

View File

@ -136,7 +136,7 @@ class PackedSequence(PackedSequence_):
return self
else:
# Does not forward device or dtype arg/kwargs, device is set from data.device
kwargs = {k : v for k, v in filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items())}
kwargs = dict(filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items()))
sorted_indices = bind(self.sorted_indices, lambda t: t.to(data.device, **kwargs))
unsorted_indices = bind(self.unsorted_indices, lambda t: t.to(data.device, **kwargs))
return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices)

View File

@ -391,9 +391,8 @@ class Optimizer:
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {old_id: p for old_id, p in
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}
id_map = dict(zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups))))
def cast(param, value, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""

View File

@ -670,7 +670,7 @@ class MemoryProfile:
allocation_times[(key, is_allocation)] = event.start_time_ns
snapshot = self._category_snapshot()
last_version = {key: version for key, version in sorted(snapshot.keys())}
last_version = dict(sorted(snapshot.keys()))
events: List[Tuple[int, Action, TensorAndID]] = [
(-1, Action.PREEXISTING, (key, version))

View File

@ -767,9 +767,7 @@ class FSDPTest(MultiProcessTestCase):
]
for values in itertools.product(*subtest_config_values):
# Map keyword to chosen value
subtest_kwargs = {
kwarg: value for kwarg, value in zip(subtest_config_keys, values)
}
subtest_kwargs = dict(zip(subtest_config_keys, values))
with self.subTest(**subtest_kwargs):
test_fn(*test_args, **test_kwargs, **subtest_kwargs)
dist.barrier()

View File

@ -410,9 +410,7 @@ class parametrize(_TestParametrizer):
'values and {} names for test "{}"'.format(
len(values), len(self.arg_names), test.__name__))
param_kwargs = {
name: value for name, value in zip(self.arg_names, values)
}
param_kwargs = dict(zip(self.arg_names, values))
test_name = self._get_subtest_name(values, explicit_name=maybe_name)

View File

@ -628,7 +628,7 @@ class DistributedTest:
def _verify_buffers_equal(self, m1, m2):
# verify buffers across models
m1_buf_dict = {k: v for k, v in m1.module.named_buffers()}
m1_buf_dict = dict(m1.module.named_buffers())
for name, buf in m2.module.named_buffers():
self.assertEqual(buf, m1_buf_dict[name])

View File

@ -51,7 +51,7 @@ def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
return list(d.values()), list(d.keys())
def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
return {key: value for key, value in zip(context, values)}
return dict(zip(context, values))
def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
return d, None

View File

@ -264,7 +264,7 @@ def construct_version_maps(
) -> str:
version_map = torch._C._get_operator_version_map()
sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0]) # type: ignore[no-any-return]
sorted_version_map = {name: lst for name, lst in sorted_version_map_}
sorted_version_map = dict(sorted_version_map_)
operator_list_in_version_map_part = []
for op_name in sorted_version_map: