[BE]: Simplify some list comps to generators C419 (#132578)

Simplifies some list comprehensions to generator which is more efficient. Automatically applied diffs for the most part with ruff

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132578
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2024-08-04 17:46:24 +00:00
committed by PyTorch MergeBot
parent 4226ed1585
commit fd4b649e6c
7 changed files with 26 additions and 42 deletions

View File

@ -162,10 +162,8 @@ def _get_model_size(model):
for name, child in model.named_children():
if not isinstance(child, torch.nn.Embedding):
model_size += sum(
[
p.numel() * p.dtype.itemsize
for p in itertools.chain(child.parameters(), child.buffers())
]
p.numel() * p.dtype.itemsize
for p in itertools.chain(child.parameters(), child.buffers())
)
# Remove the inactivated experts from the model size if this is mixture of experts
@ -178,12 +176,10 @@ def _get_model_size(model):
):
model_size -= (
sum(
[
p.numel() * p.dtype.itemsize
for p in itertools.chain(
submodule.parameters(), child.buffers()
)
]
p.numel() * p.dtype.itemsize
for p in itertools.chain(
submodule.parameters(), child.buffers()
)
)
* (config.num_experts - config.num_activated_experts)
/ config.num_experts

View File

@ -149,7 +149,7 @@
"inps = [torch.randn(3), torch.randn(3)]\n",
"\n",
"def pass_checker(fx_g, inps):\n",
" return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes]))\n",
" return (torch.ops.aten.mul in {i.target for i in fx_g.graph.nodes})\n",
"\n",
"min_f, inps = minifier(fx.symbolic_trace(failing_f), inps, pass_checker)"
]

View File

@ -127,7 +127,7 @@ class MemoryBudgetTest(TestCase):
def f(x, ws):
xs = [torch.mm(x, w).cos() for w in ws]
return sum([x.sum() for x in xs])
return sum(x.sum() for x in xs)
x = torch.randn(512, 512, requires_grad=True)
@ -272,7 +272,7 @@ class MemoryBudgetTest(TestCase):
def test_prioritize_cheaper_matmul(self):
def f(xs, ws):
xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)]
return sum([x.sum() for x in xs])
return sum(x.sum() for x in xs)
x1, w1 = create_pair(1, 4)
x2, w2 = create_pair(2, 2)
@ -311,7 +311,7 @@ class MemoryBudgetTest(TestCase):
def test_prioritize_cheaper_matmul2(self):
def f(xs, ws):
xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)]
return sum([x.sum() for x in xs])
return sum(x.sum() for x in xs)
data = [(4, 4), (6, 2), (2, 6)]
xs, ws = zip(*[create_pair(a, b) for a, b in data])

View File

@ -5148,16 +5148,12 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
expected_batch_size = len(tensor_list)
expected_contiguous = True
expected_min_seqlen = min(
[
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
]
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
expected_max_seqlen = max(
[
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
]
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
self._validate_nt(
nt,
@ -5197,16 +5193,12 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
expected_batch_size = len(tensor_list)
expected_contiguous = True
expected_min_seqlen = min(
[
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
]
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
expected_max_seqlen = max(
[
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
]
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
self._validate_nt(
nt,
@ -5244,16 +5236,12 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
expected_batch_size = len(tensor_list)
expected_min_seqlen = min(
[
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
]
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
expected_max_seqlen = max(
[
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
]
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
self._validate_nt(
nt,

View File

@ -1529,7 +1529,7 @@ def choose_saved_values_set(
return runtime_optimized_saved_values
def estimate_activations_size(saved_values: List[fx.Node]) -> float:
return sum([_size_of(i) for i in saved_values]) / 1e9
return sum(map(_size_of, saved_values)) / 1e9
min_act_size = estimate_activations_size(node_info.inputs)
max_act_size = estimate_activations_size(runtime_optimized_saved_values)

View File

@ -605,7 +605,7 @@ class FunctionEvent(FormattedTimesMixin):
return 0
if self.device_type == DeviceType.CPU:
return self.device_time_total - sum(
[child.device_time_total for child in self.cpu_children]
child.device_time_total for child in self.cpu_children
)
else:
assert self.device_type in [

View File

@ -414,8 +414,8 @@ def jagged_from_list(
)
# compute this now since it's easy
min_seqlen = min([t.shape[0] for t in tensors])
max_seqlen = max([t.shape[0] for t in tensors])
min_seqlen = min(t.shape[0] for t in tensors)
max_seqlen = max(t.shape[0] for t in tensors)
ret_nt = nested_view_from_values_offsets(
values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
)