mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[BE] Enable C419 rule for any all shortcircuiting (#99890)
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890 Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
e43918b93a
commit
e2a3817dfd
@ -21,7 +21,7 @@ from torch.autograd.grad_mode import no_grad
|
||||
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
|
||||
with_indices: Optional[bool] = False) -> \
|
||||
Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]:
|
||||
assert all([not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist]), (
|
||||
assert all(not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist), (
|
||||
"all specified tensorlists must match in length")
|
||||
per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
|
||||
lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
|
||||
@ -39,4 +39,4 @@ def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
|
||||
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
|
||||
if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting():
|
||||
return False
|
||||
return all([t is None or type(t) == torch.Tensor for t in tensors])
|
||||
return all(t is None or type(t) == torch.Tensor for t in tensors)
|
||||
|
Reference in New Issue
Block a user