diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index d7c4cbdcb507..4338e3047b89 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -170,6 +170,8 @@ def substitute_in_graph( *, can_constant_fold_through: bool = False, skip_signature_check: bool = False, + # type that is embedded in the Python interpreter + is_embedded_type: bool = False, # internal use only ) -> Callable[[_F], _F]: """ Register a polyfill handler for a function, usually a C function from the C extension, to be @@ -219,10 +221,22 @@ def substitute_in_graph( >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) 2 """ - if not is_function(original_fn): + if not is_function(original_fn) and not ( + is_embedded_type and inspect.isclass(original_fn) + ): raise TypeError( f"substitute_in_graph expects a function but got {type(original_fn)!r}" ) + if is_embedded_type: + if not inspect.isclass(original_fn): + raise TypeError( + f"substitute_in_graph expects a class but got {type(original_fn)!r}" + ) + + from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS + + if id(original_fn) in ITERTOOLS_TYPE_IDS: + ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn)) def wrapper(traceable_fn: _F) -> _F: if not is_function(traceable_fn): diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 090df0f84b52..802a62a82c8c 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -10,12 +10,31 @@ from typing import Iterable, Iterator, TypeVar from ..decorators import substitute_in_graph -__all__ = ["tee"] +__all__ = [ + "chain", + "chain_from_iterable", + "tee", +] _T = TypeVar("_T") +# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain +@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type] +def chain(*iterables: Iterable[_T]) -> Iterator[_T]: + for iterable in iterables: + yield from iterable + + +@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] +def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: + return itertools.chain(*iterable) + + +chain.from_iterable = chain_from_iterable # type: ignore[method-assign] + + # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 255486b6e016..d9367b55bf14 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -32,3 +32,9 @@ for polyfill_module in POLYFILLED_MODULES: polyfill_handler = getattr(polyfill_module, polyfill_name) original_fn = polyfill_handler.__torch_dynamo_original__ trace_rules._builtin_function_ids.remove(id(original_fn)) + + # Unregister the class object if the original function is its __new__ method + if original_fn.__name__ == "__new__" and isinstance( + getattr(original_fn, "__self__", None), type + ): + trace_rules._builtin_function_ids.remove(id(original_fn.__self__)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index e4f2e3786892..464b411a8b95 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2993,9 +2993,7 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) - rv.update( - {id(v): f"itertools.{v.__name__}" for v in (itertools.chain, itertools.islice)} - ) + rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)}) rv.update( { id(cast): "typing.cast", @@ -3471,9 +3469,7 @@ def check_verbose(obj, is_inlined_call=False): # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: Set[str] = set() - rule = torch._dynamo.trace_rules.lookup_inner( - fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons - ) + rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)): return SkipResult( False, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index b4f35dbbd404..e0d355a7fb2b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -17,7 +17,17 @@ import sys import types import warnings import weakref -from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import ( + Any, + FrozenSet, + List, + MutableMapping, + NamedTuple, + Optional, + Set, + TYPE_CHECKING, + Union, +) import torch from torch import SymInt @@ -319,6 +329,17 @@ class FrameStateSizeEntry: stride: Optional[List[int]] +# All class-based iterators in itertools +# NOTE: use id() because some objects are not hashable, it will raise error during lookup +ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset( + id(member) + for name, member in vars(itertools).items() + if not name.startswith("_") and inspect.isclass(member) +) +# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py +ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set() + + class VariableBuilder: """Wrap a python value in a VariableTracker() instance""" @@ -874,7 +895,10 @@ class VariableBuilder: value, source=self.source, ) - elif istype(value, type) and value in itertools.__dict__.values(): + elif ( + id(value) in ITERTOOLS_TYPE_IDS + and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS + ): self.install_guards(GuardBuilder.FUNCTION_MATCH) return ItertoolsVariable(value, source=self.source) elif isinstance(value, torch.SymBool): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index ca972b282e38..f01b4b6efbe5 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -994,15 +994,6 @@ class BuiltinVariable(VariableTracker): ) if self.fn is dict and name == "fromkeys": return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) - if self.fn is itertools.chain and name == "from_iterable": - assert len(args) == 1 - assert len(kwargs) == 0 - obj = args[0] - items = [] - for item in obj.unpack_var_sequence(tx): - items.extend(item.unpack_var_sequence(tx)) - return variables.TupleVariable(items) - return super().call_method(tx, name, args, kwargs) def _call_int_float(self, tx: "InstructionTranslator", arg): @@ -1898,13 +1889,6 @@ class BuiltinVariable(VariableTracker): ) return variables.ListVariable(items) - def call_chain(self, tx: "InstructionTranslator", *args): - if all(obj.has_unpack_var_sequence(tx) for obj in args): - items = [] - for obj in args: - items.extend(obj.unpack_var_sequence(tx)) - return variables.TupleVariable(items) - def call_islice(self, tx: "InstructionTranslator", iterable, *args): if iterable.has_unpack_var_sequence(tx) and all( x.is_python_constant() for x in args diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index a9ec80cc3de5..239b720a5a75 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -18,6 +18,7 @@ from ..utils import ( check_constant_args, check_unspec_or_constant_args, identity, + is_function, is_wrapper_or_member_descriptor, istype, make_cell, @@ -992,6 +993,27 @@ class PolyfilledFunctionVariable(VariableTracker): handler, ).call_function(tx, args, kwargs) + return super().call_function(tx, args, kwargs) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__call__": + return self.call_function(tx, args, kwargs) + + method = getattr(self.fn, name, None) + assert method is not None, f"Member {name} not found in {self.fn}" + assert is_function(method), f"Member {name} is not callable in {self.fn}" + options = {} + if self.source: + options["source"] = AttrSource(self.source, name) + member_variable = PolyfilledFunctionVariable(method, **options) + return member_variable.call_function(tx, args, kwargs) + def as_python_constant(self): return self.fn diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 6687611cf0aa..ed3ae786634e 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -52,14 +52,6 @@ class ItertoolsVariable(VariableTracker): for item in itertools.product(*seqs): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) - elif ( - self.value is itertools.chain - and not kwargs - and all(arg.has_unpack_var_sequence(tx) for arg in args) - ): - seqs = [arg.unpack_var_sequence(tx) for arg in args] - items = list(itertools.chain.from_iterable(seqs)) - return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) elif self.value is itertools.accumulate: from .builtin import BuiltinVariable