diff --git a/test/export/test_export.py b/test/export/test_export.py index bdcc63034abd..c35cd8fee385 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -26,6 +26,7 @@ import torch.utils._pytree as pytree from functorch.experimental.control_flow import cond, map from torch import Tensor from torch._decomp import decomposition_table, get_decompositions +from torch._dynamo._trace_wrapped_higher_order_op import mod_index from torch._dynamo.test_case import TestCase from torch._dynamo.testing import normalize_gm from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse @@ -13615,6 +13616,52 @@ def forward(self, x): ): _ = export(Foo(), (torch.randn(4, 4),), strict=False) + def test_vmap_custom_autograd_function(self): + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + + class IndexingModule(torch.nn.Module): + def __init__(self, base_size: int = 10): + super().__init__() + self.register_buffer("base", torch.arange(base_size)) + + def forward(self, indices: torch.Tensor) -> torch.Tensor: + with TransformGetItemToIndex(): + # Each element of `indices` is a scalar tensor, so our override kicks in + return torch.vmap(lambda i: self.base[i])(indices) + + m = IndexingModule(10) + idxs = torch.tensor([0, 3, 7, 9]) + ep = torch.export.export(m, (idxs,), strict=False) + self.assertExpectedInline( + ep.graph, + """\ +graph(): + %b_base : [num_users=1] = placeholder[target=b_base] + %indices : [num_users=1] = placeholder[target=indices] + %lazy_load_decompositions : [num_users=0] = call_function[target=torch._functorch.predispatch.lazy_load_decompositions](args = (), kwargs = {}) + %_vmap_increment_nesting : [num_users=0] = call_function[target=torch._functorch.predispatch._vmap_increment_nesting](args = (4, error), kwargs = {}) + %_add_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%indices, 0, 1), kwargs = {}) + %torch__dynamo__trace_wrapped_higher_order_op_mod_index0 : [num_users=1] = get_attr[target=torch__dynamo__trace_wrapped_higher_order_op_ModIndex0] + %function_const_func_spec0 : [num_users=1] = get_attr[target=function_const_func_spec0] + %flat_apply : [num_users=1] = call_function[target=torch.ops.higher_order.flat_apply](args = (%function_const_func_spec0, %torch__dynamo__trace_wrapped_higher_order_op_mod_index0, torch._dynamo._trace_wrapped_higher_order_op.ModIndex, %b_base, %_add_batch_dim), kwargs = {}) + %_remove_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%flat_apply, 1, 4, 0), kwargs = {}) + %_vmap_decrement_nesting : [num_users=0] = call_function[target=torch._functorch.predispatch._vmap_decrement_nesting](args = (), kwargs = {}) + return (_remove_batch_dim,)""", + ) + + self.assertEqual(m(idxs), ep.module()(idxs)) + ep = ep.run_decompositions({}) + self.assertExpectedInline( + ep.graph, + """\ +graph(): + %b_base : [num_users=1] = placeholder[target=b_base] + %indices : [num_users=1] = placeholder[target=indices] + %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%b_base, [%indices]), kwargs = {}) + return (index,)""", + ) + self.assertEqual(m(idxs), ep.module()(idxs)) + def test_unbacked_deferred_runtime_retrace(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -14412,10 +14459,7 @@ graph(): def forward(self, x): return x.cos() - with self.assertRaisesRegex( - RuntimeError, "TestExport.test_capture_subclass_wrong..Foo" - ): - export(Foo(), (torch.randn(4, 4),)) + export(Foo(), (torch.randn(4, 4),)) def test_capture_subclass_constructor_torch_ir(self): class Foo(torch.nn.Module): diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 17b664fc5e0e..9b000ee926a1 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -116,6 +116,11 @@ class ModIndex(torch.autograd.Function): None, ) + @classmethod + @torch._export.wrappers.allow_in_pre_dispatch_graph + def apply(cls, *args, **kwargs): # type: ignore[no-untyped-def] + return super().apply(*args, **kwargs) + mod_index = ModIndex.apply diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 58c0f1771a1e..28593291b22c 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -216,6 +216,7 @@ class Verifier(metaclass=_VerifierMeta): torch.sym_not, torch.sym_sqrt, torch.sym_sum, + torch.export.custom_ops._call_custom_autograd_function_in_pre_dispatch, # TODO (tmanlaibaatar) # Predispatch export is able to contain autograd ops. # These will be modeled as HOO later diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py index b851847bada8..e02316940393 100644 --- a/torch/_export/wrappers.py +++ b/torch/_export/wrappers.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-defs +import inspect from contextlib import contextmanager +from functools import wraps import torch import torch._custom_ops @@ -15,7 +17,6 @@ from torch._higher_order_ops.utils import autograd_not_implemented from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( - get_proxy_slot, PreDispatchTorchFunctionMode, ProxyTorchDispatchMode, track_tensor_tree, @@ -129,7 +130,7 @@ def _mark_strict_experimental(cls): return cls -def _register_subclass_spec_proxy_in_tracer(tracer, name, spec): +def _register_func_spec_proxy_in_tracer(tracer, name, spec): """ This is a wrapper utility method on top of tracer to cache the already registered subclass spec attribute. This is useful because @@ -146,6 +147,41 @@ def _register_subclass_spec_proxy_in_tracer(tracer, name, spec): return tracer.create_proxy("get_attr", qualname, (), {}) +def _emit_flat_apply_call( + *, + tracer, + spec_name: str, + const_target_for_apply, + graphable_args, + track_value, + call_spec_cache_key: str, +): + # Flatten to graphable form and record the spec on the FX root + flat_args, in_spec = to_graphable(graphable_args) + qualname = tracer.get_fresh_qualname(spec_name) # type: ignore[union-attr] + setattr(tracer.root, qualname, in_spec) # type: ignore[union-attr] + spec_proxy = tracer.create_proxy("get_attr", qualname, (), {}) + + # Reuse/cached ConstantFunction spec on the root + _, func_spec = pytree.tree_flatten(_ConstantFunction(const_target_for_apply)) + func_spec_proxy = _register_func_spec_proxy_in_tracer( + tracer, f"{call_spec_cache_key}_const_func_spec", func_spec + ) + + # Map runtime args -> proxies (always via tracer.unwrap_proxy now) + flat_proxy_args = pytree.tree_map(tracer.unwrap_proxy, flat_args) + + # Emit flat_apply and track result structure + out_proxy = tracer.create_proxy( + "call_function", flat_apply, (func_spec_proxy, spec_proxy, *flat_proxy_args), {} + ) + track_tensor_tree(track_value, out_proxy, constant=None, tracer=tracer) + + +def _is_init(fn): + return callable(fn) and fn.__name__ == "__init__" + + def mark_subclass_constructor_exportable_experimental(constructor_subclass): """ Experimental decorator that makes subclass to be traceable in export @@ -167,10 +203,6 @@ def mark_subclass_constructor_exportable_experimental(constructor_subclass): def __init__(self, elem, ...): # ... """ - - def _is_init(fn): - return callable(fn) and fn.__name__ == "__init__" - if not _is_init(constructor_subclass): raise RuntimeError( f"torch._export.wrappers.mark_constructor_exportable_experimental can only be applied on subclass tensor.__init__" @@ -179,14 +211,18 @@ def mark_subclass_constructor_exportable_experimental(constructor_subclass): ) def wrapper(*args, **kwargs): + constructor_subclass(*args, **kwargs) + + if not torch.compiler.is_exporting(): + return + if not is_traceable_wrapper_subclass_type(type(args[0])): assert constructor_subclass.__qualname__.endswith("__init__") obj_name = constructor_subclass.__qualname__[: -len("__init__")] raise RuntimeError( - f"Applying mark_constructor_exportable_experimental on {obj_name} is not valid as it is not a traceable " + f"Can't intercept {obj_name} in export because this object is not a traceable " f"tensor subclass. Please look at DTensor.__init__ implementation as an example of proper usage of this API." ) - constructor_subclass(*args, **kwargs) mode = _maybe_find_pre_dispatch_tf_mode_for_export() if mode is None: @@ -196,46 +232,106 @@ def mark_subclass_constructor_exportable_experimental(constructor_subclass): tracer = mode.tracer subclass = args[0] + graphable = (tuple(args[1:]), kwargs) - flat_args, in_spec = to_graphable((tuple(args[1:]), kwargs)) + spec_name = "_".join(constructor_subclass.__qualname__.lower().split(".")) + call_spec_cache_key = type(subclass).__name__.lower() - constructor_spec_name = "_".join( - constructor_subclass.__qualname__.lower().split(".") + _emit_flat_apply_call( + tracer=tracer, + spec_name=spec_name, + const_target_for_apply=type(subclass), + graphable_args=graphable, + track_value=subclass, # track the constructed subclass instance + call_spec_cache_key=call_spec_cache_key, ) - qualname = tracer.get_fresh_qualname(constructor_spec_name) # type: ignore[union-attr] - setattr(tracer.root, qualname, in_spec) # type: ignore[union-attr] - spec_proxy = tracer.create_proxy("get_attr", qualname, (), {}) - flat_proxy_args = pytree.tree_map_only( - torch.Tensor, lambda x: get_proxy_slot(x, tracer).proxy, flat_args - ) - - _, func_spec = torch.utils._pytree.tree_flatten( - _ConstantFunction(type(subclass)) - ) - - # We actually don't want to create a new spec for each instance - # In fx graph, it will look like dtensor_const_func_spec - # We can't directly shove DTensor.__init__ into fx as it is not - # allowed type. - fxable_constructor_call_spec_name = ( - type(subclass).__name__.lower() + "_const_func_spec" - ) - - # We should try to reuse the constructor call spec as it is guaranteed to be same - # for each subclass type. This is different from proxy-ing the init arguments which - # can't be reused because for example, DTensor can receive different DeviceMesh etc - # as it's arguments - func_spec_proxy = _register_subclass_spec_proxy_in_tracer( - tracer, fxable_constructor_call_spec_name, func_spec - ) - - inner_proxy = tracer.create_proxy( - "call_function", - flat_apply, - (func_spec_proxy, spec_proxy, *flat_proxy_args), - {}, - ) - track_tensor_tree(subclass, inner_proxy, constant=None, tracer=tracer) return return wrapper + + +def allow_in_pre_dispatch_graph(func): + """ + Experimental decorator that adds user function to export pre-dispatch graph. Note that + we only support custom autograd function/subclass constructors today. To use this function: + 1. For subclasses: + 1. refer to instructions in mark_subclass_constructor_exportable_experimental + 2. Define apply method on your custom autograd function and apply this decorator. + + Example: + + class MyCoolCustomAutogradFunc(autograd.Function): + @classmethod + @torch._export.wrappers.allow_in_pre_dispatch_graph + def apply(cls, *args, **kwargs): + return super(MyCoolCustomAutogradFunc, cls).apply(*args, **kwargs) + + """ + if _is_init(func): + return mark_subclass_constructor_exportable_experimental(func) + + if not (_is_init(func) or func.__name__ == "apply"): + raise RuntimeError( + f"torch._export.wrappers.allow_in_pre_dispatch_graph can only be applied on subclass tensor.__init_ " + f"or custom_autograd_function.apply. " + f"But, you are adding it on {func.__name__} which is not supported. " + f"If __init__ doesn't exist on your subclass, please add it. Look at DTensor.__init__ implementation for example. " + f"If you are adding it on custom autograd function, please add it on apply method. " + f"If anything else, file an issue on github and we may consider extending our support. " + ) + + @wraps(func) + def wrapper(*args, **kwargs): + if not torch.compiler.is_exporting(): + return func(*args, **kwargs) + + if not inspect.isclass(args[0]): + return func(*args, **kwargs) + + if not issubclass(args[0], torch.autograd.Function): + return func(*args, **kwargs) + + from torch._ops import _get_dispatch_mode_pre_dispatch + + mode = _get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY) + if mode is None: + return func(*args, **kwargs) + + # Sometimes custom autograd functions can call into HOPs that don't have proxy impl + # at PreDispatch level, so we just dispatch it below to get the concrete result. + include_to_set = torch._C._dispatch_tls_local_include_set().remove( + torch._C.DispatchKey.PreDispatch + ) + exclude_to_set = ( + torch._C._dispatch_tls_local_exclude_set() + | torch._C.DispatchKeySet(torch._C.DispatchKey.PreDispatch) + ) + + with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): + out = func(*args, **kwargs) + + assert mode.pre_dispatch, "Should only do this in predispatch" + tracer = mode.tracer + + function_cls_name = f"{args[0].__module__}.{args[0].__qualname__}" + graphable = ((function_cls_name, *args[1:]), kwargs) + + from torch.export.custom_ops import ( + _call_custom_autograd_function_in_pre_dispatch, + ) + + spec_name = "_".join(function_cls_name.split(".")) + call_spec_cache_key = type( + _call_custom_autograd_function_in_pre_dispatch + ).__name__.lower() + _emit_flat_apply_call( + tracer=tracer, + spec_name=spec_name, + const_target_for_apply=_call_custom_autograd_function_in_pre_dispatch, + graphable_args=graphable, + track_value=out, + call_spec_cache_key=call_spec_cache_key, + ) + return out + + return wrapper diff --git a/torch/export/custom_ops.py b/torch/export/custom_ops.py index 57288fa344c1..9df7988da931 100644 --- a/torch/export/custom_ops.py +++ b/torch/export/custom_ops.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs +import importlib + import torch @@ -24,3 +27,23 @@ def _access_subclass_inner_tensor( f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}" ) return val + + +def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs): + """ + Import a custom autograd function by string name and call it. This is pretty bad + because: + 1) There is no schema + + Ideally we should automatically wrap custom autograd functions with a custom op, but + that is too much work because we need to schematize custom autograd functions. For now, + we just hackily put it in the IR. + """ + # Parse module and class name + module_name, class_name = function_cls_name.rsplit(".", 1) + + # Import the module and get the class + module = importlib.import_module(module_name) + function_cls = getattr(module, class_name) + assert hasattr(function_cls, "apply") + return function_cls.apply(*args, **kwargs)