[dynamo][itertools] refactor itertools.chain and itertools.chain.from_iterable to use polyfills (#133864)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133864
Approved by: https://github.com/jansel
ghstack dependencies: #133769, #133778, #133779
This commit is contained in:
Xuehai Pan
2024-08-30 00:47:04 +08:00
committed by PyTorch MergeBot
parent eaa449fbf0
commit 1b70366957
8 changed files with 91 additions and 34 deletions

View File

@ -170,6 +170,8 @@ def substitute_in_graph(
*,
can_constant_fold_through: bool = False,
skip_signature_check: bool = False,
# type that is embedded in the Python interpreter
is_embedded_type: bool = False, # internal use only
) -> Callable[[_F], _F]:
"""
Register a polyfill handler for a function, usually a C function from the C extension, to be
@ -219,10 +221,22 @@ def substitute_in_graph(
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2
"""
if not is_function(original_fn):
if not is_function(original_fn) and not (
is_embedded_type and inspect.isclass(original_fn)
):
raise TypeError(
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
)
if is_embedded_type:
if not inspect.isclass(original_fn):
raise TypeError(
f"substitute_in_graph expects a class but got {type(original_fn)!r}"
)
from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS
if id(original_fn) in ITERTOOLS_TYPE_IDS:
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
def wrapper(traceable_fn: _F) -> _F:
if not is_function(traceable_fn):

View File

@ -10,12 +10,31 @@ from typing import Iterable, Iterator, TypeVar
from ..decorators import substitute_in_graph
__all__ = ["tee"]
__all__ = [
"chain",
"chain_from_iterable",
"tee",
]
_T = TypeVar("_T")
# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type]
def chain(*iterables: Iterable[_T]) -> Iterator[_T]:
for iterable in iterables:
yield from iterable
@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type]
def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
return itertools.chain(*iterable)
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
@substitute_in_graph(itertools.tee)
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:

View File

@ -32,3 +32,9 @@ for polyfill_module in POLYFILLED_MODULES:
polyfill_handler = getattr(polyfill_module, polyfill_name)
original_fn = polyfill_handler.__torch_dynamo_original__
trace_rules._builtin_function_ids.remove(id(original_fn))
# Unregister the class object if the original function is its __new__ method
if original_fn.__name__ == "__new__" and isinstance(
getattr(original_fn, "__self__", None), type
):
trace_rules._builtin_function_ids.remove(id(original_fn.__self__))

View File

@ -2993,9 +2993,7 @@ def _builtin_function_ids() -> Dict[int, str]:
if not k.startswith("_") and callable(v)
}
)
rv.update(
{id(v): f"itertools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
)
rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)})
rv.update(
{
id(cast): "typing.cast",
@ -3471,9 +3469,7 @@ def check_verbose(obj, is_inlined_call=False):
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
reasons: Set[str] = set()
rule = torch._dynamo.trace_rules.lookup_inner(
fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons
)
rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons)
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
return SkipResult(
False,

View File

@ -17,7 +17,17 @@ import sys
import types
import warnings
import weakref
from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union
from typing import (
Any,
FrozenSet,
List,
MutableMapping,
NamedTuple,
Optional,
Set,
TYPE_CHECKING,
Union,
)
import torch
from torch import SymInt
@ -319,6 +329,17 @@ class FrameStateSizeEntry:
stride: Optional[List[int]]
# All class-based iterators in itertools
# NOTE: use id() because some objects are not hashable, it will raise error during lookup
ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset(
id(member)
for name, member in vars(itertools).items()
if not name.startswith("_") and inspect.isclass(member)
)
# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py
ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set()
class VariableBuilder:
"""Wrap a python value in a VariableTracker() instance"""
@ -874,7 +895,10 @@ class VariableBuilder:
value,
source=self.source,
)
elif istype(value, type) and value in itertools.__dict__.values():
elif (
id(value) in ITERTOOLS_TYPE_IDS
and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS
):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return ItertoolsVariable(value, source=self.source)
elif isinstance(value, torch.SymBool):

View File

@ -994,15 +994,6 @@ class BuiltinVariable(VariableTracker):
)
if self.fn is dict and name == "fromkeys":
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
if self.fn is itertools.chain and name == "from_iterable":
assert len(args) == 1
assert len(kwargs) == 0
obj = args[0]
items = []
for item in obj.unpack_var_sequence(tx):
items.extend(item.unpack_var_sequence(tx))
return variables.TupleVariable(items)
return super().call_method(tx, name, args, kwargs)
def _call_int_float(self, tx: "InstructionTranslator", arg):
@ -1898,13 +1889,6 @@ class BuiltinVariable(VariableTracker):
)
return variables.ListVariable(items)
def call_chain(self, tx: "InstructionTranslator", *args):
if all(obj.has_unpack_var_sequence(tx) for obj in args):
items = []
for obj in args:
items.extend(obj.unpack_var_sequence(tx))
return variables.TupleVariable(items)
def call_islice(self, tx: "InstructionTranslator", iterable, *args):
if iterable.has_unpack_var_sequence(tx) and all(
x.is_python_constant() for x in args

View File

@ -18,6 +18,7 @@ from ..utils import (
check_constant_args,
check_unspec_or_constant_args,
identity,
is_function,
is_wrapper_or_member_descriptor,
istype,
make_cell,
@ -992,6 +993,27 @@ class PolyfilledFunctionVariable(VariableTracker):
handler,
).call_function(tx, args, kwargs)
return super().call_function(tx, args, kwargs)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__call__":
return self.call_function(tx, args, kwargs)
method = getattr(self.fn, name, None)
assert method is not None, f"Member {name} not found in {self.fn}"
assert is_function(method), f"Member {name} is not callable in {self.fn}"
options = {}
if self.source:
options["source"] = AttrSource(self.source, name)
member_variable = PolyfilledFunctionVariable(method, **options)
return member_variable.call_function(tx, args, kwargs)
def as_python_constant(self):
return self.fn

View File

@ -52,14 +52,6 @@ class ItertoolsVariable(VariableTracker):
for item in itertools.product(*seqs):
items.append(variables.TupleVariable(list(item)))
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif (
self.value is itertools.chain
and not kwargs
and all(arg.has_unpack_var_sequence(tx) for arg in args)
):
seqs = [arg.unpack_var_sequence(tx) for arg in args]
items = list(itertools.chain.from_iterable(seqs))
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif self.value is itertools.accumulate:
from .builtin import BuiltinVariable