[list] Implement list.count (#153969)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153969
Approved by: https://github.com/zou3519, https://github.com/XuehaiPan
This commit is contained in:
Guilherme Leobas
2025-07-05 15:54:43 -03:00
committed by PyTorch MergeBot
parent 16c3b4143b
commit 034e996d37
2 changed files with 30 additions and 2 deletions

View File

@ -5,14 +5,18 @@ Python polyfills for operator
from __future__ import annotations
import operator
from typing import Any, Callable, overload, TypeVar
from typing import Any, Callable, overload, TYPE_CHECKING, TypeVar
from typing_extensions import TypeVarTuple, Unpack
from ..decorators import substitute_in_graph
if TYPE_CHECKING:
from collections.abc import Iterable
# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`)
__all__ = ["attrgetter", "itemgetter", "methodcaller"]
__all__ = ["attrgetter", "itemgetter", "methodcaller", "countOf"]
_T = TypeVar("_T")
@ -103,3 +107,9 @@ def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any
return getattr(obj, name)(*args, **kwargs)
return caller
# Reference: https://docs.python.org/3/library/operator.html#operator.countOf
@substitute_in_graph(operator.countOf, can_constant_fold_through=True) # type: ignore[arg-type,misc]
def countOf(a: Iterable[_T], b: _T, /) -> int:
return sum(it is b or it == b for it in a)

View File

@ -136,6 +136,10 @@ class BaseListVariable(VariableTracker):
if name == "__getitem__":
from .tensor import TensorVariable
if len(args) != 1:
msg = f"{name} takes exactly one argument ({len(args)} given)"
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
assert not kwargs and len(args) == 1
if isinstance(args[0], TensorVariable):
value = get_fake_value(args[0].as_proxy().node, tx)
@ -152,6 +156,11 @@ class BaseListVariable(VariableTracker):
)
else:
value = args[0]
if value.python_type() not in (int, slice):
msg = f"indices must be integers or slices, not {value.python_type()}"
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
return self.getitem_const(tx, value)
elif name == "__contains__":
assert len(args) == 1
@ -163,6 +172,15 @@ class BaseListVariable(VariableTracker):
[self] + list(args),
kwargs,
)
elif name == "count":
if len(args) != 1:
msg = f"{name} takes exactly one argument ({len(args)} given)"
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
return VariableTracker.build(tx, operator.countOf).call_function(
tx,
[self, args[0]],
kwargs,
)
elif name in cmp_name_to_op_mapping:
left = self
right = args[0]