mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] match implementation for sorted(...)
with CPython (#141227)
```python def sorted(iterable, /, *, key=None, reverse=False): seq = list(iterable) seq.sort(key=key, reverse=reverse) return seq ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141227 Approved by: https://github.com/jansel, https://github.com/Skylion007 ghstack dependencies: #141224
This commit is contained in:
committed by
PyTorch MergeBot
parent
259a00b727
commit
675735cfc9
@ -1896,29 +1896,12 @@ class BuiltinVariable(VariableTracker):
|
||||
if obj.has_force_unpack_var_sequence(tx) and not isinstance(
|
||||
obj, variables.TensorVariable
|
||||
):
|
||||
unpacked = obj.force_unpack_var_sequence(tx)
|
||||
if not all(x.is_python_constant() for x in unpacked):
|
||||
# TODO: support `key(x)` is Python constant and sortable. The `key` function should
|
||||
# be a pure function and should not have any side effects.
|
||||
return # try next handler
|
||||
key_fn = kwargs.pop("key", ConstantVariable.create(None))
|
||||
reverse = kwargs.pop(
|
||||
"reverse", ConstantVariable.create(False)
|
||||
).as_python_constant()
|
||||
assert len(kwargs) == 0
|
||||
|
||||
if key_fn.is_python_constant() and key_fn.as_python_constant() is None:
|
||||
|
||||
def key(x):
|
||||
return x.as_python_constant()
|
||||
|
||||
else:
|
||||
|
||||
def key(x):
|
||||
return key_fn.call_function(tx, [x], {}).as_python_constant()
|
||||
|
||||
items = sorted(unpacked, key=key, reverse=reverse)
|
||||
return variables.ListVariable(items)
|
||||
list_var = variables.ListVariable(
|
||||
obj.force_unpack_var_sequence(tx),
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
list_var.call_method(tx, "sort", [], kwargs)
|
||||
return list_var
|
||||
|
||||
# neg is a constant fold function, so we only get here if constant fold is not valid
|
||||
def call_neg(self, tx: "InstructionTranslator", a):
|
||||
|
@ -432,8 +432,38 @@ class ListVariable(CommonListMethodsVariable):
|
||||
else:
|
||||
self.items[key.as_python_constant()] = value
|
||||
return ConstantVariable.create(None)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
if name == "sort" and self.is_mutable():
|
||||
assert len(args) == 0
|
||||
key_fn_var = kwargs.pop("key", ConstantVariable.create(None))
|
||||
reverse = kwargs.pop(
|
||||
"reverse", ConstantVariable.create(False)
|
||||
).as_python_constant()
|
||||
assert len(kwargs) == 0
|
||||
|
||||
if not all(x.is_python_constant() for x in self.items):
|
||||
# TODO: support `key(x)` is Python constant and sortable. The `key` function should
|
||||
# be a pure function and should not have any side effects.
|
||||
return super().call_method(tx, name, args, kwargs) # try next handler
|
||||
|
||||
if (
|
||||
key_fn_var.is_python_constant()
|
||||
and key_fn_var.as_python_constant() is None
|
||||
):
|
||||
|
||||
def key_fn(x):
|
||||
return x.as_python_constant()
|
||||
|
||||
else:
|
||||
|
||||
def key_fn(x):
|
||||
return key_fn_var.call_function(tx, [x], {}).as_python_constant()
|
||||
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.sort(key=key_fn, reverse=reverse)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx, name):
|
||||
if name == "__class__":
|
||||
|
Reference in New Issue
Block a user