Revert "torch.compiler public namespace (#102182)"

This reverts commit b5840f99c3f2ae01b7831fd32b99758180fc22c3.

Reverted https://github.com/pytorch/pytorch/pull/102182 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/102182#issuecomment-1576144551))
This commit is contained in:
PyTorch MergeBot
2023-06-05 06:52:37 +00:00
parent 6ac3352a37
commit 258d398eec
9 changed files with 17 additions and 239 deletions

1
.gitignore vendored
View File

@ -44,7 +44,6 @@ docs/cpp/source/html/
docs/cpp/source/latex/
docs/source/compile/generated/
docs/source/generated/
docs/source/compile/generated/
log
usage_log.txt
test-reports/

13
docs/source/_dynamo.rst Normal file
View File

@ -0,0 +1,13 @@
.. _torch_dynamo:
torch._dynamo
--------------------------
.. warning ::
This module is an early prototype and is subject to change.
.. currentmodule:: torch._dynamo
.. automodule:: torch._dynamo
:members:
:member-order: bysource

View File

@ -1,19 +0,0 @@
torch.compiler
========================
.. currentmodule:: torch.compiler
.. automodule:: torch.compiler
torch.compiler API reference
------------------------------
.. autosummary::
:toctree: generated
:nosignatures:
compile
reset
allow_in_graph
assume_constant_result
list_backends
disable

View File

@ -99,7 +99,7 @@ Features described in this documentation are classified by release status:
torch.distributed.tensor.parallel <distributed.tensor.parallel>
torch.distributed.checkpoint <distributed.checkpoint>
torch.distributions <distributions>
torch.compiler <compiler>
torch._dynamo <_dynamo>
torch.fft <fft>
torch.func <func>
futures

View File

@ -1,12 +1,10 @@
# Owner(s): ["module: dynamo"]
import inspect
import os
import tempfile
import unittest
import torch
import torch.compiler
from torch._dynamo.testing import CompileCounter
@ -71,41 +69,3 @@ class InPlaceCompilationTests(unittest.TestCase):
torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
loaded_model(torch.randn(1, 10))
# The private variants of the below functions are extensively tested
# So as long as the signatures match we're good
class PublicTorchCompilerTests(unittest.TestCase):
def check_signature(self, public_fn_name, private_fn_name, private_namespace):
public_fn = getattr(torch.compiler, public_fn_name)
private_fn = getattr(private_namespace, private_fn_name)
public_sig = inspect.signature(public_fn)
private_sig = inspect.signature(private_fn)
self.assertEqual(
public_sig,
private_sig,
f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}",
)
def test_is_enabled(self):
self.assertTrue(torch.compiler.is_enabled())
def test_dynamo_signatures(self):
function_names = [
"reset",
"allow_in_graph",
"list_backends",
"assume_constant_result",
"disable",
]
for fn_name in function_names:
self.check_signature(fn_name, fn_name, torch._dynamo)
def test_inductor_signatures(self):
function_names = ["list_options", "list_mode_options"]
for fn_name in function_names:
self.check_signature(fn_name, fn_name, torch._inductor)

View File

@ -55,7 +55,7 @@ __all__ = [
]
def reset() -> None:
def reset():
"""Clear all compile caches and restore initial state"""
for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
code = weak_code()

View File

