[RFC][dynamo] add decorator to register polyfill for unsupported C++ function to avoid graph break (#133712)

Add decorator `torch.compiler.substitute_in_graph` to register polyfill for unsupported C++ function to avoid graph break. This API provides an official way to add support for dynamo for third-party C extensions. Also, it can be used to simplify our implementation for `torch._dynamo.polyfill`.

5ee070266f/torch/_dynamo/variables/builtin.py (L97-L107)

Example:

```python
>>> import operator
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
2

>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
Unsupported: ...

>>> @torch.compiler.substitute_in_graph(operator.indexOf)
... def indexOf(sequence, x):
...     for i, item in enumerate(sequence):
...         if item is x or item == x:
...             return i
...     raise ValueError("sequence.index(x): x not in sequence")

>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133712
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2024-08-21 00:23:11 +08:00
committed by PyTorch MergeBot
parent 843fdf81c2
commit 022cd7c9aa
9 changed files with 283 additions and 4 deletions

View File

@ -1222,7 +1222,6 @@ exclude_patterns = [
'torch/_export/trace.py',
'torch/_export/verifier.py',
'torch/_vendor/**',
'torch/compiler/__init__.py',
'torch/contrib/__init__.py',
'torch/contrib/_tensorboard_vis.py',
"torch/cuda/_gpu_trace.py",

View File

@ -16,6 +16,7 @@ For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`.
compile
reset
allow_in_graph
substitute_in_graph
assume_constant_result
list_backends
disable

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import functools
import operator
import os
import unittest.mock as mock
from unittest.mock import patch
@ -8,6 +9,7 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.exc import IncorrectUsage
from torch._dynamo.utils import counters
def my_custom_function(x):
@ -245,6 +247,51 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
opt_fn(torch.randn(4))
self.assertEqual(cnts.frame_count, 2)
def test_substitute_in_graph(self):
counters.clear()
# NB: Choose another C function for test when we support operator.indexOf
# out of the box
cnts = torch._dynamo.testing.CompileCounter()
fn = operator.indexOf
opt_fn = torch._dynamo.optimize(cnts)(fn)
out = fn([1, 2, 3, 4, 5], 3)
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
self.assertEqual(out, opt_out)
self.assertEqual(cnts.frame_count, 0)
self.assertEqual(len(counters["graph_break"]), 1)
torch._dynamo.reset()
counters.clear()
@torch._dynamo.substitute_in_graph(operator.indexOf)
def polyfill(sequence, x):
for i, item in enumerate(sequence):
if item is x or item == x:
return i
raise ValueError("sequence.index(x): x not in sequence")
cnts = torch._dynamo.testing.CompileCounter()
fn = operator.indexOf
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
out = fn([1, 2, 3, 4, 5], 3)
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
self.assertEqual(out, opt_out)
self.assertEqual(cnts.frame_count, 0)
self.assertEqual(len(counters["graph_break"]), 0)
torch._dynamo.reset()
counters.clear()
cnts = torch._dynamo.testing.CompileCounter()
fn = polyfill
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
out = fn([1, 2, 3, 4, 5], 3)
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
self.assertEqual(out, opt_out)
self.assertEqual(cnts.frame_count, 0)
self.assertEqual(len(counters["graph_break"]), 0)
@patch.object(torch._dynamo.config, "suppress_errors", True)
def test_nested_disable_decorator(self):
cnts = torch._dynamo.testing.CompileCounter()

View File

@ -17,6 +17,7 @@ from .decorators import (
mark_static_address,
maybe_mark_dynamic,
run,
substitute_in_graph,
)
from .eval_frame import (
_reset_guarded_backend_cache,
@ -39,6 +40,7 @@ __all__ = [
"assume_constant_result",
"disallow_in_graph",
"forbid_in_graph",
"substitute_in_graph",
"graph_break",
"mark_dynamic",
"maybe_mark_dynamic",

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
# ruff: noqa: TCH004
import functools
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any, Callable, TYPE_CHECKING, TypeVar
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
@ -11,6 +12,7 @@ from .comptime import comptime
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
from .exc import IncorrectUsage
from .external_utils import is_compiling
from .utils import is_function
if TYPE_CHECKING:
@ -28,6 +30,9 @@ else:
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
_F = TypeVar("_F", bound=Callable[..., Any])
def run(fn=None):
"""Don't do any dynamic compiles, just use prior optimizations"""
if fn is not None:
@ -159,6 +164,106 @@ def forbid_in_graph(fn):
return fn
def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
"""
Register a polyfill handler for a function, usually a C function from the C extension, to be
used in place of the original function when inlining the original function in the graph.
.. note::
The polyfill handler is only used when inlining the original function. It is not used when
the original function is called directly. In the eager mode, the decorated function calls
the performant C function rather than the polyfill handler.
The polyfill handler is a function that will be called in place of the original function when
inlining the original function. The polyfill handler should have the same signature and the same
behavior as the original function.
Args:
original_fn (callable): The original function, usually a C function, to register a polyfill
handler for.
Returns:
A decorator that registers the polyfill handler for the original function.
Example::
>>> # xdoctest: +SKIP("conflict with the tests: duplicate polyfill handlers")
>>> import operator
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
2
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
Traceback (most recent call last):
...
torch._dynamo.exc.Unsupported: ...
>>> @torch.compiler.substitute_in_graph(operator.indexOf)
... def indexOf(a, b, /):
... for i, item in enumerate(a):
... if item is b or item == b:
... return i
... raise ValueError("sequence.index(x): x not in sequence")
>>>
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2
"""
if not is_function(original_fn):
raise TypeError(
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
)
def wrapper(traceable_fn: _F) -> _F:
from torch._dynamo.guards import GuardBuilder
from torch._dynamo.trace_rules import get_torch_obj_rule_map
from torch._dynamo.variables import PolyfilledFunctionVariable
from torch._dynamo.variables.builder import VariableBuilder
id_dispatch_map = VariableBuilder._id_dispatch()
if id(original_fn) in id_dispatch_map:
raise ValueError(
f"Duplicate dispatch rule for {original_fn}: "
"already registered in VariableBuilder's id dispatch map"
)
rule_map = get_torch_obj_rule_map()
if original_fn in rule_map:
raise ValueError(
f"Duplicate object {original_fn} with different rules: "
f"{PolyfilledFunctionVariable}, {rule_map[original_fn]}"
)
polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers()
if original_fn in polyfill_handlers:
raise ValueError(
f"Duplicate polyfill handlers for {original_fn}: "
f"already handled by {polyfill_handlers[original_fn]}"
)
# Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a
# C++ function.
@functools.wraps(traceable_fn)
def wrapped(*args, **kwargs):
return original_fn(*args, **kwargs)
def dispatch_fn(self, value):
return PolyfilledFunctionVariable(
value,
source=self.source,
**self.install_guards(GuardBuilder.CLOSURE_MATCH),
)
id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn
rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable
polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = traceable_fn
wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined]
wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined]
return wrapped # type: ignore[return-value]
return wrapper
# Helper function to flatten a tensor subclass and apply a function to
# all inner tensors that match the outer dim. Used to reduce duplication
# across the various marking APIs.

View File

@ -47,6 +47,7 @@ from .variables import (
FunctionalCallVariable,
FunctorchHigherOrderVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
SkipFunctionVariable,
TorchInGraphFunctionVariable,
UserFunctionVariable,
@ -3459,7 +3460,7 @@ def check_verbose(obj, is_inlined_call=False):
rule = torch._dynamo.trace_rules.lookup_inner(
fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons
)
if issubclass(rule, UserFunctionVariable):
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
return SkipResult(
False,
f"inlined according trace_rules.lookup {reasons.pop()}",

View File

@ -30,6 +30,7 @@ from .distributed import BackwardHookVariable, DistributedVariable, PlacementVar
from .functions import (
FunctoolsPartialVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
SkipFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
@ -144,6 +145,7 @@ __all__ = [
"NumpyVariable",
"OptimizerVariable",
"PlacementVariable",
"PolyfilledFunctionVariable",
"PythonModuleVariable",
"RangeVariable",
"RegexPatternVariable",

View File

@ -927,6 +927,63 @@ class FunctoolsPartialVariable(VariableTracker):
)
class PolyfilledFunctionVariable(VariableTracker):
_nonvar_fields = {
"fn",
*BaseUserFunctionVariable._nonvar_fields,
}
@classmethod
@functools.lru_cache(None)
def _get_polyfill_handlers(cls):
return {}
@classmethod
def create_with_source(cls, value, source):
return cls(
value,
source=source,
)
def __init__(self, fn: VariableTracker, **kwargs) -> None:
super().__init__(**kwargs)
self.fn = fn
def get_function(self):
return self.as_python_constant()
def call_function(
self,
tx: "InstructionTranslator",
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from torch._dynamo.variables.builder import SourcelessBuilder
handler = self._get_polyfill_handlers().get(self.fn)
if handler:
assert callable(handler)
return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs)
for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"):
handler = getattr(self.fn, candidate, None)
if handler:
assert callable(handler)
if self.source:
source = AttrSource(self.source, candidate)
return UserFunctionVariable.create_with_source(
handler,
source=source,
).call_function(tx, args, kwargs)
return SourcelessBuilder.create(
tx,
handler,
).call_function(tx, args, kwargs)
def as_python_constant(self):
return self.fn
from torch._higher_order_ops.triton_kernel_wrap import TritonHOPifier

View File

@ -1,12 +1,15 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, List, TypeVar
import torch
from typing import List
__all__ = [
"compile",
"assume_constant_result",
"reset",
"allow_in_graph",
"substitute_in_graph",
"list_backends",
"disable",
"cudagraph_mark_step_begin",
@ -15,12 +18,17 @@ __all__ = [
"is_dynamo_compiling",
]
_F = TypeVar("_F", bound=Callable[..., Any])
def compile(*args, **kwargs):
"""
See :func:`torch.compile` for details on the arguments for this function.
"""
return torch.compile(*args, **kwargs)
def reset() -> None:
"""
This function clears all compilation caches and restores the system to its initial state.
@ -31,6 +39,7 @@ def reset() -> None:
torch._dynamo.reset()
def allow_in_graph(fn):
"""
Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
@ -111,6 +120,54 @@ def allow_in_graph(fn):
return torch._dynamo.allow_in_graph(fn)
def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
"""
Register a polyfill handler for a function, usually a C function from the C extension, to be
used in place of the original function when inlining the original function in the graph.
.. note::
The polyfill handler is only used when inlining the original function. It is not used when
the original function is called directly. In the eager mode, the decorated function calls
the performant C function rather than the polyfill handler.
The polyfill handler is a function that will be called in place of the original function when
inlining the original function. The polyfill handler should have the same signature and the same
behavior as the original function.
Args:
original_fn (callable): The original function, usually a C function, to register a polyfill
handler for.
Returns:
A decorator that registers the polyfill handler for the original function.
Example::
>>> import operator
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
2
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
... # xdoctest: +SKIP("Long tracebacks")
Traceback (most recent call last):
...
torch._dynamo.exc.Unsupported: ...
>>> @torch.compiler.substitute_in_graph(operator.indexOf)
... def indexOf(a, b, /):
... for i, item in enumerate(a):
... if item is b or item == b:
... return i
... raise ValueError("sequence.index(x): x not in sequence")
>>>
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2
"""
import torch._dynamo
return torch._dynamo.substitute_in_graph(original_fn)
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
"""
Return valid strings that can be passed to `torch.compile(..., backend="name")`.
@ -122,6 +179,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
return torch._dynamo.list_backends(exclude_tags)
def assume_constant_result(fn):
"""
This function is used to mark a function `fn` as having a constant result.
@ -140,6 +198,7 @@ def assume_constant_result(fn):
return torch._dynamo.assume_constant_result(fn)
def disable(fn=None, recursive=True):
"""
This function provides both a decorator and a context manager to disable compilation on a function
@ -153,6 +212,7 @@ def disable(fn=None, recursive=True):
return torch._dynamo.disable(fn, recursive)
def cudagraph_mark_step_begin():
"""
Indicates that a new iteration of inference or training is about to begin.
@ -178,6 +238,7 @@ def cudagraph_mark_step_begin():
cudagraph_trees.mark_step_begin()
def wrap_numpy(fn):
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
from ``torch.Tensor``s to ``torch.Tensor``s.
@ -206,10 +267,13 @@ def wrap_numpy(fn):
tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0')
"""
from torch._dynamo.external_utils import wrap_numpy as wrap
return wrap(fn)
_is_compiling_flag: bool = False
def is_compiling() -> bool:
"""
Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().
@ -231,6 +295,7 @@ def is_compiling() -> bool:
else:
return _is_compiling_flag
def is_dynamo_compiling() -> bool:
"""
Indicates whether a graph is traced via TorchDynamo.