Files
pytorch/torch/export/custom_ops.py
Tugsbayasgalan Manlaibaatar 463fbc8ca0 Support vmap + custom autograd function/improve DTensor constructor inefficiency (#162240)
This makes gemma3 exportable on transformers=4.55.4

In HF, there is a torch funciton mode called TransformGetItemToIndex which internally calls custom autograd function. When this custom autograd function is called under vmap, It triggers CustomFunctionHigherOrderOP which error-ed because there was no pre-dispatch proxy mode implementation.

Since there are number of requests lately to add various operators in pre-dispatch IR, I introduce a decorator in export that works similar to `allow_in_graph`. Basically:
1) We intercept custom_autograd_function.apply at pre-dispatch mode when this decorator is applied
2) We apply `flat_apply` HOP to hide the pytree spec for this autograd function. Note that this adds restriction that this custom autograd function needs to take in fx-able types.
3) subclass constructor decorator is implemented similarly, so we just refactor it to use similar implementation as this new decorator. eventually we should delete the subclass constructor decorator.
4) Move some code in subclass constructor decorator to exit early in non-export environment which should shave off some inefficiency (around 1% according to @swolchok 's benchmark)

Fixes: https://github.com/pytorch/pytorch/issues/161563#issuecomment-3246309758

Differential Revision: [D82141316](https://our.internmc.facebook.com/intern/diff/D82141316)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162240
Approved by: https://github.com/ydwu4
2025-09-11 17:42:41 +00:00

50 lines
1.7 KiB
Python

# mypy: allow-untyped-defs
import importlib
import torch
lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901
lib.define(
"access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor"
)
@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd")
# When running under torch.inference_mode(), we seem to skip AUtograd key
# so we should desugar this op as soon as we start tracing to post-dispatch.
@torch.library.impl(lib, "access_subclass_inner_tensor", "Python")
def _access_subclass_inner_tensor(
src_subclass_tensor: torch.Tensor, attr: str
) -> torch.Tensor:
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
assert is_traceable_wrapper_subclass(src_subclass_tensor)
val = getattr(src_subclass_tensor, attr, None)
if val is None or not isinstance(val, torch.Tensor):
raise RuntimeError(
f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
)
return val
def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs):
"""
Import a custom autograd function by string name and call it. This is pretty bad
because:
1) There is no schema
Ideally we should automatically wrap custom autograd functions with a custom op, but
that is too much work because we need to schematize custom autograd functions. For now,
we just hackily put it in the IR.
"""
# Parse module and class name
module_name, class_name = function_cls_name.rsplit(".", 1)
# Import the module and get the class
module = importlib.import_module(module_name)
function_cls = getattr(module, class_name)
assert hasattr(function_cls, "apply")
return function_cls.apply(*args, **kwargs)