When we do `torch.compile(module)`, we eventually end up returning a new
`OptimizedModule` instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) for the compiled module.
`OptimizedModule` also inherits `nn.module.__call__`, and thus
has its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.
However, this might create unexpected behavior for global module hooks,
because `torch.compile(module)` causes the hook to fire one extra time
for `OptimizedModule`, when compared to eager.
To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.
Fixes#149502
Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152740
Approved by: https://github.com/anijain2305, https://github.com/zou3519
When we do `torch.compile(mod)`, we eventually end up returning a new
module instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) from the default `torch.nn.Module.__call__`.
As a result we can't reuse the inherited default `__call__` as is,
because we'd end up running the logic twice.
This patch makes the returned `OptimizedModule` override the default
`__call__`, and directly calls into its compiled `forward` method.
Fixes#149502
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152740
Approved by: https://github.com/anijain2305
When torch.compile is applied to a module via `mod.compile(...)`, it's equivalent to `torch.compile(mod._call_impl)` which takes a different path than `OptimizedModule`. This PR ensures that the `wrap_top_frame` config can also take effect for the `torch.compile(mod._call_impl)` use case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150209
Approved by: https://github.com/anijain2305
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127125
Approved by: https://github.com/Skylion007
ghstack dependencies: #127122, #127123, #127124
This adds support for backwards hooks that are *both*:
1) Interior to the graph; and
2) Dynamically generated (e.g. lambdas)
We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382
Approved by: https://github.com/xmfan
Fixes https://github.com/pytorch/pytorch/issues/112446
This is a doozy of a PR, there's a few important things to keep in mind here:
1) We MUST lift all tensors accessed via attrs to inputs, getattr is a no go in the graph, it violates the aot_autograd contract. Furthermore, aot_autograd does not know how to apply in-place ops to intermediary tensors that are attributes (aka from getattr) anyway. Views from ops are fine.
2) `.grad` access handling in dynamo peeks at the underlying value, the real tensor, because re-piping FakeTensors already made with this fake_mode through builder anew is a no go.
3) We have no proper mechanism for updating the hint / grapharg.example (the real value in (2) above) midway through trace
Therefore, what we need to do is reconcile the difference in grad stashed on grapharg.example. The easiest way to do this is lazily, upon .grad access, by reading the new value off the right fake tensors. We can then make a tensor using that data as a hint to VariableBuilder to make the right VariableTracker. Note that the example value used here (torch.zeros) in the PR, is a dummy value only used as a tracing hint, it does not leak out into real runtime code.
Alternatively, we could implement accumulate_grad_ in python...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112811
Approved by: https://github.com/jansel
The main thrust of the initial effort here was to capture `register_hook` calls on tensors in compile regions. The first part of this was done in https://github.com/pytorch/pytorch/pull/108903 wherein we added support for register_hook input tensors.
The distinction between input and intermediary is due to implementation differences.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
**The plan:**
For tensors w/ a source: (The PR above)
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call register_hook. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke register_hook on. As long as we guard on the identity of the lifted function, this is sound to do.
For tensors w/o a source: (This PR)
Ostensibly, the most correct and complete solution would be to smuggle hooks into a runtime wrapper in aot_autograd, where all the items the hooks close over are lifted to inputs as necessary and passed alongside the user provided function. This is necessary so that we can properly trace out and capture all the mutations within the user defined hook at backwards time.
This is too complicated - so, we limited the scope of this initial PR to a simple subset of hooks:
- Hooks must have a source (be known to us already, not a lambda or intermediary defined function)
- We must be tracing under compiled autograd
**The flow**:
We use the HOP added in https://github.com/pytorch/pytorch/pull/109690/files, referred to as the HOP below.
1) We intercept register_hook calls and wrap the user defined fn in the HOP
2) We write a `_register_hook_trampoline` to the graph that is a local no-arg function that is invoked as a call_function in the dynamo graph
3) aot_autograd inlines through it during its trace, and sees the HOP
4) the HOP preserves itself in the graph - it does not get traced into
5) During backwards, compiled_autograd installs the HOP under a hook call
6) When compiled_autograd enters compilation over its generated graph, dynamo traces the contents of the hook
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109537
Approved by: https://github.com/ezyang
The strategy in this PR is pretty straightforward.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
The plan:
**For tensors w/ a source:**
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call `register_hook`. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke `register_hook` on. As long as we guard on the identity of the lifted function, this is sound to do.
**For tensors w/o a source:**
Graph break - we will support this in a subsequent PR
**Handles:**
An interesting new component here is the creation of a `STORE_FAST `->`LOAD_FAST` associated with the handle, the return result of `register_hook`. If the user code stored the result of `register_hook` in a handle, we need to honor that. We do so by interceding into `STORE_FAST`, and recording the name of the local variable as directed by user code. We then honor that same name in the reconstructed bytecode. If the user did not store a hook, we merely pop the produced value to preserve the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108903
Approved by: https://github.com/ezyang
ghstack dependencies: #108846, #109092