Use tuples to have a deterministic ordering. (#164851)

When debugging I noticed some non-deterministic behavior and tracked it down to this literal set. Changed to be a tuple for determinism. Changed two other small literal sets also because using a set for a small lookup like that is slow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164851
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
This commit is contained in:
Aaron Orenstein
2025-10-07 10:53:47 -07:00
committed by PyTorch MergeBot
parent d444384003
commit ad7b2bebc6

View File

@ -51,11 +51,11 @@ def _extract_tensor_metadata(
memory_format = None
if include_contiguity and not is_sparse_any(result):
memory_formats = {
memory_formats = (
torch.contiguous_format,
torch.channels_last,
torch.channels_last_3d,
}
)
for query_format in memory_formats:
if is_contiguous_for_memory_format_or_false(
result, memory_format=query_format
@ -68,14 +68,14 @@ def _extract_tensor_metadata(
if is_quantized:
qscheme = result.qscheme()
qparams["qscheme"] = qscheme
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
if qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric):
qparams["scale"] = result.q_scale() # type: ignore[assignment]
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
elif qscheme in {
elif qscheme in (
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
torch.per_channel_symmetric,
}:
):
# In this branch, scale and zero_point are expected to be tensors,
# we store the values as immutable_list in TensorMetadata for
# easier serialization downstream