mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamo][itertools] refactor itertools.chain
and itertools.chain.from_iterable
to use polyfills (#133864)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133864 Approved by: https://github.com/jansel ghstack dependencies: #133769, #133778, #133779
This commit is contained in:
committed by
PyTorch MergeBot
parent
eaa449fbf0
commit
1b70366957
@ -170,6 +170,8 @@ def substitute_in_graph(
|
|||||||
*,
|
*,
|
||||||
can_constant_fold_through: bool = False,
|
can_constant_fold_through: bool = False,
|
||||||
skip_signature_check: bool = False,
|
skip_signature_check: bool = False,
|
||||||
|
# type that is embedded in the Python interpreter
|
||||||
|
is_embedded_type: bool = False, # internal use only
|
||||||
) -> Callable[[_F], _F]:
|
) -> Callable[[_F], _F]:
|
||||||
"""
|
"""
|
||||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||||
@ -219,10 +221,22 @@ def substitute_in_graph(
|
|||||||
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
|
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
|
||||||
2
|
2
|
||||||
"""
|
"""
|
||||||
if not is_function(original_fn):
|
if not is_function(original_fn) and not (
|
||||||
|
is_embedded_type and inspect.isclass(original_fn)
|
||||||
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
|
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
|
||||||
)
|
)
|
||||||
|
if is_embedded_type:
|
||||||
|
if not inspect.isclass(original_fn):
|
||||||
|
raise TypeError(
|
||||||
|
f"substitute_in_graph expects a class but got {type(original_fn)!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS
|
||||||
|
|
||||||
|
if id(original_fn) in ITERTOOLS_TYPE_IDS:
|
||||||
|
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
|
||||||
|
|
||||||
def wrapper(traceable_fn: _F) -> _F:
|
def wrapper(traceable_fn: _F) -> _F:
|
||||||
if not is_function(traceable_fn):
|
if not is_function(traceable_fn):
|
||||||
|
@ -10,12 +10,31 @@ from typing import Iterable, Iterator, TypeVar
|
|||||||
from ..decorators import substitute_in_graph
|
from ..decorators import substitute_in_graph
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["tee"]
|
__all__ = [
|
||||||
|
"chain",
|
||||||
|
"chain_from_iterable",
|
||||||
|
"tee",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
|
||||||
|
@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type]
|
||||||
|
def chain(*iterables: Iterable[_T]) -> Iterator[_T]:
|
||||||
|
for iterable in iterables:
|
||||||
|
yield from iterable
|
||||||
|
|
||||||
|
|
||||||
|
@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type]
|
||||||
|
def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
|
||||||
|
return itertools.chain(*iterable)
|
||||||
|
|
||||||
|
|
||||||
|
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
|
||||||
|
|
||||||
|
|
||||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
|
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
|
||||||
@substitute_in_graph(itertools.tee)
|
@substitute_in_graph(itertools.tee)
|
||||||
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
|
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
|
||||||
|
@ -32,3 +32,9 @@ for polyfill_module in POLYFILLED_MODULES:
|
|||||||
polyfill_handler = getattr(polyfill_module, polyfill_name)
|
polyfill_handler = getattr(polyfill_module, polyfill_name)
|
||||||
original_fn = polyfill_handler.__torch_dynamo_original__
|
original_fn = polyfill_handler.__torch_dynamo_original__
|
||||||
trace_rules._builtin_function_ids.remove(id(original_fn))
|
trace_rules._builtin_function_ids.remove(id(original_fn))
|
||||||
|
|
||||||
|
# Unregister the class object if the original function is its __new__ method
|
||||||
|
if original_fn.__name__ == "__new__" and isinstance(
|
||||||
|
getattr(original_fn, "__self__", None), type
|
||||||
|
):
|
||||||
|
trace_rules._builtin_function_ids.remove(id(original_fn.__self__))
|
||||||
|
@ -2993,9 +2993,7 @@ def _builtin_function_ids() -> Dict[int, str]:
|
|||||||
if not k.startswith("_") and callable(v)
|
if not k.startswith("_") and callable(v)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
rv.update(
|
rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)})
|
||||||
{id(v): f"itertools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
|
|
||||||
)
|
|
||||||
rv.update(
|
rv.update(
|
||||||
{
|
{
|
||||||
id(cast): "typing.cast",
|
id(cast): "typing.cast",
|
||||||
@ -3471,9 +3469,7 @@ def check_verbose(obj, is_inlined_call=False):
|
|||||||
|
|
||||||
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
|
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
|
||||||
reasons: Set[str] = set()
|
reasons: Set[str] = set()
|
||||||
rule = torch._dynamo.trace_rules.lookup_inner(
|
rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons)
|
||||||
fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons
|
|
||||||
)
|
|
||||||
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
|
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
|
||||||
return SkipResult(
|
return SkipResult(
|
||||||
False,
|
False,
|
||||||
|
@ -17,7 +17,17 @@ import sys
|
|||||||
import types
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
FrozenSet,
|
||||||
|
List,
|
||||||
|
MutableMapping,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import SymInt
|
from torch import SymInt
|
||||||
@ -319,6 +329,17 @@ class FrameStateSizeEntry:
|
|||||||
stride: Optional[List[int]]
|
stride: Optional[List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
# All class-based iterators in itertools
|
||||||
|
# NOTE: use id() because some objects are not hashable, it will raise error during lookup
|
||||||
|
ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset(
|
||||||
|
id(member)
|
||||||
|
for name, member in vars(itertools).items()
|
||||||
|
if not name.startswith("_") and inspect.isclass(member)
|
||||||
|
)
|
||||||
|
# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py
|
||||||
|
ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set()
|
||||||
|
|
||||||
|
|
||||||
class VariableBuilder:
|
class VariableBuilder:
|
||||||
"""Wrap a python value in a VariableTracker() instance"""
|
"""Wrap a python value in a VariableTracker() instance"""
|
||||||
|
|
||||||
@ -874,7 +895,10 @@ class VariableBuilder:
|
|||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
)
|
)
|
||||||
elif istype(value, type) and value in itertools.__dict__.values():
|
elif (
|
||||||
|
id(value) in ITERTOOLS_TYPE_IDS
|
||||||
|
and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS
|
||||||
|
):
|
||||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return ItertoolsVariable(value, source=self.source)
|
return ItertoolsVariable(value, source=self.source)
|
||||||
elif isinstance(value, torch.SymBool):
|
elif isinstance(value, torch.SymBool):
|
||||||
|
@ -994,15 +994,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
if self.fn is dict and name == "fromkeys":
|
if self.fn is dict and name == "fromkeys":
|
||||||
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
|
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
|
||||||
if self.fn is itertools.chain and name == "from_iterable":
|
|
||||||
assert len(args) == 1
|
|
||||||
assert len(kwargs) == 0
|
|
||||||
obj = args[0]
|
|
||||||
items = []
|
|
||||||
for item in obj.unpack_var_sequence(tx):
|
|
||||||
items.extend(item.unpack_var_sequence(tx))
|
|
||||||
return variables.TupleVariable(items)
|
|
||||||
|
|
||||||
return super().call_method(tx, name, args, kwargs)
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
def _call_int_float(self, tx: "InstructionTranslator", arg):
|
def _call_int_float(self, tx: "InstructionTranslator", arg):
|
||||||
@ -1898,13 +1889,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
return variables.ListVariable(items)
|
return variables.ListVariable(items)
|
||||||
|
|
||||||
def call_chain(self, tx: "InstructionTranslator", *args):
|
|
||||||
if all(obj.has_unpack_var_sequence(tx) for obj in args):
|
|
||||||
items = []
|
|
||||||
for obj in args:
|
|
||||||
items.extend(obj.unpack_var_sequence(tx))
|
|
||||||
return variables.TupleVariable(items)
|
|
||||||
|
|
||||||
def call_islice(self, tx: "InstructionTranslator", iterable, *args):
|
def call_islice(self, tx: "InstructionTranslator", iterable, *args):
|
||||||
if iterable.has_unpack_var_sequence(tx) and all(
|
if iterable.has_unpack_var_sequence(tx) and all(
|
||||||
x.is_python_constant() for x in args
|
x.is_python_constant() for x in args
|
||||||
|
@ -18,6 +18,7 @@ from ..utils import (
|
|||||||
check_constant_args,
|
check_constant_args,
|
||||||
check_unspec_or_constant_args,
|
check_unspec_or_constant_args,
|
||||||
identity,
|
identity,
|
||||||
|
is_function,
|
||||||
is_wrapper_or_member_descriptor,
|
is_wrapper_or_member_descriptor,
|
||||||
istype,
|
istype,
|
||||||
make_cell,
|
make_cell,
|
||||||
@ -992,6 +993,27 @@ class PolyfilledFunctionVariable(VariableTracker):
|
|||||||
handler,
|
handler,
|
||||||
).call_function(tx, args, kwargs)
|
).call_function(tx, args, kwargs)
|
||||||
|
|
||||||
|
return super().call_function(tx, args, kwargs)
|
||||||
|
|
||||||
|
def call_method(
|
||||||
|
self,
|
||||||
|
tx,
|
||||||
|
name,
|
||||||
|
args: "List[VariableTracker]",
|
||||||
|
kwargs: "Dict[str, VariableTracker]",
|
||||||
|
) -> "VariableTracker":
|
||||||
|
if name == "__call__":
|
||||||
|
return self.call_function(tx, args, kwargs)
|
||||||
|
|
||||||
|
method = getattr(self.fn, name, None)
|
||||||
|
assert method is not None, f"Member {name} not found in {self.fn}"
|
||||||
|
assert is_function(method), f"Member {name} is not callable in {self.fn}"
|
||||||
|
options = {}
|
||||||
|
if self.source:
|
||||||
|
options["source"] = AttrSource(self.source, name)
|
||||||
|
member_variable = PolyfilledFunctionVariable(method, **options)
|
||||||
|
return member_variable.call_function(tx, args, kwargs)
|
||||||
|
|
||||||
def as_python_constant(self):
|
def as_python_constant(self):
|
||||||
return self.fn
|
return self.fn
|
||||||
|
|
||||||
|
@ -52,14 +52,6 @@ class ItertoolsVariable(VariableTracker):
|
|||||||
for item in itertools.product(*seqs):
|
for item in itertools.product(*seqs):
|
||||||
items.append(variables.TupleVariable(list(item)))
|
items.append(variables.TupleVariable(list(item)))
|
||||||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
||||||
elif (
|
|
||||||
self.value is itertools.chain
|
|
||||||
and not kwargs
|
|
||||||
and all(arg.has_unpack_var_sequence(tx) for arg in args)
|
|
||||||
):
|
|
||||||
seqs = [arg.unpack_var_sequence(tx) for arg in args]
|
|
||||||
items = list(itertools.chain.from_iterable(seqs))
|
|
||||||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
|
||||||
elif self.value is itertools.accumulate:
|
elif self.value is itertools.accumulate:
|
||||||
from .builtin import BuiltinVariable
|
from .builtin import BuiltinVariable
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user