From 258d398eecd4c215c238e3318ac7d0d14251cf4f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 5 Jun 2023 06:52:37 +0000 Subject: [PATCH] Revert "torch.compiler public namespace (#102182)" This reverts commit b5840f99c3f2ae01b7831fd32b99758180fc22c3. Reverted https://github.com/pytorch/pytorch/pull/102182 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/102182#issuecomment-1576144551)) --- .gitignore | 1 - docs/source/_dynamo.rst | 13 +++ docs/source/compiler.rst | 19 ---- docs/source/index.rst | 2 +- test/dynamo/test_compile.py | 40 ------- torch/_dynamo/__init__.py | 2 +- torch/_dynamo/backends/registry.py | 2 +- torch/_dynamo/external_utils.py | 2 +- torch/compiler/__init__.py | 175 ----------------------------- 9 files changed, 17 insertions(+), 239 deletions(-) create mode 100644 docs/source/_dynamo.rst delete mode 100644 docs/source/compiler.rst delete mode 100644 torch/compiler/__init__.py diff --git a/.gitignore b/.gitignore index 9ffab8fffac3..5a4daddff941 100644 --- a/.gitignore +++ b/.gitignore @@ -44,7 +44,6 @@ docs/cpp/source/html/ docs/cpp/source/latex/ docs/source/compile/generated/ docs/source/generated/ -docs/source/compile/generated/ log usage_log.txt test-reports/ diff --git a/docs/source/_dynamo.rst b/docs/source/_dynamo.rst new file mode 100644 index 000000000000..5e16dcf52dde --- /dev/null +++ b/docs/source/_dynamo.rst @@ -0,0 +1,13 @@ +.. _torch_dynamo: + +torch._dynamo +-------------------------- + +.. warning :: + This module is an early prototype and is subject to change. + +.. currentmodule:: torch._dynamo + +.. automodule:: torch._dynamo + :members: + :member-order: bysource diff --git a/docs/source/compiler.rst b/docs/source/compiler.rst deleted file mode 100644 index 31ed37171e19..000000000000 --- a/docs/source/compiler.rst +++ /dev/null @@ -1,19 +0,0 @@ -torch.compiler -======================== - -.. currentmodule:: torch.compiler - -.. automodule:: torch.compiler - -torch.compiler API reference ------------------------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - compile - reset - allow_in_graph - assume_constant_result - list_backends - disable \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 3242293a4814..6956371bca94 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -99,7 +99,7 @@ Features described in this documentation are classified by release status: torch.distributed.tensor.parallel torch.distributed.checkpoint torch.distributions - torch.compiler + torch._dynamo <_dynamo> torch.fft torch.func futures diff --git a/test/dynamo/test_compile.py b/test/dynamo/test_compile.py index d9ba4639e46e..1241ad533b8b 100644 --- a/test/dynamo/test_compile.py +++ b/test/dynamo/test_compile.py @@ -1,12 +1,10 @@ # Owner(s): ["module: dynamo"] -import inspect import os import tempfile import unittest import torch -import torch.compiler from torch._dynamo.testing import CompileCounter @@ -71,41 +69,3 @@ class InPlaceCompilationTests(unittest.TestCase): torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt")) loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt")) loaded_model(torch.randn(1, 10)) - - -# The private variants of the below functions are extensively tested -# So as long as the signatures match we're good -class PublicTorchCompilerTests(unittest.TestCase): - def check_signature(self, public_fn_name, private_fn_name, private_namespace): - public_fn = getattr(torch.compiler, public_fn_name) - private_fn = getattr(private_namespace, private_fn_name) - - public_sig = inspect.signature(public_fn) - private_sig = inspect.signature(private_fn) - - self.assertEqual( - public_sig, - private_sig, - f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}", - ) - - def test_is_enabled(self): - self.assertTrue(torch.compiler.is_enabled()) - - def test_dynamo_signatures(self): - function_names = [ - "reset", - "allow_in_graph", - "list_backends", - "assume_constant_result", - "disable", - ] - - for fn_name in function_names: - self.check_signature(fn_name, fn_name, torch._dynamo) - - def test_inductor_signatures(self): - function_names = ["list_options", "list_mode_options"] - - for fn_name in function_names: - self.check_signature(fn_name, fn_name, torch._inductor) diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 1c3977e4d208..ced5d117198c 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -55,7 +55,7 @@ __all__ = [ ] -def reset() -> None: +def reset(): """Clear all compile caches and restore initial state""" for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen: code = weak_code() diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index b8f2a607c894..0423966bab96 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -64,7 +64,7 @@ def lookup_backend(compiler_fn): return compiler_fn -def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: +def list_backends(exclude_tags=("debug", "experimental")): """ Return valid strings that can be passed to: diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 88fb27317ca2..149d3f76eaec 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -3,7 +3,7 @@ import functools -def is_compiling() -> bool: +def is_compiling(): return False diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py deleted file mode 100644 index 5335b62ae461..000000000000 --- a/torch/compiler/__init__.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch._dynamo -import torch._inductor -from typing import Callable, Union, List, Set, Tuple, Any, Dict - -__all__ = [ - "compile", - "is_enabled", - "reset", - "allow_in_graph", - "list_backends", - "disable", -] - -def compile(*args, **kwargs): - """ - See :func:`torch.compile` for details on the arguments for this function. - """ - - return torch.compile(*args, **kwargs) - -def reset() -> None: - """ - This function clears all compilation caches and restores the system to its initial state. - It is recommended to call this function, especially after using operations like `torch.compile(...)` - to ensure a clean state before subsequent compilation. - - Usage: - 1. Call `reset()` to clear all compilation caches and restore the initial state. - 2. Perform any desired operations, such as `torch.compile(...)`. - 3. If you need to start fresh or perform another `torch.compile(...)`, call `reset()` to ensure a clean state. - - """ - - torch._dynamo.reset() - -def allow_in_graph(fn): - """ - Customize which functions compilation will include in the generated graph. - It bypasses all introspection of the symbolic python code in favor of - directly writing it to the graph. - - Arguments: - - fn: A callable representing the function to be included in the graph. - - Returns: - - If `fn` is a single callable, it adds the function to the list of allowed functions - in compilations internal storage and returns the function itself. - - If `fn` is a list or tuple of callables, it recursively applies the `allow_in_graph()` - function to each item in the list or tuple and returns a new list containing the - modified functions. - - Note: - - The function assumes that `fn` is a callable. If it is not, an assertion error is raised. - - Warning: - - `allow_in_graph` skips TorchDynamo completely on the decorated function - skipping all TorchDynamo safety checks (graph breaks, handling closures, etc). - - Therefore, one has to be very careful with `allow_in_graph` - Today, downstream components like AOT Autograd rely on TorchDynamo to take care of complex Python features - but `allow_in_graph` bypasses TorchDynamo. - - If not careful, this could lead to soundness and really hard-to-debug issues. - """ - - return torch._dynamo.allow_in_graph(fn) - - -def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: - """ - Return valid strings that can be passed to `torch.compile(..., backend="name")`. - - Arguments: - - exclude_tags (optional): A tuple of strings representing tags to exclude. - Backends with any of the specified tags will not be included in the returned list. - By default, the tags "debug" and "experimental" are excluded. - - Returns: - - A sorted list of backend names that can be passed to `torch.compile()`. - - Example: - To retrieve a list of available backends excluding the tags "debug" and "experimental", - we can call the `list_backends()` function as follows: - - :: - valid_backends = list_backends(exclude_tags=("debug", "experimental")) - - """ - - return torch._dynamo.list_backends(exclude_tags) - -def assume_constant_result(fn): - """ - This function is used to mark a function `fn` as having a constant result. - This allows the compiler to optimize away your function - - Arguments: - - fn: The function to be marked as having a constant result. - - Returns: - - The same function `fn` - - Example: - To mark a function `my_function()` as having a constant result, we can call the - `assume_constant_result()` function as follows: - - :: - marked_function = assume_constant_result(my_function) - - Warning: - - `assume_constant_result` can if invalid cause safety and soundness issues, `torch.compile` - will not attempt to validate whether the constant assumption is true or not - - """ - - return torch._dynamo.assume_constant_result(fn) - -def disable(fn=None, recursive=True): - """ - This function provides both a decorator and a context manager to disable compilation. - - Arguments: - - fn (optional): The function to be decorated or used as a context manager. - If provided, compilation will be disabled for the decorated function frame and any - recursively invoked functions within it. If not provided, a context manager will be returned. - - recursive (optional): A boolean value indicating whether the disabling should be recursive. - If set to True (default), compilation is completely skipped on the decorated function frame - and any recursively invoked functions within it. If set to False, compilation skips frames - associated with the function code but still processes recursively invoked frames. - - Returns: - - If `recursive=True` and `fn` is provided, a decorated version of the function `fn` is returned, - with compilation disabled for the decorated function frame and any recursively invoked functions. - - If `recursive=True` and `fn` is not provided, a context manager is returned, allowing compilation - to be disabled within a specific code block. - - If `recursive=False`, the `skip()` function is returned, which allows compilation to skip frames - associated with the function code but still process recursively invoked frames. - - Note: - - When using the decorator or context manager compilation processing is selectively disabled for - the decorated function frame and any recursive function calls, depending on the `recursive` flag. - - The function internally uses the `innermost_fn()` function to ensure that the innermost function - is decorated when `fn` is provided. - - The `skip()` function is used when `recursive=False` to skip frames associated with the function code - but still process recursively invoked frames. - - Example: - 1. Using the decorator with recursive disabling: - - :: - @disable(recursive=True) - def my_function(): - - In this example, `my_function()` is decorated with compi disabled, meaning that compilations - processing will be skipped for the function frame and any recursive function calls within it. - - 2. Using the context manager with recursive disabling: - - :: - with disable(recursive=True): - - In this example, the code block within the `with` statement will have compilation disabled, meaning - that compilations processing will be skipped for the code within the block and any recursive function - calls within that code. - - 3. Using the skip function with non-recursive disabling: - - :: - disable(recursive=False)(my_function) - - In this example, `my_function()` is wrapped with the `skip()` function, which disables compilations - processing for the function frame but still processes recursively invoked functions. - - """ - - return torch._dynamo.disable(fn, recursive)