Apply some safe comprehension optimizations (#94323)

Optimize unnecessary collection cast calls, unnecessary calls to list, tuple, and dict, and simplify calls to the sorted builtin. This should strictly improve speed and improve readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94323
Approved by: https://github.com/albanD
This commit is contained in:
Aaron Gokaslan
2023-02-07 23:53:42 +00:00
committed by PyTorch MergeBot
parent bef2483ed8
commit 3ce1ebb6fb
20 changed files with 39 additions and 43 deletions

View File

@ -15,7 +15,7 @@ _enabled = False
def _enable_layers(dims):
global _enabled
assert not _enabled
input = list(sorted((d._level, d.size) for d in dims if not isinstance(d, int)))
input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))
n = len(input)
try:
_vmap_add_layers(input)

View File

@ -1327,12 +1327,10 @@ class ReproTests(torch._dynamo.test_case.TestCase):
(1, 5),
)
tensors = list(
[
torch.empty(shape, dtype=dtype).fill_(17)
for shape, dtype in itertools.product(shapes, dtypes)
]
)
tensors = [
torch.empty(shape, dtype=dtype).fill_(17)
for shape, dtype in itertools.product(shapes, dtypes)
]
x_vals = (5.0, *tensors)
y_vals = (6.0, *tensors)

View File

@ -4550,7 +4550,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
def make_input(batch_size, layers, packed_sequence):
batch_first = True if packed_sequence == 2 else False
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
@ -9434,7 +9434,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
@ -9501,7 +9501,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
@ -9644,7 +9644,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]

View File

@ -424,7 +424,7 @@ class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase):
def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
@ -485,7 +485,7 @@ class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase):
def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
@ -540,7 +540,7 @@ class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase):
def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
inputs = [inputs]
@ -581,7 +581,7 @@ class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase):
def test_rnn_init_predict_split(self):
model = nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 3, bidirectional=True)
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=7)
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
input = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
input = rnn_utils.pad_sequence(input)

View File

@ -777,8 +777,8 @@ class TestQuantizedTensor(TestCase):
# change memory format
qlast = qr.contiguous(memory_format=torch.channels_last)
self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride()))))
self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride()))))
self.assertEqual(qr.stride(), sorted(qr.stride(), reverse=True))
self.assertNotEqual(qlast.stride(), sorted(qlast.stride(), reverse=True))
self.assertEqual(qr.int_repr(), qlast.int_repr())
self.assertEqual(qr.q_scale(), qlast.q_scale())
self.assertEqual(qr.q_zero_point(), qlast.q_zero_point())
@ -804,8 +804,8 @@ class TestQuantizedTensor(TestCase):
# but we can change memory format
qlast = qr.contiguous(memory_format=torch.channels_last)
self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride()))))
self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride()))))
self.assertEqual(qr.stride(), sorted(qr.stride(), reverse=True))
self.assertNotEqual(qlast.stride(), sorted(qlast.stride(), reverse=True))
self.assertEqual(qr.int_repr(), qlast.int_repr())
self.assertEqual(scales.to(dtype=torch.float64), qlast.q_per_channel_scales())
self.assertEqual(zero_points, qlast.q_per_channel_zero_points())

View File

@ -5430,8 +5430,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
self.assertEqual(F.cosine_similarity(input1, input2, dim=1).size(), expected_size)
# Check numerical precision, issue #18057
vv1 = torch.tensor(list([float(i) for i in range(84)])).unsqueeze(0)
vv2 = torch.tensor(list([float(i) for i in range(84)])).unsqueeze(0)
vv1 = torch.tensor([float(i) for i in range(84)]).unsqueeze(0)
vv2 = torch.tensor([float(i) for i in range(84)]).unsqueeze(0)
out = F.cosine_similarity(vv1, vv2)
self.assertLessEqual(out, 1.0)

View File

