Improve sort with non-constant keys error message (#151193)

Fixes https://github.com/pytorch/pytorch/issues/143505

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151193
Approved by: https://github.com/jansel, https://github.com/anijain2305, https://github.com/williamwen42
This commit is contained in:
rzou
2025-04-13 17:02:08 -07:00
committed by PyTorch MergeBot
parent 46ce8f7df6
commit dea50b0778
2 changed files with 44 additions and 1 deletions

View File

@ -126,6 +126,33 @@ from user code:
return torch.equal(x, x)""",
)
def test_sort_with_nonconstant_keys(self):
lst = [
torch.tensor(4),
torch.tensor(1),
torch.tensor(2),
torch.tensor(3),
]
def fn(lst):
return sorted(lst)
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(lst),
"""\
sort with non-constant keys
Explanation: Cannot perform sort with non-constant key. First non-constant key type: <class 'torch.Tensor'>. Most notably, we cannot sort with Tensor or SymInt keys, but we can sort ints.
Hint: Use something else as the key.
Developer debug context: TensorVariable()
from user code:
File "test_error_messages.py", line N, in fn
return sorted(lst)""",
)
def test_super_call_method(self):
def fn(it):
return [x + 1 for x in it]

View File

@ -485,7 +485,23 @@ class ListVariable(CommonListMethodsVariable):
keys = [key_fn_var.call_function(tx, [x], {}) for x in self.items]
if not all(k.is_python_constant() for k in keys):
unimplemented("sort with non-constant keys")
first_non_constant_key = None
for k in keys:
if not k.is_python_constant():
first_non_constant_key = k
assert first_non_constant_key is not None
unimplemented_v2(
gb_type="sort with non-constant keys",
context=str(first_non_constant_key),
explanation=(
f"Cannot perform sort with non-constant key. "
f"First non-constant key type: {first_non_constant_key.python_type()}. "
f"Most notably, we cannot sort with Tensor or SymInt keys, but we can "
f"sort ints."
),
hints=["Use something else as the key."],
)
tx.output.side_effects.mutation(self)
sorted_items_with_keys = sorted(