mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamo][itertools] refactor itertools.chain
and itertools.chain.from_iterable
to use polyfills (#133864)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133864 Approved by: https://github.com/jansel ghstack dependencies: #133769, #133778, #133779
This commit is contained in:
committed by
PyTorch MergeBot
parent
eaa449fbf0
commit
1b70366957
@ -170,6 +170,8 @@ def substitute_in_graph(
|
||||
*,
|
||||
can_constant_fold_through: bool = False,
|
||||
skip_signature_check: bool = False,
|
||||
# type that is embedded in the Python interpreter
|
||||
is_embedded_type: bool = False, # internal use only
|
||||
) -> Callable[[_F], _F]:
|
||||
"""
|
||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||
@ -219,10 +221,22 @@ def substitute_in_graph(
|
||||
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
|
||||
2
|
||||
"""
|
||||
if not is_function(original_fn):
|
||||
if not is_function(original_fn) and not (
|
||||
is_embedded_type and inspect.isclass(original_fn)
|
||||
):
|
||||
raise TypeError(
|
||||
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
|
||||
)
|
||||
if is_embedded_type:
|
||||
if not inspect.isclass(original_fn):
|
||||
raise TypeError(
|
||||
f"substitute_in_graph expects a class but got {type(original_fn)!r}"
|
||||
)
|
||||
|
||||
from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS
|
||||
|
||||
if id(original_fn) in ITERTOOLS_TYPE_IDS:
|
||||
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
|
||||
|
||||
def wrapper(traceable_fn: _F) -> _F:
|
||||
if not is_function(traceable_fn):
|
||||
|
@ -10,12 +10,31 @@ from typing import Iterable, Iterator, TypeVar
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
|
||||
__all__ = ["tee"]
|
||||
__all__ = [
|
||||
"chain",
|
||||
"chain_from_iterable",
|
||||
"tee",
|
||||
]
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
|
||||
@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type]
|
||||
def chain(*iterables: Iterable[_T]) -> Iterator[_T]:
|
||||
for iterable in iterables:
|
||||
yield from iterable
|
||||
|
||||
|
||||
@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type]
|
||||
def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
|
||||
return itertools.chain(*iterable)
|
||||
|
||||
|
||||
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
|
||||
|
||||
|
||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
|
||||
@substitute_in_graph(itertools.tee)
|
||||
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
|
||||
|
@ -32,3 +32,9 @@ for polyfill_module in POLYFILLED_MODULES:
|
||||
polyfill_handler = getattr(polyfill_module, polyfill_name)
|
||||
original_fn = polyfill_handler.__torch_dynamo_original__
|
||||
trace_rules._builtin_function_ids.remove(id(original_fn))
|
||||
|
||||
# Unregister the class object if the original function is its __new__ method
|
||||
if original_fn.__name__ == "__new__" and isinstance(
|
||||
getattr(original_fn, "__self__", None), type
|
||||
):
|
||||
trace_rules._builtin_function_ids.remove(id(original_fn.__self__))
|
||||
|
@ -2993,9 +2993,7 @@ def _builtin_function_ids() -> Dict[int, str]:
|
||||
if not k.startswith("_") and callable(v)
|
||||
}
|
||||
)
|
||||
rv.update(
|
||||
{id(v): f"itertools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
|
||||
)
|
||||
rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)})
|
||||
rv.update(
|
||||
{
|
||||
id(cast): "typing.cast",
|
||||
@ -3471,9 +3469,7 @@ def check_verbose(obj, is_inlined_call=False):
|
||||
|
||||
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
|
||||
reasons: Set[str] = set()
|
||||
rule = torch._dynamo.trace_rules.lookup_inner(
|
||||
fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons
|
||||
)
|
||||
rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons)
|
||||
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
|
||||
return SkipResult(
|
||||
False,
|
||||
|
@ -17,7 +17,17 @@ import sys
|
||||
import types
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from typing import (
|
||||
Any,
|
||||
FrozenSet,
|
||||
List,
|
||||
MutableMapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch import SymInt
|
||||
@ -319,6 +329,17 @@ class FrameStateSizeEntry:
|
||||
stride: Optional[List[int]]
|
||||
|
||||
|
||||
# All class-based iterators in itertools
|
||||
# NOTE: use id() because some objects are not hashable, it will raise error during lookup
|
||||
ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset(
|
||||
id(member)
|
||||
for name, member in vars(itertools).items()
|
||||
if not name.startswith("_") and inspect.isclass(member)
|
||||
)
|
||||
# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py
|
||||
ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set()
|
||||
|
||||
|
||||
class VariableBuilder:
|
||||
"""Wrap a python value in a VariableTracker() instance"""
|
||||
|
||||
@ -874,7 +895,10 @@ class VariableBuilder:
|
||||
value,
|
||||
source=self.source,
|
||||
)
|
||||
elif istype(value, type) and value in itertools.__dict__.values():
|
||||
elif (
|
||||
id(value) in ITERTOOLS_TYPE_IDS
|
||||
and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS
|
||||
):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return ItertoolsVariable(value, source=self.source)
|
||||
elif isinstance(value, torch.SymBool):
|
||||
|
@ -994,15 +994,6 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
if self.fn is dict and name == "fromkeys":
|
||||
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
|
||||
if self.fn is itertools.chain and name == "from_iterable":
|
||||
assert len(args) == 1
|
||||
assert len(kwargs) == 0
|
||||
obj = args[0]
|
||||
items = []
|
||||
for item in obj.unpack_var_sequence(tx):
|
||||
items.extend(item.unpack_var_sequence(tx))
|
||||
return variables.TupleVariable(items)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def _call_int_float(self, tx: "InstructionTranslator", arg):
|
||||
@ -1898,13 +1889,6 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
return variables.ListVariable(items)
|
||||
|
||||
def call_chain(self, tx: "InstructionTranslator", *args):
|
||||
if all(obj.has_unpack_var_sequence(tx) for obj in args):
|
||||
items = []
|
||||
for obj in args:
|
||||
items.extend(obj.unpack_var_sequence(tx))
|
||||
return variables.TupleVariable(items)
|
||||
|
||||
def call_islice(self, tx: "InstructionTranslator", iterable, *args):
|
||||
if iterable.has_unpack_var_sequence(tx) and all(
|
||||
x.is_python_constant() for x in args
|
||||
|
@ -18,6 +18,7 @@ from ..utils import (
|
||||
check_constant_args,
|
||||
check_unspec_or_constant_args,
|
||||
identity,
|
||||
is_function,
|
||||
is_wrapper_or_member_descriptor,
|
||||
istype,
|
||||
make_cell,
|
||||
@ -992,6 +993,27 @@ class PolyfilledFunctionVariable(VariableTracker):
|
||||
handler,
|
||||
).call_function(tx, args, kwargs)
|
||||
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "__call__":
|
||||
return self.call_function(tx, args, kwargs)
|
||||
|
||||
method = getattr(self.fn, name, None)
|
||||
assert method is not None, f"Member {name} not found in {self.fn}"
|
||||
assert is_function(method), f"Member {name} is not callable in {self.fn}"
|
||||
options = {}
|
||||
if self.source:
|
||||
options["source"] = AttrSource(self.source, name)
|
||||
member_variable = PolyfilledFunctionVariable(method, **options)
|
||||
return member_variable.call_function(tx, args, kwargs)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.fn
|
||||
|
||||
|
@ -52,14 +52,6 @@ class ItertoolsVariable(VariableTracker):
|
||||
for item in itertools.product(*seqs):
|
||||
items.append(variables.TupleVariable(list(item)))
|
||||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
||||
elif (
|
||||
self.value is itertools.chain
|
||||
and not kwargs
|
||||
and all(arg.has_unpack_var_sequence(tx) for arg in args)
|
||||
):
|
||||
seqs = [arg.unpack_var_sequence(tx) for arg in args]
|
||||
items = list(itertools.chain.from_iterable(seqs))
|
||||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
||||
elif self.value is itertools.accumulate:
|
||||
from .builtin import BuiltinVariable
|
||||
|
||||
|
Reference in New Issue
Block a user