@ -62,13 +62,11 @@ def gen_autograd(
template_path = os.path.join(autograd_dir, "templates")
native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
fns = list(
sorted(
filter(
operator_selector.is_native_function_selected_for_training, native_funcs
),
key=lambda f: cpp.name(f.func),
)
fns = sorted(
filter(
operator_selector.is_native_function_selected_for_training, native_funcs
),
key=lambda f: cpp.name(f.func),
)
fns_with_diff_infos: List[
NativeFunctionWithDifferentiabilityInfo

View File

@ -107,7 +107,7 @@ class Profiler:
last_op_end_time = -1
captured_region_end_time = -1
events = list(sorted(self.prof.events(), key=lambda x: x.time_range.start))
events = sorted(self.prof.events(), key=lambda x: x.time_range.start)
for e in events:
if e.name == "TORCHDYNAMO":
captured_region_end_time = e.time_range.end

View File

@ -945,7 +945,7 @@ class TritonKernel(Kernel):
default = triton_constant(ir.Reduction.default_value(reduction_type, src_dtype))
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
self.filter_masks(masks)
masks = sorted(list(masks))
masks = sorted(masks)
if self._load_mask:
masks.append(self._load_mask)
sizes = [":" for _ in self.range_trees]

View File

@ -260,7 +260,7 @@ def _transpose_nvfuser(fd, a, dims):
def _squeeze_nvfuser(fd, a, a_shape, dimensions):
for idx in reversed(sorted(dimensions)):
for idx in sorted(dimensions, reverse=True):
a = fd.ops.squeeze(a, a_shape, idx)
a_shape = a_shape[:idx] + a_shape[idx + 1 :]
return a

View File

@ -320,7 +320,7 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
# Sorts (length, stride) pairs by stride
lengths_and_strides = sorted(
tuple(zip(a.shape, a.stride())), key=operator.itemgetter(1)
zip(a.shape, a.stride()), key=operator.itemgetter(1)
)
expected_stride = 1

View File

@ -355,8 +355,8 @@ class ModelReportVisualizer:
tensor_features.add(feature_name)
# we make them lists for iteration purposes
tensor_features_list: List[str] = sorted(list(tensor_features))
channel_features_list: List[str] = sorted(list(channel_features))
tensor_features_list: List[str] = sorted(tensor_features)
channel_features_list: List[str] = sorted(channel_features)
# get the tensor info
tensor_headers, tensor_table = self._generate_tensor_table(filtered_data, tensor_features_list)

View File

@ -1293,7 +1293,7 @@ def _map_param_key_to_optim_keys(
merge_all_optim_state_keys = [
key for local_keys in all_keys for key in local_keys
]
all_optim_state_keys = sorted(list(set(merge_all_optim_state_keys)))
all_optim_state_keys = sorted(set(merge_all_optim_state_keys))
else:
key_obj_list: List[Optional[List[_OptimStateKey]]] = (
[all_optim_state_keys] if rank == 0 else [None]
@ -1613,7 +1613,7 @@ def _all_gather_optim_state(
gathered_state: Dict[str, Any] = {}
all_tensor_states = sorted(
list(set([n for state in object_list for n in state.tensors.keys()]))
set([n for state in object_list for n in state.tensors.keys()])
)
empty_ranks: Set[int] = set()
for name in all_tensor_states:

View File

@ -25,7 +25,7 @@ def tree_flatten_spec(pytree: PyTree, spec: TreeSpec) -> List[Any]:
return result
def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]:
return list([d[k] for k in spec.context])
return [d[k] for k in spec.context]
def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
return [d[i] for i in range(len(spec.children_specs))]

View File

@ -447,7 +447,7 @@ def eval_is_non_overlapping_and_dense(sizes, strides):
# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
# Sorts (length, stride) pairs by stride
lengths_and_strides = sorted(
tuple(zip(sizes, strides)), key=operator.itemgetter(1)
zip(sizes, strides), key=operator.itemgetter(1)
)
# Unlike the C++ code, we don't move the 0/1 size dimensions to the

View File

@ -19,7 +19,7 @@ def _gen_unsupported_methods_properties():
properties = []
methods = []
sorted_tensor_attrs = sorted(list(tensor_attrs), key=lambda x: x.lower())
sorted_tensor_attrs = sorted(tensor_attrs, key=lambda x: x.lower())
for attr in sorted_tensor_attrs:
funcs_str = funcs_template.format(op=attr)
scope: Dict[str, Any] = {}

View File

@ -443,7 +443,7 @@ class MultiStepLR(LRScheduler):
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
milestones = list(sorted(self.milestones.elements()))
milestones = sorted(self.milestones.elements())
return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
for base_lr in self.base_lrs]

View File

@ -180,7 +180,7 @@ def get_cudnn_version(run_lambda):
if not files_set:
return None
# Alphabetize the result because the order is non-deterministic otherwise
files = list(sorted(files_set))
files = sorted(files_set)
if len(files) == 1:
return files[0]
result = '\n'.join(files)

View File

@ -1790,7 +1790,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
if arch.endswith('+PTX'):
flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')
return sorted(list(set(flags)))
return sorted(set(flags))
def _get_rocm_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:

View File

@ -356,7 +356,7 @@ class StreamWrapper:
def __dir__(self):
attrs = list(self.__dict__.keys()) + list(StreamWrapper.__dict__.keys())
attrs += dir(self.file_obj)
return list(set(list(attrs)))
return list(set(attrs))
def __del__(self):
if not self.closed: