diff --git a/BUCK.oss b/BUCK.oss index 928de6339fb0..125868dc9ec5 100644 --- a/BUCK.oss +++ b/BUCK.oss @@ -149,6 +149,8 @@ ATEN_EXPORTED_HEADERS = { "CPUFunctions_inl.h": ":gen_aten[CPUFunctions_inl.h]", "CompositeExplicitAutogradFunctions.h": ":gen_aten[CompositeExplicitAutogradFunctions.h]", "CompositeExplicitAutogradFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradFunctions_inl.h]", + "CompositeExplicitAutogradNonFunctionalFunctions.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions.h]", + "CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]", "CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]", "CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]", "FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]", diff --git a/BUILD.bazel b/BUILD.bazel index eb77859fce77..192a444fb53a 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -47,12 +47,15 @@ generated_cpu_cpp = [ "aten/src/ATen/RegisterZeroTensor.cpp", "aten/src/ATen/RegisterCompositeImplicitAutograd.cpp", "aten/src/ATen/RegisterCompositeExplicitAutograd.cpp", + "aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp", "aten/src/ATen/RegisterMeta.cpp", "aten/src/ATen/RegisterSchema.cpp", "aten/src/ATen/CPUFunctions.h", "aten/src/ATen/CPUFunctions_inl.h", "aten/src/ATen/CompositeExplicitAutogradFunctions.h", "aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h", + "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h", + "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h", "aten/src/ATen/CompositeImplicitAutogradFunctions.h", "aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h", "aten/src/ATen/CompositeViewCopyKernels.cpp", diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 31482a0ccb90..1d94fce90408 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -228,10 +228,13 @@ std::pair OperatorEntry::computeDispatchTab // For any dispatch key, it'll pick a kernel using the following order: // (1) Use kernel if it's directly registered to this key // (2) Handle runtime keys that have kernels available from alias keys - // (2.1) Use kernel from DispatchKey::CompositeExplicitAutograd if available. + // (2.1) Use kernel from DispatchKey::CompositeExplicitAutogradNonFunctional if available. + // This is used to register a kernel that works for all backends in inference, except "functional" backends + // like LazyTensor/XLA. But it requires separate registration for Autograd keys to support training. + // (2.2) Use kernel from DispatchKey::CompositeExplicitAutograd if available. // This is used to register a kernel that works for all backend in inference. But it requires // separate registration for Autograd keys to support training. - // (2.2) Use kernel from DispatchKey::CompositeImplicitAutograd if available. + // (2.3) Use kernel from DispatchKey::CompositeImplicitAutograd if available. // For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration // to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd]. // For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of @@ -240,13 +243,13 @@ std::pair OperatorEntry::computeDispatchTab // A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't // cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available) // in this case. - // (2.3) Use kernel from DispatchKey::Autograd if available + // (2.4) Use kernel from DispatchKey::Autograd if available // The implementation of (2.2) relies on the invariant that for a given backend, // `computeDispatchTableEntryWithDebug()` will be called for that backend's autograd key after the // backend key. See Note [Refresh Runtime Autograd entries in dispatchTable_] // (3) Use fallthrough kernel that are registered as fallback. // Alias Key Precedence: - // CompositeExplicitAutograd > CompositeImplicitAutograd > Autograd + // CompositExplicitAutogradNonFunctional > CompositeExplicitAutograd > CompositeImplicitAutograd > Autograd // Note [CompositeExplicitAutograd and CompositeImplicitAutograd] // When there're registrations to both CompositeExplicitAutograd & CompositeImplicitAutograd & Autograd, from (2.2) we know CompositeExplicitAutograd // and Autograd kernels will be picked up and CompositeImplicitAutograd is overriden. @@ -258,7 +261,15 @@ std::pair OperatorEntry::computeDispatchTab return {*direct_registration, "kernel"}; } - // 2.1 Use CompositeExplicitAutograd kernel if available. + // 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available. + // See Note [Undefined in dispatchTable_] for the special handling for Undefined. + if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeExplicitAutogradNonFunctional)) { + if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::CompositeExplicitAutogradNonFunctional)) { + return {*default_backend_registration, "default backend kernel"}; + } + } + + // 2.2 Use CompositeExplicitAutograd kernel if available. // See Note [Undefined in dispatchTable_] for the special handling for Undefined. if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeExplicitAutograd)) { if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::CompositeExplicitAutograd)) { @@ -273,7 +284,7 @@ std::pair OperatorEntry::computeDispatchTab // See Note [No Alias Keys in DispatchKeySet] hasKernelForDispatchKey(DispatchKey::CompositeExplicitAutograd); - // 2.2. Use CompositeImplicitAutograd kernel if available. For autograd keys, we only use kernel from CompositeImplicitAutograd + // 2.3. Use CompositeImplicitAutograd kernel if available. For autograd keys, we only use kernel from CompositeImplicitAutograd // when there's no direct registration to its corresponding backend key or CompositeExplicitAutograd. // For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration // to any of its backends. @@ -289,7 +300,7 @@ std::pair OperatorEntry::computeDispatchTab } } - // 2.3. For autograd backend keys, use kernel from DispatchKey::Autograd if available + // 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd)) { if (auto autograd_registration = getKernelForDispatchKey(DispatchKey::Autograd)) { return {*autograd_registration, "autograd kernel"}; @@ -339,9 +350,11 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) { updateDispatchTableEntry_(dispatcher, k); } - // Registration to CompositeExplicitAutograd and CompositeImplicitAutograd should be populated to Undefined. + // Registration to CompositeExplicitAutogradNonFunctional, CompositeExplicitAutograd and CompositeImplicitAutograd should be populated to Undefined. // We cannot do this above since Undefined cannot be represented in DispatchKeySet. - if (dispatch_key == DispatchKey::CompositeImplicitAutograd || dispatch_key == DispatchKey::CompositeExplicitAutograd) { + if (dispatch_key == DispatchKey::CompositeImplicitAutograd + || dispatch_key == DispatchKey::CompositeExplicitAutograd + || dispatch_key == DispatchKey::CompositeExplicitAutogradNonFunctional) { updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined); } // Note [Refresh Runtime Autograd entries in dispatchTable_] @@ -375,7 +388,7 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) // no dispatch keys are available we just slide into the undefined handler which would then raise // the error message. // In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to - // catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd + // catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd, // or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd) // should return true, it returns false because Undefined cannot be represented in a DispatchKeySet. updateDispatchTable_(dispatcher, DispatchKey::Undefined); diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index 3c10afef14fa..856493d6d2c2 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -292,7 +292,7 @@ to reuse the same function name in both cases. Available backend options can be found by searching `dispatch_keys` in [codegen](https://github.com/pytorch/pytorch/blob/master/torchgen/gen.py). -There are also two special "generic" backends: +There are also three special "generic" backends: - `CompositeExplicitAutograd` (previously known as `DefaultBackend`): implementations of kernels that work for all backends, but require an @@ -305,6 +305,18 @@ There are also two special "generic" backends: DispatchStub should NOT be registered as CompositeExplicitAutograd, as DispatchStub only works for `CPU, CUDA`) + - `CompositeExplicitAutogradNonFunctional`: + Similar to CompositeExplicitAutograd, but this key should be used if: + (1) Your kernel is written for a non-aliasing operator. + (2) *and* it calls internally into an aliasing operator. + An example of this is select_backward, which is non-aliasing, but decomposes into select. + We would like to distinguish between "ordinary" CompositeExplicitAutograd kernels + and these kernels, because some backends would not like + to decompose an non-aliasing op into an aliasing op. + LazyTensor + XLA are the two current examples of this - since they operate on a functional IR, + they would prefer to directly implement a non-aliasing operator with their own kernel, + instead of using a decomposition that results in more aliasing operators. + - `CompositeImplicitAutograd` (previously known as `Math`): implementations of kernels that work for all backends, and also can implicitly support autograd, because all of the operations it calls support autograd. Direct use of diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9b2c4133c508..6029f9dae065 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -715,7 +715,6 @@ structured_delegate: atanh.out variants: function, method dispatch: - CompositeExplicitAutograd: atanh SparseCPU, SparseCUDA: atanh_sparse SparseCsrCPU, SparseCsrCUDA: atanh_sparse_csr @@ -1160,7 +1159,6 @@ structured_delegate: ceil.out variants: function, method dispatch: - CompositeExplicitAutograd: ceil SparseCPU, SparseCUDA: ceil_sparse SparseCsrCPU, SparseCsrCUDA: ceil_sparse_csr @@ -1169,7 +1167,6 @@ structured_delegate: ceil.out variants: function, method dispatch: - CompositeExplicitAutograd: ceil_ SparseCPU, SparseCUDA: ceil_sparse_ SparseCsrCPU, SparseCsrCUDA: ceil_sparse_csr_ @@ -1226,8 +1223,6 @@ variants: function, method cpp_no_default_args: ['min'] structured_delegate: clamp.out - dispatch: - CompositeExplicitAutograd: clamp_ - func: clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) variants: function, method @@ -2242,7 +2237,6 @@ structured_delegate: floor.out variants: function, method dispatch: - CompositeExplicitAutograd: floor SparseCPU, SparseCUDA: floor_sparse SparseCsrCPU, SparseCsrCUDA: floor_sparse_csr @@ -2251,7 +2245,6 @@ structured_delegate: floor.out variants: function, method dispatch: - CompositeExplicitAutograd: floor_ SparseCPU, SparseCUDA: floor_sparse_ SparseCsrCPU, SparseCsrCUDA: floor_sparse_csr_ @@ -2822,8 +2815,6 @@ device_check: NoCheck # TensorIterator structured_delegate: log10.out variants: function, method - dispatch: - CompositeExplicitAutograd: log10 - func: log10_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2892,8 +2883,6 @@ - func: logaddexp(Tensor self, Tensor other) -> Tensor variants: method, function structured_delegate: logaddexp.out - dispatch: - CompositeExplicitAutograd: logaddexp - func: logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -2905,8 +2894,6 @@ - func: logaddexp2(Tensor self, Tensor other) -> Tensor variants: method, function structured_delegate: logaddexp2.out - dispatch: - CompositeExplicitAutograd: logaddexp2 - func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -3472,7 +3459,7 @@ dispatch: CPU: narrow_copy_dense_cpu SparseCPU, SparseCUDA: narrow_copy_sparse - CompositeExplicitAutograd: narrow_copy_dense + CompositeExplicitAutogradNonFunctional: narrow_copy_dense tags: view_copy - func: narrow_copy.SymInt(Tensor self, int dim, int start, SymInt length) -> Tensor @@ -4114,14 +4101,10 @@ - func: silu(Tensor self) -> Tensor structured_delegate: silu.out python_module: nn - dispatch: - CompositeExplicitAutograd: silu - func: silu_(Tensor(a!) self) -> Tensor(a!) structured_delegate: silu.out python_module: nn - dispatch: - CompositeExplicitAutograd: silu_ - func: silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -4148,14 +4131,10 @@ - func: mish(Tensor self) -> Tensor structured_delegate: mish.out python_module: nn - dispatch: - CompositeExplicitAutograd: mish - func: mish_(Tensor(a!) self) -> Tensor(a!) structured_delegate: mish.out python_module: nn - dispatch: - CompositeExplicitAutograd: mish_ - func: mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -4937,7 +4916,6 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CompositeExplicitAutograd: trunc SparseCPU, SparseCUDA: trunc_sparse SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr @@ -4946,7 +4924,6 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CompositeExplicitAutograd: trunc_ SparseCPU, SparseCUDA: trunc_sparse_ SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr_ @@ -6596,15 +6573,11 @@ structured_delegate: eq.Scalar_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: eq_ - func: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: eq.Tensor_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: eq_ - func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -7110,15 +7083,11 @@ structured_delegate: ne.Scalar_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: ne_ - func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: ne.Tensor_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: ne_ # not_equal, alias for torch.ne - func: not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) @@ -7205,15 +7174,11 @@ structured_delegate: ge.Scalar_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: ge_ - func: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: ge.Tensor_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: ge_ # greater_equal, alias for torch.ge - func: greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) @@ -7268,15 +7233,11 @@ structured_delegate: le.Scalar_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: le_ - func: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: le.Tensor_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: le_ # less_equal, alias for torch.le - func: less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) @@ -7331,15 +7292,11 @@ structured_delegate: gt.Scalar_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: gt_ - func: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: gt.Tensor_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: gt_ # greater, alias for torch.gt - func: greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) @@ -7394,15 +7351,11 @@ structured_delegate: lt.Scalar_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: lt_ - func: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: lt.Tensor_out device_check: NoCheck # TensorIterator variants: method - dispatch: - CompositeExplicitAutograd: lt_ # less, alias for torch.lt - func: less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) @@ -7840,7 +7793,6 @@ structured_delegate: sign.out variants: function, method dispatch: - CompositeExplicitAutograd: sign SparseCPU, SparseCUDA: sign_sparse SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr @@ -7849,7 +7801,6 @@ structured_delegate: sign.out variants: method dispatch: - CompositeExplicitAutograd: sign_ SparseCPU, SparseCUDA: sign_sparse_ SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr_ @@ -8032,8 +7983,6 @@ - func: hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: hypot.out variants: method - dispatch: - CompositeExplicitAutograd: hypot_ - func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8076,8 +8025,6 @@ - func: nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: nextafter.out variants: method - dispatch: - CompositeExplicitAutograd: nextafter_ - func: remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -9506,8 +9453,6 @@ structured_delegate: elu.out device_check: NoCheck # TensorIterator python_module: nn - dispatch: - CompositeExplicitAutograd: elu_ - func: glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -11116,8 +11061,6 @@ python_module: special variants: function structured_delegate: special_zeta.out - dispatch: - CompositeExplicitAutograd: special_zeta - func: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -12026,175 +11969,175 @@ - func: _fw_primal_copy(Tensor self, int level) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: _fw_primal_copy + CompositeExplicitAutogradNonFunctional: _fw_primal_copy tags: view_copy - func: _make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: _make_dual_copy + CompositeExplicitAutogradNonFunctional: _make_dual_copy tags: view_copy - func: view_as_real_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: view_as_real_copy + CompositeExplicitAutogradNonFunctional: view_as_real_copy tags: view_copy - func: view_as_complex_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: view_as_complex_copy + CompositeExplicitAutogradNonFunctional: view_as_complex_copy tags: view_copy - func: _conj_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: _conj_copy + CompositeExplicitAutogradNonFunctional: _conj_copy tags: view_copy - func: _neg_view_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: _neg_view_copy + CompositeExplicitAutogradNonFunctional: _neg_view_copy tags: view_copy - func: as_strided_copy(Tensor self, int[] size, int[] stride, int? storage_offset=None) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: as_strided_copy + CompositeExplicitAutogradNonFunctional: as_strided_copy tags: view_copy - func: _sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: _sparse_broadcast_to_copy + CompositeExplicitAutogradNonFunctional: _sparse_broadcast_to_copy tags: view_copy - func: diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: diagonal_copy + CompositeExplicitAutogradNonFunctional: diagonal_copy tags: view_copy - func: expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: expand_copy + CompositeExplicitAutogradNonFunctional: expand_copy tags: view_copy - func: expand_copy.SymInt(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: expand_copy_SymInt + CompositeExplicitAutogradNonFunctional: expand_copy_SymInt tags: view_copy - func: permute_copy(Tensor self, int[] dims) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: permute_copy + CompositeExplicitAutogradNonFunctional: permute_copy tags: view_copy - func: _reshape_alias_copy(Tensor self, int[] size, int[] stride) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: _reshape_alias_copy + CompositeExplicitAutogradNonFunctional: _reshape_alias_copy tags: view_copy - func: select_copy.int(Tensor self, int dim, int index) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: select_copy_int + CompositeExplicitAutogradNonFunctional: select_copy_int tags: view_copy - func: detach_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: detach_copy + CompositeExplicitAutogradNonFunctional: detach_copy tags: view_copy - func: slice_copy.Tensor(Tensor self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: slice_copy_Tensor + CompositeExplicitAutogradNonFunctional: slice_copy_Tensor tags: view_copy - func: split_copy.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[] variants: function dispatch: - CompositeExplicitAutograd: split_copy_Tensor + CompositeExplicitAutogradNonFunctional: split_copy_Tensor tags: view_copy - func: split_with_sizes_copy(Tensor self, int[] split_sizes, int dim=0) -> Tensor[] variants: function dispatch: - CompositeExplicitAutograd: split_with_sizes_copy + CompositeExplicitAutogradNonFunctional: split_with_sizes_copy tags: view_copy - func: squeeze_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: squeeze_copy + CompositeExplicitAutogradNonFunctional: squeeze_copy tags: view_copy - func: squeeze_copy.dim(Tensor self, int dim) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: squeeze_copy_dim + CompositeExplicitAutogradNonFunctional: squeeze_copy_dim tags: view_copy - func: t_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: t_copy + CompositeExplicitAutogradNonFunctional: t_copy tags: view_copy - func: transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: transpose_copy_int + CompositeExplicitAutogradNonFunctional: transpose_copy_int tags: view_copy - func: unsqueeze_copy(Tensor self, int dim) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: unsqueeze_copy + CompositeExplicitAutogradNonFunctional: unsqueeze_copy tags: view_copy - func: _indices_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: _indices_copy + CompositeExplicitAutogradNonFunctional: _indices_copy tags: view_copy - func: _values_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: _values_copy + CompositeExplicitAutogradNonFunctional: _values_copy tags: view_copy - func: indices_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: indices_copy + CompositeExplicitAutogradNonFunctional: indices_copy tags: view_copy - func: values_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: values_copy + CompositeExplicitAutogradNonFunctional: values_copy tags: view_copy - func: crow_indices_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: crow_indices_copy + CompositeExplicitAutogradNonFunctional: crow_indices_copy tags: view_copy - func: col_indices_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: col_indices_copy + CompositeExplicitAutogradNonFunctional: col_indices_copy tags: view_copy - func: ccol_indices_copy(Tensor self) -> Tensor @@ -12212,31 +12155,31 @@ - func: unbind_copy.int(Tensor self, int dim=0) -> Tensor[] variants: function dispatch: - CompositeExplicitAutograd: unbind_copy_int + CompositeExplicitAutogradNonFunctional: unbind_copy_int tags: view_copy - func: view_copy(Tensor self, int[] size) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: view_copy + CompositeExplicitAutogradNonFunctional: view_copy tags: view_copy - func: view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: view_copy_dtype + CompositeExplicitAutogradNonFunctional: view_copy_dtype tags: view_copy - func: unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: unfold_copy + CompositeExplicitAutogradNonFunctional: unfold_copy tags: view_copy - func: alias_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: alias_copy + CompositeExplicitAutogradNonFunctional: alias_copy tags: view_copy - func: _fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!) diff --git a/build.bzl b/build.bzl index 71ae9a7fa4e6..b7fcbdb81f29 100644 --- a/build.bzl +++ b/build.bzl @@ -155,6 +155,8 @@ GENERATED_H_CORE = [ "CPUFunctions_inl.h", "CompositeExplicitAutogradFunctions.h", "CompositeExplicitAutogradFunctions_inl.h", + "CompositeExplicitAutogradNonFunctionalFunctions.h", + "CompositeExplicitAutogradNonFunctionalFunctions_inl.h", "CompositeImplicitAutogradFunctions.h", "CompositeImplicitAutogradFunctions_inl.h", "MetaFunctions.h", @@ -191,6 +193,7 @@ GENERATED_CPP = [ "RegisterZeroTensor.cpp", "RegisterMeta.cpp", "RegisterCompositeExplicitAutograd.cpp", + "RegisterCompositeExplicitAutogradNonFunctional.cpp", "CompositeViewCopyKernels.cpp", "RegisterSchema.cpp", "RegisterFunctionalization_0.cpp", diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 9cc0b1524cd9..c07ea8731489 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -197,6 +197,9 @@ const char* toString(DispatchKey t) { case DispatchKey::CompositeExplicitAutograd: return "CompositeExplicitAutograd"; + case DispatchKey::CompositeExplicitAutogradNonFunctional: + return "CompositeExplicitAutogradNonFunctional"; + case DispatchKey::TESTING_ONLY_GenericWrapper: return "TESTING_ONLY_GenericWrapper"; @@ -349,6 +352,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { c10::DispatchKey::CompositeImplicitAutograd}, {"CompositeExplicitAutograd", c10::DispatchKey::CompositeExplicitAutograd}, + {"CompositeExplicitAutogradNonFunctional", + c10::DispatchKey::CompositeExplicitAutogradNonFunctional}, }; auto it = key_map.find(k); TORCH_CHECK(it != key_map.end(), "could not parse dispatch key: ", k); diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index dcd60989d826..69c57ec89f5e 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -532,11 +532,14 @@ enum class DispatchKey : uint16_t { // build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp CompositeExplicitAutograd, // registered at // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp + // See Note [CompositeExplicitAutogradNonFunctional Key] + CompositeExplicitAutogradNonFunctional, // registered at + // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp // Define an alias key to represent end of alias dispatch keys. // If you add new alias keys after Autograd, please also update it here. StartOfAliasKeys = Autograd, - EndOfAliasKeys = CompositeExplicitAutograd, // + EndOfAliasKeys = CompositeExplicitAutogradNonFunctional, // // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // // The aliases exist for backwards compatibility reasons, they shouldn't diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 497ecf0f3bbb..14151fce6feb 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -10,6 +10,27 @@ namespace c10 { constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | DispatchKeySet(DispatchKey::Dense); +// See Note [CompositeExplicitAutogradNonFunctional Key] +// We have several types of decompositions in aten, that each have their own +// alias key. You should register your decomposition to the +// `CompositeExplicitAutogradNonFunctional key` if: (1) It's an out-of-place op +// (2) It decomposes into one more mutation ops +// (3) It has a derivative formula +// (In theory we could also have a separate key for +// "CompositeImplicitAutogradNonFunctional", but there isn't much of a use +// case for it currently). +// This key is important for "functional" backends like LazyTensor / XLA. +// If you're a backend that only expects to deal with "functional ops", +// then you don't want to decompose a functional op into an op that causes +// aliasing. You should just directly write a kernel for that functional op +// instead! +constexpr DispatchKeySet non_functional_backend_dispatch_keyset = + backend_dispatch_keyset + // XLA and LazyTensor are currently the only 2 backends in core + // that use functionalization pass in eager mode. + .remove_backend(BackendComponent::XLABit) + .remove_backend(BackendComponent::LazyBit); + bool isBackendDispatchKey(DispatchKey t) { return t != DispatchKey::Undefined // See Note [No Alias Keys in DispatchKeySet] @@ -42,6 +63,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { return math_dispatch_keyset; case DispatchKey::CompositeExplicitAutograd: return backend_dispatch_keyset; + case DispatchKey::CompositeExplicitAutogradNonFunctional: + return non_functional_backend_dispatch_keyset; default: return DispatchKeySet(t); } @@ -58,6 +81,10 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) { case DispatchKey::CompositeExplicitAutograd: // See Note [NestedTensor Not Included in Backend Keys] return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k); + case DispatchKey::CompositeExplicitAutogradNonFunctional: + // See Note [NestedTensor Not Included in Backend Keys] + return k != DispatchKey::NestedTensor && + non_functional_backend_dispatch_keyset.has(k); default: return t == k; } diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 174d24aceddb..656dbfc44129 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -367,6 +367,12 @@ class DispatchKeySet final { return DispatchKeySet( repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask)); } + // You're allowed to remove a backend bit from a DispatchKeySet, + // but you have to be explicit about it (remove_backend() instead of + // remove()). + constexpr DispatchKeySet remove_backend(BackendComponent b) const { + return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_)); + } // Is the set empty? (AKA undefined tensor) bool empty() const { return repr_ == 0; diff --git a/pt_defs.oss.bzl b/pt_defs.oss.bzl index 76061efd37d4..879acb31f8b8 100644 --- a/pt_defs.oss.bzl +++ b/pt_defs.oss.bzl @@ -20,6 +20,7 @@ PT_BACKEND_HEADERS = [ "CPU", "CUDA", "CompositeExplicitAutograd", + "CompositeExplicitAutogradNonFunctional", "CompositeImplicitAutograd", "Meta", ] @@ -333,6 +334,7 @@ def get_aten_generated_files(enabled_backends): "RegisterBackendSelect.cpp", "RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeExplicitAutograd.cpp", + "RegisterCompositeExplicitAutogradNonFunctional.cpp", "CompositeViewCopyKernels.cpp", "RegisterSchema.cpp", "Declarations.yaml", @@ -353,6 +355,8 @@ def get_aten_generated_files(enabled_backends): "CompositeImplicitAutogradFunctions_inl.h", "CompositeExplicitAutogradFunctions.h", "CompositeExplicitAutogradFunctions_inl.h", + "CompositeExplicitAutogradNonFunctionalFunctions.h", + "CompositeExplicitAutogradNonFunctionalFunctions_inl.h", "core/ATenOpList.cpp", "core/TensorBody.h", "core/TensorMethods.cpp", @@ -523,7 +527,7 @@ def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends): def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends): return [ ":{}[{}]".format(aten_rule_name, f) - for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"] + for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"] ] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends) def get_aten_derived_type_srcs(enabled_backends): diff --git a/test/test_functionalization.py b/test/test_functionalization.py index f9c425fb01a7..59837372ccb5 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -625,7 +625,7 @@ $3 = torch._ops.aten.add.Tensor($2, 1)""") x1_not_functional = torch.ones(4) x2_functional = torch._to_functional_tensor(torch.ones(4)) - # When dealing with mixed functional + nonfunctional tensors, + # When dealing with mixed functional + non functional tensors, # normal_tensor.add_(functional_tensor) is not valid # because normal_tensor would need to be "promoted" to a functional tensor. with self.assertRaises(RuntimeError): diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index ea5cad54d53a..4d6f9fb0702c 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -87,7 +87,7 @@ def gen_empty_impl_names( empty_impl = f"at::detail::empty_{dispatch}" empty_strided_impl = f"at::detail::empty_strided_{dispatch}" elif backend_index.dispatch_key in ( - DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, DispatchKey.QuantizedCPU, DispatchKey.QuantizedCUDA, ): @@ -139,6 +139,10 @@ c10::optional maybe_create_proxy(const Tensor &out, IntArrayRef sizes, I def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]: + if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: + # The function isn't used by this key (since only functional ops have a kernel for this key), + # so we need to not include it to avoid a defined-but-not-used error. + return [] return [ """ void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) { @@ -332,7 +336,10 @@ class RegisterDispatchKey: "Do not explicitly specify Meta dispatch key on structured " "functions, they will be automatically generated for you" ) - elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: + elif ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): assert not self.backend_index.has_kernel(g.out), ( "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " "functions, they will be automatically generated for you" @@ -566,7 +573,7 @@ void set_output_{name}( if self.backend_index.dispatch_key in [ DispatchKey.CUDA, DispatchKey.MPS, - DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, ]: maybe_set_guard = """ auto current_device = guard_.current_device(); @@ -597,7 +604,7 @@ if (C10_UNLIKELY(maybe_proxy.has_value())) { DispatchKey.CPU, DispatchKey.CUDA, DispatchKey.MPS, - DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, ) return f"""{maybe_set_guard_line} outputs_[output_idx] = create_out(sizes, strides, options);""" @@ -664,7 +671,10 @@ resize_out(out, sizes, strides, options); guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;" else: guard_field = "c10::cuda::OptionalCUDAGuard guard_;" - elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: + elif ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): guard_field = "c10::OptionalDeviceGuard guard_;" elif self.backend_index.dispatch_key == DispatchKey.MPS: # TODO: Move to OptionalMPSGuard. @@ -699,7 +709,7 @@ resize_out(out, sizes, strides, options); return None # TODO: Now, there is something interesting going on here. In the code below, - # we generate CompositeExplicitAutograd implementations of functional and inplace + # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? @@ -710,7 +720,8 @@ resize_out(out, sizes, strides, options); # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if ( - self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional and f.func.kind() is SchemaKind.out ): # Never generate a default implementation for out, that's what you @@ -766,7 +777,8 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" elif ( - self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd + self.backend_index.dispatch_key + is DispatchKey.CompositeExplicitAutogradNonFunctional ): # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" @@ -858,7 +870,10 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si # With the expanded context, do the impl call (if not a meta # function) - if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: + if ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding diff --git a/torchgen/gen.py b/torchgen/gen.py index 9c9e2a622b76..87e47bdd7ebd 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -318,6 +318,7 @@ def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]: return [backend.dispatch_key for backend in backends] + [ DispatchKey.CompositeImplicitAutograd, DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, ] @@ -332,6 +333,8 @@ def get_static_dispatch_backend( return backend_index.dispatch_key elif f.has_composite_explicit_autograd_kernel: return DispatchKey.CompositeExplicitAutograd + elif f.has_composite_explicit_autograd_non_functional_kernel: + return DispatchKey.CompositeExplicitAutogradNonFunctional elif f.has_composite_implicit_autograd_kernel: return DispatchKey.CompositeImplicitAutograd return None @@ -420,6 +423,8 @@ def generate_static_dispatch_fallback_call( exprs = translate_args_dispatcher_to_cpp(f) if f.has_composite_explicit_autograd_kernel: return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});" + elif f.has_composite_explicit_autograd_non_functional_kernel: + return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});" elif f.has_composite_implicit_autograd_kernel: return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});" else: @@ -1902,16 +1907,21 @@ def gen_source_files( ): is_registered = True # TODO: this condition is a bit questionable + # (It has to do with the fact that structured kernels get generated kernels + # to the Meta + CompositeExplicitAutogradNonFunctional keys). elif g.structured and dispatch_key in ( DispatchKey.Meta, - DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, ): is_registered = True if not is_registered: continue headers.append(f"#include ") - if dispatch_key == DispatchKey.CompositeExplicitAutograd: + if ( + dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): headers.append(f"#include ") if dispatch_key in functions_keys: headers.append( @@ -1924,7 +1934,7 @@ def gen_source_files( def operator_headers() -> List[str]: headers = ["#include "] - if dispatch_key == DispatchKey.CompositeExplicitAutograd: + if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: headers.append("#include ") if dispatch_key in functions_keys: headers.append(f"#include ") @@ -2279,7 +2289,7 @@ TORCH_LIBRARY({custom_namespace}, m) {{ # are expected to implement kernels for these {view}_copy kernels instead. # The code for {view}_copy operators in core is pretty boilerplate-heavy however, # so we codegen the following: - # (1) A CompositeExplicitAutograd kernel for every {view}_copy operator. + # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator. # These are never explicitly invoked by the functionalization pass, # but they could theoretically be called from user code (I added these kernels for completeness, # since the ops are part of the public API). @@ -2512,6 +2522,7 @@ def main() -> None: DispatchKey.CUDA, DispatchKey.CompositeImplicitAutograd, DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, DispatchKey.Meta, } if options.backend_whitelist: diff --git a/torchgen/model.py b/torchgen/model.py index f0c34eb6cc30..a3090a16f9a5 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -123,7 +123,8 @@ class DispatchKey(Enum): Autograd = auto() CompositeImplicitAutograd = auto() CompositeExplicitAutograd = auto() - EndOfAliasKeys = CompositeExplicitAutograd + CompositeExplicitAutogradNonFunctional = auto() + EndOfAliasKeys = CompositeExplicitAutogradNonFunctional CPUTensorId = CPU CUDATensorId = CUDA @@ -162,6 +163,7 @@ dispatch_keys = [ DispatchKey.QuantizedCUDA, DispatchKey.CompositeImplicitAutograd, DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, DispatchKey.NestedTensorCPU, DispatchKey.NestedTensorCUDA, # Meta is a magic key: it is automatically generated for structured @@ -175,6 +177,7 @@ dispatch_keys = [ def is_generic_dispatch_key(dk: DispatchKey) -> bool: return dk in { DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, DispatchKey.CompositeImplicitAutograd, } @@ -422,6 +425,7 @@ class NativeFunction: # Whether or not the NativeFunction contains a backend-agnostic kernel has_composite_implicit_autograd_kernel: bool has_composite_explicit_autograd_kernel: bool + has_composite_explicit_autograd_non_functional_kernel: bool # Tags are used to describe semantic information about (groups of) operators, # That aren't easily inferrable directly from the operator's schema. @@ -613,11 +617,17 @@ class NativeFunction: cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE ) - assert not ( - DispatchKey.CompositeExplicitAutograd in dispatch - and DispatchKey.CompositeImplicitAutograd in dispatch - ), ( - "cannot specify both CompositeExplicitAutograd and CompositeImplicitAutograd on a single kernel; each " + composites_in_dispatch = [ + d + for d in dispatch + if d == DispatchKey.CompositeExplicitAutograd + or d == DispatchKey.CompositeExplicitAutogradNonFunctional + or d == DispatchKey.CompositeImplicitAutograd + ] + + assert len(composites_in_dispatch) <= 1, ( + "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, " + "or CompositeImplicitAutograd on a single kernel; each " "strictly subsumes the other. If you wanted to provide an explicit autograd " "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only" ) @@ -673,6 +683,9 @@ class NativeFunction: has_composite_explicit_autograd_kernel = ( DispatchKey.CompositeExplicitAutograd in dispatch.keys() ) + has_composite_explicit_autograd_non_functional_kernel = ( + DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch.keys() + ) # We aren't going to store dispatch metadata inline in NativeFunctions; # instead it is separately indexed by backend (so other backends can @@ -715,6 +728,7 @@ class NativeFunction: is_abstract=is_abstract, has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel, has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel, + has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel, tags=tags, namespace=namespace, ), @@ -789,6 +803,7 @@ class NativeFunction: return ( self.has_composite_implicit_autograd_kernel or self.has_composite_explicit_autograd_kernel + or self.has_composite_explicit_autograd_non_functional_kernel ) @property diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 42b70b753ff5..fe7f7182ec05 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -251,6 +251,7 @@ def generate_function( is_abstract=f.is_abstract, has_composite_implicit_autograd_kernel=False, has_composite_explicit_autograd_kernel=True, + has_composite_explicit_autograd_non_functional_kernel=False, # Every generated NativeFunction gets a "generated" tag, so it's easy to tell # which NativeFunction objects did not come directly from native_functions.yaml. tags=set(["generated"]),