diff --git a/.gitignore b/.gitignore index 5a4daddff941..9ffab8fffac3 100644 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,7 @@ docs/cpp/source/html/ docs/cpp/source/latex/ docs/source/compile/generated/ docs/source/generated/ +docs/source/compile/generated/ log usage_log.txt test-reports/ diff --git a/docs/source/_dynamo.rst b/docs/source/_dynamo.rst deleted file mode 100644 index 5e16dcf52dde..000000000000 --- a/docs/source/_dynamo.rst +++ /dev/null @@ -1,13 +0,0 @@ -.. _torch_dynamo: - -torch._dynamo --------------------------- - -.. warning :: - This module is an early prototype and is subject to change. - -.. currentmodule:: torch._dynamo - -.. automodule:: torch._dynamo - :members: - :member-order: bysource diff --git a/docs/source/compiler.rst b/docs/source/compiler.rst new file mode 100644 index 000000000000..31ed37171e19 --- /dev/null +++ b/docs/source/compiler.rst @@ -0,0 +1,19 @@ +torch.compiler +======================== + +.. currentmodule:: torch.compiler + +.. automodule:: torch.compiler + +torch.compiler API reference +------------------------------ +.. autosummary:: + :toctree: generated + :nosignatures: + + compile + reset + allow_in_graph + assume_constant_result + list_backends + disable \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 6956371bca94..3242293a4814 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -99,7 +99,7 @@ Features described in this documentation are classified by release status: torch.distributed.tensor.parallel torch.distributed.checkpoint torch.distributions - torch._dynamo <_dynamo> + torch.compiler torch.fft torch.func futures diff --git a/test/dynamo/test_compile.py b/test/dynamo/test_compile.py index 1241ad533b8b..d9ba4639e46e 100644 --- a/test/dynamo/test_compile.py +++ b/test/dynamo/test_compile.py @@ -1,10 +1,12 @@ # Owner(s): ["module: dynamo"] +import inspect import os import tempfile import unittest import torch +import torch.compiler from torch._dynamo.testing import CompileCounter @@ -69,3 +71,41 @@ class InPlaceCompilationTests(unittest.TestCase): torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt")) loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt")) loaded_model(torch.randn(1, 10)) + + +# The private variants of the below functions are extensively tested +# So as long as the signatures match we're good +class PublicTorchCompilerTests(unittest.TestCase): + def check_signature(self, public_fn_name, private_fn_name, private_namespace): + public_fn = getattr(torch.compiler, public_fn_name) + private_fn = getattr(private_namespace, private_fn_name) + + public_sig = inspect.signature(public_fn) + private_sig = inspect.signature(private_fn) + + self.assertEqual( + public_sig, + private_sig, + f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}", + ) + + def test_is_enabled(self): + self.assertTrue(torch.compiler.is_enabled()) + + def test_dynamo_signatures(self): + function_names = [ + "reset", + "allow_in_graph", + "list_backends", + "assume_constant_result", + "disable", + ] + + for fn_name in function_names: + self.check_signature(fn_name, fn_name, torch._dynamo) + + def test_inductor_signatures(self): + function_names = ["list_options", "list_mode_options"] + + for fn_name in function_names: + self.check_signature(fn_name, fn_name, torch._inductor) diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 3c1e36992e7a..9bb4b350ba2a 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -46,7 +46,7 @@ __all__ = [ ] -def reset(): +def reset() -> None: """Clear all compile caches and restore initial state""" for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen: code = weak_code() diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index 0423966bab96..b8f2a607c894 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -64,7 +64,7 @@ def lookup_backend(compiler_fn): return compiler_fn -def list_backends(exclude_tags=("debug", "experimental")): +def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: """ Return valid strings that can be passed to: diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 149d3f76eaec..88fb27317ca2 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -3,7 +3,7 @@ import functools -def is_compiling(): +def is_compiling() -> bool: return False diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py new file mode 100644 index 000000000000..5335b62ae461 --- /dev/null +++ b/torch/compiler/__init__.py @@ -0,0 +1,175 @@ +import torch +import torch._dynamo +import torch._inductor +from typing import Callable, Union, List, Set, Tuple, Any, Dict + +__all__ = [ + "compile", + "is_enabled", + "reset", + "allow_in_graph", + "list_backends", + "disable", +] + +def compile(*args, **kwargs): + """ + See :func:`torch.compile` for details on the arguments for this function. + """ + + return torch.compile(*args, **kwargs) + +def reset() -> None: + """ + This function clears all compilation caches and restores the system to its initial state. + It is recommended to call this function, especially after using operations like `torch.compile(...)` + to ensure a clean state before subsequent compilation. + + Usage: + 1. Call `reset()` to clear all compilation caches and restore the initial state. + 2. Perform any desired operations, such as `torch.compile(...)`. + 3. If you need to start fresh or perform another `torch.compile(...)`, call `reset()` to ensure a clean state. + + """ + + torch._dynamo.reset() + +def allow_in_graph(fn): + """ + Customize which functions compilation will include in the generated graph. + It bypasses all introspection of the symbolic python code in favor of + directly writing it to the graph. + + Arguments: + - fn: A callable representing the function to be included in the graph. + + Returns: + - If `fn` is a single callable, it adds the function to the list of allowed functions + in compilations internal storage and returns the function itself. + - If `fn` is a list or tuple of callables, it recursively applies the `allow_in_graph()` + function to each item in the list or tuple and returns a new list containing the + modified functions. + + Note: + - The function assumes that `fn` is a callable. If it is not, an assertion error is raised. + + Warning: + - `allow_in_graph` skips TorchDynamo completely on the decorated function + skipping all TorchDynamo safety checks (graph breaks, handling closures, etc). + - Therefore, one has to be very careful with `allow_in_graph` + Today, downstream components like AOT Autograd rely on TorchDynamo to take care of complex Python features + but `allow_in_graph` bypasses TorchDynamo. + - If not careful, this could lead to soundness and really hard-to-debug issues. + """ + + return torch._dynamo.allow_in_graph(fn) + + +def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: + """ + Return valid strings that can be passed to `torch.compile(..., backend="name")`. + + Arguments: + - exclude_tags (optional): A tuple of strings representing tags to exclude. + Backends with any of the specified tags will not be included in the returned list. + By default, the tags "debug" and "experimental" are excluded. + + Returns: + - A sorted list of backend names that can be passed to `torch.compile()`. + + Example: + To retrieve a list of available backends excluding the tags "debug" and "experimental", + we can call the `list_backends()` function as follows: + + :: + valid_backends = list_backends(exclude_tags=("debug", "experimental")) + + """ + + return torch._dynamo.list_backends(exclude_tags) + +def assume_constant_result(fn): + """ + This function is used to mark a function `fn` as having a constant result. + This allows the compiler to optimize away your function + + Arguments: + - fn: The function to be marked as having a constant result. + + Returns: + - The same function `fn` + + Example: + To mark a function `my_function()` as having a constant result, we can call the + `assume_constant_result()` function as follows: + + :: + marked_function = assume_constant_result(my_function) + + Warning: + - `assume_constant_result` can if invalid cause safety and soundness issues, `torch.compile` + will not attempt to validate whether the constant assumption is true or not + + """ + + return torch._dynamo.assume_constant_result(fn) + +def disable(fn=None, recursive=True): + """ + This function provides both a decorator and a context manager to disable compilation. + + Arguments: + - fn (optional): The function to be decorated or used as a context manager. + If provided, compilation will be disabled for the decorated function frame and any + recursively invoked functions within it. If not provided, a context manager will be returned. + - recursive (optional): A boolean value indicating whether the disabling should be recursive. + If set to True (default), compilation is completely skipped on the decorated function frame + and any recursively invoked functions within it. If set to False, compilation skips frames + associated with the function code but still processes recursively invoked frames. + + Returns: + - If `recursive=True` and `fn` is provided, a decorated version of the function `fn` is returned, + with compilation disabled for the decorated function frame and any recursively invoked functions. + - If `recursive=True` and `fn` is not provided, a context manager is returned, allowing compilation + to be disabled within a specific code block. + - If `recursive=False`, the `skip()` function is returned, which allows compilation to skip frames + associated with the function code but still process recursively invoked frames. + + Note: + - When using the decorator or context manager compilation processing is selectively disabled for + the decorated function frame and any recursive function calls, depending on the `recursive` flag. + - The function internally uses the `innermost_fn()` function to ensure that the innermost function + is decorated when `fn` is provided. + - The `skip()` function is used when `recursive=False` to skip frames associated with the function code + but still process recursively invoked frames. + + Example: + 1. Using the decorator with recursive disabling: + + :: + @disable(recursive=True) + def my_function(): + + In this example, `my_function()` is decorated with compi disabled, meaning that compilations + processing will be skipped for the function frame and any recursive function calls within it. + + 2. Using the context manager with recursive disabling: + + :: + with disable(recursive=True): + + In this example, the code block within the `with` statement will have compilation disabled, meaning + that compilations processing will be skipped for the code within the block and any recursive function + calls within that code. + + 3. Using the skip function with non-recursive disabling: + + :: + disable(recursive=False)(my_function) + + In this example, `my_function()` is wrapped with the `skip()` function, which disables compilations + processing for the function frame but still processes recursively invoked functions. + + """ + + return torch._dynamo.disable(fn, recursive)