mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "torch.compiler public namespace (#102182)"
This reverts commit b5840f99c3f2ae01b7831fd32b99758180fc22c3. Reverted https://github.com/pytorch/pytorch/pull/102182 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/102182#issuecomment-1576144551))
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -44,7 +44,6 @@ docs/cpp/source/html/
|
|||||||
docs/cpp/source/latex/
|
docs/cpp/source/latex/
|
||||||
docs/source/compile/generated/
|
docs/source/compile/generated/
|
||||||
docs/source/generated/
|
docs/source/generated/
|
||||||
docs/source/compile/generated/
|
|
||||||
log
|
log
|
||||||
usage_log.txt
|
usage_log.txt
|
||||||
test-reports/
|
test-reports/
|
||||||
|
13
docs/source/_dynamo.rst
Normal file
13
docs/source/_dynamo.rst
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
.. _torch_dynamo:
|
||||||
|
|
||||||
|
torch._dynamo
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
.. warning ::
|
||||||
|
This module is an early prototype and is subject to change.
|
||||||
|
|
||||||
|
.. currentmodule:: torch._dynamo
|
||||||
|
|
||||||
|
.. automodule:: torch._dynamo
|
||||||
|
:members:
|
||||||
|
:member-order: bysource
|
@ -1,19 +0,0 @@
|
|||||||
torch.compiler
|
|
||||||
========================
|
|
||||||
|
|
||||||
.. currentmodule:: torch.compiler
|
|
||||||
|
|
||||||
.. automodule:: torch.compiler
|
|
||||||
|
|
||||||
torch.compiler API reference
|
|
||||||
------------------------------
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
|
|
||||||
compile
|
|
||||||
reset
|
|
||||||
allow_in_graph
|
|
||||||
assume_constant_result
|
|
||||||
list_backends
|
|
||||||
disable
|
|
@ -99,7 +99,7 @@ Features described in this documentation are classified by release status:
|
|||||||
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
||||||
torch.distributed.checkpoint <distributed.checkpoint>
|
torch.distributed.checkpoint <distributed.checkpoint>
|
||||||
torch.distributions <distributions>
|
torch.distributions <distributions>
|
||||||
torch.compiler <compiler>
|
torch._dynamo <_dynamo>
|
||||||
torch.fft <fft>
|
torch.fft <fft>
|
||||||
torch.func <func>
|
torch.func <func>
|
||||||
futures
|
futures
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.compiler
|
|
||||||
from torch._dynamo.testing import CompileCounter
|
from torch._dynamo.testing import CompileCounter
|
||||||
|
|
||||||
|
|
||||||
@ -71,41 +69,3 @@ class InPlaceCompilationTests(unittest.TestCase):
|
|||||||
torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
|
torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
|
||||||
loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
|
loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
|
||||||
loaded_model(torch.randn(1, 10))
|
loaded_model(torch.randn(1, 10))
|
||||||
|
|
||||||
|
|
||||||
# The private variants of the below functions are extensively tested
|
|
||||||
# So as long as the signatures match we're good
|
|
||||||
class PublicTorchCompilerTests(unittest.TestCase):
|
|
||||||
def check_signature(self, public_fn_name, private_fn_name, private_namespace):
|
|
||||||
public_fn = getattr(torch.compiler, public_fn_name)
|
|
||||||
private_fn = getattr(private_namespace, private_fn_name)
|
|
||||||
|
|
||||||
public_sig = inspect.signature(public_fn)
|
|
||||||
private_sig = inspect.signature(private_fn)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
public_sig,
|
|
||||||
private_sig,
|
|
||||||
f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_is_enabled(self):
|
|
||||||
self.assertTrue(torch.compiler.is_enabled())
|
|
||||||
|
|
||||||
def test_dynamo_signatures(self):
|
|
||||||
function_names = [
|
|
||||||
"reset",
|
|
||||||
"allow_in_graph",
|
|
||||||
"list_backends",
|
|
||||||
"assume_constant_result",
|
|
||||||
"disable",
|
|
||||||
]
|
|
||||||
|
|
||||||
for fn_name in function_names:
|
|
||||||
self.check_signature(fn_name, fn_name, torch._dynamo)
|
|
||||||
|
|
||||||
def test_inductor_signatures(self):
|
|
||||||
function_names = ["list_options", "list_mode_options"]
|
|
||||||
|
|
||||||
for fn_name in function_names:
|
|
||||||
self.check_signature(fn_name, fn_name, torch._inductor)
|
|
||||||
|
@ -55,7 +55,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def reset() -> None:
|
def reset():
|
||||||
"""Clear all compile caches and restore initial state"""
|
"""Clear all compile caches and restore initial state"""
|
||||||
for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
|
for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
|
||||||
code = weak_code()
|
code = weak_code()
|
||||||
|
@ -64,7 +64,7 @@ def lookup_backend(compiler_fn):
|
|||||||
return compiler_fn
|
return compiler_fn
|
||||||
|
|
||||||
|
|
||||||
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
|
def list_backends(exclude_tags=("debug", "experimental")):
|
||||||
"""
|
"""
|
||||||
Return valid strings that can be passed to:
|
Return valid strings that can be passed to:
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
|
||||||
def is_compiling() -> bool:
|
def is_compiling():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,175 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch._dynamo
|
|
||||||
import torch._inductor
|
|
||||||
from typing import Callable, Union, List, Set, Tuple, Any, Dict
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"compile",
|
|
||||||
"is_enabled",
|
|
||||||
"reset",
|
|
||||||
"allow_in_graph",
|
|
||||||
"list_backends",
|
|
||||||
"disable",
|
|
||||||
]
|
|
||||||
|
|
||||||
def compile(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
See :func:`torch.compile` for details on the arguments for this function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return torch.compile(*args, **kwargs)
|
|
||||||
|
|
||||||
def reset() -> None:
|
|
||||||
"""
|
|
||||||
This function clears all compilation caches and restores the system to its initial state.
|
|
||||||
It is recommended to call this function, especially after using operations like `torch.compile(...)`
|
|
||||||
to ensure a clean state before subsequent compilation.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
1. Call `reset()` to clear all compilation caches and restore the initial state.
|
|
||||||
2. Perform any desired operations, such as `torch.compile(...)`.
|
|
||||||
3. If you need to start fresh or perform another `torch.compile(...)`, call `reset()` to ensure a clean state.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
torch._dynamo.reset()
|
|
||||||
|
|
||||||
def allow_in_graph(fn):
|
|
||||||
"""
|
|
||||||
Customize which functions compilation will include in the generated graph.
|
|
||||||
It bypasses all introspection of the symbolic python code in favor of
|
|
||||||
directly writing it to the graph.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
- fn: A callable representing the function to be included in the graph.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- If `fn` is a single callable, it adds the function to the list of allowed functions
|
|
||||||
in compilations internal storage and returns the function itself.
|
|
||||||
- If `fn` is a list or tuple of callables, it recursively applies the `allow_in_graph()`
|
|
||||||
function to each item in the list or tuple and returns a new list containing the
|
|
||||||
modified functions.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- The function assumes that `fn` is a callable. If it is not, an assertion error is raised.
|
|
||||||
|
|
||||||
Warning:
|
|
||||||
- `allow_in_graph` skips TorchDynamo completely on the decorated function
|
|
||||||
skipping all TorchDynamo safety checks (graph breaks, handling closures, etc).
|
|
||||||
- Therefore, one has to be very careful with `allow_in_graph`
|
|
||||||
Today, downstream components like AOT Autograd rely on TorchDynamo to take care of complex Python features
|
|
||||||
but `allow_in_graph` bypasses TorchDynamo.
|
|
||||||
- If not careful, this could lead to soundness and really hard-to-debug issues.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return torch._dynamo.allow_in_graph(fn)
|
|
||||||
|
|
||||||
|
|
||||||
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
|
|
||||||
"""
|
|
||||||
Return valid strings that can be passed to `torch.compile(..., backend="name")`.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
- exclude_tags (optional): A tuple of strings representing tags to exclude.
|
|
||||||
Backends with any of the specified tags will not be included in the returned list.
|
|
||||||
By default, the tags "debug" and "experimental" are excluded.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- A sorted list of backend names that can be passed to `torch.compile()`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
To retrieve a list of available backends excluding the tags "debug" and "experimental",
|
|
||||||
we can call the `list_backends()` function as follows:
|
|
||||||
|
|
||||||
::
|
|
||||||
valid_backends = list_backends(exclude_tags=("debug", "experimental"))
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
return torch._dynamo.list_backends(exclude_tags)
|
|
||||||
|
|
||||||
def assume_constant_result(fn):
|
|
||||||
"""
|
|
||||||
This function is used to mark a function `fn` as having a constant result.
|
|
||||||
This allows the compiler to optimize away your function
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
- fn: The function to be marked as having a constant result.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- The same function `fn`
|
|
||||||
|
|
||||||
Example:
|
|
||||||
To mark a function `my_function()` as having a constant result, we can call the
|
|
||||||
`assume_constant_result()` function as follows:
|
|
||||||
|
|
||||||
::
|
|
||||||
marked_function = assume_constant_result(my_function)
|
|
||||||
|
|
||||||
Warning:
|
|
||||||
- `assume_constant_result` can if invalid cause safety and soundness issues, `torch.compile`
|
|
||||||
will not attempt to validate whether the constant assumption is true or not
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
return torch._dynamo.assume_constant_result(fn)
|
|
||||||
|
|
||||||
def disable(fn=None, recursive=True):
|
|
||||||
"""
|
|
||||||
This function provides both a decorator and a context manager to disable compilation.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
- fn (optional): The function to be decorated or used as a context manager.
|
|
||||||
If provided, compilation will be disabled for the decorated function frame and any
|
|
||||||
recursively invoked functions within it. If not provided, a context manager will be returned.
|
|
||||||
- recursive (optional): A boolean value indicating whether the disabling should be recursive.
|
|
||||||
If set to True (default), compilation is completely skipped on the decorated function frame
|
|
||||||
and any recursively invoked functions within it. If set to False, compilation skips frames
|
|
||||||
associated with the function code but still processes recursively invoked frames.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- If `recursive=True` and `fn` is provided, a decorated version of the function `fn` is returned,
|
|
||||||
with compilation disabled for the decorated function frame and any recursively invoked functions.
|
|
||||||
- If `recursive=True` and `fn` is not provided, a context manager is returned, allowing compilation
|
|
||||||
to be disabled within a specific code block.
|
|
||||||
- If `recursive=False`, the `skip()` function is returned, which allows compilation to skip frames
|
|
||||||
associated with the function code but still process recursively invoked frames.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- When using the decorator or context manager compilation processing is selectively disabled for
|
|
||||||
the decorated function frame and any recursive function calls, depending on the `recursive` flag.
|
|
||||||
- The function internally uses the `innermost_fn()` function to ensure that the innermost function
|
|
||||||
is decorated when `fn` is provided.
|
|
||||||
- The `skip()` function is used when `recursive=False` to skip frames associated with the function code
|
|
||||||
but still process recursively invoked frames.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
1. Using the decorator with recursive disabling:
|
|
||||||
|
|
||||||
::
|
|
||||||
@disable(recursive=True)
|
|
||||||
def my_function():
|
|
||||||
|
|
||||||
In this example, `my_function()` is decorated with compi disabled, meaning that compilations
|
|
||||||
processing will be skipped for the function frame and any recursive function calls within it.
|
|
||||||
|
|
||||||
2. Using the context manager with recursive disabling:
|
|
||||||
|
|
||||||
::
|
|
||||||
with disable(recursive=True):
|
|
||||||
|
|
||||||
In this example, the code block within the `with` statement will have compilation disabled, meaning
|
|
||||||
that compilations processing will be skipped for the code within the block and any recursive function
|
|
||||||
calls within that code.
|
|
||||||
|
|
||||||
3. Using the skip function with non-recursive disabling:
|
|
||||||
|
|
||||||
::
|
|
||||||
disable(recursive=False)(my_function)
|
|
||||||
|
|
||||||
In this example, `my_function()` is wrapped with the `skip()` function, which disables compilations
|
|
||||||
processing for the function frame but still processes recursively invoked functions.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
return torch._dynamo.disable(fn, recursive)
|
|
Reference in New Issue
Block a user