26 Commits

Author SHA1 Message Date
463fbc8ca0 Support vmap + custom autograd function/improve DTensor constructor inefficiency (#162240)
This makes gemma3 exportable on transformers=4.55.4

In HF, there is a torch funciton mode called TransformGetItemToIndex which internally calls custom autograd function. When this custom autograd function is called under vmap, It triggers CustomFunctionHigherOrderOP which error-ed because there was no pre-dispatch proxy mode implementation.

Since there are number of requests lately to add various operators in pre-dispatch IR, I introduce a decorator in export that works similar to `allow_in_graph`. Basically:
1) We intercept custom_autograd_function.apply at pre-dispatch mode when this decorator is applied
2) We apply `flat_apply` HOP to hide the pytree spec for this autograd function. Note that this adds restriction that this custom autograd function needs to take in fx-able types.
3) subclass constructor decorator is implemented similarly, so we just refactor it to use similar implementation as this new decorator. eventually we should delete the subclass constructor decorator.
4) Move some code in subclass constructor decorator to exit early in non-export environment which should shave off some inefficiency (around 1% according to @swolchok 's benchmark)

Fixes: https://github.com/pytorch/pytorch/issues/161563#issuecomment-3246309758

Differential Revision: [D82141316](https://our.internmc.facebook.com/intern/diff/D82141316)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162240
Approved by: https://github.com/ydwu4
2025-09-11 17:42:41 +00:00
dbef606631 Add support for tracing vmap in pre-dispatch export (#154650)
Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

The implementation strategy is:
1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC.
2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API.

Test Plan: CI

Differential Revision: D75623875

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154650
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-08-20 19:31:07 +00:00
162ca185ff [BE][PYFMT] migrate PYFMT for torch/_[a-h]*/ to ruff format (#144551)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144551
Approved by: https://github.com/ezyang
ghstack dependencies: #148186
2025-06-25 06:16:06 +00:00
6b1b95ad2a Support subclass constructor capturing in export (#147014)
Notable TODOs:
1. Need to implement AutogradHOP to get rid of subclasses before serializing
2. Need to implement mechanism to figure out what subclasses will be used in export when they are not expressed in the inputs

Differential Revision: [D69640673](https://our.internmc.facebook.com/intern/diff/D69640673)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147014
Approved by: https://github.com/bdhirsh
2025-03-16 18:19:19 +00:00
6cf360be04 fix lost input mutations with export_tracepoint (#148709)
Preserving module call signatures in the presence of input mutation cause incorrect results. The root cause turned out to be that export tracepoints would unwrap / wrap functional args that would lose mutation info on those args.

Differential Revision: [D70734821](https://our.internmc.facebook.com/intern/diff/D70734821/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148709
Approved by: https://github.com/angelayi
2025-03-07 09:36:18 +00:00
51eacea8c4 graph module retracing without preserving MCS (#143676)
Retracing while preserving module call signatures used to be a problem because graph modules don't have submodules at given paths. This led to a number of failing retracebility tests. By not trying to wrap modules with export tracepoints we can pass most of these tests; the only exception is where you do module swapping on retraced programs, which is still not possible.

Differential Revision: [D67539304](https://our.internmc.facebook.com/intern/diff/D67539304/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143676
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
ghstack dependencies: #143664
2024-12-21 07:57:43 +00:00
b07d0a22f5 [hop] require hops to override __call__. (#134352)
Fixes https://github.com/pytorch/pytorch/issues/133719 by making `__call__` of hops an abstractmethod.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134352
Approved by: https://github.com/zou3519
2024-08-28 19:56:40 +00:00
a23d86c178 [hop] ban creating hop by directly instantiating HigherOrderOperator. (#133645)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133645
Approved by: https://github.com/zou3519
2024-08-23 17:28:02 +00:00
1491a61769 Revert "[hop] ban creating hop by directly instantiating HigherOrderOperator. (#133645)"
This reverts commit 696107efcb83f9359aa669ab343c2cfa2a111372.

Reverted https://github.com/pytorch/pytorch/pull/133645 on behalf of https://github.com/ydwu4 due to breaking ci. probably due to land race ([comment](https://github.com/pytorch/pytorch/pull/133645#issuecomment-2302866106))
2024-08-21 19:33:14 +00:00
696107efcb [hop] ban creating hop by directly instantiating HigherOrderOperator. (#133645)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133645
Approved by: https://github.com/zou3519
ghstack dependencies: #133521
2024-08-21 17:34:21 +00:00
54efd43022 [BE] Simplify code interacting with get_proxy_mode/enable_tracing (#132675)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132675
Approved by: https://github.com/Skylion007, https://github.com/ydwu4, https://github.com/zou3519
ghstack dependencies: #132674
2024-08-08 12:03:00 +00:00
9d476fee53 Revert "[BE] Simplify code interacting with get_proxy_mode/enable_tracing (#132675)"
This reverts commit c2bccfd4311fe905ff78c0977281b8e642bb10d6.

Reverted https://github.com/pytorch/pytorch/pull/132675 on behalf of https://github.com/PaliC due to We need to now revert https://github.com/pytorch/pytorch/pull/132216 in OSS and there is a dependency on this pr ([comment](https://github.com/pytorch/pytorch/pull/132674#issuecomment-2274062785))
2024-08-07 18:25:33 +00:00
c2bccfd431 [BE] Simplify code interacting with get_proxy_mode/enable_tracing (#132675)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132675
Approved by: https://github.com/Skylion007, https://github.com/ydwu4, https://github.com/zou3519
ghstack dependencies: #132674
2024-08-06 18:13:22 +00:00
ea614fb2b1 Flip default value for mypy disallow_untyped_defs [2/11] (#127839)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839
Approved by: https://github.com/oulgen
2024-06-08 18:23:08 +00:00
3f62b05d31 [export] Use forward hooks to capture module signatures. (#120468)
Summary:
When we export in on strict mode and turn on preserve_module_call_signature, the following assertion error will occur today:
```
child_split[: len(parent_split)] == parent_split
```
This is due to the fact that we're monkey patching forward call directly, which kinda breaks the attribute propagation in the tracer. It's actually better to implement this by using forward hook because we don't have to alter the original module structure at all during export.

Test Plan: CI

Differential Revision: D54102714

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120468
Approved by: https://github.com/ydwu4
2024-02-27 17:44:06 +00:00
f316c35a34 [export] Support preserving submodule callling convention in non-strict export (#117796)
Summary: Title

Test Plan: CI

Reviewed By: zhxchen17

Differential Revision: D52889236

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117796
Approved by: https://github.com/angelayi
2024-01-19 17:16:45 +00:00
88197f2202 Rename experimental API (#116895)
Summary: Title

Test Plan: CI

Differential Revision: D52571286

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116895
Approved by: https://github.com/zhxchen17
2024-01-06 08:01:09 +00:00
81f98f1082 Experimental non-strict mode (#114658)
This is proof-of-concept implementation of how people can use a marker `mark_strict` to enable torchdynamo while exporting under non-strict mode. The main idea is that `mark_strict` will turn into an HOO which then utilizes dynamo to do correctness analysis in the same way how torch.cond works today. There are some notable limitations:
1. This API is not meant for public use yet
2. Strict region can't work with arbitrary container inputs
3. We don't preserve `nn_module_stack` and other node metadata for the strict region.
4. strict_mode HOO will show up in the final graph. This is undesirable in the long term, but for short term experiments, it should be good enough. Will fix this in the follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114658
Approved by: https://github.com/ydwu4
2024-01-04 12:24:58 +00:00
cc1de49340 [HigherOrderOp] fallthrough some keys by default. (#110478)
Fixes #109253

Test Plan:
Added a new test that shows default fallthrough keys can be overrided.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110478
Approved by: https://github.com/ezyang
2023-10-05 16:25:42 +00:00
238fb66085 python functionalization: support higher order ops (#108656)
We now have two types of functionalization, C++ Functionalization (through the `Functionalize` dispatch key), and python functionalization (through the `FunctionalTensorMode` torch_dispatch mode).

This means that all higher order ops need custom functionalization rules for the python variant too. I added them here, as well as a helper function `dispatch_functionalize()` - equivalent to `torch.func.functionalize()`, except that it uses `FunctionalTensorMode`.

In theory we could have secretly switched `torch.func.functionalize` to use `FunctionalTensorMode`. This would be BC-breaking, though, since `FunctionalTensorMode` isn't composable with the other functorch transforms (the functorch layer-mode stack doesn't know how to re-order torch_dispatch modes arbitrarily).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108656
Approved by: https://github.com/zou3519
ghstack dependencies: #109024, #109248
2023-09-20 04:37:31 +00:00
518308a740 Trace through pytree API with dynamo. (#108533)
Fix: #107315

This PR enables dynamo to trace through the `pytree` API by inlining its functions. In
order to do so, a few details of `pytree` had to be changed.

In summary, this PR:

- Introduces `TreeSpecVariable` for representing `TreeSpec` instances
- Specializes `<type>.__bases__` call, returning a `TupleVariable`
- Enables the call to `id` builtin function for every variable that implements
  `as_python_constant` method
- Specializes `ConstantVariable.call_method` for its (un)flatten functions
- Implements `UserDefinedObjectVariable.as_python_constant`
- Modifies `pytree` by:
    - Make `SUPPORTED_NODES` a map of ids (instead of types) to `NodeDef`
    - Removed `functools.wraps` function, since it can't be inlined

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108533
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
ghstack dependencies: #109201
2023-09-20 00:04:56 +00:00
8a567bb59d [HigherOrderOp] Should automatically pop modes (#109157)
Fixes #108282

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109157
Approved by: https://github.com/zou3519
2023-09-18 20:54:09 +00:00
07f2efa285 Revert "[HigherOrderOp] Should automatically pop modes (#109157)"
This reverts commit f03b8abd4706e53b3fb6aefbd4304884e537616d.

Reverted https://github.com/pytorch/pytorch/pull/109157 on behalf of https://github.com/clee2000 due to broke internal builds D49346922 ([comment](https://github.com/pytorch/pytorch/pull/109157#issuecomment-1722571262))
2023-09-17 21:19:52 +00:00
f03b8abd47 [HigherOrderOp] Should automatically pop modes (#109157)
Fixes #108282

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109157
Approved by: https://github.com/zou3519
2023-09-14 20:46:26 +00:00
138fafe72d [export] Fix torch.export() issues for server use cases. (#108275)
Test Plan: In D48788843

Differential Revision: D48811793

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108275
Approved by: https://github.com/tugsbayasgalan
2023-08-31 07:19:18 +00:00
547ccae0db [export] Support preserving calling convention to some modules. (#106798)
Summary: APS use this feature to swap out some submodules after unflattening.

Test Plan: test_export_preserve_signature

Differential Revision: D48154341

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106798
Approved by: https://github.com/tugsbayasgalan
2023-08-11 21:17:45 +00:00