mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This makes gemma3 exportable on transformers=4.55.4 In HF, there is a torch funciton mode called TransformGetItemToIndex which internally calls custom autograd function. When this custom autograd function is called under vmap, It triggers CustomFunctionHigherOrderOP which error-ed because there was no pre-dispatch proxy mode implementation. Since there are number of requests lately to add various operators in pre-dispatch IR, I introduce a decorator in export that works similar to `allow_in_graph`. Basically: 1) We intercept custom_autograd_function.apply at pre-dispatch mode when this decorator is applied 2) We apply `flat_apply` HOP to hide the pytree spec for this autograd function. Note that this adds restriction that this custom autograd function needs to take in fx-able types. 3) subclass constructor decorator is implemented similarly, so we just refactor it to use similar implementation as this new decorator. eventually we should delete the subclass constructor decorator. 4) Move some code in subclass constructor decorator to exit early in non-export environment which should shave off some inefficiency (around 1% according to @swolchok 's benchmark) Fixes: https://github.com/pytorch/pytorch/issues/161563#issuecomment-3246309758 Differential Revision: [D82141316](https://our.internmc.facebook.com/intern/diff/D82141316) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162240 Approved by: https://github.com/ydwu4
50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
# mypy: allow-untyped-defs
|
|
import importlib
|
|
|
|
import torch
|
|
|
|
|
|
lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901
|
|
|
|
lib.define(
|
|
"access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor"
|
|
)
|
|
|
|
|
|
@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd")
|
|
# When running under torch.inference_mode(), we seem to skip AUtograd key
|
|
# so we should desugar this op as soon as we start tracing to post-dispatch.
|
|
@torch.library.impl(lib, "access_subclass_inner_tensor", "Python")
|
|
def _access_subclass_inner_tensor(
|
|
src_subclass_tensor: torch.Tensor, attr: str
|
|
) -> torch.Tensor:
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
|
|
assert is_traceable_wrapper_subclass(src_subclass_tensor)
|
|
val = getattr(src_subclass_tensor, attr, None)
|
|
if val is None or not isinstance(val, torch.Tensor):
|
|
raise RuntimeError(
|
|
f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
|
|
)
|
|
return val
|
|
|
|
|
|
def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs):
|
|
"""
|
|
Import a custom autograd function by string name and call it. This is pretty bad
|
|
because:
|
|
1) There is no schema
|
|
|
|
Ideally we should automatically wrap custom autograd functions with a custom op, but
|
|
that is too much work because we need to schematize custom autograd functions. For now,
|
|
we just hackily put it in the IR.
|
|
"""
|
|
# Parse module and class name
|
|
module_name, class_name = function_cls_name.rsplit(".", 1)
|
|
|
|
# Import the module and get the class
|
|
module = importlib.import_module(module_name)
|
|
function_cls = getattr(module, class_name)
|
|
assert hasattr(function_cls, "apply")
|
|
return function_cls.apply(*args, **kwargs)
|