mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[list] Implement list.count (#153969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153969 Approved by: https://github.com/zou3519, https://github.com/XuehaiPan
This commit is contained in:
committed by
PyTorch MergeBot
parent
16c3b4143b
commit
034e996d37
@ -5,14 +5,18 @@ Python polyfills for operator
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import Any, Callable, overload, TypeVar
|
||||
from typing import Any, Callable, overload, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`)
|
||||
__all__ = ["attrgetter", "itemgetter", "methodcaller"]
|
||||
__all__ = ["attrgetter", "itemgetter", "methodcaller", "countOf"]
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@ -103,3 +107,9 @@ def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any
|
||||
return getattr(obj, name)(*args, **kwargs)
|
||||
|
||||
return caller
|
||||
|
||||
|
||||
# Reference: https://docs.python.org/3/library/operator.html#operator.countOf
|
||||
@substitute_in_graph(operator.countOf, can_constant_fold_through=True) # type: ignore[arg-type,misc]
|
||||
def countOf(a: Iterable[_T], b: _T, /) -> int:
|
||||
return sum(it is b or it == b for it in a)
|
||||
|
@ -136,6 +136,10 @@ class BaseListVariable(VariableTracker):
|
||||
if name == "__getitem__":
|
||||
from .tensor import TensorVariable
|
||||
|
||||
if len(args) != 1:
|
||||
msg = f"{name} takes exactly one argument ({len(args)} given)"
|
||||
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
|
||||
|
||||
assert not kwargs and len(args) == 1
|
||||
if isinstance(args[0], TensorVariable):
|
||||
value = get_fake_value(args[0].as_proxy().node, tx)
|
||||
@ -152,6 +156,11 @@ class BaseListVariable(VariableTracker):
|
||||
)
|
||||
else:
|
||||
value = args[0]
|
||||
|
||||
if value.python_type() not in (int, slice):
|
||||
msg = f"indices must be integers or slices, not {value.python_type()}"
|
||||
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
|
||||
|
||||
return self.getitem_const(tx, value)
|
||||
elif name == "__contains__":
|
||||
assert len(args) == 1
|
||||
@ -163,6 +172,15 @@ class BaseListVariable(VariableTracker):
|
||||
[self] + list(args),
|
||||
kwargs,
|
||||
)
|
||||
elif name == "count":
|
||||
if len(args) != 1:
|
||||
msg = f"{name} takes exactly one argument ({len(args)} given)"
|
||||
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
|
||||
return VariableTracker.build(tx, operator.countOf).call_function(
|
||||
tx,
|
||||
[self, args[0]],
|
||||
kwargs,
|
||||
)
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
left = self
|
||||
right = args[0]
|
||||
|
Reference in New Issue
Block a user