mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[RFC][dynamo] add decorator to register polyfill for unsupported C++ function to avoid graph break (#133712)
Add decorator `torch.compiler.substitute_in_graph` to register polyfill for unsupported C++ function to avoid graph break. This API provides an official way to add support for dynamo for third-party C extensions. Also, it can be used to simplify our implementation for `torch._dynamo.polyfill`.
5ee070266f/torch/_dynamo/variables/builtin.py (L97-L107)
Example:
```python
>>> import operator
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
2
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
Unsupported: ...
>>> @torch.compiler.substitute_in_graph(operator.indexOf)
... def indexOf(sequence, x):
... for i, item in enumerate(sequence):
... if item is x or item == x:
... return i
... raise ValueError("sequence.index(x): x not in sequence")
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133712
Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
843fdf81c2
commit
022cd7c9aa
@ -1222,7 +1222,6 @@ exclude_patterns = [
|
||||
'torch/_export/trace.py',
|
||||
'torch/_export/verifier.py',
|
||||
'torch/_vendor/**',
|
||||
'torch/compiler/__init__.py',
|
||||
'torch/contrib/__init__.py',
|
||||
'torch/contrib/_tensorboard_vis.py',
|
||||
"torch/cuda/_gpu_trace.py",
|
||||
|
@ -16,6 +16,7 @@ For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`.
|
||||
compile
|
||||
reset
|
||||
allow_in_graph
|
||||
substitute_in_graph
|
||||
assume_constant_result
|
||||
list_backends
|
||||
disable
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import functools
|
||||
import operator
|
||||
import os
|
||||
import unittest.mock as mock
|
||||
from unittest.mock import patch
|
||||
@ -8,6 +9,7 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.exc import IncorrectUsage
|
||||
from torch._dynamo.utils import counters
|
||||
|
||||
|
||||
def my_custom_function(x):
|
||||
@ -245,6 +247,51 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
opt_fn(torch.randn(4))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_substitute_in_graph(self):
|
||||
counters.clear()
|
||||
|
||||
# NB: Choose another C function for test when we support operator.indexOf
|
||||
# out of the box
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = operator.indexOf
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
out = fn([1, 2, 3, 4, 5], 3)
|
||||
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
||||
self.assertEqual(out, opt_out)
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
|
||||
@torch._dynamo.substitute_in_graph(operator.indexOf)
|
||||
def polyfill(sequence, x):
|
||||
for i, item in enumerate(sequence):
|
||||
if item is x or item == x:
|
||||
return i
|
||||
raise ValueError("sequence.index(x): x not in sequence")
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = operator.indexOf
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
out = fn([1, 2, 3, 4, 5], 3)
|
||||
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
||||
self.assertEqual(out, opt_out)
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = polyfill
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
out = fn([1, 2, 3, 4, 5], 3)
|
||||
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
||||
self.assertEqual(out, opt_out)
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
@patch.object(torch._dynamo.config, "suppress_errors", True)
|
||||
def test_nested_disable_decorator(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
@ -17,6 +17,7 @@ from .decorators import (
|
||||
mark_static_address,
|
||||
maybe_mark_dynamic,
|
||||
run,
|
||||
substitute_in_graph,
|
||||
)
|
||||
from .eval_frame import (
|
||||
_reset_guarded_backend_cache,
|
||||
@ -39,6 +40,7 @@ __all__ = [
|
||||
"assume_constant_result",
|
||||
"disallow_in_graph",
|
||||
"forbid_in_graph",
|
||||
"substitute_in_graph",
|
||||
"graph_break",
|
||||
"mark_dynamic",
|
||||
"maybe_mark_dynamic",
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# ruff: noqa: TCH004
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
@ -11,6 +12,7 @@ from .comptime import comptime
|
||||
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
|
||||
from .exc import IncorrectUsage
|
||||
from .external_utils import is_compiling
|
||||
from .utils import is_function
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -28,6 +30,9 @@ else:
|
||||
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
|
||||
|
||||
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def run(fn=None):
|
||||
"""Don't do any dynamic compiles, just use prior optimizations"""
|
||||
if fn is not None:
|
||||
@ -159,6 +164,106 @@ def forbid_in_graph(fn):
|
||||
return fn
|
||||
|
||||
|
||||
def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
|
||||
"""
|
||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||
used in place of the original function when inlining the original function in the graph.
|
||||
|
||||
.. note::
|
||||
|
||||
The polyfill handler is only used when inlining the original function. It is not used when
|
||||
the original function is called directly. In the eager mode, the decorated function calls
|
||||
the performant C function rather than the polyfill handler.
|
||||
|
||||
The polyfill handler is a function that will be called in place of the original function when
|
||||
inlining the original function. The polyfill handler should have the same signature and the same
|
||||
behavior as the original function.
|
||||
|
||||
Args:
|
||||
original_fn (callable): The original function, usually a C function, to register a polyfill
|
||||
handler for.
|
||||
|
||||
Returns:
|
||||
A decorator that registers the polyfill handler for the original function.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP("conflict with the tests: duplicate polyfill handlers")
|
||||
>>> import operator
|
||||
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
|
||||
2
|
||||
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
torch._dynamo.exc.Unsupported: ...
|
||||
|
||||
>>> @torch.compiler.substitute_in_graph(operator.indexOf)
|
||||
... def indexOf(a, b, /):
|
||||
... for i, item in enumerate(a):
|
||||
... if item is b or item == b:
|
||||
... return i
|
||||
... raise ValueError("sequence.index(x): x not in sequence")
|
||||
>>>
|
||||
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
|
||||
2
|
||||
"""
|
||||
if not is_function(original_fn):
|
||||
raise TypeError(
|
||||
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
|
||||
)
|
||||
|
||||
def wrapper(traceable_fn: _F) -> _F:
|
||||
from torch._dynamo.guards import GuardBuilder
|
||||
from torch._dynamo.trace_rules import get_torch_obj_rule_map
|
||||
from torch._dynamo.variables import PolyfilledFunctionVariable
|
||||
from torch._dynamo.variables.builder import VariableBuilder
|
||||
|
||||
id_dispatch_map = VariableBuilder._id_dispatch()
|
||||
if id(original_fn) in id_dispatch_map:
|
||||
raise ValueError(
|
||||
f"Duplicate dispatch rule for {original_fn}: "
|
||||
"already registered in VariableBuilder's id dispatch map"
|
||||
)
|
||||
|
||||
rule_map = get_torch_obj_rule_map()
|
||||
if original_fn in rule_map:
|
||||
raise ValueError(
|
||||
f"Duplicate object {original_fn} with different rules: "
|
||||
f"{PolyfilledFunctionVariable}, {rule_map[original_fn]}"
|
||||
)
|
||||
|
||||
polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers()
|
||||
if original_fn in polyfill_handlers:
|
||||
raise ValueError(
|
||||
f"Duplicate polyfill handlers for {original_fn}: "
|
||||
f"already handled by {polyfill_handlers[original_fn]}"
|
||||
)
|
||||
|
||||
# Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a
|
||||
# C++ function.
|
||||
@functools.wraps(traceable_fn)
|
||||
def wrapped(*args, **kwargs):
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
def dispatch_fn(self, value):
|
||||
return PolyfilledFunctionVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
**self.install_guards(GuardBuilder.CLOSURE_MATCH),
|
||||
)
|
||||
|
||||
id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn
|
||||
rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable
|
||||
polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = traceable_fn
|
||||
|
||||
wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined]
|
||||
wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined]
|
||||
|
||||
return wrapped # type: ignore[return-value]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Helper function to flatten a tensor subclass and apply a function to
|
||||
# all inner tensors that match the outer dim. Used to reduce duplication
|
||||
# across the various marking APIs.
|
||||
|
@ -47,6 +47,7 @@ from .variables import (
|
||||
FunctionalCallVariable,
|
||||
FunctorchHigherOrderVariable,
|
||||
NestedUserFunctionVariable,
|
||||
PolyfilledFunctionVariable,
|
||||
SkipFunctionVariable,
|
||||
TorchInGraphFunctionVariable,
|
||||
UserFunctionVariable,
|
||||
@ -3459,7 +3460,7 @@ def check_verbose(obj, is_inlined_call=False):
|
||||
rule = torch._dynamo.trace_rules.lookup_inner(
|
||||
fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons
|
||||
)
|
||||
if issubclass(rule, UserFunctionVariable):
|
||||
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
|
||||
return SkipResult(
|
||||
False,
|
||||
f"inlined according trace_rules.lookup {reasons.pop()}",
|
||||
|
@ -30,6 +30,7 @@ from .distributed import BackwardHookVariable, DistributedVariable, PlacementVar
|
||||
from .functions import (
|
||||
FunctoolsPartialVariable,
|
||||
NestedUserFunctionVariable,
|
||||
PolyfilledFunctionVariable,
|
||||
SkipFunctionVariable,
|
||||
UserFunctionVariable,
|
||||
UserMethodVariable,
|
||||
@ -144,6 +145,7 @@ __all__ = [
|
||||
"NumpyVariable",
|
||||
"OptimizerVariable",
|
||||
"PlacementVariable",
|
||||
"PolyfilledFunctionVariable",
|
||||
"PythonModuleVariable",
|
||||
"RangeVariable",
|
||||
"RegexPatternVariable",
|
||||
|
@ -927,6 +927,63 @@ class FunctoolsPartialVariable(VariableTracker):
|
||||
)
|
||||
|
||||
|
||||
class PolyfilledFunctionVariable(VariableTracker):
|
||||
_nonvar_fields = {
|
||||
"fn",
|
||||
*BaseUserFunctionVariable._nonvar_fields,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
def _get_polyfill_handlers(cls):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def create_with_source(cls, value, source):
|
||||
return cls(
|
||||
value,
|
||||
source=source,
|
||||
)
|
||||
|
||||
def __init__(self, fn: VariableTracker, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.fn = fn
|
||||
|
||||
def get_function(self):
|
||||
return self.as_python_constant()
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from torch._dynamo.variables.builder import SourcelessBuilder
|
||||
|
||||
handler = self._get_polyfill_handlers().get(self.fn)
|
||||
if handler:
|
||||
assert callable(handler)
|
||||
return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs)
|
||||
|
||||
for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"):
|
||||
handler = getattr(self.fn, candidate, None)
|
||||
if handler:
|
||||
assert callable(handler)
|
||||
if self.source:
|
||||
source = AttrSource(self.source, candidate)
|
||||
return UserFunctionVariable.create_with_source(
|
||||
handler,
|
||||
source=source,
|
||||
).call_function(tx, args, kwargs)
|
||||
return SourcelessBuilder.create(
|
||||
tx,
|
||||
handler,
|
||||
).call_function(tx, args, kwargs)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.fn
|
||||
|
||||
|
||||
from torch._higher_order_ops.triton_kernel_wrap import TritonHOPifier
|
||||
|
||||
|
||||
|
@ -1,12 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Callable, List, TypeVar
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
|
||||
__all__ = [
|
||||
"compile",
|
||||
"assume_constant_result",
|
||||
"reset",
|
||||
"allow_in_graph",
|
||||
"substitute_in_graph",
|
||||
"list_backends",
|
||||
"disable",
|
||||
"cudagraph_mark_step_begin",
|
||||
@ -15,12 +18,17 @@ __all__ = [
|
||||
"is_dynamo_compiling",
|
||||
]
|
||||
|
||||
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def compile(*args, **kwargs):
|
||||
"""
|
||||
See :func:`torch.compile` for details on the arguments for this function.
|
||||
"""
|
||||
return torch.compile(*args, **kwargs)
|
||||
|
||||
|
||||
def reset() -> None:
|
||||
"""
|
||||
This function clears all compilation caches and restores the system to its initial state.
|
||||
@ -31,6 +39,7 @@ def reset() -> None:
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
def allow_in_graph(fn):
|
||||
"""
|
||||
Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
|
||||
@ -111,6 +120,54 @@ def allow_in_graph(fn):
|
||||
return torch._dynamo.allow_in_graph(fn)
|
||||
|
||||
|
||||
def substitute_in_graph(original_fn: _F) -> Callable[[_F], _F]:
|
||||
"""
|
||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||
used in place of the original function when inlining the original function in the graph.
|
||||
|
||||
.. note::
|
||||
|
||||
The polyfill handler is only used when inlining the original function. It is not used when
|
||||
the original function is called directly. In the eager mode, the decorated function calls
|
||||
the performant C function rather than the polyfill handler.
|
||||
|
||||
The polyfill handler is a function that will be called in place of the original function when
|
||||
inlining the original function. The polyfill handler should have the same signature and the same
|
||||
behavior as the original function.
|
||||
|
||||
Args:
|
||||
original_fn (callable): The original function, usually a C function, to register a polyfill
|
||||
handler for.
|
||||
|
||||
Returns:
|
||||
A decorator that registers the polyfill handler for the original function.
|
||||
|
||||
Example::
|
||||
|
||||
>>> import operator
|
||||
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
|
||||
2
|
||||
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
|
||||
... # xdoctest: +SKIP("Long tracebacks")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
torch._dynamo.exc.Unsupported: ...
|
||||
|
||||
>>> @torch.compiler.substitute_in_graph(operator.indexOf)
|
||||
... def indexOf(a, b, /):
|
||||
... for i, item in enumerate(a):
|
||||
... if item is b or item == b:
|
||||
... return i
|
||||
... raise ValueError("sequence.index(x): x not in sequence")
|
||||
>>>
|
||||
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
|
||||
2
|
||||
"""
|
||||
import torch._dynamo
|
||||
|
||||
return torch._dynamo.substitute_in_graph(original_fn)
|
||||
|
||||
|
||||
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
|
||||
"""
|
||||
Return valid strings that can be passed to `torch.compile(..., backend="name")`.
|
||||
@ -122,6 +179,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
|
||||
|
||||
return torch._dynamo.list_backends(exclude_tags)
|
||||
|
||||
|
||||
def assume_constant_result(fn):
|
||||
"""
|
||||
This function is used to mark a function `fn` as having a constant result.
|
||||
@ -140,6 +198,7 @@ def assume_constant_result(fn):
|
||||
|
||||
return torch._dynamo.assume_constant_result(fn)
|
||||
|
||||
|
||||
def disable(fn=None, recursive=True):
|
||||
"""
|
||||
This function provides both a decorator and a context manager to disable compilation on a function
|
||||
@ -153,6 +212,7 @@ def disable(fn=None, recursive=True):
|
||||
|
||||
return torch._dynamo.disable(fn, recursive)
|
||||
|
||||
|
||||
def cudagraph_mark_step_begin():
|
||||
"""
|
||||
Indicates that a new iteration of inference or training is about to begin.
|
||||
@ -178,6 +238,7 @@ def cudagraph_mark_step_begin():
|
||||
|
||||
cudagraph_trees.mark_step_begin()
|
||||
|
||||
|
||||
def wrap_numpy(fn):
|
||||
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
|
||||
from ``torch.Tensor``s to ``torch.Tensor``s.
|
||||
@ -206,10 +267,13 @@ def wrap_numpy(fn):
|
||||
tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0')
|
||||
"""
|
||||
from torch._dynamo.external_utils import wrap_numpy as wrap
|
||||
|
||||
return wrap(fn)
|
||||
|
||||
|
||||
_is_compiling_flag: bool = False
|
||||
|
||||
|
||||
def is_compiling() -> bool:
|
||||
"""
|
||||
Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().
|
||||
@ -231,6 +295,7 @@ def is_compiling() -> bool:
|
||||
else:
|
||||
return _is_compiling_flag
|
||||
|
||||
|
||||
def is_dynamo_compiling() -> bool:
|
||||
"""
|
||||
Indicates whether a graph is traced via TorchDynamo.
|
||||
|
Reference in New Issue
Block a user