Files
pytorch/torch/_export/wrappers.py
Tugsbayasgalan Manlaibaatar 463fbc8ca0 Support vmap + custom autograd function/improve DTensor constructor inefficiency (#162240)
This makes gemma3 exportable on transformers=4.55.4

In HF, there is a torch funciton mode called TransformGetItemToIndex which internally calls custom autograd function. When this custom autograd function is called under vmap, It triggers CustomFunctionHigherOrderOP which error-ed because there was no pre-dispatch proxy mode implementation.

Since there are number of requests lately to add various operators in pre-dispatch IR, I introduce a decorator in export that works similar to `allow_in_graph`. Basically:
1) We intercept custom_autograd_function.apply at pre-dispatch mode when this decorator is applied
2) We apply `flat_apply` HOP to hide the pytree spec for this autograd function. Note that this adds restriction that this custom autograd function needs to take in fx-able types.
3) subclass constructor decorator is implemented similarly, so we just refactor it to use similar implementation as this new decorator. eventually we should delete the subclass constructor decorator.
4) Move some code in subclass constructor decorator to exit early in non-export environment which should shave off some inefficiency (around 1% according to @swolchok 's benchmark)

Fixes: https://github.com/pytorch/pytorch/issues/161563#issuecomment-3246309758

Differential Revision: [D82141316](https://our.internmc.facebook.com/intern/diff/D82141316)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162240
Approved by: https://github.com/ydwu4
2025-09-11 17:42:41 +00:00

338 lines
12 KiB
Python

# mypy: allow-untyped-defs
import inspect
from contextlib import contextmanager
from functools import wraps
import torch
import torch._custom_ops
from torch._C import DispatchKey
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
from torch._higher_order_ops.flat_apply import (
_ConstantFunction,
flat_apply,
to_graphable,
)
from torch._higher_order_ops.strict_mode import strict_mode
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
PreDispatchTorchFunctionMode,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
class ExportTracepoint(HigherOrderOperator):
def __init__(self):
super().__init__("_export_tracepoint")
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
_export_tracepoint = ExportTracepoint()
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
proxy = mode.tracer.create_proxy(
"call_function", _export_tracepoint, p_args, p_kwargs
)
return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
@_export_tracepoint.py_impl(FakeTensorMode)
def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
with mode:
return args
@_export_tracepoint.py_functionalize_impl
def export_tracepoint_functional(ctx, *args, **kwargs):
unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
_export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
return args
_export_tracepoint.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(_export_tracepoint, deferred_error=True)
)
@_export_tracepoint.py_impl(DispatchKey.CPU)
def export_tracepoint_cpu(*args, **kwargs):
return args
def _wrap_submodule(mod, path, module_call_specs):
assert isinstance(mod, torch.nn.Module)
assert path != ""
submodule = torch.fx.graph_module._get_attr(mod, path)
def update_module_call_signatures(path, in_spec, out_spec):
if path in module_call_specs:
assert module_call_specs[path]["in_spec"] == in_spec
assert module_call_specs[path]["out_spec"] == out_spec
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
def check_flattened(flat_args):
for a in flat_args:
if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
raise AssertionError(
f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
)
def pre_hook(module, args, kwargs):
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
check_flattened(flat_args)
flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
return args, kwargs
def post_hook(module, args, kwargs, res):
_, in_spec = pytree.tree_flatten((args, kwargs))
flat_res, out_spec = pytree.tree_flatten(res)
check_flattened(flat_res)
flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
update_module_call_signatures(path, in_spec, out_spec)
return pytree.tree_unflatten(flat_res, out_spec)
pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True)
post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True)
return pre_handle, post_handle
@contextmanager
def _wrap_submodules(f, preserve_signature, module_call_signatures):
handles = []
try:
for path in preserve_signature:
handles.extend(_wrap_submodule(f, path, module_call_signatures))
yield
finally:
for handle in handles:
handle.remove()
def _mark_strict_experimental(cls):
def call(self, *args):
return strict_mode(self, args)
cls.__call__ = call
return cls
def _register_func_spec_proxy_in_tracer(tracer, name, spec):
"""
This is a wrapper utility method on top of tracer to cache the
already registered subclass spec attribute. This is useful because
Subclass.__init__ will be same for each subclass. By default, fx will
create multiple attributes/proxies for given attribute.
"""
fx_name = name + "0"
if hasattr(tracer.root, fx_name):
assert getattr(tracer.root, fx_name) == spec
return tracer.create_proxy("get_attr", fx_name, (), {})
qualname = tracer.get_fresh_qualname(name)
setattr(tracer.root, qualname, spec)
return tracer.create_proxy("get_attr", qualname, (), {})
def _emit_flat_apply_call(
*,
tracer,
spec_name: str,
const_target_for_apply,
graphable_args,
track_value,
call_spec_cache_key: str,
):
# Flatten to graphable form and record the spec on the FX root
flat_args, in_spec = to_graphable(graphable_args)
qualname = tracer.get_fresh_qualname(spec_name) # type: ignore[union-attr]
setattr(tracer.root, qualname, in_spec) # type: ignore[union-attr]
spec_proxy = tracer.create_proxy("get_attr", qualname, (), {})
# Reuse/cached ConstantFunction spec on the root
_, func_spec = pytree.tree_flatten(_ConstantFunction(const_target_for_apply))
func_spec_proxy = _register_func_spec_proxy_in_tracer(
tracer, f"{call_spec_cache_key}_const_func_spec", func_spec
)
# Map runtime args -> proxies (always via tracer.unwrap_proxy now)
flat_proxy_args = pytree.tree_map(tracer.unwrap_proxy, flat_args)
# Emit flat_apply and track result structure
out_proxy = tracer.create_proxy(
"call_function", flat_apply, (func_spec_proxy, spec_proxy, *flat_proxy_args), {}
)
track_tensor_tree(track_value, out_proxy, constant=None, tracer=tracer)
def _is_init(fn):
return callable(fn) and fn.__name__ == "__init__"
def mark_subclass_constructor_exportable_experimental(constructor_subclass):
"""
Experimental decorator that makes subclass to be traceable in export
with pre-dispatch IR. To make your subclass traceble in export, you need to:
1. Implement __init__ method for your subclass (Look at DTensor implementation)
2. Decorate your __init__ method with _mark_constructor_exportable_experimental
3. Put torch._dynamo_disable decorator to prevent dynamo from peeking into its' impl
Example:
class FooTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem, *, requires_grad=False):
# ...
return torch.Tensor._make_subclass(cls, elem, requires_grad=requires_grad)
@torch._dynamo_disable
@mark_subclass_constructor_exportable_experimental
def __init__(self, elem, ...):
# ...
"""
if not _is_init(constructor_subclass):
raise RuntimeError(
f"torch._export.wrappers.mark_constructor_exportable_experimental can only be applied on subclass tensor.__init__"
f"But, you are adding it on {constructor_subclass.__name__} which is not supported. "
f"If __init__ doesn't exist on your subclass, please add it. Look at DTensor.__init__ implementation for example"
)
def wrapper(*args, **kwargs):
constructor_subclass(*args, **kwargs)
if not torch.compiler.is_exporting():
return
if not is_traceable_wrapper_subclass_type(type(args[0])):
assert constructor_subclass.__qualname__.endswith("__init__")
obj_name = constructor_subclass.__qualname__[: -len("__init__")]
raise RuntimeError(
f"Can't intercept {obj_name} in export because this object is not a traceable "
f"tensor subclass. Please look at DTensor.__init__ implementation as an example of proper usage of this API."
)
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
if mode is None:
return
assert isinstance(mode, PreDispatchTorchFunctionMode)
tracer = mode.tracer
subclass = args[0]
graphable = (tuple(args[1:]), kwargs)
spec_name = "_".join(constructor_subclass.__qualname__.lower().split("."))
call_spec_cache_key = type(subclass).__name__.lower()
_emit_flat_apply_call(
tracer=tracer,
spec_name=spec_name,
const_target_for_apply=type(subclass),
graphable_args=graphable,
track_value=subclass, # track the constructed subclass instance
call_spec_cache_key=call_spec_cache_key,
)
return
return wrapper
def allow_in_pre_dispatch_graph(func):
"""
Experimental decorator that adds user function to export pre-dispatch graph. Note that
we only support custom autograd function/subclass constructors today. To use this function:
1. For subclasses:
1. refer to instructions in mark_subclass_constructor_exportable_experimental
2. Define apply method on your custom autograd function and apply this decorator.
Example:
class MyCoolCustomAutogradFunc(autograd.Function):
@classmethod
@torch._export.wrappers.allow_in_pre_dispatch_graph
def apply(cls, *args, **kwargs):
return super(MyCoolCustomAutogradFunc, cls).apply(*args, **kwargs)
"""
if _is_init(func):
return mark_subclass_constructor_exportable_experimental(func)
if not (_is_init(func) or func.__name__ == "apply"):
raise RuntimeError(
f"torch._export.wrappers.allow_in_pre_dispatch_graph can only be applied on subclass tensor.__init_ "
f"or custom_autograd_function.apply. "
f"But, you are adding it on {func.__name__} which is not supported. "
f"If __init__ doesn't exist on your subclass, please add it. Look at DTensor.__init__ implementation for example. "
f"If you are adding it on custom autograd function, please add it on apply method. "
f"If anything else, file an issue on github and we may consider extending our support. "
)
@wraps(func)
def wrapper(*args, **kwargs):
if not torch.compiler.is_exporting():
return func(*args, **kwargs)
if not inspect.isclass(args[0]):
return func(*args, **kwargs)
if not issubclass(args[0], torch.autograd.Function):
return func(*args, **kwargs)
from torch._ops import _get_dispatch_mode_pre_dispatch
mode = _get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
if mode is None:
return func(*args, **kwargs)
# Sometimes custom autograd functions can call into HOPs that don't have proxy impl
# at PreDispatch level, so we just dispatch it below to get the concrete result.
include_to_set = torch._C._dispatch_tls_local_include_set().remove(
torch._C.DispatchKey.PreDispatch
)
exclude_to_set = (
torch._C._dispatch_tls_local_exclude_set()
| torch._C.DispatchKeySet(torch._C.DispatchKey.PreDispatch)
)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
out = func(*args, **kwargs)
assert mode.pre_dispatch, "Should only do this in predispatch"
tracer = mode.tracer
function_cls_name = f"{args[0].__module__}.{args[0].__qualname__}"
graphable = ((function_cls_name, *args[1:]), kwargs)
from torch.export.custom_ops import (
_call_custom_autograd_function_in_pre_dispatch,
)
spec_name = "_".join(function_cls_name.split("."))
call_spec_cache_key = type(
_call_custom_autograd_function_in_pre_dispatch
).__name__.lower()
_emit_flat_apply_call(
tracer=tracer,
spec_name=spec_name,
const_target_for_apply=_call_custom_autograd_function_in_pre_dispatch,
graphable_args=graphable,
track_value=out,
call_spec_cache_key=call_spec_cache_key,
)
return out
return wrapper