From 23b8414391b6a649e1a0ef56c841d5855eb42b43 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 11 Apr 2022 12:38:14 -0700 Subject: [PATCH] code-generate non-aliasing {view}_copy kernels (#73442) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73442 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D35016025 Pulled By: bdhirsh fbshipit-source-id: 2a7f303ec76f5913b744c7822a531d55a57589c9 (cherry picked from commit 3abe13c2a787bcbe9c41b0a335c96e5a3d3642fb) --- BUILD.bazel | 1 + aten/src/ATen/native/native_functions.yaml | 200 ++++++++++++++++++ aten/src/ATen/native/tags.yaml | 10 + .../templates/CompositeViewCopyKernels.cpp | 20 ++ test/test_view_ops.py | 18 ++ tools/autograd/load_derivatives.py | 45 +++- tools/codegen/api/autograd.py | 35 ++- tools/codegen/context.py | 39 +++- tools/codegen/gen.py | 137 ++++++++++-- tools/codegen/gen_functionalization_type.py | 193 ++++++++++++----- tools/codegen/model.py | 124 ++++++++++- torch/_torch_docs.py | 138 ++++++++++++ torch/overrides.py | 34 +++ 13 files changed, 909 insertions(+), 85 deletions(-) create mode 100644 aten/src/ATen/native/tags.yaml create mode 100644 aten/src/ATen/templates/CompositeViewCopyKernels.cpp diff --git a/BUILD.bazel b/BUILD.bazel index 197592f81e0d..dfc910c0e2be 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -50,6 +50,7 @@ generated_cpu_cpp = [ "aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h", "aten/src/ATen/CompositeImplicitAutogradFunctions.h", "aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h", + "aten/src/ATen/CompositeViewCopyKernels.cpp", "aten/src/ATen/FunctionalInverses.h", "aten/src/ATen/Functions.h", "aten/src/ATen/Functions.cpp", diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 768ba90c10dd..a892aa4e31f3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -145,6 +145,7 @@ - func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) variants: method + tags: inplace_view - func: rename(Tensor(a) self, Dimname[]? names) -> Tensor(a) variants: method @@ -3262,6 +3263,7 @@ CPU: narrow_copy_dense_cpu SparseCPU, SparseCUDA: narrow_copy_sparse CompositeExplicitAutograd: narrow_copy_dense + tags: view_copy - func: narrow_copy.SymInt(Tensor self, int dim, int start, SymInt length) -> Tensor variants: function, method @@ -11355,3 +11357,201 @@ - func: nested_tensor(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor variants: function + +- func: _fw_primal_copy(Tensor self, int level) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _fw_primal_copy + tags: view_copy + +- func: _make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _make_dual_copy + tags: view_copy + +- func: view_as_real_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: view_as_real_copy + tags: view_copy + +- func: view_as_complex_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: view_as_complex_copy + tags: view_copy + +- func: _conj_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _conj_copy + tags: view_copy + +- func: _neg_view_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _neg_view_copy + tags: view_copy + +- func: as_strided_copy(Tensor self, int[] size, int[] stride, int? storage_offset=None) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: as_strided_copy + tags: view_copy + +- func: _sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _sparse_broadcast_to_copy + tags: view_copy + +- func: diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: diagonal_copy + tags: view_copy + +- func: expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: expand_copy + tags: view_copy + +- func: permute_copy(Tensor self, int[] dims) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: permute_copy + tags: view_copy + +- func: _reshape_alias_copy(Tensor self, int[] size, int[] stride) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _reshape_alias_copy + tags: view_copy + +- func: select_copy.int(Tensor self, int dim, int index) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: select_copy_int + tags: view_copy + +- func: detach_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: detach_copy + tags: view_copy + +- func: slice_copy.Tensor(Tensor self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: slice_copy_Tensor + tags: view_copy + +- func: split_copy.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[] + variants: function + dispatch: + CompositeExplicitAutograd: split_copy_Tensor + tags: view_copy + +- func: split_with_sizes_copy(Tensor self, int[] split_sizes, int dim=0) -> Tensor[] + variants: function + dispatch: + CompositeExplicitAutograd: split_with_sizes_copy + tags: view_copy + +- func: squeeze_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: squeeze_copy + tags: view_copy + +- func: squeeze_copy.dim(Tensor self, int dim) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: squeeze_copy_dim + tags: view_copy + +- func: t_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: t_copy + tags: view_copy + +- func: transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: transpose_copy_int + tags: view_copy + +- func: unsqueeze_copy(Tensor self, int dim) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: unsqueeze_copy + tags: view_copy + +- func: _indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _indices_copy + tags: view_copy + +- func: _values_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _values_copy + tags: view_copy + +- func: indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: indices_copy + tags: view_copy + +- func: values_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: values_copy + tags: view_copy + +- func: crow_indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: crow_indices_copy + tags: view_copy + +- func: col_indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: col_indices_copy + tags: view_copy + +- func: unbind_copy.int(Tensor self, int dim=0) -> Tensor[] + variants: function + dispatch: + CompositeExplicitAutograd: unbind_copy_int + tags: view_copy + +- func: view_copy(Tensor self, int[] size) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: view_copy + tags: view_copy + +- func: view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: view_copy_dtype + tags: view_copy + +- func: unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: unfold_copy + tags: view_copy + +- func: alias_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: alias_copy + tags: view_copy diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml new file mode 100644 index 000000000000..d79b13adae84 --- /dev/null +++ b/aten/src/ATen/native/tags.yaml @@ -0,0 +1,10 @@ +# This yaml file contains all the possible tags that can be defined in `tags` in `native_functions.yaml` + +- tag: inplace_view + desc: | + This tag indicates if an operator *only* modifies the tensor metadata +- tag: view_copy + desc: | + This tag indicates operators that are *_copy* variants + of view/aliasing operators. If an operator has a view_copy tag, + then it should have the name {op}_copy, where {op} is a view operator. diff --git a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp new file mode 100644 index 000000000000..558802a7b7e8 --- /dev/null +++ b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp @@ -0,0 +1,20 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +$ops_headers +#endif + +namespace at { +namespace native { + + +${CompositeViewCopyKernel_Definitions} + +} // namespace native +} // namespace at diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 064d001727ab..3fcfa72cf45e 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -908,6 +908,24 @@ class TestViewOps(TestCase): op = partial(fn, source=0, destination=1) run_test(device, op) + # Testing that the generated view_copy kernel and its derivative are implemented correctly + def test_view_copy(self, device): + a = torch.randn(4, device=device, requires_grad=True) + a_ref = a.clone().detach().requires_grad_() + a_view = a_ref.view(2, 2) + a_view_copy = torch.view_copy(a, (2, 2)) + + # view_copy ops don't preserve view relationship + self.assertTrue(self.is_view_of(a_ref, a_view)) + self.assertFalse(self.is_view_of(a, a_view_copy)) + + a_view_copy.sum().backward() + a_view.sum().backward() + + # forward and backward give the same shape + result + self.assertEqual(a_view_copy, a_view) + self.assertEqual(a.grad, a_ref.grad) + class TestOldViewOps(TestCase): def test_ravel(self, device): diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index e62ab95c66d0..8ffaccaffb48 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -14,13 +14,38 @@ from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, Bas tensorGeometryT, scalarTypeT, SpecialArgName, OptionalCType, stringT) from tools.codegen.api import cpp -from tools.codegen.gen import parse_native_yaml +from tools.codegen.gen import parse_native_yaml, get_grouped_by_view_native_functions from tools.codegen.context import with_native_function -from tools.codegen.model import FunctionSchema, NativeFunction, Variant, Type -from tools.codegen.utils import IDENT_REGEX, split_name_params, YamlLoader +from tools.codegen.model import ( + FunctionSchema, NativeFunction, Variant, Type, + NativeFunctionsViewGroup, OperatorName +) +from tools.codegen.utils import IDENT_REGEX, split_name_params, YamlLoader, concatMap _GLOBAL_LOAD_DERIVATIVE_CACHE = {} +# This function directly adds derivative entries for {view}_copy variants of each view op. +# Since every {view} and {view}_copy op shares the same derivative formula, +# we generate them here instead of duplicating them in the yaml. +# See Note [Codegen'd {view}_copy Operators] +def add_view_copy_derivatives( + infos: List[DifferentiabilityInfo], + view_groups: List[NativeFunctionsViewGroup] +) -> List[DifferentiabilityInfo]: + # Get the map from each view op's name to its corresponding view group + view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = { + g.view.func.name: g for g in view_groups} + + view_copy_differentiability_infos = [] + for info in infos: + maybe_view_group = view_name_to_group.get(info.func.func.name, None) + if maybe_view_group is not None and maybe_view_group.view_copy is not None: + view_copy_info = info.create_view_copy_from_view_derivative(maybe_view_group) + if view_copy_info is not None: + view_copy_differentiability_infos.append(view_copy_info) + + return view_copy_differentiability_infos + def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Sequence[DifferentiabilityInfo]: # Do some caching as this is a deterministic function global _GLOBAL_LOAD_DERIVATIVE_CACHE @@ -30,7 +55,16 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque with open(derivatives_yaml_path, 'r') as f: definitions = yaml.load(f, Loader=YamlLoader) - functions = parse_native_yaml(native_yaml_path).native_functions + funcs = parse_native_yaml(native_yaml_path).native_functions + # From the parsed native functions, separate out the (generated) view_copy functions, + # so we can generate derivatives for them separately. + native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs) + native_functions_without_view_copies = concatMap( + # We need to pull out the view_inplace ops too, since they might have their own derivative entries. + lambda g: [g] if isinstance(g, NativeFunction) else list(g.functions(include_copy=False)), + native_functions_with_view_groups + ) + view_groups = [g for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup)] # What's the difference between function schema v.s. signature? # function schema is the complete declaration including mutability annotation / default value and etc. @@ -38,7 +72,7 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque # that are semantically related. functions_by_signature: Dict[FunctionSchema, List[NativeFunction]] = defaultdict(list) functions_by_schema: Dict[str, NativeFunction] = dict() - for function in functions: + for function in native_functions_without_view_copies: functions_by_signature[function.func.signature()].append(function) assert str(function.func) not in functions_by_schema functions_by_schema[str(function.func)] = function @@ -50,6 +84,7 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque infos = [ create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter) for defn in definitions] + infos += add_view_copy_derivatives(infos, view_groups) _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos diff --git a/tools/codegen/api/autograd.py b/tools/codegen/api/autograd.py index 635ad927e8a2..edd1b0351f1c 100644 --- a/tools/codegen/api/autograd.py +++ b/tools/codegen/api/autograd.py @@ -4,7 +4,7 @@ from typing import Optional, Sequence, Set, List, Tuple, Match from tools.codegen.api import cpp from tools.codegen.api.types import Binding, NamedCType -from tools.codegen.model import NativeFunction, Type, SchemaKind +from tools.codegen.model import NativeFunction, Type, SchemaKind, NativeFunctionsViewGroup from tools.codegen.utils import IDENT_REGEX # Represents a saved attribute involved in backward calculation. @@ -149,6 +149,39 @@ class DifferentiabilityInfo: def has_derivatives(self) -> bool: return len(self.args_with_derivatives) > 0 + # Generates a new DifferentiabilityInfo using the exact same set of derivative information, + # but with a new operator name. + # This is used when generating "copy" variants of view ops, + # which are able to use the exact same derivative formula as the original view op + # See Note [Codegen'd {view}_copy Operators] + def create_view_copy_from_view_derivative(self, g: NativeFunctionsViewGroup) -> Optional['DifferentiabilityInfo']: + if g.view_copy is None: + return None + f = g.view_copy + + name_split_by_period = self.name.split('.', maxsplit=2) + # Append a "_copy" to the base name of the operator (but keep the overload name the same) + view_copy_name = f'{name_split_by_period[0]}_copy.' + '.'.join(name_split_by_period[1:]) + view_copy_op_name = None if self.op is None else f'{self.op}_copy' + + return DifferentiabilityInfo( + # Use the "_copy" version of name/func/op + name=view_copy_name, + func=f, + op=view_copy_op_name, + # But keep all derivative info the same + derivatives=self.derivatives, + forward_derivatives=self.forward_derivatives, + all_saved_inputs=self.all_saved_inputs, + all_saved_outputs=self.all_saved_outputs, + available_named_gradients=self.available_named_gradients, + used_named_gradients=self.used_named_gradients, + args_with_derivatives=self.args_with_derivatives, + non_differentiable_arg_names=self.non_differentiable_arg_names, + output_differentiability=self.output_differentiability, + output_differentiability_conditions=self.output_differentiability_conditions, + ) + def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool: if info is None: return False diff --git a/tools/codegen/context.py b/tools/codegen/context.py index ba21c86c7934..cb607efa18ca 100644 --- a/tools/codegen/context.py +++ b/tools/codegen/context.py @@ -1,9 +1,12 @@ from tools.codegen.utils import S, T, context -from tools.codegen.model import (NativeFunction, NativeFunctionsGroup, BackendIndex, DispatchKey) +from tools.codegen.model import ( + NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup, BackendIndex, DispatchKey +) import tools.codegen.local as local +from tools.codegen.selective_build.selector import SelectiveBuilder import functools -from typing import TypeVar, Union, Iterator, Callable, Dict +from typing import TypeVar, Union, Iterator, Callable, Dict, Optional import contextlib # Helper functions for defining generators on things in the model @@ -12,17 +15,28 @@ F = TypeVar( 'F', NativeFunction, NativeFunctionsGroup, + NativeFunctionsViewGroup, Union[NativeFunction, NativeFunctionsGroup], + Union[NativeFunction, NativeFunctionsViewGroup], +) + +F2 = TypeVar( + 'F2', + NativeFunction, + Optional[NativeFunction], ) @contextlib.contextmanager -def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunction]) -> Iterator[None]: +def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]) -> Iterator[None]: if isinstance(g, NativeFunctionsGroup): # By default, we associate all errors with structured native functions # with the out variant. In some cases, it might be better to have # a more specific place to hang things; if so, use # native_function_manager again on the inside f = g.out + elif isinstance(g, NativeFunctionsViewGroup): + # We associate errors with the view operator + f = g.view else: f = g with context(lambda: f'in native_functions.yaml line {f.loc}:\n {f.func}'): @@ -41,6 +55,14 @@ def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]: return func(f) return wrapper +def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]: + @functools.wraps(func) + def wrapper(f: F, f2: F2) -> T: + # The first native_function is assumed to be the one with the appropriate context. + with native_function_manager(f): + return func(f, f2) + return wrapper + def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]: @functools.wraps(func) def wrapper(slf: S, f: F) -> T: @@ -57,6 +79,17 @@ def with_native_function_and_index(func: Callable[[F, BackendIndex], T]) -> Call return func(f, backend_index) return wrapper +# Convenience decorator for functions that explicitly take in a BackendIndex, +# instead of indirectly taking one in as a closure +def with_native_function_and_selector_and_index( + func: Callable[[SelectiveBuilder, F, BackendIndex], T] +) -> Callable[[SelectiveBuilder, F, BackendIndex], T]: + @functools.wraps(func) + def wrapper(selector: SelectiveBuilder, f: F, backend_index: BackendIndex) -> T: + with native_function_manager(f): + return func(selector, f, backend_index) + return wrapper + def with_native_function_and_indices( func: Callable[[F, Dict[DispatchKey, BackendIndex]], T] ) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]: diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 101f1fbe96ae..6d34c86cd3d9 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -17,7 +17,10 @@ from tools.codegen.model import (Argument, DispatchKey, FunctionSchema, is_cuda_dispatch_key, is_generic_dispatch_key, is_ufunc_dispatch_key, - Tag, BaseOperatorName) + NativeFunctionsViewGroup, + ViewSchemaKind, + BaseOperatorName, + Tag) from tools.codegen.api.types import (Binding, CppSignature, CppSignatureGroup, DispatcherSignature, NativeSignature) from tools.codegen.api import cpp @@ -40,7 +43,8 @@ from tools.codegen.gen_functionalization_type import ( needs_functionalization, gen_functionalization_definition, gen_functionalization_registration, - gen_functionalization_view_inverse_declaration + gen_functionalization_view_inverse_declaration, + gen_composite_view_copy_kernel, ) T = TypeVar('T') @@ -1002,6 +1006,43 @@ def pre_group_native_functions( d[f.func.kind()] = f return pre_grouped_native_functions +def get_grouped_by_view_native_functions( + native_functions: Sequence[NativeFunction] +) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]: + def maybe_create_view_group(d: Dict[ViewSchemaKind, NativeFunction]) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]: + funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = [] + if ViewSchemaKind.aliasing not in d: + # Case 1: this op / op group is not aliasing, so we don't create a view group. + # return the original (ungrouped) native functions instead. + for func in d.values(): + funcs.append(func) + else: + # Case 2: this op group contains an aliasing op, so we create a ViewGroup for it. + # The handling for out= ops here is unfortunate. + # out= ops don't really make sense for view operators. + # However, we have at least one existing {view}_copy.out operator in native_functions.yaml. + # It shouldn't be part of a view group, so we explicitly don't group it. + # There currently aren't any out= view ops (and there probably shouldn't be). + # We also expect that when we hit this case, the `non_aliasing` op in the dict + # *must* be a view_copy op (this is asserted in the NativeFunctionsViewGroup constructor) + if ViewSchemaKind.out in d: + funcs.append(d[ViewSchemaKind.out]) + + funcs.append(NativeFunctionsViewGroup( + view=d[ViewSchemaKind.aliasing], + view_copy=d.get(ViewSchemaKind.non_aliasing, None), + view_inplace=d.get(ViewSchemaKind.inplace, None), + )) + return funcs + + grouped_by_views: Dict[FunctionSchema, Dict[ViewSchemaKind, NativeFunction]] = defaultdict(dict) + for f in native_functions: + schema = f.func.view_signature() + assert f.view_schema_kind not in grouped_by_views[schema] + grouped_by_views[schema][f.view_schema_kind] = f + + return list(concatMap(maybe_create_view_group, grouped_by_views.values())) + def get_grouped_native_functions( native_functions: Sequence[NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: @@ -1316,10 +1357,6 @@ def gen_headers( 'registration_declarations': [compute_registration_declarations(f, backend_indices) for f in native_functions], }) - cpu_fm.write('FunctionalInverses.h', lambda: { - 'view_inverse_declarations': list(mapMaybe(gen_functionalization_view_inverse_declaration, native_functions)) - }) - def gen_aten_interned_strings() -> Dict[str, str]: attrs = set() # All function argument names @@ -1354,6 +1391,7 @@ def gen_source_files( native_functions: Sequence[NativeFunction], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], structured_native_functions: Sequence[NativeFunctionsGroup], + native_functions_with_view_groups: Sequence[Union[NativeFunction, NativeFunctionsViewGroup]], selector: SelectiveBuilder, backend_indices: Dict[DispatchKey, BackendIndex], core_fm: FileManager, @@ -1528,7 +1566,7 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { else list(mapMaybe(RegisterSchema(schema_selector), native_functions)), }) - def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str: + def key_func(fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str: return fn.root_name cpu_fm.write_sharded( @@ -1564,20 +1602,43 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { def functionalization_env_callable( - g: Union[NativeFunction, NativeFunctionsGroup] + g: Union[NativeFunction, NativeFunctionsViewGroup] ) -> Dict[str, List[str]]: - functions = [g] if isinstance(g, NativeFunction) else list(g.functions()) - functions_needing_functionalization = [ - fn for fn in functions if needs_functionalization(selector, fn)] + functions_needing_functionalization = [g] if needs_functionalization(selector, g) else [] + + def gen_op_headers(g: Union[NativeFunction, NativeFunctionsViewGroup]) -> List[str]: + if not needs_functionalization(selector, g): + return [] + if isinstance(g, NativeFunctionsViewGroup): + # view ops always get a functionalization kernel + headers = [ + f"#include ", + f"#include ", + ] + if g.view_copy is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + else: + f = g + return [ + f"#include ", + f"#include ", + ] + return { - 'ops_headers': ([ - f"#include ", - f"#include ", - ] if functions_needing_functionalization else []), - 'func_definitions': list(mapMaybe( - lambda f: gen_functionalization_definition(selector, f, to_functional_op[f.func.name]), + 'ops_headers': gen_op_headers(g), + 'func_definitions': list(concatMap( + lambda f: gen_functionalization_definition( + selector, + g, + # We need to manually map inplace ops to their out-of-place variants + # (we can't do this with NativeFunctionsGroup today because not all inplace ops have out= variants) + None if isinstance(g, NativeFunctionsViewGroup) else to_functional_op.get(g.func.name, None)), functions_needing_functionalization)), - 'func_registrations': list(mapMaybe( + 'func_registrations': list(concatMap( lambda f: gen_functionalization_registration( selector, f, backend_indices[DispatchKey.CompositeImplicitAutograd]), functions_needing_functionalization)), @@ -1586,13 +1647,46 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { cpu_fm.write_sharded( 'RegisterFunctionalization.cpp', - grouped_native_functions, + native_functions_with_view_groups, key_fn=key_func, env_callable=functionalization_env_callable, num_shards=4, sharded_keys={'ops_headers', 'func_definitions', 'func_registrations'} ) + cpu_fm.write('FunctionalInverses.h', lambda: { + 'view_inverse_declarations': list(mapMaybe( + lambda g: gen_functionalization_view_inverse_declaration(selector, g), + [g for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup)])) + }) + + # Note [view_copy NativeFunctions] + # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd + # needs to have a corresponding non-aliasing {view}_copy variant. + # Backends that use functionalization and don't know how to handle aliasing ops + # are expected to implement kernels for these {view}_copy kernels instead. + # The code for {view}_copy operators in core is pretty boilerplate-heavy however, + # so we codegen the following: + # (1) A CompositeExplicitAutograd kernel for every {view}_copy operator. + # These are never explicitly invoked by the functionalization pass, + # but they could theoretically be called from user code (I added these kernels for completeness, + # since the ops are part of the public API). + # (2) A derivative formula for every {view}_copy operator + # {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts, + # so rather than stamping all of the entries out in derivatives.yaml, + # we codegen them in. + # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry. + cpu_fm.write('CompositeViewCopyKernels.cpp', lambda: { + 'ops_headers': [ + '\n'.join(f'#include ' for f in + ([g.view] if g.view_copy is None else [g.view, g.view_copy])) + for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup) + ], + 'CompositeViewCopyKernel_Definitions': list(mapMaybe( + gen_composite_view_copy_kernel, + [g for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup)])) + }) + def gen_declarations_yaml( cpu_fm: FileManager, @@ -1674,9 +1768,13 @@ def main() -> None: native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml') parsed_yaml = parse_native_yaml(native_yaml_path) native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices + grouped_native_functions = get_grouped_native_functions(native_functions) structured_native_functions = [g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)] + native_functions_with_view_groups = get_grouped_by_view_native_functions(native_functions) + + template_dir = os.path.join(options.source_path, "templates") # NB: It is mandatory to NOT use os.path.join here, as the install directory # will eventually be ingested by cmake, which does not respect Windows style @@ -1734,6 +1832,7 @@ def main() -> None: native_functions=native_functions, grouped_native_functions=grouped_native_functions, structured_native_functions=structured_native_functions, + native_functions_with_view_groups=native_functions_with_view_groups, selector=selector, backend_indices=backend_indices, core_fm=core_fm, diff --git a/tools/codegen/gen_functionalization_type.py b/tools/codegen/gen_functionalization_type.py index 06521836d733..28209e9e1b23 100644 --- a/tools/codegen/gen_functionalization_type.py +++ b/tools/codegen/gen_functionalization_type.py @@ -1,16 +1,75 @@ from tools.codegen.api import cpp from tools.codegen.api.types import ( - DispatcherSignature, Binding, FunctionalizationLambda, ViewInverseSignature + DispatcherSignature, Binding, FunctionalizationLambda, ViewInverseSignature, + NativeSignature ) from tools.codegen.api.translate import translate -from tools.codegen.context import with_native_function +from tools.codegen.context import ( + with_native_function, with_native_function_and_selector_and_index, + with_native_function_and +) from tools.codegen.model import ( Argument, NativeFunction, SchemaKind, BackendIndex, - Tag, FunctionSchema, SelfArgument, TensorOptionsArguments, BaseType, BaseTy + Tag, FunctionSchema, SelfArgument, TensorOptionsArguments, BaseType, BaseTy, + NativeFunctionsViewGroup, ListType ) from tools.codegen.selective_build.selector import SelectiveBuilder from typing import List, Optional, Union, Tuple +# This file contains codegen that relates to the functionalization pass. +# It includes: +# - gen_functionalization_definition +# Generates dispatcher kernel definitions for the functionalization pass. +# - gen_functionalization_registration +# Generates dispatcher kernel registrations for the functionalization pass. +# - gen_functionalization_view_inverse_declaration +# Generates a declaration for an "inverse view", for every view op +# that is needed in functionalization. We manually implement their definitions. +# - gen_composite_view_copy_kernel +# Generates view_copy() composite kernels for all view_copy operators. + + +# Generates the body of the default composite C++ kernel for a {view}_copy NativeFunction +# See Note [view_copy NativeFunctions] +@with_native_function +def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]: + + if g.view_copy is None: + return None + # view_copy is a native signature, since we're generating an at::native:: kernel + view_copy_sig = NativeSignature(g.view_copy.func) + # view is a dispatcher signature, since we're calling into the at::_ops API + view_sig = DispatcherSignature(g.view.func) + + view_api_name = g.view.func.name.unambiguous_name() + exprs = ', '.join([e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())]) + + # view ops today always return either a Tensor or a list of Tensors + assert len(g.view.func.returns) == 1 + assert g.view.func.returns[0].type == BaseType(BaseTy.Tensor) \ + or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None) + + if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): + return_cloned_output = '''\ + return output.clone();''' + else: + # If the return type is a list, we need to clone each tensor in the list. + return_cloned_output = f'''\ + {view_copy_sig.returns_type().cpp_type()} out_clone; + for (const auto i : c10::irange(output.size())) {{ + out_clone.push_back(output[i].clone()); + }} + return out_clone;''' + + # The default generated composite kernel for {view}_copy() operators just clones + # the input tensor, and runs the underlying view on the clone. + return f""" +{view_copy_sig.defn()} {{ + auto output = at::_ops::{view_api_name}::call({exprs}); + {return_cloned_output} +}} +""" + def modifies_arguments(f: NativeFunction) -> bool: return f.func.kind() in [SchemaKind.inplace, SchemaKind.out] @@ -110,6 +169,7 @@ View operators with multiple aliasing inputs aren't supported yet. Found an oper # Generates the Functionalization kernel for: # - ops that create aliases (e.g. transpose()) # - ops that are views AND mutations (e.g. transpose_()) +@with_native_function_and def emit_view_functionalization_body( f: NativeFunction, functional_op: NativeFunction @@ -154,6 +214,7 @@ def emit_view_functionalization_body( if f.tag is Tag.inplace_view: # See Note [Functionalization Pass - Inplace View Ops] for more details return f""" + {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. {unwrap_tensor_args_str} @@ -178,10 +239,12 @@ def emit_view_functionalization_body( // See Note [Propagating strides in the functionalization pass] at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output); return {view_tensor_name}; + }} """ else: return f""" + {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ {unwrap_tensor_args_str} if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. @@ -210,9 +273,11 @@ def emit_view_functionalization_body( // See Note [Propagating strides in the functionalization pass] at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output); return out; + }} """ # Generates the Functionalization kernel for inplace ops +@with_native_function_and def emit_inplace_functionalization_body( f: NativeFunction, functional_op: Optional[NativeFunction] @@ -262,6 +327,7 @@ Instead, it's calling the inplace/view operator directly. \ If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch." return f""" + {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{ TORCH_WARN("{warn_str}"); }} @@ -269,6 +335,7 @@ If this causes problems in your program, consider upstreaming the out-of-place o at::AutoDispatchSkipFunctionalize guard; // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops. {maybe_return}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)}); + }} """ else: # call the out-of-place variant of the op @@ -286,6 +353,7 @@ If this causes problems in your program, consider upstreaming the out-of-place o if a.annotation and a.annotation.is_write and a.type.is_tensor_like()]) return f""" + {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ {unwrap_tensor_args_str} if (!({check_all_mutated_args_are_functional})) {{ if (({check_any_non_mutated_args_are_functional})) {{ @@ -308,7 +376,8 @@ If this causes problems in your program, consider upstreaming the out-of-place o }} {mutable_input_post_processing} {return_str(f)}; - }}""" + }} + }}""" def emit_declaration_for_noncomposite_views(f: NativeFunction) -> str: @@ -322,26 +391,47 @@ def emit_declaration_for_noncomposite_views(f: NativeFunction) -> str: # These files provide the kernels that run the functionalization pass, which can be opted into # per backend (e.g. XLA or Vulkan), or as a composable transform (functionalize() in functorch). +# TODO: later, I'll generate separate kernels for "Functionalize", and "AddBackViews" +# That will probably be an enum, that this function will take in to know which ops get kernels. def needs_functionalization( selector: SelectiveBuilder, - f: NativeFunction, + g: Union[NativeFunction, NativeFunctionsViewGroup], ) -> bool: - return (selector.include_all_operators and - (f.is_view_op or modifies_arguments(f))) + # Don't generate kernels in mobile build + if not selector.include_all_operators: + return False + # Every view op needs *some* handling in functionalization: + # - non-composite view ops get a generated kernel + # - composite view ops get a generated registration (that directly registers the composite kernel) + if isinstance(g, NativeFunctionsViewGroup): + return True + # For non-view ops, only inplace operators get functionalized + assert isinstance(g, NativeFunction) + return modifies_arguments(g) +# See Note [Functionalization Pass: View Inverses]. +def gen_functionalization_view_inverse_declaration(selector: SelectiveBuilder, g: NativeFunctionsViewGroup) -> Optional[str]: + # For every (non-composite) view op, we need a corresponding "inverse view" function. + # This generates the declarations so we get a good compiler error when someone adds a new view. + @with_native_function + def emit_decl_helper(g: NativeFunctionsViewGroup) -> Optional[str]: + if g.view.has_composite_implicit_autograd_kernel: + return None + view_inverse_sig = ViewInverseSignature(g.view) + return view_inverse_sig.decl() + + return emit_decl_helper(g) + + +@with_native_function_and_selector_and_index def gen_functionalization_registration( selector: SelectiveBuilder, - f: NativeFunction, + g: Union[NativeFunction, NativeFunctionsViewGroup], composite_implicit_autograd_index: BackendIndex -) -> Optional[str]: - @with_native_function - def emit_registration_helper(f: NativeFunction) -> Optional[str]: - # Note: for now, this logic is meant to avoid registering functionalization kernels for mobile. - # At some point, Vulkan we'll want to use functionalization and we'll need to change this. - if not needs_functionalization(selector, f): - return None - if f.is_view_op and f.has_composite_implicit_autograd_kernel: +) -> List[str]: + def emit_registration_helper(f: NativeFunction, *, is_view: bool) -> str: + if is_view and f.has_composite_implicit_autograd_kernel: metadata = composite_implicit_autograd_index.get_kernel(f) assert metadata is not None native_api_name = metadata.kernel @@ -353,49 +443,42 @@ def gen_functionalization_registration( # because we don't want to decompose non-view ops that are composite, like `at::ones`. registration_str = f'static_cast<{sig.ptr_type()}>(at::native::{native_api_name})' else: + # non-composite view ops (and inplace ops) get a normal registration. registration_str = f'TORCH_FN(functionalization::{wrapper_name(f.func)})' - return f'm.impl("{f.func.name}", {registration_str});' - return emit_registration_helper(f) + # Note: for now, this logic is meant to avoid registering functionalization kernels for mobile. + # At some point, Vulkan we'll want to use functionalization and we'll need to change this. + if not needs_functionalization(selector, g): + return [] + + if isinstance(g, NativeFunctionsViewGroup): + view_str = [emit_registration_helper(g.view, is_view=True)] + if g.view_inplace is not None: + view_str.append(emit_registration_helper(g.view_inplace, is_view=True)) + return view_str + else: + f = g + return [emit_registration_helper(f, is_view=False)] def gen_functionalization_definition( selector: SelectiveBuilder, - f: NativeFunction, + g: Union[NativeFunction, NativeFunctionsViewGroup], functional_op: Optional[NativeFunction] -) -> Optional[str]: - @with_native_function - def emit_definition_helper(f: NativeFunction) -> Optional[str]: - if not needs_functionalization(selector, f): - return None - if f.is_view_op and f.has_composite_implicit_autograd_kernel: - # See Note [Composite view ops in the functionalization pass] - return None - # order is important here, ops that are both views and mutations should hit the view path. - if f.is_view_op: - # Every view op is expected to have a functional counterpart (e.g. transpose_() -> transpose()) - assert functional_op is not None - body_str = emit_view_functionalization_body(f, functional_op) - else: - # inplace op - assert modifies_arguments(f) - body_str = emit_inplace_functionalization_body(f, functional_op) - sig = DispatcherSignature.from_schema(f.func) - return f""" - {sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ - {body_str} - }} - """ - - return emit_definition_helper(f) - -# See Note [Functionalization Pass: View Inverses]. -@with_native_function -def gen_functionalization_view_inverse_declaration(f: NativeFunction) -> Optional[str]: - # We only need to generate view_inverse declarations for view ops that: - # - aren't composite (since they'll decompose and we'll get them for free). - # - aren't inplace (since they should have a corresponding functional version, which we call instead). - if f.is_view_op and not f.has_composite_implicit_autograd_kernel and not modifies_arguments(f): - output = emit_declaration_for_noncomposite_views(f) - return output - return None +) -> List[str]: + if not needs_functionalization(selector, g): + return [] + if isinstance(g, NativeFunctionsViewGroup): + view_defs = [] + if not g.view.has_composite_implicit_autograd_kernel: + # For now I'm just asserting this - later, the view functionalization kernels + # should be updated to redispatch to view_copy(). + assert g.view_copy is not None + view_defs.append(emit_view_functionalization_body(g.view, g.view)) + if g.view_inplace is not None and not g.view_inplace.has_composite_implicit_autograd_kernel: + view_defs.append(emit_view_functionalization_body(g.view_inplace, g.view)) + return view_defs + else: + f = g + assert modifies_arguments(f) + return [emit_inplace_functionalization_body(f, functional_op)] diff --git a/tools/codegen/model.py b/tools/codegen/model.py index 8536588848c4..e48c3753c4ee 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -274,6 +274,7 @@ class DeviceCheckType(Enum): class Tag(Enum): inplace_view = 0 + view_copy = 1 def __str__(self) -> str: return self.name @@ -285,6 +286,8 @@ class Tag(Enum): return v raise AssertionError(f'unknown tag {value}') +ViewSchemaKind = Enum('ViewSchemaKind', ('aliasing', 'inplace', 'out', 'non_aliasing')) + # The basic input to the code generation is native_functions.yaml. # The name "native", BTW, comes from the distinction between native # functions and legacy TH functions. The legacy TH functions are gone, @@ -653,6 +656,18 @@ class NativeFunction: inp.annotation.alias_set_after != "" for inp in self.func.schema_order_arguments()) return is_non_mutating_view or is_inplace_view or is_wildcard_view + @property + def view_schema_kind(self) -> ViewSchemaKind: + # This covers both "ordinary" inplace ops, and inplace_views + if self.func.name.name.inplace: + return ViewSchemaKind.inplace + elif self.func.is_out_fn(): + return ViewSchemaKind.out + elif self.is_view_op: + return ViewSchemaKind.aliasing + else: + return ViewSchemaKind.non_aliasing + @property def root_name(self) -> str: return self.func.name.name.base @@ -1043,7 +1058,7 @@ class FunctionSchema: else: return SchemaKind.functional - def signature(self, *, strip_default: bool = False) -> 'FunctionSchema': + def signature(self, *, strip_default: bool = False, strip_view_copy_name: bool = False) -> 'FunctionSchema': """ Certain schemas are 'related', in that they are simply inplace/out/functional versions of the same function. This method @@ -1060,6 +1075,10 @@ class FunctionSchema: because you cannot overload on mutability annotation) - Return names are stripped since they are not overloadable and some variants have return names but some not + + Finally, we want to be able to pair up related "view" and their + corresponding "view_copy" operators. We do this by optionally + stripping the trailing "_copy" from the base name. """ def strip_ret_annotation(r: Return) -> Return: @@ -1069,10 +1088,14 @@ class FunctionSchema: annotation=None, ) + base_name = self.name.name.base + if strip_view_copy_name and base_name.endswith('_copy'): + base_name = base_name.replace('_copy', '') + return FunctionSchema( name=OperatorName( name=BaseOperatorName( - base=self.name.name.base, + base=base_name, inplace=False, dunder_method=self.name.name.dunder_method, ), @@ -1082,6 +1105,14 @@ class FunctionSchema: returns=tuple(map(strip_ret_annotation, self.returns)), ) + def view_signature(self) -> 'FunctionSchema': + return self.signature(strip_view_copy_name=True) + + @property + def modifies_arguments(self) -> bool: + return self.kind() in [SchemaKind.inplace, SchemaKind.out] + + def __str__(self) -> str: all_arguments_str = str(self.arguments) if len(self.returns) == 1: @@ -1751,6 +1782,95 @@ def gets_generated_out_inplace_wrapper(f: NativeFunction, g: NativeFunctionsGrou not b.has_kernel(f) and \ b.has_kernel(g.functional) +# NativeFunction objects that are views (f.is_view_op returns True) +# are added into a `NativeFunctionsViewGroup`, which we can use to +# easily access the generated (optional) view_copy NativeFunction. +# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup. +# See Note [Codegen'd {view}_copy Operators] +# +# One property of this representation is that in order for a view-like op to be part of +# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist. +# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op, +# but don't have corresponding aliasing `narrow.out` op. +# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup. +@dataclass(frozen=True) +class NativeFunctionsViewGroup: + view: NativeFunction + # Note: the {view}_copy operator is optional because we currently don't generate copy variants + # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views + # (we already get them "for free" through decomposition) + view_copy: Optional[NativeFunction] + # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant. + view_inplace: Optional[NativeFunction] + + def __post_init__(self) -> None: + assert self.view.is_view_op + if self.view_copy is None: + assert not gets_generated_view_copy(self.view), \ + f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs." \ + " The codegen expects you to add a corresponding operator to native_functions.yaml:" \ + " {str(get_view_copy_name(self.view)}." \ + " See Note [view_copy NativeFunctions] for details." + else: + assert self.view_copy.func.name.name.base.endswith('_copy') + assert self.view.func.signature() == self.view_copy.func.signature(strip_view_copy_name=True) + assert self.view_copy.tag == Tag.view_copy, \ + f"{str(self.view_copy.func.name)} appears to be a view_copy operator. The codegen expects" \ + " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml." \ + " See Note [view_copy NativeFunction] for details." + if self.view_inplace is not None: + assert self.view.func.signature() == self.view_inplace.func.signature() + + def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]: + yield self.view + if self.view_inplace is not None: + yield self.view_inplace + if self.view_copy is not None and include_copy: + yield self.view_copy + + @property + def root_name(self) -> str: + return self.view.root_name + +def gets_generated_view_copy(f: NativeFunction) -> bool: + # Only aliasing (view) operators get a copy variant. + if not f.is_view_op: + return False + # We don't need to bother generating copy variants for CompositeImplicitAutograd ops, + # because we can let them decompose into base view ops. + if f.has_composite_implicit_autograd_kernel: + return False + # We also don't need to generate copy variants for inplace views. + if f.tag == Tag.inplace_view: + return False + return True + +# Given a NativeFunction that corresponds to a view op, +# returns the OperatorName of the corresponding "copy" variant of the op. +def get_view_copy_name(f: NativeFunction) -> 'OperatorName': + # Right now, when asking for a view op's corresponding "view_copy" name + # we assert for sanity that the op is allowed to have a generated view_copy variant. + # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op). + # However, narrow_copy() already exists as an op directly in native_functions.yaml. + # I'm hardcoding narrow_copy here for now to maintain the assert, + # But we could also just get rid of the assert. + list_of_ops_with_explicit_view_copy_operators = [ + 'narrow' + ] + if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators: + assert gets_generated_view_copy(f) + + base_name = f'{f.func.name.name.base}_copy' + view_copy_name = OperatorName( + name=BaseOperatorName( + base=base_name, + inplace=False, + dunder_method=f.func.name.name.dunder_method), + overload_name=f.func.name.overload_name + ) + return view_copy_name + + # Helper functions for parsing argument lists (both inputs and returns) def parse_returns(return_decl: str) -> Tuple[Return, ...]: diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 10626fe72aa1..8b312bfeaeb0 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -11963,3 +11963,141 @@ Example:: tensor([[2, 3, 5], [2, 3, 5]]) """) + +add_docstr(torch.view_as_real_copy, + r""" +Performs the same operation as :func:`torch.view_as_real`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.view_as_complex_copy, + r""" +Performs the same operation as :func:`torch.view_as_complex`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.as_strided_copy, + r""" +Performs the same operation as :func:`torch.as_strided`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.diagonal_copy, + r""" +Performs the same operation as :func:`torch.diagonal`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.expand_copy, + r""" +Performs the same operation as :func:`torch.expand`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.permute_copy, + r""" +Performs the same operation as :func:`torch.permute`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.select_copy, + r""" +Performs the same operation as :func:`torch.select`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.detach_copy, + r""" +Performs the same operation as :func:`torch.detach`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.slice_copy, + r""" +Performs the same operation as :func:`torch.slice`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.split_copy, + r""" +Performs the same operation as :func:`torch.split`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.split_with_sizes_copy, + r""" +Performs the same operation as :func:`torch.split_with_sizes`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.squeeze_copy, + r""" +Performs the same operation as :func:`torch.squeeze`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.t_copy, + r""" +Performs the same operation as :func:`torch.t`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.transpose_copy, + r""" +Performs the same operation as :func:`torch.transpose`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.unsqueeze_copy, + r""" +Performs the same operation as :func:`torch.unsqueeze`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.indices_copy, + r""" +Performs the same operation as :func:`torch.indices`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.values_copy, + r""" +Performs the same operation as :func:`torch.values`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.crow_indices_copy, + r""" +Performs the same operation as :func:`torch.crow_indices`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.col_indices_copy, + r""" +Performs the same operation as :func:`torch.col_indices`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.unbind_copy, + r""" +Performs the same operation as :func:`torch.unbind`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.view_copy, + r""" +Performs the same operation as :func:`torch.view`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.unfold_copy, + r""" +Performs the same operation as :func:`torch.unfold`, but all output tensors +are freshly created instead of aliasing the input. +""") + +add_docstr(torch.alias_copy, + r""" +Performs the same operation as :func:`torch.alias`, but all output tensors +are freshly created instead of aliasing the input. +""") diff --git a/torch/overrides.py b/torch/overrides.py index bf000c91ea1d..f9072f7bdc74 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1014,6 +1014,40 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.vstack: lambda tensors, out=None: -1, torch.where: lambda condition, x=None, y=None: -1, torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch._fw_primal_copy: lambda self, level: -1, + torch._make_dual_copy: lambda primal, tangent, level: -1, + torch.view_as_real_copy: lambda self: -1, + torch.view_as_complex_copy: lambda self: -1, + torch._conj_copy: lambda self: -1, + torch._neg_view_copy: lambda self: -1, + torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1, + torch._sparse_broadcast_to_copy: lambda self, size: -1, + torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1, + torch.expand_copy: lambda self, size, *, implicit=False: -1, + torch.narrow_copy: lambda self, dim, start, length: -1, + torch.permute_copy: lambda self, dims: -1, + torch._reshape_alias_copy: lambda self, size, stride: -1, + torch.select_copy: lambda self, dim, index: -1, + torch.detach_copy: lambda self: -1, + torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1, + torch.split_copy: lambda self, split_size, dim=0: -1, + torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1, + torch.squeeze_copy: lambda self: -1, + torch.squeeze_copy: lambda self, dim: -1, + torch.t_copy: lambda self: -1, + torch.transpose_copy: lambda self, dim0, dim1: -1, + torch.unsqueeze_copy: lambda self, dim: -1, + torch._indices_copy: lambda self: -1, + torch._values_copy: lambda self: -1, + torch.indices_copy: lambda self: -1, + torch.values_copy: lambda self: -1, + torch.crow_indices_copy: lambda self: -1, + torch.col_indices_copy: lambda self: -1, + torch.unbind_copy: lambda self, dim=0: -1, + torch.view_copy: lambda self, size: -1, + torch.view_copy: lambda self, dtype: -1, + torch.unfold_copy: lambda self, dimension, size, step: -1, + torch.alias_copy: lambda self: -1, Tensor.__floordiv__: lambda self, other: -1, Tensor.__rfloordiv__: lambda self, other: -1, Tensor.__ifloordiv__: lambda self, other: -1,