mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE] Remove unnecessary dict comprehensions (#97116)
Removes unnecessary dict comprehensions that optimize creation of dicts from iterables Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116 Approved by: https://github.com/kit1980
This commit is contained in:
committed by
PyTorch MergeBot
parent
be0b415a5a
commit
5471621497
@ -113,8 +113,8 @@ def main():
|
||||
# Trains a model for n_inner_iter using the support and returns a loss
|
||||
# using the query.
|
||||
def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
|
||||
params = {k: v for k, v in net.named_parameters()}
|
||||
buffers = {k: v for k, v in net.named_buffers()}
|
||||
params = dict(net.named_parameters())
|
||||
buffers = dict(net.named_buffers())
|
||||
querysz = x_qry.size(0)
|
||||
|
||||
def compute_loss(new_params, buffers, x, y):
|
||||
@ -139,8 +139,8 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
|
||||
|
||||
|
||||
def train(db, net, device, meta_opt, epoch, log):
|
||||
params = {k: v for k, v in net.named_parameters()}
|
||||
buffers = {k: v for k, v in net.named_buffers()}
|
||||
params = dict(net.named_parameters())
|
||||
buffers = dict(net.named_buffers())
|
||||
n_train_iter = db.x_train.shape[0] // db.batchsz
|
||||
|
||||
for batch_idx in range(n_train_iter):
|
||||
@ -186,8 +186,8 @@ def test(db, net, device, epoch, log):
|
||||
# Most research papers using MAML for this task do an extra
|
||||
# stage of fine-tuning here that should be added if you are
|
||||
# adapting this code for research.
|
||||
params = {k: v for k, v in net.named_parameters()}
|
||||
buffers = {k: v for k, v in net.named_buffers()}
|
||||
params = dict(net.named_parameters())
|
||||
buffers = dict(net.named_buffers())
|
||||
n_test_iter = db.x_test.shape[0] // db.batchsz
|
||||
|
||||
qry_losses = []
|
||||
|
@ -136,7 +136,7 @@ if True:
|
||||
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
|
||||
with open('count_ops.txt', 'r') as f:
|
||||
opinfo_counts = [i.strip() for i in f.readlines()]
|
||||
opinfo_counts = defaultdict(int, {k: v for k, v in zip(opinfo_ops, opinfo_counts)})
|
||||
opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts)))
|
||||
|
||||
def count_fn(x):
|
||||
return opinfo_counts[x['full_name']]
|
||||
|
@ -216,7 +216,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
|
||||
def __init__(self, stmt, setup, timer, globals):
|
||||
self._random_state = np.random.RandomState(seed=self._seed)
|
||||
self._mean_cost = {k: v for k, v in self._function_costs}[stmt]
|
||||
self._mean_cost = dict(self._function_costs)[stmt]
|
||||
|
||||
def sample(self, mean, noise_level):
|
||||
return max(self._random_state.normal(mean, mean * noise_level), 5e-9)
|
||||
|
@ -232,9 +232,7 @@ def compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params):
|
||||
return cpp_args_construction_stmts, cpp_forward_args_symbols
|
||||
|
||||
def serialize_arg_dict_as_script_module(arg_dict):
|
||||
arg_dict_flat = {arg_name: arg_value
|
||||
for arg_name, arg_value in
|
||||
arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args'] + arg_dict['other']}
|
||||
arg_dict_flat = dict(arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args'] + arg_dict['other'])
|
||||
arg_dict_module = torch.nn.Module()
|
||||
for arg_name, arg_value in arg_dict_flat.items():
|
||||
assert isinstance(arg_value, torch.Tensor)
|
||||
|
@ -1098,7 +1098,7 @@ class AbstractCommTest:
|
||||
rank = dist.get_rank(process_group)
|
||||
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
|
||||
dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
|
||||
rank_to_seq_num = {rank: num for (rank, num) in obj_list}
|
||||
rank_to_seq_num = dict(obj_list)
|
||||
self.assertEqual(len(set(rank_to_seq_num.values())), 2)
|
||||
self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2])
|
||||
expected_same = {
|
||||
|
@ -529,7 +529,7 @@ class ModelOutput(collections.OrderedDict):
|
||||
|
||||
def __getitem__(self, k):
|
||||
if isinstance(k, str):
|
||||
inner_dict = {k: v for (k, v) in self.items()}
|
||||
inner_dict = dict(self.items())
|
||||
return inner_dict[k]
|
||||
else:
|
||||
return self.to_tuple()[k]
|
||||
|
@ -89,7 +89,7 @@ class TestTyping(JitTestCase):
|
||||
def fn():
|
||||
l1 = [1, 2, "foo", 3]
|
||||
l2 = ["foo", "bar", "baz", "qux"]
|
||||
d: Dict[int, str] = {k : v for k, v in zip(l1, l2)}
|
||||
d: Dict[int, str] = dict(zip(l1, l2))
|
||||
return d
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Dicts may only "
|
||||
@ -102,7 +102,7 @@ class TestTyping(JitTestCase):
|
||||
def fn():
|
||||
l1 = ["foo", "bar", "baz", "qux"]
|
||||
l2 = [1, 2, "foo", 3]
|
||||
d: Dict[str, int] = {k : v for k, v in zip(l1, l2)}
|
||||
d: Dict[str, int] = dict(zip(l1, l2))
|
||||
return d
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
|
||||
|
@ -177,7 +177,7 @@ class TestIdentifyGradients(TestCase):
|
||||
|
||||
def test_extract_gradients_from_module(self) -> None:
|
||||
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
|
||||
named_parameters = {name: p for name, p in model.named_parameters()}
|
||||
named_parameters = dict(model.named_parameters())
|
||||
self.assertEqual(len(named_parameters), 3)
|
||||
|
||||
def assert_only_gradients(prof: torch.profiler.profile):
|
||||
|
@ -432,9 +432,7 @@ def replay(filename):
|
||||
config.replay_record_enabled = False
|
||||
with open(filename, "rb") as in_file:
|
||||
record = ExecutionRecord.load(in_file)
|
||||
record.globals = {
|
||||
k: v for k, v in itertools.chain(record.globals.items(), globals().items())
|
||||
}
|
||||
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
|
||||
|
||||
try:
|
||||
_compile(
|
||||
|
@ -71,7 +71,7 @@ pytree._register_pytree_node(
|
||||
immutable_collections.immutable_dict,
|
||||
lambda x: (list(x.values()), list(x.keys())),
|
||||
lambda x, c: immutable_collections.immutable_dict(
|
||||
{key: value for key, value in zip(c, x)}
|
||||
dict(zip(c, x))
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -600,7 +600,7 @@ method_only_ops = [
|
||||
|
||||
|
||||
def get_nn_functional_top_list():
|
||||
top_nn_functional_ = {k: v for k, v in top_nn_functional}
|
||||
top_nn_functional_ = dict(top_nn_functional)
|
||||
for _, count, functional_name in top_nn_module:
|
||||
if functional_name is None:
|
||||
continue
|
||||
|
@ -660,7 +660,7 @@ def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
|
||||
if uuids is None:
|
||||
raise RuntimeError("Can't get device UUIDs")
|
||||
visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids)
|
||||
idx_map = {idx: real_idx for idx, real_idx in enumerate(cast(List[int], visible_devices))}
|
||||
idx_map = dict(enumerate(cast(List[int], visible_devices)))
|
||||
if idx not in idx_map:
|
||||
raise RuntimeError(f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})")
|
||||
return idx_map[idx]
|
||||
|
@ -513,7 +513,7 @@ class MultiprocessContext(PContext):
|
||||
|
||||
def pids(self) -> Dict[int, int]:
|
||||
assert self._pc is not None # assertion for mypy type checking
|
||||
return {local_rank: pid for local_rank, pid in enumerate(self._pc.pids())}
|
||||
return dict(enumerate(self._pc.pids()))
|
||||
|
||||
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
|
||||
if not self._pc:
|
||||
|
@ -1056,7 +1056,7 @@ def _get_param_id_to_param_from_optim_input(
|
||||
# Assume the standard case of passing `model.parameters()` to the optimizer
|
||||
# if `optim_input` is not specified
|
||||
if optim_input is None:
|
||||
return {pid: param for pid, param in enumerate(model.parameters())}
|
||||
return dict(enumerate(model.parameters()))
|
||||
try:
|
||||
params = cast(List[nn.Parameter], list(optim_input))
|
||||
except TypeError as e:
|
||||
@ -1076,7 +1076,7 @@ def _get_param_id_to_param_from_optim_input(
|
||||
if not all_tensors and not all_dicts:
|
||||
raise TypeError("Optimizer input should be an iterable of Tensors or dicts")
|
||||
if all_tensors:
|
||||
return {pid: param for pid, param in enumerate(params)}
|
||||
return dict(enumerate(params))
|
||||
assert all_dicts
|
||||
param_id_to_param: List[nn.Parameter] = []
|
||||
for param_group in params:
|
||||
@ -1089,7 +1089,7 @@ def _get_param_id_to_param_from_optim_input(
|
||||
# Implicitly map `flat_param_id` (current length of the list) to
|
||||
# `param`
|
||||
param_id_to_param.append(param)
|
||||
return {pid: param for pid, param in enumerate(param_id_to_param)}
|
||||
return dict(enumerate(param_id_to_param))
|
||||
|
||||
|
||||
def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[nn.Parameter, str]:
|
||||
|
@ -46,7 +46,7 @@ def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> T
|
||||
if len(kwarg_keys) == 0:
|
||||
return flat_args, {}
|
||||
args = flat_args[: -len(kwarg_keys)]
|
||||
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])}
|
||||
kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
|
||||
return args, kwargs
|
||||
|
||||
def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies):
|
||||
|
@ -627,7 +627,7 @@ class Tracer(TracerBase):
|
||||
raise RuntimeError(
|
||||
f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
|
||||
)
|
||||
concrete_args = {name: val for name, val in zip(arg_names, concrete_args)}
|
||||
concrete_args = dict(zip(arg_names, concrete_args))
|
||||
args.extend(proxy_placeholder(names) for names in arg_names)
|
||||
|
||||
if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
|
||||
|
@ -259,12 +259,7 @@ def get_device_to_partitions_mapping(
|
||||
# Find devices for all the partitions without a device
|
||||
found_device = True
|
||||
for partition in no_device_partitions:
|
||||
device_to_left_mem_bytes = {
|
||||
d: left_mem_bytes
|
||||
for d, left_mem_bytes in sorted(
|
||||
device_to_left_mem_bytes.items(), key=lambda item: item[1]
|
||||
)
|
||||
}
|
||||
device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=lambda item: item[1]))
|
||||
found_device = find_device_for(partition)
|
||||
if not found_device:
|
||||
break
|
||||
|
@ -39,7 +39,7 @@ def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
||||
return list(d.values()), list(d.keys())
|
||||
|
||||
def _immutable_dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
|
||||
return immutable_dict({key: value for key, value in zip(context, values)})
|
||||
return immutable_dict(dict(zip(context, values)))
|
||||
|
||||
def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
|
||||
return d, None
|
||||
|
@ -23,7 +23,7 @@ def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
|
||||
def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
|
||||
args = expanded_args_and_kwargs[:len(expanded_args_and_kwargs) - len(kwarg_names)]
|
||||
kwargs = expanded_args_and_kwargs[len(expanded_args_and_kwargs) - len(kwarg_names):]
|
||||
kwargs = {name: arg for (name, arg) in zip(kwarg_names, kwargs)}
|
||||
kwargs = dict(zip(kwarg_names, kwargs))
|
||||
|
||||
return conv_normalizer(*args, **kwargs)
|
||||
|
||||
|
@ -21,7 +21,7 @@ def standard_kwargs(kwarg_names, expanded_args):
|
||||
'''
|
||||
kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names):]
|
||||
expanded_args_without_kwargs = expanded_args[:len(expanded_args) - len(kwarg_names)]
|
||||
expanded_kwargs = {name: value for (name, value) in zip(kwarg_names, kwarg_values)}
|
||||
expanded_kwargs = dict(zip(kwarg_names, kwarg_values))
|
||||
return expanded_args_without_kwargs, expanded_kwargs
|
||||
|
||||
def forward_helper(func, expanded_args, expanded_kwargs):
|
||||
|
@ -136,7 +136,7 @@ class PackedSequence(PackedSequence_):
|
||||
return self
|
||||
else:
|
||||
# Does not forward device or dtype arg/kwargs, device is set from data.device
|
||||
kwargs = {k : v for k, v in filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items())}
|
||||
kwargs = dict(filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items()))
|
||||
sorted_indices = bind(self.sorted_indices, lambda t: t.to(data.device, **kwargs))
|
||||
unsorted_indices = bind(self.unsorted_indices, lambda t: t.to(data.device, **kwargs))
|
||||
return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices)
|
||||
|
@ -391,9 +391,8 @@ class Optimizer:
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Update the state
|
||||
id_map = {old_id: p for old_id, p in
|
||||
zip(chain.from_iterable((g['params'] for g in saved_groups)),
|
||||
chain.from_iterable((g['params'] for g in groups)))}
|
||||
id_map = dict(zip(chain.from_iterable((g['params'] for g in saved_groups)),
|
||||
chain.from_iterable((g['params'] for g in groups))))
|
||||
|
||||
def cast(param, value, key=None):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
|
@ -670,7 +670,7 @@ class MemoryProfile:
|
||||
allocation_times[(key, is_allocation)] = event.start_time_ns
|
||||
|
||||
snapshot = self._category_snapshot()
|
||||
last_version = {key: version for key, version in sorted(snapshot.keys())}
|
||||
last_version = dict(sorted(snapshot.keys()))
|
||||
|
||||
events: List[Tuple[int, Action, TensorAndID]] = [
|
||||
(-1, Action.PREEXISTING, (key, version))
|
||||
|
@ -767,9 +767,7 @@ class FSDPTest(MultiProcessTestCase):
|
||||
]
|
||||
for values in itertools.product(*subtest_config_values):
|
||||
# Map keyword to chosen value
|
||||
subtest_kwargs = {
|
||||
kwarg: value for kwarg, value in zip(subtest_config_keys, values)
|
||||
}
|
||||
subtest_kwargs = dict(zip(subtest_config_keys, values))
|
||||
with self.subTest(**subtest_kwargs):
|
||||
test_fn(*test_args, **test_kwargs, **subtest_kwargs)
|
||||
dist.barrier()
|
||||
|
@ -410,9 +410,7 @@ class parametrize(_TestParametrizer):
|
||||
'values and {} names for test "{}"'.format(
|
||||
len(values), len(self.arg_names), test.__name__))
|
||||
|
||||
param_kwargs = {
|
||||
name: value for name, value in zip(self.arg_names, values)
|
||||
}
|
||||
param_kwargs = dict(zip(self.arg_names, values))
|
||||
|
||||
test_name = self._get_subtest_name(values, explicit_name=maybe_name)
|
||||
|
||||
|
@ -628,7 +628,7 @@ class DistributedTest:
|
||||
|
||||
def _verify_buffers_equal(self, m1, m2):
|
||||
# verify buffers across models
|
||||
m1_buf_dict = {k: v for k, v in m1.module.named_buffers()}
|
||||
m1_buf_dict = dict(m1.module.named_buffers())
|
||||
for name, buf in m2.module.named_buffers():
|
||||
self.assertEqual(buf, m1_buf_dict[name])
|
||||
|
||||
|
@ -51,7 +51,7 @@ def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
||||
return list(d.values()), list(d.keys())
|
||||
|
||||
def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
|
||||
return {key: value for key, value in zip(context, values)}
|
||||
return dict(zip(context, values))
|
||||
|
||||
def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
|
||||
return d, None
|
||||
|
@ -264,7 +264,7 @@ def construct_version_maps(
|
||||
) -> str:
|
||||
version_map = torch._C._get_operator_version_map()
|
||||
sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0]) # type: ignore[no-any-return]
|
||||
sorted_version_map = {name: lst for name, lst in sorted_version_map_}
|
||||
sorted_version_map = dict(sorted_version_map_)
|
||||
|
||||
operator_list_in_version_map_part = []
|
||||
for op_name in sorted_version_map:
|
||||
|
Reference in New Issue
Block a user