diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml index 17c9bd4234f3..4ef380704de8 100644 --- a/aten/src/ATen/native/ts_native_functions.yaml +++ b/aten/src/ATen/native/ts_native_functions.yaml @@ -202,6 +202,7 @@ supported: - select_backward - _trilinear - linalg_pinv.atol_rtol_tensor + - svd - logsumexp.out symint: - empty.memory_format diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 96ef6b3522ba..72e72f49a5e4 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -52,9 +52,7 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | // where we would like to support composite implicit kernels but not // explicit kernels therefore we manually add the key to the // math_dispatch_keyset - DispatchKeySet{DispatchKey::NestedTensor} | - // Functionalize should always reuse CompositeImplicit decomps. - DispatchKeySet{DispatchKey::Functionalize}; + DispatchKeySet{DispatchKey::NestedTensor}; constexpr DispatchKeySet nested_dispatch_keyset = DispatchKeySet( diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 080002999964..41b37a687fae 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -7207,7 +7207,6 @@ metadata incorrectly. aot_eager = torch.compile(backend="aot_eager")(fn)(x) self.assertEqual(eager, aot_eager, atol=0, rtol=0) - @unittest.expectedFailure @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_rms_norm(self): # Only CUDA rms norm fails to be decomposed diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py index 7c467dc62413..e4652a465d72 100644 --- a/test/lazy/test_ts_opinfo.py +++ b/test/lazy/test_ts_opinfo.py @@ -85,6 +85,7 @@ def init_lists(): "linalg_inv_ex", "linalg_pinv.atol_rtol_tensor", "logsumexp", + "svd", } # For some ops, we don't support all variants. Here we use formatted_name # to uniquely identify the variant. @@ -220,20 +221,15 @@ class TestLazyOpInfo(TestCase): torch._lazy.wait_device_ops() prefix = "aten" if op.name in FALLBACK_LIST else "lazy" symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else "" - found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes( - torch._lazy.metrics.counter_names() - ) + metrics = remove_suffixes(torch._lazy.metrics.counter_names()) + cands = [f"{prefix}::{op.name}{symint_suffix}"] # check aliases - if not found: - for alias in op.aliases: - alias_found = ( - f"{prefix}::{alias.name}{symint_suffix}" - in remove_suffixes(torch._lazy.metrics.counter_names()) - ) - found = found or alias_found - if found: - break - self.assertTrue(found) + for alias in op.aliases: + cands.append(f"{prefix}::{alias.name}{symint_suffix}") + + self.assertTrue( + any(c in metrics for c in cands), f"none of {cands} not found in {metrics}" + ) @ops( [ diff --git a/test/test_decomp.py b/test/test_decomp.py index 610465db4c48..a534b643997b 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -1255,11 +1255,10 @@ class DecompOneOffTests(TestCase): ) # check RMSNorm was fused with sinh + self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0]) self.assertTrue( - "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] - ) - self.assertTrue( - "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] + "triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul" + in generated_codes[1] ) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index c4396932818d..69ef0b901bed 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -404,6 +404,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.max_unpool3d, aten.mish, aten.mish_, + aten.mish_backward, aten.mse_loss, aten.mse_loss_backward, aten.multi_margin_loss, @@ -419,6 +420,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, + aten._fused_rms_norm, aten._fused_rms_norm_backward, aten.new_empty, aten.new_full, @@ -475,6 +477,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.silu, aten.silu_, aten.silu_backward.grad_input, + aten.silu_backward, aten.sinc, aten.sinc_, aten.slice_backward, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 18c6ac5945e5..bdbdf21b0d4c 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1757,6 +1757,58 @@ def native_layer_norm_backward_out( return grad_input +@register_decomposition(aten._fused_rms_norm.default) +def _fused_rms_norm( + input: Tensor, + normalized_shape: list[int], + weight: Optional[Tensor], + eps: Optional[float], +) -> tuple[Tensor, Tensor]: + dims_to_reduce: list[int] = [] + for i in range(len(normalized_shape)): + dims_to_reduce.append(input.dim() - i - 1) + + # upcast is needed for fp16 and bf16 + computation_dtype = utils.get_computation_dtype(input.dtype) + upcasted_input = input.to(computation_dtype) + + # computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble] + if eps is None: + if computation_dtype in (torch.float32, torch.complex64): + eps_val = sys.float_info.epsilon + else: + eps_val = sys.float_info.epsilon + else: + eps_val = eps + + rqrst_input = torch.rsqrt( + # NB: don't inplace here, will violate functional IR invariant + torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val) + ) + + upcasted_result = upcasted_input.mul(rqrst_input) + + if weight is not None: + upcasted_result = upcasted_result.mul(weight) + + # NB: nested should be dead here, just here for fidelity + is_nested = input.is_nested or (weight is not None and weight.is_nested) + memory_format = utils.suggest_memory_format(input) + is_channels_last = memory_format in ( + torch.channels_last, + torch.channels_last_3d, + ) + + if not is_nested and not is_channels_last: + upcasted_result = upcasted_result.contiguous() + rqrst_input = rqrst_input.contiguous() + + # Cast normalized result back to original input type + result = upcasted_result.type_as(input) + + return result, rqrst_input + + @register_decomposition(aten._fused_rms_norm_backward.default) def _fused_rms_norm_backward( grad_out: Tensor, diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 15ed56ddca3c..d3b9ac7858ce 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -15,6 +15,7 @@ from torch._subclasses.meta_utils import is_sparse_any from torch.utils._python_dispatch import ( _detect_infra_mode, _disable_infra_mode, + autograd_would_have_decomposed, return_and_correct_aliasing, TorchDispatchMode, ) @@ -409,8 +410,13 @@ class FunctionalTensorMode(TorchDispatchMode): return False return True - # in normal torch.compile IR, we decompose functional composite ops - return True + # in normal torch.compile IR, we only decompose an op if autograd + # would have decomposed it (NB: autograd may have been skipped if + # we are in inference mode) + # TODO: the flatten here can potentially be deduped with the + # unwrapping pytree_map later + flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs)) + return autograd_would_have_decomposed(func, flat_args_kwargs) if ( func not in FunctionalTensor.metadata_fns diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index 1bb720b810f9..f1f69e092591 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -466,6 +466,14 @@ at::Tensor LazyNativeFunctions::linalg_pinv( linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian); } +std::tuple LazyNativeFunctions::svd( + const at::Tensor& self, + bool some, + bool compute_uv) { + return at::functionalization::functionalize_aten_op::call( + self, some, compute_uv); +} + // functionalize_aten_op can't handle out= ops directly. // Instead, we can call the composite kernel from core, and copy and mutations // back to the inputs. diff --git a/torch/export/decomp_utils.py b/torch/export/decomp_utils.py index 8e7e14e1b2ec..932a11ab076c 100644 --- a/torch/export/decomp_utils.py +++ b/torch/export/decomp_utils.py @@ -21,6 +21,10 @@ backends are ready, this list allows opt-in one at a time. PRESERVED_ATEN_CIA_OPS = { torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.upsample_nearest2d.vec, + # NB: don't use the C++ decomp, because it is not functional! + torch.ops.aten.silu_backward.default, + torch.ops.aten.mish_backward.default, + torch.ops.aten._fused_rms_norm.default, } diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index c9c6412ab4e7..c31a79cedfbc 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -63,6 +63,7 @@ from torch.utils._python_dispatch import ( _disable_infra_mode, _push_mode, _unset_infra_mode, + autograd_would_have_decomposed, TorchDispatchMode, ) from torch.utils._stats import count @@ -908,11 +909,16 @@ def proxy_call( return r # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. - if not pre_dispatch and func not in [ - torch.ops.aten.size.default, - torch.ops.aten.stride.default, - torch.ops.aten.storage_offset.default, - ]: + if ( + not pre_dispatch + and func + not in [ + torch.ops.aten.size.default, + torch.ops.aten.stride.default, + torch.ops.aten.storage_offset.default, + ] + and autograd_would_have_decomposed(func, flat_args_kwargs) + ): with proxy_mode: r = func.decompose(*args, **kwargs) if r is not NotImplemented: diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 520bc8e3ece1..f5b028b6df08 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs +from __future__ import annotations + import contextlib import functools import warnings from collections import deque -from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional, overload, Protocol, Union +from typing import Optional, overload, Protocol, TYPE_CHECKING, Union from typing_extensions import TypeIs import torch @@ -20,6 +21,10 @@ from torch._C import ( ) +if TYPE_CHECKING: + from collections.abc import Sequence + + # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: # - We need a better user-facing api for _DisableTorchDispatch that # is able to selectively disable __torch_dispatch__ of a particular class. @@ -414,7 +419,7 @@ class TensorWithFlatten(Protocol): @overload def to( self, - device: Optional["torch._prims_common.DeviceLikeType"] = None, + device: Optional[torch._prims_common.DeviceLikeType] = None, dtype: Optional[torch.types._dtype] = None, non_blocking: bool = False, copy: bool = False, @@ -682,6 +687,61 @@ def get_alias_info(func) -> SchemaInfo: return schema_info +def autograd_would_have_decomposed( + func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]] +) -> bool: + """ + Suppose that an operator has CompositeImplicitAutograd decomp registered. + Would autograd have used this decomposition? It will only use it if there + isn't an explicit backend registration for the device as well. This function + will tell if this would have occurred. + + Why do we need to apply these decompositions later? When inference mode is + on, the autograd key is bypassed entirely, so a lower level mode cannot rely + on the decomposition have been applied. It's easy to accidentally never apply + the decomposition, resulting in an operator showing up in a graph that + is unexpected. + + Why do we need to AVOID applying the decomposition when autograd wouldn't + have decomposed? If autograd doesn't decompose, this means in eager mode + we would have run the fused kernel. It must be possible to trace this + fused kernel directly into the graph for fidelity with eager (NB: a user + has the option of then further decomposing at proxy tensor mode via + decomposition table, but we must preserve it to proxy mode to have the + choice.) + + Why does functionalization need to also perform the test here? This is + because some CompositeImplicitAutograd decompositions are not functional. + If we are eventually going to decompose, we need to do this while we can + still turn functionalization back on, so those decompositions get functionalized. + So an early decomposition in functionalization may still be necessary. Note that + if proxy tensor decomposition process could turn functionalization back on, this + wouldn't be necessary, and maybe that is a useful thing to do anyway because + the decomposition table is user specified and a user could violate the functional + decomp requirement with a bad decomp. If this happened, then you could always + pass through functionalization. + """ + has_backend_registration = False + for a in flat_args: + if isinstance(a, torch.Tensor): + backend_key = torch._C._parse_dispatch_key( + torch._C._dispatch_key_for_device(a.device.type) + ) + assert backend_key is not None + # TODO: use func.has_kernel_for_dispatch_key(backend_key) + # but this one checks py_impl and CompositeImplicitAutograd + # incorrectly shows up as has backend reg here + has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), backend_key + ) + + # in theory we should take all backend keys and take the highest priority one + # to properly mimic the dispatcher, + # this just grabs the first tensor and takes its device key + break + return not has_backend_registration + + # See NOTE[SchemaInfo int_tags] above. _TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload] diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index c396941cf913..1cb681ba19d3 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1024,8 +1024,22 @@ def gen_functionalization_registration( ) -> list[str]: @with_native_function def emit_registration_helper(f: NativeFunction) -> str: - assert not f.has_composite_implicit_autograd_kernel - registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})" + if f.has_composite_implicit_autograd_kernel: + metadata = composite_implicit_autograd_index.get_kernel(f) + assert metadata is not None + native_api_name = metadata.kernel + sig = NativeSignature(f.func, symint=metadata.supports_symint()) + # Note [Composite view ops in the functionalization pass] + # We don't need to worry about implemententing functionalization kernels for views with + # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators. + # We can't just opt the entire Functionalization dispatch key into the composite keyset though, + # because we don't want to decompose non-view ops that are composite, like `at::ones`. + registration_str = ( + f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})" + ) + else: + # non-composite view ops (and inplace ops) get a normal registration. + registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})" return f'm.impl("{f.func.name}", {registration_str});' # Don't generate kernels in mobile build @@ -1038,12 +1052,8 @@ def gen_functionalization_registration( if str(g.view.func.name) == "lift_fresh": return [] view_str = [] - if not g.view.has_composite_implicit_autograd_kernel: - view_str.append(emit_registration_helper(g.view)) - if ( - g.view_inplace is not None - and not g.view_inplace.has_composite_implicit_autograd_kernel - ): + view_str.append(emit_registration_helper(g.view)) + if g.view_inplace is not None: assert g.view_inplace.is_view_op view_str.append(emit_registration_helper(g.view_inplace)) return view_str