Files
pytorch/torch
Tugsbayasgalan Manlaibaatar 463fbc8ca0 Support vmap + custom autograd function/improve DTensor constructor inefficiency (#162240)
This makes gemma3 exportable on transformers=4.55.4

In HF, there is a torch funciton mode called TransformGetItemToIndex which internally calls custom autograd function. When this custom autograd function is called under vmap, It triggers CustomFunctionHigherOrderOP which error-ed because there was no pre-dispatch proxy mode implementation.

Since there are number of requests lately to add various operators in pre-dispatch IR, I introduce a decorator in export that works similar to `allow_in_graph`. Basically:
1) We intercept custom_autograd_function.apply at pre-dispatch mode when this decorator is applied
2) We apply `flat_apply` HOP to hide the pytree spec for this autograd function. Note that this adds restriction that this custom autograd function needs to take in fx-able types.
3) subclass constructor decorator is implemented similarly, so we just refactor it to use similar implementation as this new decorator. eventually we should delete the subclass constructor decorator.
4) Move some code in subclass constructor decorator to exit early in non-export environment which should shave off some inefficiency (around 1% according to @swolchok 's benchmark)

Fixes: https://github.com/pytorch/pytorch/issues/161563#issuecomment-3246309758

Differential Revision: [D82141316](https://our.internmc.facebook.com/intern/diff/D82141316)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162240
Approved by: https://github.com/ydwu4
2025-09-11 17:42:41 +00:00
..
2025-09-11 16:35:23 +00:00
2025-04-27 09:56:42 +00:00
2025-04-27 09:56:42 +00:00
2025-06-14 18:18:43 +00:00