mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "torch.compiler public namespace (#102182)"
This reverts commit b5840f99c3f2ae01b7831fd32b99758180fc22c3. Reverted https://github.com/pytorch/pytorch/pull/102182 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/102182#issuecomment-1576144551))
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -44,7 +44,6 @@ docs/cpp/source/html/
|
||||
docs/cpp/source/latex/
|
||||
docs/source/compile/generated/
|
||||
docs/source/generated/
|
||||
docs/source/compile/generated/
|
||||
log
|
||||
usage_log.txt
|
||||
test-reports/
|
||||
|
13
docs/source/_dynamo.rst
Normal file
13
docs/source/_dynamo.rst
Normal file
@ -0,0 +1,13 @@
|
||||
.. _torch_dynamo:
|
||||
|
||||
torch._dynamo
|
||||
--------------------------
|
||||
|
||||
.. warning ::
|
||||
This module is an early prototype and is subject to change.
|
||||
|
||||
.. currentmodule:: torch._dynamo
|
||||
|
||||
.. automodule:: torch._dynamo
|
||||
:members:
|
||||
:member-order: bysource
|
@ -1,19 +0,0 @@
|
||||
torch.compiler
|
||||
========================
|
||||
|
||||
.. currentmodule:: torch.compiler
|
||||
|
||||
.. automodule:: torch.compiler
|
||||
|
||||
torch.compiler API reference
|
||||
------------------------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
compile
|
||||
reset
|
||||
allow_in_graph
|
||||
assume_constant_result
|
||||
list_backends
|
||||
disable
|
@ -99,7 +99,7 @@ Features described in this documentation are classified by release status:
|
||||
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
||||
torch.distributed.checkpoint <distributed.checkpoint>
|
||||
torch.distributions <distributions>
|
||||
torch.compiler <compiler>
|
||||
torch._dynamo <_dynamo>
|
||||
torch.fft <fft>
|
||||
torch.func <func>
|
||||
futures
|
||||
|
@ -1,12 +1,10 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.compiler
|
||||
from torch._dynamo.testing import CompileCounter
|
||||
|
||||
|
||||
@ -71,41 +69,3 @@ class InPlaceCompilationTests(unittest.TestCase):
|
||||
torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model(torch.randn(1, 10))
|
||||
|
||||
|
||||
# The private variants of the below functions are extensively tested
|
||||
# So as long as the signatures match we're good
|
||||
class PublicTorchCompilerTests(unittest.TestCase):
|
||||
def check_signature(self, public_fn_name, private_fn_name, private_namespace):
|
||||
public_fn = getattr(torch.compiler, public_fn_name)
|
||||
private_fn = getattr(private_namespace, private_fn_name)
|
||||
|
||||
public_sig = inspect.signature(public_fn)
|
||||
private_sig = inspect.signature(private_fn)
|
||||
|
||||
self.assertEqual(
|
||||
public_sig,
|
||||
private_sig,
|
||||
f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}",
|
||||
)
|
||||
|
||||
def test_is_enabled(self):
|
||||
self.assertTrue(torch.compiler.is_enabled())
|
||||
|
||||
def test_dynamo_signatures(self):
|
||||
function_names = [
|
||||
"reset",
|
||||
"allow_in_graph",
|
||||
"list_backends",
|
||||
"assume_constant_result",
|
||||
"disable",
|
||||
]
|
||||
|
||||
for fn_name in function_names:
|
||||
self.check_signature(fn_name, fn_name, torch._dynamo)
|
||||
|
||||
def test_inductor_signatures(self):
|
||||
function_names = ["list_options", "list_mode_options"]
|
||||
|
||||
for fn_name in function_names:
|
||||
self.check_signature(fn_name, fn_name, torch._inductor)
|
||||
|
@ -55,7 +55,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def reset() -> None:
|
||||
def reset():
|
||||
"""Clear all compile caches and restore initial state"""
|
||||
for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
|
||||
code = weak_code()
|
||||
|
@ -64,7 +64,7 @@ def lookup_backend(compiler_fn):
|
||||
return compiler_fn
|
||||
|
||||
|
||||
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
|
||||
def list_backends(exclude_tags=("debug", "experimental")):
|
||||
"""
|
||||
Return valid strings that can be passed to:
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
import functools
|
||||
|
||||
|
||||
def is_compiling() -> bool:
|
||||
def is_compiling():
|
||||
return False
|
||||
|
||||
|
||||
|
@ -1,175 +0,0 @@
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._inductor
|
||||
from typing import Callable, Union, List, Set, Tuple, Any, Dict
|
||||
|
||||
__all__ = [
|
||||
"compile",
|
||||
"is_enabled",
|
||||
"reset",
|
||||
"allow_in_graph",
|
||||
"list_backends",
|
||||
"disable",
|
||||
]
|
||||
|
||||
def compile(*args, **kwargs):
|
||||
"""
|
||||
See :func:`torch.compile` for details on the arguments for this function.
|
||||
"""
|
||||
|
||||
return torch.compile(*args, **kwargs)
|
||||
|
||||
def reset() -> None:
|
||||
"""
|
||||
This function clears all compilation caches and restores the system to its initial state.
|
||||
It is recommended to call this function, especially after using operations like `torch.compile(...)`
|
||||
to ensure a clean state before subsequent compilation.
|
||||
|
||||
Usage:
|
||||
1. Call `reset()` to clear all compilation caches and restore the initial state.
|
||||
2. Perform any desired operations, such as `torch.compile(...)`.
|
||||
3. If you need to start fresh or perform another `torch.compile(...)`, call `reset()` to ensure a clean state.
|
||||
|
||||
"""
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
def allow_in_graph(fn):
|
||||
"""
|
||||
Customize which functions compilation will include in the generated graph.
|
||||
It bypasses all introspection of the symbolic python code in favor of
|
||||
directly writing it to the graph.
|
||||
|
||||
Arguments:
|
||||
- fn: A callable representing the function to be included in the graph.
|
||||
|
||||
Returns:
|
||||
- If `fn` is a single callable, it adds the function to the list of allowed functions
|
||||
in compilations internal storage and returns the function itself.
|
||||
- If `fn` is a list or tuple of callables, it recursively applies the `allow_in_graph()`
|
||||
function to each item in the list or tuple and returns a new list containing the
|
||||
modified functions.
|
||||
|
||||
Note:
|
||||
- The function assumes that `fn` is a callable. If it is not, an assertion error is raised.
|
||||
|
||||
Warning:
|
||||
- `allow_in_graph` skips TorchDynamo completely on the decorated function
|
||||
skipping all TorchDynamo safety checks (graph breaks, handling closures, etc).
|
||||
- Therefore, one has to be very careful with `allow_in_graph`
|
||||
Today, downstream components like AOT Autograd rely on TorchDynamo to take care of complex Python features
|
||||
but `allow_in_graph` bypasses TorchDynamo.
|
||||
- If not careful, this could lead to soundness and really hard-to-debug issues.
|
||||
"""
|
||||
|
||||
return torch._dynamo.allow_in_graph(fn)
|
||||
|
||||
|
||||
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
|
||||
"""
|
||||
Return valid strings that can be passed to `torch.compile(..., backend="name")`.
|
||||
|
||||
Arguments:
|
||||
- exclude_tags (optional): A tuple of strings representing tags to exclude.
|
||||
Backends with any of the specified tags will not be included in the returned list.
|
||||
By default, the tags "debug" and "experimental" are excluded.
|
||||
|
||||
Returns:
|
||||
- A sorted list of backend names that can be passed to `torch.compile()`.
|
||||
|
||||
Example:
|
||||
To retrieve a list of available backends excluding the tags "debug" and "experimental",
|
||||
we can call the `list_backends()` function as follows:
|
||||
|
||||
::
|
||||
valid_backends = list_backends(exclude_tags=("debug", "experimental"))
|
||||
|
||||
"""
|
||||
|
||||
return torch._dynamo.list_backends(exclude_tags)
|
||||
|
||||
def assume_constant_result(fn):
|
||||
"""
|
||||
This function is used to mark a function `fn` as having a constant result.
|
||||
This allows the compiler to optimize away your function
|
||||
|
||||
Arguments:
|
||||
- fn: The function to be marked as having a constant result.
|
||||
|
||||
Returns:
|
||||
- The same function `fn`
|
||||
|
||||
Example:
|
||||
To mark a function `my_function()` as having a constant result, we can call the
|
||||
`assume_constant_result()` function as follows:
|
||||
|
||||
::
|
||||
marked_function = assume_constant_result(my_function)
|
||||
|
||||
Warning:
|
||||
- `assume_constant_result` can if invalid cause safety and soundness issues, `torch.compile`
|
||||
will not attempt to validate whether the constant assumption is true or not
|
||||
|
||||
"""
|
||||
|
||||
return torch._dynamo.assume_constant_result(fn)
|
||||
|
||||
def disable(fn=None, recursive=True):
|
||||
"""
|
||||
This function provides both a decorator and a context manager to disable compilation.
|
||||
|
||||
Arguments:
|
||||
- fn (optional): The function to be decorated or used as a context manager.
|
||||
If provided, compilation will be disabled for the decorated function frame and any
|
||||
recursively invoked functions within it. If not provided, a context manager will be returned.
|
||||
- recursive (optional): A boolean value indicating whether the disabling should be recursive.
|
||||
If set to True (default), compilation is completely skipped on the decorated function frame
|
||||
and any recursively invoked functions within it. If set to False, compilation skips frames
|
||||
associated with the function code but still processes recursively invoked frames.
|
||||
|
||||
Returns:
|
||||
- If `recursive=True` and `fn` is provided, a decorated version of the function `fn` is returned,
|
||||
with compilation disabled for the decorated function frame and any recursively invoked functions.
|
||||
- If `recursive=True` and `fn` is not provided, a context manager is returned, allowing compilation
|
||||
to be disabled within a specific code block.
|
||||
- If `recursive=False`, the `skip()` function is returned, which allows compilation to skip frames
|
||||
associated with the function code but still process recursively invoked frames.
|
||||
|
||||
Note:
|
||||
- When using the decorator or context manager compilation processing is selectively disabled for
|
||||
the decorated function frame and any recursive function calls, depending on the `recursive` flag.
|
||||
- The function internally uses the `innermost_fn()` function to ensure that the innermost function
|
||||
is decorated when `fn` is provided.
|
||||
- The `skip()` function is used when `recursive=False` to skip frames associated with the function code
|
||||
but still process recursively invoked frames.
|
||||
|
||||
Example:
|
||||
1. Using the decorator with recursive disabling:
|
||||
|
||||
::
|
||||
@disable(recursive=True)
|
||||
def my_function():
|
||||
|
||||
In this example, `my_function()` is decorated with compi disabled, meaning that compilations
|
||||
processing will be skipped for the function frame and any recursive function calls within it.
|
||||
|
||||
2. Using the context manager with recursive disabling:
|
||||
|
||||
::
|
||||
with disable(recursive=True):
|
||||
|
||||
In this example, the code block within the `with` statement will have compilation disabled, meaning
|
||||
that compilations processing will be skipped for the code within the block and any recursive function
|
||||
calls within that code.
|
||||
|
||||
3. Using the skip function with non-recursive disabling:
|
||||
|
||||
::
|
||||
disable(recursive=False)(my_function)
|
||||
|
||||
In this example, `my_function()` is wrapped with the `skip()` function, which disables compilations
|
||||
processing for the function frame but still processes recursively invoked functions.
|
||||
|
||||
"""
|
||||
|
||||
return torch._dynamo.disable(fn, recursive)
|
Reference in New Issue
Block a user