See added test for the case that this PR handles. In particular, the semantics for nested torch.compile with toggled fullgraph settings was strange before - `@torch.compile(fullgraph=True)` overrides the existing fullgraph setting, while `@torch.compile(fullgraph=False)` does not.
Note that this change will add an extra frame to any inlined torch.compile'd function (which I don't expect to happen frequently).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155166
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782
Implements https://github.com/pytorch/pytorch/issues/144908.
Implementation notes:
- `set_fullgraph` is implemented using `patch_config`, which changes config correctly during runtime and tracing.
- Moved setting `config.error_on_graph_break` from convert_frame.py to eval_frame.py. This is because this should only be done at the top-level decorated function. If we kept this in convert_frame.py, we would be changing `config.error_on_graph_break` on every top-level frame, which causes confusing behavior (see added test for example).
- InstructionTranslator reads from `config.error_on_graph_break` every `step()`. This is to determine the value of `config.error_on_graph_break` at the time of the graph break, because tracer cleanup will restore the value of `config.error_on_graph_break` .
- `convert_frame.py` determines whether we should abort tracing (fullgraph=True) or continue (fullgraph=False) by reading the value of the tracer's `error_on_graph_break`. If there is no tracer (failed to initialize), then default to reading `config.error_on_graph_break`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154289
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: #154283
`torch.compile` now always goes through `torch._dynamo._optimize`. fullgraph is now implemented in `torch.compile` by looking at `config.error_on_graph_break`. Export still goes through `torch._dynamo._optimize_assert`, which uses `tx.one_graph` instead of `config.error_on_graph_break`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154283
Approved by: https://github.com/jansel, https://github.com/anijain2305
See added test for the case that this PR handles. In particular, the semantics for nested torch.compile with toggled fullgraph settings was strange before - `@torch.compile(fullgraph=True)` overrides the existing fullgraph setting, while `@torch.compile(fullgraph=False)` does not.
Note that this change will add an extra frame to any inlined torch.compile'd function (which I don't expect to happen frequently).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155166
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782
Implements https://github.com/pytorch/pytorch/issues/144908.
Implementation notes:
- `set_fullgraph` is implemented using `patch_config`, which changes config correctly during runtime and tracing.
- Moved setting `config.error_on_graph_break` from convert_frame.py to eval_frame.py. This is because this should only be done at the top-level decorated function. If we kept this in convert_frame.py, we would be changing `config.error_on_graph_break` on every top-level frame, which causes confusing behavior (see added test for example).
- InstructionTranslator reads from `config.error_on_graph_break` every `step()`. This is to determine the value of `config.error_on_graph_break` at the time of the graph break, because tracer cleanup will restore the value of `config.error_on_graph_break` .
- `convert_frame.py` determines whether we should abort tracing (fullgraph=True) or continue (fullgraph=False) by reading the value of the tracer's `error_on_graph_break`. If there is no tracer (failed to initialize), then default to reading `config.error_on_graph_break`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154289
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: #154283
`torch.compile` now always goes through `torch._dynamo._optimize`. fullgraph is now implemented in `torch.compile` by looking at `config.error_on_graph_break`. Export still goes through `torch._dynamo._optimize_assert`, which uses `tx.one_graph` instead of `config.error_on_graph_break`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154283
Approved by: https://github.com/jansel, https://github.com/anijain2305
When running a distributed job with compiler collectives enabled, if one rank recompiles while others do not, this leads to a deadlock (as not everyone will rendezvous with the compiler collective from the recompile). Although there aren't any convenient ways to cheaply solve this problem, if you are willing to force everyone to sync when evaluating guards, you can just force everyone to recompile if anyone requires a recompile. So the way guard collectives work is:
1. Perform compiled code lookup (evaluating guards)
2. Run a collective, communicating if you found a compiled code or not
3. If anyone requires recompile, force everyone to recompile
One current deficiency in the implementation is we can't conveniently track the time it takes to run this collective.
I need to test if we actually successfully are running the collective on a separate stream, or if we have to wait for user collectives to all finish.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155558
Approved by: https://github.com/Microve
When running a distributed job with compiler collectives enabled, if one rank recompiles while others do not, this leads to a deadlock (as not everyone will rendezvous with the compiler collective from the recompile). Although there aren't any convenient ways to cheaply solve this problem, if you are willing to force everyone to sync when evaluating guards, you can just force everyone to recompile if anyone requires a recompile. So the way guard collectives work is:
1. Perform compiled code lookup (evaluating guards)
2. Run a collective, communicating if you found a compiled code or not
3. If anyone requires recompile, force everyone to recompile
One current deficiency in the implementation is we can't conveniently track the time it takes to run this collective.
I need to test if we actually successfully are running the collective on a separate stream, or if we have to wait for user collectives to all finish.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155558
Approved by: https://github.com/Microve
Adding a per torch.compile() object CompilePackage which tracks dynamo artifact. CompilePackage is considered a low level component and should not be directly exposed to end users. It has the following interface:
1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states.
a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object.
b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state.
2. `package.save()` which dumps the dynamo states into _DynamoCacheEntry.
3. `package.install(backends)` which will handle all the side-effectful global scope updates with compiled functions and resume functions.
This diff focus on making the low level mechanism for precompile. It will be left to upper level interface to use these API to build more user-facing frontend.
Differential Revision: [D75956538](https://our.internmc.facebook.com/intern/diff/D75956538/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155118
Approved by: https://github.com/jamesjwu
Co-authored-by: James Wu <jjwu@meta.com>
#153622 introduced a hook for getting the relevant code objects after frame tracing. The idea is to have vLLM use this instead of monkey-patching `inline_call_()` to determine the source code files to hash. Unfortunately, the hook runs too late; the vLLM backend needs access to the set of source code filenames while it's running.
This PR replaces the newly-added hook with a utility function that a backend can call to get this information. I've made the change in vLLM and can verify that this allows the information to be queried at the right time.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155249
Approved by: https://github.com/zou3519
In a precompiled bytecode, it looks like the following:
```
pre-graph bytecode
...
compiled graph code
...
post-graph bytecode
```
In pre-graph bytecode we have calls into helper functions like torch._dynamo.utils.call_size which will invoke @disable inside the bytecode.
Normally torch.compile() will handle these frames fine, but for precompile we will load bytecode from a clean state of dynamo and we want a way to assert recompile never happen, so the current way to ensure this is by doing set_stance("fail_on_recompile") (open to any other idea to test this, but IMO this is the closest thing we have today).
This approach doesn't work when util functions like call_size() is involved and this PR fixes a bunch of places to make sure "fail_on_recompile" can skip through the functions meant to be skipped during compilation.
Differential Revision: [D76156867](https://our.internmc.facebook.com/intern/diff/D76156867/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155363
Approved by: https://github.com/jamesjwu, https://github.com/jansel
ghstack dependencies: #155329
This PR:
* Expands `Hooks` with a new, optional `frame_traced_fn` field. It should be a callable receiving the list of traced code objects
* Maintains a list of `traced_code` objects in the `TracingContext` of an `OutputGraph`
* Whenever an `inline_call()` is encountered, the corresponding code object is added to this set
* `OutputGraph`'s associated `f_code` is added to the list just before the hook is called
I believe use of this hook should enable the source code hashing that vLLM does in a better way than monkey-patching `inline_call()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153622
Approved by: https://github.com/jansel
Basically adds native _IntWrapper support to dynamo. Here's my process of trying to make symint input support work on dynamo, and how I ended up with this approach [(doc)](https://docs.google.com/document/d/1GvNRQd8BnxlMay_hrEVgEta6VUeUW_hcFeRuB7q1nDY/edit?tab=t.0).
What I did was, before passing inputs to dynamo.export, I first wrap them with a class, `_IntWrapper`. When processing dynamic shapes, I will then add the corresponding dynamic shape specification to the `dynamism` field stored on the `_IntWrapper`. If there is no dynamism specified, then this will get unwrapped back to an integer. When dynamo tracing, when we encounter an `_IntWrapper`, we will convert this to a symint if the dynamism was specified as `Dim.DYNAMIC/AUTO`. Dynamo will then trace a graph that contains symint inputs, which will get passed to AOTAutograd and so on.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152677
Approved by: https://github.com/pianpwk
When we do `torch.compile(module)`, we eventually end up returning a new
`OptimizedModule` instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) for the compiled module.
`OptimizedModule` also inherits `nn.module.__call__`, and thus
has its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.
However, this might create unexpected behavior for global module hooks,
because `torch.compile(module)` causes the hook to fire one extra time
for `OptimizedModule`, when compared to eager.
To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.
Fixes#149502
Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152740
Approved by: https://github.com/anijain2305, https://github.com/zou3519
Toggling on `torch._dynamo.config.compiled_autograd = True` was erroring export (optimize_assert didn't have `rebuild_ctx` defined). Separately add a way to `rebuild_ctx` for `optimize_assert` since it is a public API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153193
Approved by: https://github.com/jansel
When we do `torch.compile(mod)`, we eventually end up returning a new
module instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) from the default `torch.nn.Module.__call__`.
As a result we can't reuse the inherited default `__call__` as is,
because we'd end up running the logic twice.
This patch makes the returned `OptimizedModule` override the default
`__call__`, and directly calls into its compiled `forward` method.
Fixes#149502
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152740
Approved by: https://github.com/anijain2305
When torch.compile is applied to a module via `mod.compile(...)`, it's equivalent to `torch.compile(mod._call_impl)` which takes a different path than `OptimizedModule`. This PR ensures that the `wrap_top_frame` config can also take effect for the `torch.compile(mod._call_impl)` use case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150209
Approved by: https://github.com/anijain2305
This preserves graph breaks in the case that one graph break directly causes another, e.g. graph breaks in generic context managers.
```python
import torch
class CtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
@torch.compile(backend="eager", fullgraph=True)
def fn():
with CtxMgr():
with CtxMgr():
pass
with CtxMgr():
with CtxMgr():
pass
torch._dynamo.graph_break()
fn()
```
Output:
```
torch._dynamo.exc.Unsupported: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/playground.py", line 23, in <module>
fn()
File "/data/users/williamwen/pytorch/torch/_dynamo/eval_frame.py", line 664, in _fn
raise e.with_traceback(None) from e.__cause__
torch._dynamo.exc.Unsupported: Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(CtxMgr), GenericContextWrappingVariable(CtxMgr)]
from user code:
File "/data/users/williamwen/pytorch/playground.py", line 20, in fn
torch._dynamo.graph_break()
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
```
Note in particular that both graph breaks (torch._dynamo.graph_break and graph break in context manager) are present in the logs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149676
Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/anijain2305
This allows for each device type to check current devices for Triton compatibility and ensure their Triton backend is present.
This PR replaces the `has_triton()` global method which was previously used for this task, and moves the initial check for each Inductor backend on to their associated `BaseScheduler` subclass. This means that other backends, such as Halide, can also implement their own availability checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139171
Approved by: https://github.com/jansel