@ -64,7 +64,7 @@ def lookup_backend(compiler_fn):
return compiler_fn
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
def list_backends(exclude_tags=("debug", "experimental")):
"""
Return valid strings that can be passed to:

View File

@ -3,7 +3,7 @@
import functools
def is_compiling() -> bool:
def is_compiling():
return False

View File

@ -1,175 +0,0 @@
import torch
import torch._dynamo
import torch._inductor
from typing import Callable, Union, List, Set, Tuple, Any, Dict
__all__ = [
"compile",
"is_enabled",
"reset",
"allow_in_graph",
"list_backends",
"disable",
]
def compile(*args, **kwargs):
"""
See :func:`torch.compile` for details on the arguments for this function.
"""
return torch.compile(*args, **kwargs)
def reset() -> None:
"""
This function clears all compilation caches and restores the system to its initial state.
It is recommended to call this function, especially after using operations like `torch.compile(...)`
to ensure a clean state before subsequent compilation.
Usage:
1. Call `reset()` to clear all compilation caches and restore the initial state.
2. Perform any desired operations, such as `torch.compile(...)`.
3. If you need to start fresh or perform another `torch.compile(...)`, call `reset()` to ensure a clean state.
"""
torch._dynamo.reset()
def allow_in_graph(fn):
"""
Customize which functions compilation will include in the generated graph.
It bypasses all introspection of the symbolic python code in favor of
directly writing it to the graph.
Arguments:
- fn: A callable representing the function to be included in the graph.
Returns:
- If `fn` is a single callable, it adds the function to the list of allowed functions
in compilations internal storage and returns the function itself.
- If `fn` is a list or tuple of callables, it recursively applies the `allow_in_graph()`
function to each item in the list or tuple and returns a new list containing the
modified functions.
Note:
- The function assumes that `fn` is a callable. If it is not, an assertion error is raised.
Warning:
- `allow_in_graph` skips TorchDynamo completely on the decorated function
skipping all TorchDynamo safety checks (graph breaks, handling closures, etc).
- Therefore, one has to be very careful with `allow_in_graph`
Today, downstream components like AOT Autograd rely on TorchDynamo to take care of complex Python features
but `allow_in_graph` bypasses TorchDynamo.
- If not careful, this could lead to soundness and really hard-to-debug issues.
"""
return torch._dynamo.allow_in_graph(fn)
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
"""
Return valid strings that can be passed to `torch.compile(..., backend="name")`.
Arguments:
- exclude_tags (optional): A tuple of strings representing tags to exclude.
Backends with any of the specified tags will not be included in the returned list.
By default, the tags "debug" and "experimental" are excluded.
Returns:
- A sorted list of backend names that can be passed to `torch.compile()`.
Example:
To retrieve a list of available backends excluding the tags "debug" and "experimental",
we can call the `list_backends()` function as follows:
::
valid_backends = list_backends(exclude_tags=("debug", "experimental"))
"""
return torch._dynamo.list_backends(exclude_tags)
def assume_constant_result(fn):
"""
This function is used to mark a function `fn` as having a constant result.
This allows the compiler to optimize away your function
Arguments:
- fn: The function to be marked as having a constant result.
Returns:
- The same function `fn`
Example:
To mark a function `my_function()` as having a constant result, we can call the
`assume_constant_result()` function as follows:
::
marked_function = assume_constant_result(my_function)
Warning:
- `assume_constant_result` can if invalid cause safety and soundness issues, `torch.compile`
will not attempt to validate whether the constant assumption is true or not
"""
return torch._dynamo.assume_constant_result(fn)
def disable(fn=None, recursive=True):
"""
This function provides both a decorator and a context manager to disable compilation.
Arguments:
- fn (optional): The function to be decorated or used as a context manager.
If provided, compilation will be disabled for the decorated function frame and any
recursively invoked functions within it. If not provided, a context manager will be returned.
- recursive (optional): A boolean value indicating whether the disabling should be recursive.
If set to True (default), compilation is completely skipped on the decorated function frame
and any recursively invoked functions within it. If set to False, compilation skips frames
associated with the function code but still processes recursively invoked frames.
Returns:
- If `recursive=True` and `fn` is provided, a decorated version of the function `fn` is returned,
with compilation disabled for the decorated function frame and any recursively invoked functions.
- If `recursive=True` and `fn` is not provided, a context manager is returned, allowing compilation
to be disabled within a specific code block.
- If `recursive=False`, the `skip()` function is returned, which allows compilation to skip frames
associated with the function code but still process recursively invoked frames.
Note:
- When using the decorator or context manager compilation processing is selectively disabled for
the decorated function frame and any recursive function calls, depending on the `recursive` flag.
- The function internally uses the `innermost_fn()` function to ensure that the innermost function
is decorated when `fn` is provided.
- The `skip()` function is used when `recursive=False` to skip frames associated with the function code
but still process recursively invoked frames.
Example:
1. Using the decorator with recursive disabling:
::
@disable(recursive=True)
def my_function():
In this example, `my_function()` is decorated with compi disabled, meaning that compilations
processing will be skipped for the function frame and any recursive function calls within it.
2. Using the context manager with recursive disabling:
::
with disable(recursive=True):
In this example, the code block within the `with` statement will have compilation disabled, meaning
that compilations processing will be skipped for the code within the block and any recursive function
calls within that code.
3. Using the skip function with non-recursive disabling:
::
disable(recursive=False)(my_function)
In this example, `my_function()` is wrapped with the `skip()` function, which disables compilations
processing for the function frame but still processes recursively invoked functions.
"""
return torch._dynamo.disable(fn, recursive)