diff --git a/.lintrunner.toml b/.lintrunner.toml index 123581b273da..dafe6207a7c9 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1222,7 +1222,6 @@ exclude_patterns = [ 'torch/_export/trace.py', 'torch/_export/verifier.py', 'torch/_vendor/**', - 'torch/compiler/__init__.py', 'torch/contrib/__init__.py', 'torch/contrib/_tensorboard_vis.py', "torch/cuda/_gpu_trace.py", diff --git a/docs/source/torch.compiler_api.rst b/docs/source/torch.compiler_api.rst index a377349f84ae..e1c05f71c146 100644 --- a/docs/source/torch.compiler_api.rst +++ b/docs/source/torch.compiler_api.rst @@ -16,6 +16,7 @@ For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`. compile reset allow_in_graph + substitute_in_graph assume_constant_result list_backends disable diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index d822401d5a70..3938e78e139a 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import functools +import operator import os import unittest.mock as mock from unittest.mock import patch @@ -8,6 +9,7 @@ import torch import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.exc import IncorrectUsage +from torch._dynamo.utils import counters def my_custom_function(x): @@ -245,6 +247,51 @@ class DecoratorTests(torch._dynamo.test_case.TestCase): opt_fn(torch.randn(4)) self.assertEqual(cnts.frame_count, 2) + def test_substitute_in_graph(self): + counters.clear() + + # NB: Choose another C function for test when we support operator.indexOf + # out of the box + cnts = torch._dynamo.testing.CompileCounter() + fn = operator.indexOf + opt_fn = torch._dynamo.optimize(cnts)(fn) + out = fn([1, 2, 3, 4, 5], 3) + opt_out = opt_fn([1, 2, 3, 4, 5], 3) + self.assertEqual(out, opt_out) + self.assertEqual(cnts.frame_count, 0) + self.assertEqual(len(counters["graph_break"]), 1) + + torch._dynamo.reset() + counters.clear() + + @torch._dynamo.substitute_in_graph(operator.indexOf) + def polyfill(sequence, x): + for i, item in enumerate(sequence): + if item is x or item == x: + return i + raise ValueError("sequence.index(x): x not in sequence") + + cnts = torch._dynamo.testing.CompileCounter() + fn = operator.indexOf + opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) + out = fn([1, 2, 3, 4, 5], 3) + opt_out = opt_fn([1, 2, 3, 4, 5], 3) + self.assertEqual(out, opt_out) + self.assertEqual(cnts.frame_count, 0) + self.assertEqual(len(counters["graph_break"]), 0) + + torch._dynamo.reset() + counters.clear() + + cnts = torch._dynamo.testing.CompileCounter() + fn = polyfill + opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) + out = fn([1, 2, 3, 4, 5], 3) + opt_out = opt_fn([1, 2, 3, 4, 5], 3) + self.assertEqual(out, opt_out) + self.assertEqual(cnts.frame_count, 0) + self.assertEqual(len(counters["graph_break"]), 0) + @patch.object(torch._dynamo.config, "suppress_errors", True) def test_nested_disable_decorator(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 39e568ee19e0..8f7bab6362cc 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -17,6 +17,7 @@ from .decorators import ( mark_static_address, maybe_mark_dynamic, run, + substitute_in_graph, ) from .eval_frame import ( _reset_guarded_backend_cache, @@ -39,6 +40,7 @@ __all__ = [ "assume_constant_result", "disallow_in_graph", "forbid_in_graph", + "substitute_in_graph", "graph_break", "mark_dynamic", "maybe_mark_dynamic", diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 814f3e9d65dc..d5130b296eca 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs # ruff: noqa: TCH004 +import functools from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING, TypeVar import torch from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -11,6 +12,7 @@ from .comptime import comptime from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage from .external_utils import is_compiling +from .utils import is_function if TYPE_CHECKING: @@ -28,6 +30,9 @@ else: globals()[name] = getattr(torch._C._dynamo.eval_frame, name) +_F = TypeVar("_F", bound=Callable[..., Any]) + + def run(fn=None): """Don't do any dynamic compiles, just use prior optimizations""" if fn is not None: @@ -159,6 +164,106 @@ def forbid_in_graph(fn): return fn +def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]: + """ + Register a polyfill handler for a function, usually a C function from the C extension, to be + used in place of the original function when inlining the original function in the graph. + + .. note:: + + The polyfill handler is only used when inlining the original function. It is not used when + the original function is called directly. In the eager mode, the decorated function calls + the performant C function rather than the polyfill handler. + + The polyfill handler is a function that will be called in place of the original function when + inlining the original function. The polyfill handler should have the same signature and the same + behavior as the original function. + + Args: + original_fn (callable): The original function, usually a C function, to register a polyfill + handler for. + + Returns: + A decorator that registers the polyfill handler for the original function. + + Example:: + + >>> # xdoctest: +SKIP("conflict with the tests: duplicate polyfill handlers") + >>> import operator + >>> operator.indexOf([1, 2, 3, 4, 5], 3) + 2 + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + Traceback (most recent call last): + ... + torch._dynamo.exc.Unsupported: ... + + >>> @torch.compiler.substitute_in_graph(operator.indexOf) + ... def indexOf(a, b, /): + ... for i, item in enumerate(a): + ... if item is b or item == b: + ... return i + ... raise ValueError("sequence.index(x): x not in sequence") + >>> + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + 2 + """ + if not is_function(original_fn): + raise TypeError( + f"substitute_in_graph expects a function but got {type(original_fn)!r}" + ) + + def wrapper(traceable_fn: _F) -> _F: + from torch._dynamo.guards import GuardBuilder + from torch._dynamo.trace_rules import get_torch_obj_rule_map + from torch._dynamo.variables import PolyfilledFunctionVariable + from torch._dynamo.variables.builder import VariableBuilder + + id_dispatch_map = VariableBuilder._id_dispatch() + if id(original_fn) in id_dispatch_map: + raise ValueError( + f"Duplicate dispatch rule for {original_fn}: " + "already registered in VariableBuilder's id dispatch map" + ) + + rule_map = get_torch_obj_rule_map() + if original_fn in rule_map: + raise ValueError( + f"Duplicate object {original_fn} with different rules: " + f"{PolyfilledFunctionVariable}, {rule_map[original_fn]}" + ) + + polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers() + if original_fn in polyfill_handlers: + raise ValueError( + f"Duplicate polyfill handlers for {original_fn}: " + f"already handled by {polyfill_handlers[original_fn]}" + ) + + # Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a + # C++ function. + @functools.wraps(traceable_fn) + def wrapped(*args, **kwargs): + return original_fn(*args, **kwargs) + + def dispatch_fn(self, value): + return PolyfilledFunctionVariable( + value, + source=self.source, + **self.install_guards(GuardBuilder.CLOSURE_MATCH), + ) + + id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn + rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable + polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = traceable_fn + + wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined] + wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined] + + return wrapped # type: ignore[return-value] + + return wrapper + + # Helper function to flatten a tensor subclass and apply a function to # all inner tensors that match the outer dim. Used to reduce duplication # across the various marking APIs. diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 404bd49e6767..cd82511c97e8 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -47,6 +47,7 @@ from .variables import ( FunctionalCallVariable, FunctorchHigherOrderVariable, NestedUserFunctionVariable, + PolyfilledFunctionVariable, SkipFunctionVariable, TorchInGraphFunctionVariable, UserFunctionVariable, @@ -3459,7 +3460,7 @@ def check_verbose(obj, is_inlined_call=False): rule = torch._dynamo.trace_rules.lookup_inner( fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons ) - if issubclass(rule, UserFunctionVariable): + if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)): return SkipResult( False, f"inlined according trace_rules.lookup {reasons.pop()}", diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 2b4d8f787d9d..d8f97148c106 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -30,6 +30,7 @@ from .distributed import BackwardHookVariable, DistributedVariable, PlacementVar from .functions import ( FunctoolsPartialVariable, NestedUserFunctionVariable, + PolyfilledFunctionVariable, SkipFunctionVariable, UserFunctionVariable, UserMethodVariable, @@ -144,6 +145,7 @@ __all__ = [ "NumpyVariable", "OptimizerVariable", "PlacementVariable", + "PolyfilledFunctionVariable", "PythonModuleVariable", "RangeVariable", "RegexPatternVariable", diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 8f1220193799..6f220a614ae9 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -927,6 +927,63 @@ class FunctoolsPartialVariable(VariableTracker): ) +class PolyfilledFunctionVariable(VariableTracker): + _nonvar_fields = { + "fn", + *BaseUserFunctionVariable._nonvar_fields, + } + + @classmethod + @functools.lru_cache(None) + def _get_polyfill_handlers(cls): + return {} + + @classmethod + def create_with_source(cls, value, source): + return cls( + value, + source=source, + ) + + def __init__(self, fn: VariableTracker, **kwargs) -> None: + super().__init__(**kwargs) + self.fn = fn + + def get_function(self): + return self.as_python_constant() + + def call_function( + self, + tx: "InstructionTranslator", + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from torch._dynamo.variables.builder import SourcelessBuilder + + handler = self._get_polyfill_handlers().get(self.fn) + if handler: + assert callable(handler) + return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs) + + for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"): + handler = getattr(self.fn, candidate, None) + if handler: + assert callable(handler) + if self.source: + source = AttrSource(self.source, candidate) + return UserFunctionVariable.create_with_source( + handler, + source=source, + ).call_function(tx, args, kwargs) + return SourcelessBuilder.create( + tx, + handler, + ).call_function(tx, args, kwargs) + + def as_python_constant(self): + return self.fn + + from torch._higher_order_ops.triton_kernel_wrap import TritonHOPifier diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 812bbaa4c660..5f386215374a 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,12 +1,15 @@ # mypy: allow-untyped-defs +from typing import Any, Callable, List, TypeVar + import torch -from typing import List + __all__ = [ "compile", "assume_constant_result", "reset", "allow_in_graph", + "substitute_in_graph", "list_backends", "disable", "cudagraph_mark_step_begin", @@ -15,12 +18,17 @@ __all__ = [ "is_dynamo_compiling", ] + +_F = TypeVar("_F", bound=Callable[..., Any]) + + def compile(*args, **kwargs): """ See :func:`torch.compile` for details on the arguments for this function. """ return torch.compile(*args, **kwargs) + def reset() -> None: """ This function clears all compilation caches and restores the system to its initial state. @@ -31,6 +39,7 @@ def reset() -> None: torch._dynamo.reset() + def allow_in_graph(fn): """ Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function @@ -111,6 +120,54 @@ def allow_in_graph(fn): return torch._dynamo.allow_in_graph(fn) +def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]: + """ + Register a polyfill handler for a function, usually a C function from the C extension, to be + used in place of the original function when inlining the original function in the graph. + + .. note:: + + The polyfill handler is only used when inlining the original function. It is not used when + the original function is called directly. In the eager mode, the decorated function calls + the performant C function rather than the polyfill handler. + + The polyfill handler is a function that will be called in place of the original function when + inlining the original function. The polyfill handler should have the same signature and the same + behavior as the original function. + + Args: + original_fn (callable): The original function, usually a C function, to register a polyfill + handler for. + + Returns: + A decorator that registers the polyfill handler for the original function. + + Example:: + + >>> import operator + >>> operator.indexOf([1, 2, 3, 4, 5], 3) + 2 + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + ... # xdoctest: +SKIP("Long tracebacks") + Traceback (most recent call last): + ... + torch._dynamo.exc.Unsupported: ... + + >>> @torch.compiler.substitute_in_graph(operator.indexOf) + ... def indexOf(a, b, /): + ... for i, item in enumerate(a): + ... if item is b or item == b: + ... return i + ... raise ValueError("sequence.index(x): x not in sequence") + >>> + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + 2 + """ + import torch._dynamo + + return torch._dynamo.substitute_in_graph(original_fn) + + def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: """ Return valid strings that can be passed to `torch.compile(..., backend="name")`. @@ -122,6 +179,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: return torch._dynamo.list_backends(exclude_tags) + def assume_constant_result(fn): """ This function is used to mark a function `fn` as having a constant result. @@ -140,6 +198,7 @@ def assume_constant_result(fn): return torch._dynamo.assume_constant_result(fn) + def disable(fn=None, recursive=True): """ This function provides both a decorator and a context manager to disable compilation on a function @@ -153,6 +212,7 @@ def disable(fn=None, recursive=True): return torch._dynamo.disable(fn, recursive) + def cudagraph_mark_step_begin(): """ Indicates that a new iteration of inference or training is about to begin. @@ -178,6 +238,7 @@ def cudagraph_mark_step_begin(): cudagraph_trees.mark_step_begin() + def wrap_numpy(fn): r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function from ``torch.Tensor``s to ``torch.Tensor``s. @@ -206,10 +267,13 @@ def wrap_numpy(fn): tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0') """ from torch._dynamo.external_utils import wrap_numpy as wrap + return wrap(fn) + _is_compiling_flag: bool = False + def is_compiling() -> bool: """ Indicates whether a graph is executed/traced as part of torch.compile() or torch.export(). @@ -231,6 +295,7 @@ def is_compiling() -> bool: else: return _is_compiling_flag + def is_dynamo_compiling() -> bool: """ Indicates whether a graph is traced via TorchDynamo.