mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Support vmap + custom autograd function/improve DTensor constructor inefficiency (#162240)
This makes gemma3 exportable on transformers=4.55.4 In HF, there is a torch funciton mode called TransformGetItemToIndex which internally calls custom autograd function. When this custom autograd function is called under vmap, It triggers CustomFunctionHigherOrderOP which error-ed because there was no pre-dispatch proxy mode implementation. Since there are number of requests lately to add various operators in pre-dispatch IR, I introduce a decorator in export that works similar to `allow_in_graph`. Basically: 1) We intercept custom_autograd_function.apply at pre-dispatch mode when this decorator is applied 2) We apply `flat_apply` HOP to hide the pytree spec for this autograd function. Note that this adds restriction that this custom autograd function needs to take in fx-able types. 3) subclass constructor decorator is implemented similarly, so we just refactor it to use similar implementation as this new decorator. eventually we should delete the subclass constructor decorator. 4) Move some code in subclass constructor decorator to exit early in non-export environment which should shave off some inefficiency (around 1% according to @swolchok 's benchmark) Fixes: https://github.com/pytorch/pytorch/issues/161563#issuecomment-3246309758 Differential Revision: [D82141316](https://our.internmc.facebook.com/intern/diff/D82141316) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162240 Approved by: https://github.com/ydwu4
This commit is contained in:
committed by
PyTorch MergeBot
parent
2f53395943
commit
463fbc8ca0
@ -26,6 +26,7 @@ import torch.utils._pytree as pytree
|
||||
from functorch.experimental.control_flow import cond, map
|
||||
from torch import Tensor
|
||||
from torch._decomp import decomposition_table, get_decompositions
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import mod_index
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||
@ -13615,6 +13616,52 @@ def forward(self, x):
|
||||
):
|
||||
_ = export(Foo(), (torch.randn(4, 4),), strict=False)
|
||||
|
||||
def test_vmap_custom_autograd_function(self):
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
class IndexingModule(torch.nn.Module):
|
||||
def __init__(self, base_size: int = 10):
|
||||
super().__init__()
|
||||
self.register_buffer("base", torch.arange(base_size))
|
||||
|
||||
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
with TransformGetItemToIndex():
|
||||
# Each element of `indices` is a scalar tensor, so our override kicks in
|
||||
return torch.vmap(lambda i: self.base[i])(indices)
|
||||
|
||||
m = IndexingModule(10)
|
||||
idxs = torch.tensor([0, 3, 7, 9])
|
||||
ep = torch.export.export(m, (idxs,), strict=False)
|
||||
self.assertExpectedInline(
|
||||
ep.graph,
|
||||
"""\
|
||||
graph():
|
||||
%b_base : [num_users=1] = placeholder[target=b_base]
|
||||
%indices : [num_users=1] = placeholder[target=indices]
|
||||
%lazy_load_decompositions : [num_users=0] = call_function[target=torch._functorch.predispatch.lazy_load_decompositions](args = (), kwargs = {})
|
||||
%_vmap_increment_nesting : [num_users=0] = call_function[target=torch._functorch.predispatch._vmap_increment_nesting](args = (4, error), kwargs = {})
|
||||
%_add_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%indices, 0, 1), kwargs = {})
|
||||
%torch__dynamo__trace_wrapped_higher_order_op_mod_index0 : [num_users=1] = get_attr[target=torch__dynamo__trace_wrapped_higher_order_op_ModIndex0]
|
||||
%function_const_func_spec0 : [num_users=1] = get_attr[target=function_const_func_spec0]
|
||||
%flat_apply : [num_users=1] = call_function[target=torch.ops.higher_order.flat_apply](args = (%function_const_func_spec0, %torch__dynamo__trace_wrapped_higher_order_op_mod_index0, torch._dynamo._trace_wrapped_higher_order_op.ModIndex, %b_base, %_add_batch_dim), kwargs = {})
|
||||
%_remove_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%flat_apply, 1, 4, 0), kwargs = {})
|
||||
%_vmap_decrement_nesting : [num_users=0] = call_function[target=torch._functorch.predispatch._vmap_decrement_nesting](args = (), kwargs = {})
|
||||
return (_remove_batch_dim,)""",
|
||||
)
|
||||
|
||||
self.assertEqual(m(idxs), ep.module()(idxs))
|
||||
ep = ep.run_decompositions({})
|
||||
self.assertExpectedInline(
|
||||
ep.graph,
|
||||
"""\
|
||||
graph():
|
||||
%b_base : [num_users=1] = placeholder[target=b_base]
|
||||
%indices : [num_users=1] = placeholder[target=indices]
|
||||
%index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%b_base, [%indices]), kwargs = {})
|
||||
return (index,)""",
|
||||
)
|
||||
self.assertEqual(m(idxs), ep.module()(idxs))
|
||||
|
||||
def test_unbacked_deferred_runtime_retrace(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -14412,10 +14459,7 @@ graph():
|
||||
def forward(self, x):
|
||||
return x.cos()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "TestExport.test_capture_subclass_wrong.<locals>.Foo"
|
||||
):
|
||||
export(Foo(), (torch.randn(4, 4),))
|
||||
export(Foo(), (torch.randn(4, 4),))
|
||||
|
||||
def test_capture_subclass_constructor_torch_ir(self):
|
||||
class Foo(torch.nn.Module):
|
||||
|
@ -116,6 +116,11 @@ class ModIndex(torch.autograd.Function):
|
||||
None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@torch._export.wrappers.allow_in_pre_dispatch_graph
|
||||
def apply(cls, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
return super().apply(*args, **kwargs)
|
||||
|
||||
|
||||
mod_index = ModIndex.apply
|
||||
|
||||
|
@ -216,6 +216,7 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
torch.sym_not,
|
||||
torch.sym_sqrt,
|
||||
torch.sym_sum,
|
||||
torch.export.custom_ops._call_custom_autograd_function_in_pre_dispatch,
|
||||
# TODO (tmanlaibaatar)
|
||||
# Predispatch export is able to contain autograd ops.
|
||||
# These will be modeled as HOO later
|
||||
|
@ -1,5 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
import torch._custom_ops
|
||||
@ -15,7 +17,6 @@ from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_slot,
|
||||
PreDispatchTorchFunctionMode,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
@ -129,7 +130,7 @@ def _mark_strict_experimental(cls):
|
||||
return cls
|
||||
|
||||
|
||||
def _register_subclass_spec_proxy_in_tracer(tracer, name, spec):
|
||||
def _register_func_spec_proxy_in_tracer(tracer, name, spec):
|
||||
"""
|
||||
This is a wrapper utility method on top of tracer to cache the
|
||||
already registered subclass spec attribute. This is useful because
|
||||
@ -146,6 +147,41 @@ def _register_subclass_spec_proxy_in_tracer(tracer, name, spec):
|
||||
return tracer.create_proxy("get_attr", qualname, (), {})
|
||||
|
||||
|
||||
def _emit_flat_apply_call(
|
||||
*,
|
||||
tracer,
|
||||
spec_name: str,
|
||||
const_target_for_apply,
|
||||
graphable_args,
|
||||
track_value,
|
||||
call_spec_cache_key: str,
|
||||
):
|
||||
# Flatten to graphable form and record the spec on the FX root
|
||||
flat_args, in_spec = to_graphable(graphable_args)
|
||||
qualname = tracer.get_fresh_qualname(spec_name) # type: ignore[union-attr]
|
||||
setattr(tracer.root, qualname, in_spec) # type: ignore[union-attr]
|
||||
spec_proxy = tracer.create_proxy("get_attr", qualname, (), {})
|
||||
|
||||
# Reuse/cached ConstantFunction spec on the root
|
||||
_, func_spec = pytree.tree_flatten(_ConstantFunction(const_target_for_apply))
|
||||
func_spec_proxy = _register_func_spec_proxy_in_tracer(
|
||||
tracer, f"{call_spec_cache_key}_const_func_spec", func_spec
|
||||
)
|
||||
|
||||
# Map runtime args -> proxies (always via tracer.unwrap_proxy now)
|
||||
flat_proxy_args = pytree.tree_map(tracer.unwrap_proxy, flat_args)
|
||||
|
||||
# Emit flat_apply and track result structure
|
||||
out_proxy = tracer.create_proxy(
|
||||
"call_function", flat_apply, (func_spec_proxy, spec_proxy, *flat_proxy_args), {}
|
||||
)
|
||||
track_tensor_tree(track_value, out_proxy, constant=None, tracer=tracer)
|
||||
|
||||
|
||||
def _is_init(fn):
|
||||
return callable(fn) and fn.__name__ == "__init__"
|
||||
|
||||
|
||||
def mark_subclass_constructor_exportable_experimental(constructor_subclass):
|
||||
"""
|
||||
Experimental decorator that makes subclass to be traceable in export
|
||||
@ -167,10 +203,6 @@ def mark_subclass_constructor_exportable_experimental(constructor_subclass):
|
||||
def __init__(self, elem, ...):
|
||||
# ...
|
||||
"""
|
||||
|
||||
def _is_init(fn):
|
||||
return callable(fn) and fn.__name__ == "__init__"
|
||||
|
||||
if not _is_init(constructor_subclass):
|
||||
raise RuntimeError(
|
||||
f"torch._export.wrappers.mark_constructor_exportable_experimental can only be applied on subclass tensor.__init__"
|
||||
@ -179,14 +211,18 @@ def mark_subclass_constructor_exportable_experimental(constructor_subclass):
|
||||
)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
constructor_subclass(*args, **kwargs)
|
||||
|
||||
if not torch.compiler.is_exporting():
|
||||
return
|
||||
|
||||
if not is_traceable_wrapper_subclass_type(type(args[0])):
|
||||
assert constructor_subclass.__qualname__.endswith("__init__")
|
||||
obj_name = constructor_subclass.__qualname__[: -len("__init__")]
|
||||
raise RuntimeError(
|
||||
f"Applying mark_constructor_exportable_experimental on {obj_name} is not valid as it is not a traceable "
|
||||
f"Can't intercept {obj_name} in export because this object is not a traceable "
|
||||
f"tensor subclass. Please look at DTensor.__init__ implementation as an example of proper usage of this API."
|
||||
)
|
||||
constructor_subclass(*args, **kwargs)
|
||||
|
||||
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
||||
if mode is None:
|
||||
@ -196,46 +232,106 @@ def mark_subclass_constructor_exportable_experimental(constructor_subclass):
|
||||
|
||||
tracer = mode.tracer
|
||||
subclass = args[0]
|
||||
graphable = (tuple(args[1:]), kwargs)
|
||||
|
||||
flat_args, in_spec = to_graphable((tuple(args[1:]), kwargs))
|
||||
spec_name = "_".join(constructor_subclass.__qualname__.lower().split("."))
|
||||
call_spec_cache_key = type(subclass).__name__.lower()
|
||||
|
||||
constructor_spec_name = "_".join(
|
||||
constructor_subclass.__qualname__.lower().split(".")
|
||||
_emit_flat_apply_call(
|
||||
tracer=tracer,
|
||||
spec_name=spec_name,
|
||||
const_target_for_apply=type(subclass),
|
||||
graphable_args=graphable,
|
||||
track_value=subclass, # track the constructed subclass instance
|
||||
call_spec_cache_key=call_spec_cache_key,
|
||||
)
|
||||
qualname = tracer.get_fresh_qualname(constructor_spec_name) # type: ignore[union-attr]
|
||||
setattr(tracer.root, qualname, in_spec) # type: ignore[union-attr]
|
||||
spec_proxy = tracer.create_proxy("get_attr", qualname, (), {})
|
||||
flat_proxy_args = pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: get_proxy_slot(x, tracer).proxy, flat_args
|
||||
)
|
||||
|
||||
_, func_spec = torch.utils._pytree.tree_flatten(
|
||||
_ConstantFunction(type(subclass))
|
||||
)
|
||||
|
||||
# We actually don't want to create a new spec for each instance
|
||||
# In fx graph, it will look like dtensor_const_func_spec
|
||||
# We can't directly shove DTensor.__init__ into fx as it is not
|
||||
# allowed type.
|
||||
fxable_constructor_call_spec_name = (
|
||||
type(subclass).__name__.lower() + "_const_func_spec"
|
||||
)
|
||||
|
||||
# We should try to reuse the constructor call spec as it is guaranteed to be same
|
||||
# for each subclass type. This is different from proxy-ing the init arguments which
|
||||
# can't be reused because for example, DTensor can receive different DeviceMesh etc
|
||||
# as it's arguments
|
||||
func_spec_proxy = _register_subclass_spec_proxy_in_tracer(
|
||||
tracer, fxable_constructor_call_spec_name, func_spec
|
||||
)
|
||||
|
||||
inner_proxy = tracer.create_proxy(
|
||||
"call_function",
|
||||
flat_apply,
|
||||
(func_spec_proxy, spec_proxy, *flat_proxy_args),
|
||||
{},
|
||||
)
|
||||
track_tensor_tree(subclass, inner_proxy, constant=None, tracer=tracer)
|
||||
return
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def allow_in_pre_dispatch_graph(func):
|
||||
"""
|
||||
Experimental decorator that adds user function to export pre-dispatch graph. Note that
|
||||
we only support custom autograd function/subclass constructors today. To use this function:
|
||||
1. For subclasses:
|
||||
1. refer to instructions in mark_subclass_constructor_exportable_experimental
|
||||
2. Define apply method on your custom autograd function and apply this decorator.
|
||||
|
||||
Example:
|
||||
|
||||
class MyCoolCustomAutogradFunc(autograd.Function):
|
||||
@classmethod
|
||||
@torch._export.wrappers.allow_in_pre_dispatch_graph
|
||||
def apply(cls, *args, **kwargs):
|
||||
return super(MyCoolCustomAutogradFunc, cls).apply(*args, **kwargs)
|
||||
|
||||
"""
|
||||
if _is_init(func):
|
||||
return mark_subclass_constructor_exportable_experimental(func)
|
||||
|
||||
if not (_is_init(func) or func.__name__ == "apply"):
|
||||
raise RuntimeError(
|
||||
f"torch._export.wrappers.allow_in_pre_dispatch_graph can only be applied on subclass tensor.__init_ "
|
||||
f"or custom_autograd_function.apply. "
|
||||
f"But, you are adding it on {func.__name__} which is not supported. "
|
||||
f"If __init__ doesn't exist on your subclass, please add it. Look at DTensor.__init__ implementation for example. "
|
||||
f"If you are adding it on custom autograd function, please add it on apply method. "
|
||||
f"If anything else, file an issue on github and we may consider extending our support. "
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not torch.compiler.is_exporting():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if not inspect.isclass(args[0]):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if not issubclass(args[0], torch.autograd.Function):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
|
||||
mode = _get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
|
||||
if mode is None:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Sometimes custom autograd functions can call into HOPs that don't have proxy impl
|
||||
# at PreDispatch level, so we just dispatch it below to get the concrete result.
|
||||
include_to_set = torch._C._dispatch_tls_local_include_set().remove(
|
||||
torch._C.DispatchKey.PreDispatch
|
||||
)
|
||||
exclude_to_set = (
|
||||
torch._C._dispatch_tls_local_exclude_set()
|
||||
| torch._C.DispatchKeySet(torch._C.DispatchKey.PreDispatch)
|
||||
)
|
||||
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
assert mode.pre_dispatch, "Should only do this in predispatch"
|
||||
tracer = mode.tracer
|
||||
|
||||
function_cls_name = f"{args[0].__module__}.{args[0].__qualname__}"
|
||||
graphable = ((function_cls_name, *args[1:]), kwargs)
|
||||
|
||||
from torch.export.custom_ops import (
|
||||
_call_custom_autograd_function_in_pre_dispatch,
|
||||
)
|
||||
|
||||
spec_name = "_".join(function_cls_name.split("."))
|
||||
call_spec_cache_key = type(
|
||||
_call_custom_autograd_function_in_pre_dispatch
|
||||
).__name__.lower()
|
||||
_emit_flat_apply_call(
|
||||
tracer=tracer,
|
||||
spec_name=spec_name,
|
||||
const_target_for_apply=_call_custom_autograd_function_in_pre_dispatch,
|
||||
graphable_args=graphable,
|
||||
track_value=out,
|
||||
call_spec_cache_key=call_spec_cache_key,
|
||||
)
|
||||
return out
|
||||
|
||||
return wrapper
|
||||
|
@ -1,3 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -24,3 +27,23 @@ def _access_subclass_inner_tensor(
|
||||
f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
|
||||
)
|
||||
return val
|
||||
|
||||
|
||||
def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs):
|
||||
"""
|
||||
Import a custom autograd function by string name and call it. This is pretty bad
|
||||
because:
|
||||
1) There is no schema
|
||||
|
||||
Ideally we should automatically wrap custom autograd functions with a custom op, but
|
||||
that is too much work because we need to schematize custom autograd functions. For now,
|
||||
we just hackily put it in the IR.
|
||||
"""
|
||||
# Parse module and class name
|
||||
module_name, class_name = function_cls_name.rsplit(".", 1)
|
||||
|
||||
# Import the module and get the class
|
||||
module = importlib.import_module(module_name)
|
||||
function_cls = getattr(module, class_name)
|
||||
assert hasattr(function_cls, "apply")
|
||||
return function_cls.apply(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user