[dynamo] match implementation for sorted(...) with CPython (#141227)

```python
def sorted(iterable, /, *, key=None, reverse=False):
    seq = list(iterable)
    seq.sort(key=key, reverse=reverse)
    return seq
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141227
Approved by: https://github.com/jansel, https://github.com/Skylion007
ghstack dependencies: #141224
This commit is contained in:
Xuehai Pan
2024-11-21 19:05:32 +08:00
committed by PyTorch MergeBot
parent 259a00b727
commit 675735cfc9
2 changed files with 38 additions and 25 deletions

View File

@ -1896,29 +1896,12 @@ class BuiltinVariable(VariableTracker):
if obj.has_force_unpack_var_sequence(tx) and not isinstance(
obj, variables.TensorVariable
):
unpacked = obj.force_unpack_var_sequence(tx)
if not all(x.is_python_constant() for x in unpacked):
# TODO: support `key(x)` is Python constant and sortable. The `key` function should
# be a pure function and should not have any side effects.
return # try next handler
key_fn = kwargs.pop("key", ConstantVariable.create(None))
reverse = kwargs.pop(
"reverse", ConstantVariable.create(False)
).as_python_constant()
assert len(kwargs) == 0
if key_fn.is_python_constant() and key_fn.as_python_constant() is None:
def key(x):
return x.as_python_constant()
else:
def key(x):
return key_fn.call_function(tx, [x], {}).as_python_constant()
items = sorted(unpacked, key=key, reverse=reverse)
return variables.ListVariable(items)
list_var = variables.ListVariable(
obj.force_unpack_var_sequence(tx),
mutation_type=ValueMutationNew(),
)
list_var.call_method(tx, "sort", [], kwargs)
return list_var
# neg is a constant fold function, so we only get here if constant fold is not valid
def call_neg(self, tx: "InstructionTranslator", a):

View File

@ -432,8 +432,38 @@ class ListVariable(CommonListMethodsVariable):
else:
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
if name == "sort" and self.is_mutable():
assert len(args) == 0
key_fn_var = kwargs.pop("key", ConstantVariable.create(None))
reverse = kwargs.pop(
"reverse", ConstantVariable.create(False)
).as_python_constant()
assert len(kwargs) == 0
if not all(x.is_python_constant() for x in self.items):
# TODO: support `key(x)` is Python constant and sortable. The `key` function should
# be a pure function and should not have any side effects.
return super().call_method(tx, name, args, kwargs) # try next handler
if (
key_fn_var.is_python_constant()
and key_fn_var.as_python_constant() is None
):
def key_fn(x):
return x.as_python_constant()
else:
def key_fn(x):
return key_fn_var.call_function(tx, [x], {}).as_python_constant()
tx.output.side_effects.mutation(self)
self.items.sort(key=key_fn, reverse=reverse)
return ConstantVariable.create(None)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name == "__class__":