mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
code-generate non-aliasing {view}_copy kernels (#73442)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73442 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D35016025 Pulled By: bdhirsh fbshipit-source-id: 2a7f303ec76f5913b744c7822a531d55a57589c9 (cherry picked from commit 3abe13c2a787bcbe9c41b0a335c96e5a3d3642fb)
This commit is contained in:
committed by
PyTorch MergeBot
parent
dfcb7035a0
commit
23b8414391
@ -50,6 +50,7 @@ generated_cpu_cpp = [
|
||||
"aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h",
|
||||
"aten/src/ATen/CompositeImplicitAutogradFunctions.h",
|
||||
"aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
|
||||
"aten/src/ATen/CompositeViewCopyKernels.cpp",
|
||||
"aten/src/ATen/FunctionalInverses.h",
|
||||
"aten/src/ATen/Functions.h",
|
||||
"aten/src/ATen/Functions.cpp",
|
||||
|
@ -145,6 +145,7 @@
|
||||
|
||||
- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
|
||||
variants: method
|
||||
tags: inplace_view
|
||||
|
||||
- func: rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)
|
||||
variants: method
|
||||
@ -3262,6 +3263,7 @@
|
||||
CPU: narrow_copy_dense_cpu
|
||||
SparseCPU, SparseCUDA: narrow_copy_sparse
|
||||
CompositeExplicitAutograd: narrow_copy_dense
|
||||
tags: view_copy
|
||||
|
||||
- func: narrow_copy.SymInt(Tensor self, int dim, int start, SymInt length) -> Tensor
|
||||
variants: function, method
|
||||
@ -11355,3 +11357,201 @@
|
||||
|
||||
- func: nested_tensor(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: _fw_primal_copy(Tensor self, int level) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _fw_primal_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: _make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _make_dual_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: view_as_real_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: view_as_real_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: view_as_complex_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: view_as_complex_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: _conj_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _conj_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: _neg_view_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _neg_view_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: as_strided_copy(Tensor self, int[] size, int[] stride, int? storage_offset=None) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: as_strided_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: _sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _sparse_broadcast_to_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: diagonal_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: expand_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: permute_copy(Tensor self, int[] dims) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: permute_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: _reshape_alias_copy(Tensor self, int[] size, int[] stride) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _reshape_alias_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: select_copy.int(Tensor self, int dim, int index) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: select_copy_int
|
||||
tags: view_copy
|
||||
|
||||
- func: detach_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: detach_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: slice_copy.Tensor(Tensor self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: slice_copy_Tensor
|
||||
tags: view_copy
|
||||
|
||||
- func: split_copy.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[]
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: split_copy_Tensor
|
||||
tags: view_copy
|
||||
|
||||
- func: split_with_sizes_copy(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: split_with_sizes_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: squeeze_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: squeeze_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: squeeze_copy.dim(Tensor self, int dim) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: squeeze_copy_dim
|
||||
tags: view_copy
|
||||
|
||||
- func: t_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: t_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: transpose_copy_int
|
||||
tags: view_copy
|
||||
|
||||
- func: unsqueeze_copy(Tensor self, int dim) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: unsqueeze_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: _indices_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _indices_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: _values_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _values_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: indices_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: indices_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: values_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: values_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: crow_indices_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: crow_indices_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: col_indices_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: col_indices_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: unbind_copy.int(Tensor self, int dim=0) -> Tensor[]
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: unbind_copy_int
|
||||
tags: view_copy
|
||||
|
||||
- func: view_copy(Tensor self, int[] size) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: view_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: view_copy_dtype
|
||||
tags: view_copy
|
||||
|
||||
- func: unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: unfold_copy
|
||||
tags: view_copy
|
||||
|
||||
- func: alias_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: alias_copy
|
||||
tags: view_copy
|
||||
|
10
aten/src/ATen/native/tags.yaml
Normal file
10
aten/src/ATen/native/tags.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
# This yaml file contains all the possible tags that can be defined in `tags` in `native_functions.yaml`
|
||||
|
||||
- tag: inplace_view
|
||||
desc: |
|
||||
This tag indicates if an operator *only* modifies the tensor metadata
|
||||
- tag: view_copy
|
||||
desc: |
|
||||
This tag indicates operators that are *_copy* variants
|
||||
of view/aliasing operators. If an operator has a view_copy tag,
|
||||
then it should have the name {op}_copy, where {op} is a view operator.
|
20
aten/src/ATen/templates/CompositeViewCopyKernels.cpp
Normal file
20
aten/src/ATen/templates/CompositeViewCopyKernels.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
// ${generated_comment}
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Operators.h>
|
||||
#else
|
||||
#include <ATen/ops/clone.h>
|
||||
$ops_headers
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
|
||||
${CompositeViewCopyKernel_Definitions}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -908,6 +908,24 @@ class TestViewOps(TestCase):
|
||||
op = partial(fn, source=0, destination=1)
|
||||
run_test(device, op)
|
||||
|
||||
# Testing that the generated view_copy kernel and its derivative are implemented correctly
|
||||
def test_view_copy(self, device):
|
||||
a = torch.randn(4, device=device, requires_grad=True)
|
||||
a_ref = a.clone().detach().requires_grad_()
|
||||
a_view = a_ref.view(2, 2)
|
||||
a_view_copy = torch.view_copy(a, (2, 2))
|
||||
|
||||
# view_copy ops don't preserve view relationship
|
||||
self.assertTrue(self.is_view_of(a_ref, a_view))
|
||||
self.assertFalse(self.is_view_of(a, a_view_copy))
|
||||
|
||||
a_view_copy.sum().backward()
|
||||
a_view.sum().backward()
|
||||
|
||||
# forward and backward give the same shape + result
|
||||
self.assertEqual(a_view_copy, a_view)
|
||||
self.assertEqual(a.grad, a_ref.grad)
|
||||
|
||||
class TestOldViewOps(TestCase):
|
||||
def test_ravel(self, device):
|
||||
|
||||
|
@ -14,13 +14,38 @@ from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, Bas
|
||||
tensorGeometryT, scalarTypeT, SpecialArgName,
|
||||
OptionalCType, stringT)
|
||||
from tools.codegen.api import cpp
|
||||
from tools.codegen.gen import parse_native_yaml
|
||||
from tools.codegen.gen import parse_native_yaml, get_grouped_by_view_native_functions
|
||||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.model import FunctionSchema, NativeFunction, Variant, Type
|
||||
from tools.codegen.utils import IDENT_REGEX, split_name_params, YamlLoader
|
||||
from tools.codegen.model import (
|
||||
FunctionSchema, NativeFunction, Variant, Type,
|
||||
NativeFunctionsViewGroup, OperatorName
|
||||
)
|
||||
from tools.codegen.utils import IDENT_REGEX, split_name_params, YamlLoader, concatMap
|
||||
|
||||
_GLOBAL_LOAD_DERIVATIVE_CACHE = {}
|
||||
|
||||
# This function directly adds derivative entries for {view}_copy variants of each view op.
|
||||
# Since every {view} and {view}_copy op shares the same derivative formula,
|
||||
# we generate them here instead of duplicating them in the yaml.
|
||||
# See Note [Codegen'd {view}_copy Operators]
|
||||
def add_view_copy_derivatives(
|
||||
infos: List[DifferentiabilityInfo],
|
||||
view_groups: List[NativeFunctionsViewGroup]
|
||||
) -> List[DifferentiabilityInfo]:
|
||||
# Get the map from each view op's name to its corresponding view group
|
||||
view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = {
|
||||
g.view.func.name: g for g in view_groups}
|
||||
|
||||
view_copy_differentiability_infos = []
|
||||
for info in infos:
|
||||
maybe_view_group = view_name_to_group.get(info.func.func.name, None)
|
||||
if maybe_view_group is not None and maybe_view_group.view_copy is not None:
|
||||
view_copy_info = info.create_view_copy_from_view_derivative(maybe_view_group)
|
||||
if view_copy_info is not None:
|
||||
view_copy_differentiability_infos.append(view_copy_info)
|
||||
|
||||
return view_copy_differentiability_infos
|
||||
|
||||
def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
|
||||
# Do some caching as this is a deterministic function
|
||||
global _GLOBAL_LOAD_DERIVATIVE_CACHE
|
||||
@ -30,7 +55,16 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
|
||||
with open(derivatives_yaml_path, 'r') as f:
|
||||
definitions = yaml.load(f, Loader=YamlLoader)
|
||||
|
||||
functions = parse_native_yaml(native_yaml_path).native_functions
|
||||
funcs = parse_native_yaml(native_yaml_path).native_functions
|
||||
# From the parsed native functions, separate out the (generated) view_copy functions,
|
||||
# so we can generate derivatives for them separately.
|
||||
native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
|
||||
native_functions_without_view_copies = concatMap(
|
||||
# We need to pull out the view_inplace ops too, since they might have their own derivative entries.
|
||||
lambda g: [g] if isinstance(g, NativeFunction) else list(g.functions(include_copy=False)),
|
||||
native_functions_with_view_groups
|
||||
)
|
||||
view_groups = [g for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup)]
|
||||
|
||||
# What's the difference between function schema v.s. signature?
|
||||
# function schema is the complete declaration including mutability annotation / default value and etc.
|
||||
@ -38,7 +72,7 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
|
||||
# that are semantically related.
|
||||
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]] = defaultdict(list)
|
||||
functions_by_schema: Dict[str, NativeFunction] = dict()
|
||||
for function in functions:
|
||||
for function in native_functions_without_view_copies:
|
||||
functions_by_signature[function.func.signature()].append(function)
|
||||
assert str(function.func) not in functions_by_schema
|
||||
functions_by_schema[str(function.func)] = function
|
||||
@ -50,6 +84,7 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
|
||||
infos = [
|
||||
create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter)
|
||||
for defn in definitions]
|
||||
infos += add_view_copy_derivatives(infos, view_groups)
|
||||
|
||||
_GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional, Sequence, Set, List, Tuple, Match
|
||||
|
||||
from tools.codegen.api import cpp
|
||||
from tools.codegen.api.types import Binding, NamedCType
|
||||
from tools.codegen.model import NativeFunction, Type, SchemaKind
|
||||
from tools.codegen.model import NativeFunction, Type, SchemaKind, NativeFunctionsViewGroup
|
||||
from tools.codegen.utils import IDENT_REGEX
|
||||
|
||||
# Represents a saved attribute involved in backward calculation.
|
||||
@ -149,6 +149,39 @@ class DifferentiabilityInfo:
|
||||
def has_derivatives(self) -> bool:
|
||||
return len(self.args_with_derivatives) > 0
|
||||
|
||||
# Generates a new DifferentiabilityInfo using the exact same set of derivative information,
|
||||
# but with a new operator name.
|
||||
# This is used when generating "copy" variants of view ops,
|
||||
# which are able to use the exact same derivative formula as the original view op
|
||||
# See Note [Codegen'd {view}_copy Operators]
|
||||
def create_view_copy_from_view_derivative(self, g: NativeFunctionsViewGroup) -> Optional['DifferentiabilityInfo']:
|
||||
if g.view_copy is None:
|
||||
return None
|
||||
f = g.view_copy
|
||||
|
||||
name_split_by_period = self.name.split('.', maxsplit=2)
|
||||
# Append a "_copy" to the base name of the operator (but keep the overload name the same)
|
||||
view_copy_name = f'{name_split_by_period[0]}_copy.' + '.'.join(name_split_by_period[1:])
|
||||
view_copy_op_name = None if self.op is None else f'{self.op}_copy'
|
||||
|
||||
return DifferentiabilityInfo(
|
||||
# Use the "_copy" version of name/func/op
|
||||
name=view_copy_name,
|
||||
func=f,
|
||||
op=view_copy_op_name,
|
||||
# But keep all derivative info the same
|
||||
derivatives=self.derivatives,
|
||||
forward_derivatives=self.forward_derivatives,
|
||||
all_saved_inputs=self.all_saved_inputs,
|
||||
all_saved_outputs=self.all_saved_outputs,
|
||||
available_named_gradients=self.available_named_gradients,
|
||||
used_named_gradients=self.used_named_gradients,
|
||||
args_with_derivatives=self.args_with_derivatives,
|
||||
non_differentiable_arg_names=self.non_differentiable_arg_names,
|
||||
output_differentiability=self.output_differentiability,
|
||||
output_differentiability_conditions=self.output_differentiability_conditions,
|
||||
)
|
||||
|
||||
def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
|
||||
if info is None:
|
||||
return False
|
||||
|
@ -1,9 +1,12 @@
|
||||
from tools.codegen.utils import S, T, context
|
||||
from tools.codegen.model import (NativeFunction, NativeFunctionsGroup, BackendIndex, DispatchKey)
|
||||
from tools.codegen.model import (
|
||||
NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup, BackendIndex, DispatchKey
|
||||
)
|
||||
import tools.codegen.local as local
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
import functools
|
||||
from typing import TypeVar, Union, Iterator, Callable, Dict
|
||||
from typing import TypeVar, Union, Iterator, Callable, Dict, Optional
|
||||
import contextlib
|
||||
|
||||
# Helper functions for defining generators on things in the model
|
||||
@ -12,17 +15,28 @@ F = TypeVar(
|
||||
'F',
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
NativeFunctionsViewGroup,
|
||||
Union[NativeFunction, NativeFunctionsGroup],
|
||||
Union[NativeFunction, NativeFunctionsViewGroup],
|
||||
)
|
||||
|
||||
F2 = TypeVar(
|
||||
'F2',
|
||||
NativeFunction,
|
||||
Optional[NativeFunction],
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunction]) -> Iterator[None]:
|
||||
def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]) -> Iterator[None]:
|
||||
if isinstance(g, NativeFunctionsGroup):
|
||||
# By default, we associate all errors with structured native functions
|
||||
# with the out variant. In some cases, it might be better to have
|
||||
# a more specific place to hang things; if so, use
|
||||
# native_function_manager again on the inside
|
||||
f = g.out
|
||||
elif isinstance(g, NativeFunctionsViewGroup):
|
||||
# We associate errors with the view operator
|
||||
f = g.view
|
||||
else:
|
||||
f = g
|
||||
with context(lambda: f'in native_functions.yaml line {f.loc}:\n {f.func}'):
|
||||
@ -41,6 +55,14 @@ def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
|
||||
return func(f)
|
||||
return wrapper
|
||||
|
||||
def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(f: F, f2: F2) -> T:
|
||||
# The first native_function is assumed to be the one with the appropriate context.
|
||||
with native_function_manager(f):
|
||||
return func(f, f2)
|
||||
return wrapper
|
||||
|
||||
def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(slf: S, f: F) -> T:
|
||||
@ -57,6 +79,17 @@ def with_native_function_and_index(func: Callable[[F, BackendIndex], T]) -> Call
|
||||
return func(f, backend_index)
|
||||
return wrapper
|
||||
|
||||
# Convenience decorator for functions that explicitly take in a BackendIndex,
|
||||
# instead of indirectly taking one in as a closure
|
||||
def with_native_function_and_selector_and_index(
|
||||
func: Callable[[SelectiveBuilder, F, BackendIndex], T]
|
||||
) -> Callable[[SelectiveBuilder, F, BackendIndex], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(selector: SelectiveBuilder, f: F, backend_index: BackendIndex) -> T:
|
||||
with native_function_manager(f):
|
||||
return func(selector, f, backend_index)
|
||||
return wrapper
|
||||
|
||||
def with_native_function_and_indices(
|
||||
func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
|
||||
) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
|
||||
|
@ -17,7 +17,10 @@ from tools.codegen.model import (Argument, DispatchKey, FunctionSchema,
|
||||
is_cuda_dispatch_key,
|
||||
is_generic_dispatch_key,
|
||||
is_ufunc_dispatch_key,
|
||||
Tag, BaseOperatorName)
|
||||
NativeFunctionsViewGroup,
|
||||
ViewSchemaKind,
|
||||
BaseOperatorName,
|
||||
Tag)
|
||||
from tools.codegen.api.types import (Binding, CppSignature, CppSignatureGroup,
|
||||
DispatcherSignature, NativeSignature)
|
||||
from tools.codegen.api import cpp
|
||||
@ -40,7 +43,8 @@ from tools.codegen.gen_functionalization_type import (
|
||||
needs_functionalization,
|
||||
gen_functionalization_definition,
|
||||
gen_functionalization_registration,
|
||||
gen_functionalization_view_inverse_declaration
|
||||
gen_functionalization_view_inverse_declaration,
|
||||
gen_composite_view_copy_kernel,
|
||||
)
|
||||
|
||||
T = TypeVar('T')
|
||||
@ -1002,6 +1006,43 @@ def pre_group_native_functions(
|
||||
d[f.func.kind()] = f
|
||||
return pre_grouped_native_functions
|
||||
|
||||
def get_grouped_by_view_native_functions(
|
||||
native_functions: Sequence[NativeFunction]
|
||||
) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
|
||||
def maybe_create_view_group(d: Dict[ViewSchemaKind, NativeFunction]) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
|
||||
funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
|
||||
if ViewSchemaKind.aliasing not in d:
|
||||
# Case 1: this op / op group is not aliasing, so we don't create a view group.
|
||||
# return the original (ungrouped) native functions instead.
|
||||
for func in d.values():
|
||||
funcs.append(func)
|
||||
else:
|
||||
# Case 2: this op group contains an aliasing op, so we create a ViewGroup for it.
|
||||
# The handling for out= ops here is unfortunate.
|
||||
# out= ops don't really make sense for view operators.
|
||||
# However, we have at least one existing {view}_copy.out operator in native_functions.yaml.
|
||||
# It shouldn't be part of a view group, so we explicitly don't group it.
|
||||
# There currently aren't any out= view ops (and there probably shouldn't be).
|
||||
# We also expect that when we hit this case, the `non_aliasing` op in the dict
|
||||
# *must* be a view_copy op (this is asserted in the NativeFunctionsViewGroup constructor)
|
||||
if ViewSchemaKind.out in d:
|
||||
funcs.append(d[ViewSchemaKind.out])
|
||||
|
||||
funcs.append(NativeFunctionsViewGroup(
|
||||
view=d[ViewSchemaKind.aliasing],
|
||||
view_copy=d.get(ViewSchemaKind.non_aliasing, None),
|
||||
view_inplace=d.get(ViewSchemaKind.inplace, None),
|
||||
))
|
||||
return funcs
|
||||
|
||||
grouped_by_views: Dict[FunctionSchema, Dict[ViewSchemaKind, NativeFunction]] = defaultdict(dict)
|
||||
for f in native_functions:
|
||||
schema = f.func.view_signature()
|
||||
assert f.view_schema_kind not in grouped_by_views[schema]
|
||||
grouped_by_views[schema][f.view_schema_kind] = f
|
||||
|
||||
return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
|
||||
|
||||
def get_grouped_native_functions(
|
||||
native_functions: Sequence[NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
|
||||
def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
|
||||
@ -1316,10 +1357,6 @@ def gen_headers(
|
||||
'registration_declarations': [compute_registration_declarations(f, backend_indices) for f in native_functions],
|
||||
})
|
||||
|
||||
cpu_fm.write('FunctionalInverses.h', lambda: {
|
||||
'view_inverse_declarations': list(mapMaybe(gen_functionalization_view_inverse_declaration, native_functions))
|
||||
})
|
||||
|
||||
|
||||
def gen_aten_interned_strings() -> Dict[str, str]:
|
||||
attrs = set() # All function argument names
|
||||
@ -1354,6 +1391,7 @@ def gen_source_files(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||
native_functions_with_view_groups: Sequence[Union[NativeFunction, NativeFunctionsViewGroup]],
|
||||
selector: SelectiveBuilder,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
core_fm: FileManager,
|
||||
@ -1528,7 +1566,7 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
|
||||
else list(mapMaybe(RegisterSchema(schema_selector), native_functions)),
|
||||
})
|
||||
|
||||
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
|
||||
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str:
|
||||
return fn.root_name
|
||||
|
||||
cpu_fm.write_sharded(
|
||||
@ -1564,20 +1602,43 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
|
||||
|
||||
|
||||
def functionalization_env_callable(
|
||||
g: Union[NativeFunction, NativeFunctionsGroup]
|
||||
g: Union[NativeFunction, NativeFunctionsViewGroup]
|
||||
) -> Dict[str, List[str]]:
|
||||
functions = [g] if isinstance(g, NativeFunction) else list(g.functions())
|
||||
functions_needing_functionalization = [
|
||||
fn for fn in functions if needs_functionalization(selector, fn)]
|
||||
functions_needing_functionalization = [g] if needs_functionalization(selector, g) else []
|
||||
|
||||
def gen_op_headers(g: Union[NativeFunction, NativeFunctionsViewGroup]) -> List[str]:
|
||||
if not needs_functionalization(selector, g):
|
||||
return []
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
# view ops always get a functionalization kernel
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.view.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
|
||||
]
|
||||
if g.view_copy is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
else:
|
||||
f = g
|
||||
return [
|
||||
f"#include <ATen/ops/{f.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{f.root_name}_ops.h>",
|
||||
]
|
||||
|
||||
return {
|
||||
'ops_headers': ([
|
||||
f"#include <ATen/ops/{functions[0].root_name}_native.h>",
|
||||
f"#include <ATen/ops/{functions[0].root_name}_ops.h>",
|
||||
] if functions_needing_functionalization else []),
|
||||
'func_definitions': list(mapMaybe(
|
||||
lambda f: gen_functionalization_definition(selector, f, to_functional_op[f.func.name]),
|
||||
'ops_headers': gen_op_headers(g),
|
||||
'func_definitions': list(concatMap(
|
||||
lambda f: gen_functionalization_definition(
|
||||
selector,
|
||||
g,
|
||||
# We need to manually map inplace ops to their out-of-place variants
|
||||
# (we can't do this with NativeFunctionsGroup today because not all inplace ops have out= variants)
|
||||
None if isinstance(g, NativeFunctionsViewGroup) else to_functional_op.get(g.func.name, None)),
|
||||
functions_needing_functionalization)),
|
||||
'func_registrations': list(mapMaybe(
|
||||
'func_registrations': list(concatMap(
|
||||
lambda f: gen_functionalization_registration(
|
||||
selector, f, backend_indices[DispatchKey.CompositeImplicitAutograd]),
|
||||
functions_needing_functionalization)),
|
||||
@ -1586,13 +1647,46 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
|
||||
|
||||
cpu_fm.write_sharded(
|
||||
'RegisterFunctionalization.cpp',
|
||||
grouped_native_functions,
|
||||
native_functions_with_view_groups,
|
||||
key_fn=key_func,
|
||||
env_callable=functionalization_env_callable,
|
||||
num_shards=4,
|
||||
sharded_keys={'ops_headers', 'func_definitions', 'func_registrations'}
|
||||
)
|
||||
|
||||
cpu_fm.write('FunctionalInverses.h', lambda: {
|
||||
'view_inverse_declarations': list(mapMaybe(
|
||||
lambda g: gen_functionalization_view_inverse_declaration(selector, g),
|
||||
[g for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup)]))
|
||||
})
|
||||
|
||||
# Note [view_copy NativeFunctions]
|
||||
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
|
||||
# needs to have a corresponding non-aliasing {view}_copy variant.
|
||||
# Backends that use functionalization and don't know how to handle aliasing ops
|
||||
# are expected to implement kernels for these {view}_copy kernels instead.
|
||||
# The code for {view}_copy operators in core is pretty boilerplate-heavy however,
|
||||
# so we codegen the following:
|
||||
# (1) A CompositeExplicitAutograd kernel for every {view}_copy operator.
|
||||
# These are never explicitly invoked by the functionalization pass,
|
||||
# but they could theoretically be called from user code (I added these kernels for completeness,
|
||||
# since the ops are part of the public API).
|
||||
# (2) A derivative formula for every {view}_copy operator
|
||||
# {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts,
|
||||
# so rather than stamping all of the entries out in derivatives.yaml,
|
||||
# we codegen them in.
|
||||
# This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
|
||||
cpu_fm.write('CompositeViewCopyKernels.cpp', lambda: {
|
||||
'ops_headers': [
|
||||
'\n'.join(f'#include <ATen/ops/{f.root_name}_ops.h>' for f in
|
||||
([g.view] if g.view_copy is None else [g.view, g.view_copy]))
|
||||
for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup)
|
||||
],
|
||||
'CompositeViewCopyKernel_Definitions': list(mapMaybe(
|
||||
gen_composite_view_copy_kernel,
|
||||
[g for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup)]))
|
||||
})
|
||||
|
||||
|
||||
def gen_declarations_yaml(
|
||||
cpu_fm: FileManager,
|
||||
@ -1674,9 +1768,13 @@ def main() -> None:
|
||||
native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml')
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path)
|
||||
native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
|
||||
|
||||
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||
structured_native_functions = [g for g in grouped_native_functions
|
||||
if isinstance(g, NativeFunctionsGroup)]
|
||||
native_functions_with_view_groups = get_grouped_by_view_native_functions(native_functions)
|
||||
|
||||
template_dir = os.path.join(options.source_path, "templates")
|
||||
|
||||
# NB: It is mandatory to NOT use os.path.join here, as the install directory
|
||||
# will eventually be ingested by cmake, which does not respect Windows style
|
||||
@ -1734,6 +1832,7 @@ def main() -> None:
|
||||
native_functions=native_functions,
|
||||
grouped_native_functions=grouped_native_functions,
|
||||
structured_native_functions=structured_native_functions,
|
||||
native_functions_with_view_groups=native_functions_with_view_groups,
|
||||
selector=selector,
|
||||
backend_indices=backend_indices,
|
||||
core_fm=core_fm,
|
||||
|
@ -1,16 +1,75 @@
|
||||
from tools.codegen.api import cpp
|
||||
from tools.codegen.api.types import (
|
||||
DispatcherSignature, Binding, FunctionalizationLambda, ViewInverseSignature
|
||||
DispatcherSignature, Binding, FunctionalizationLambda, ViewInverseSignature,
|
||||
NativeSignature
|
||||
)
|
||||
from tools.codegen.api.translate import translate
|
||||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.context import (
|
||||
with_native_function, with_native_function_and_selector_and_index,
|
||||
with_native_function_and
|
||||
)
|
||||
from tools.codegen.model import (
|
||||
Argument, NativeFunction, SchemaKind, BackendIndex,
|
||||
Tag, FunctionSchema, SelfArgument, TensorOptionsArguments, BaseType, BaseTy
|
||||
Tag, FunctionSchema, SelfArgument, TensorOptionsArguments, BaseType, BaseTy,
|
||||
NativeFunctionsViewGroup, ListType
|
||||
)
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
# This file contains codegen that relates to the functionalization pass.
|
||||
# It includes:
|
||||
# - gen_functionalization_definition
|
||||
# Generates dispatcher kernel definitions for the functionalization pass.
|
||||
# - gen_functionalization_registration
|
||||
# Generates dispatcher kernel registrations for the functionalization pass.
|
||||
# - gen_functionalization_view_inverse_declaration
|
||||
# Generates a declaration for an "inverse view", for every view op
|
||||
# that is needed in functionalization. We manually implement their definitions.
|
||||
# - gen_composite_view_copy_kernel
|
||||
# Generates view_copy() composite kernels for all view_copy operators.
|
||||
|
||||
|
||||
# Generates the body of the default composite C++ kernel for a {view}_copy NativeFunction
|
||||
# See Note [view_copy NativeFunctions]
|
||||
@with_native_function
|
||||
def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]:
|
||||
|
||||
if g.view_copy is None:
|
||||
return None
|
||||
# view_copy is a native signature, since we're generating an at::native:: kernel
|
||||
view_copy_sig = NativeSignature(g.view_copy.func)
|
||||
# view is a dispatcher signature, since we're calling into the at::_ops API
|
||||
view_sig = DispatcherSignature(g.view.func)
|
||||
|
||||
view_api_name = g.view.func.name.unambiguous_name()
|
||||
exprs = ', '.join([e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())])
|
||||
|
||||
# view ops today always return either a Tensor or a list of Tensors
|
||||
assert len(g.view.func.returns) == 1
|
||||
assert g.view.func.returns[0].type == BaseType(BaseTy.Tensor) \
|
||||
or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None)
|
||||
|
||||
if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
|
||||
return_cloned_output = '''\
|
||||
return output.clone();'''
|
||||
else:
|
||||
# If the return type is a list, we need to clone each tensor in the list.
|
||||
return_cloned_output = f'''\
|
||||
{view_copy_sig.returns_type().cpp_type()} out_clone;
|
||||
for (const auto i : c10::irange(output.size())) {{
|
||||
out_clone.push_back(output[i].clone());
|
||||
}}
|
||||
return out_clone;'''
|
||||
|
||||
# The default generated composite kernel for {view}_copy() operators just clones
|
||||
# the input tensor, and runs the underlying view on the clone.
|
||||
return f"""
|
||||
{view_copy_sig.defn()} {{
|
||||
auto output = at::_ops::{view_api_name}::call({exprs});
|
||||
{return_cloned_output}
|
||||
}}
|
||||
"""
|
||||
|
||||
def modifies_arguments(f: NativeFunction) -> bool:
|
||||
return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
|
||||
|
||||
@ -110,6 +169,7 @@ View operators with multiple aliasing inputs aren't supported yet. Found an oper
|
||||
# Generates the Functionalization kernel for:
|
||||
# - ops that create aliases (e.g. transpose())
|
||||
# - ops that are views AND mutations (e.g. transpose_())
|
||||
@with_native_function_and
|
||||
def emit_view_functionalization_body(
|
||||
f: NativeFunction,
|
||||
functional_op: NativeFunction
|
||||
@ -154,6 +214,7 @@ def emit_view_functionalization_body(
|
||||
if f.tag is Tag.inplace_view:
|
||||
# See Note [Functionalization Pass - Inplace View Ops] for more details
|
||||
return f"""
|
||||
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||
if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
|
||||
// functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
|
||||
{unwrap_tensor_args_str}
|
||||
@ -178,10 +239,12 @@ def emit_view_functionalization_body(
|
||||
// See Note [Propagating strides in the functionalization pass]
|
||||
at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
|
||||
return {view_tensor_name};
|
||||
}}
|
||||
"""
|
||||
|
||||
else:
|
||||
return f"""
|
||||
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||
{unwrap_tensor_args_str}
|
||||
if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
|
||||
// functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
|
||||
@ -210,9 +273,11 @@ def emit_view_functionalization_body(
|
||||
// See Note [Propagating strides in the functionalization pass]
|
||||
at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output);
|
||||
return out;
|
||||
}}
|
||||
"""
|
||||
|
||||
# Generates the Functionalization kernel for inplace ops
|
||||
@with_native_function_and
|
||||
def emit_inplace_functionalization_body(
|
||||
f: NativeFunction,
|
||||
functional_op: Optional[NativeFunction]
|
||||
@ -262,6 +327,7 @@ Instead, it's calling the inplace/view operator directly. \
|
||||
If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch."
|
||||
|
||||
return f"""
|
||||
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||
if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{
|
||||
TORCH_WARN("{warn_str}");
|
||||
}}
|
||||
@ -269,6 +335,7 @@ If this causes problems in your program, consider upstreaming the out-of-place o
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
// Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops.
|
||||
{maybe_return}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
# call the out-of-place variant of the op
|
||||
@ -286,6 +353,7 @@ If this causes problems in your program, consider upstreaming the out-of-place o
|
||||
if a.annotation and a.annotation.is_write and a.type.is_tensor_like()])
|
||||
|
||||
return f"""
|
||||
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||
{unwrap_tensor_args_str}
|
||||
if (!({check_all_mutated_args_are_functional})) {{
|
||||
if (({check_any_non_mutated_args_are_functional})) {{
|
||||
@ -308,6 +376,7 @@ If this causes problems in your program, consider upstreaming the out-of-place o
|
||||
}}
|
||||
{mutable_input_post_processing}
|
||||
{return_str(f)};
|
||||
}}
|
||||
}}"""
|
||||
|
||||
|
||||
@ -322,26 +391,47 @@ def emit_declaration_for_noncomposite_views(f: NativeFunction) -> str:
|
||||
# These files provide the kernels that run the functionalization pass, which can be opted into
|
||||
# per backend (e.g. XLA or Vulkan), or as a composable transform (functionalize() in functorch).
|
||||
|
||||
# TODO: later, I'll generate separate kernels for "Functionalize", and "AddBackViews"
|
||||
# That will probably be an enum, that this function will take in to know which ops get kernels.
|
||||
def needs_functionalization(
|
||||
selector: SelectiveBuilder,
|
||||
f: NativeFunction,
|
||||
g: Union[NativeFunction, NativeFunctionsViewGroup],
|
||||
) -> bool:
|
||||
return (selector.include_all_operators and
|
||||
(f.is_view_op or modifies_arguments(f)))
|
||||
# Don't generate kernels in mobile build
|
||||
if not selector.include_all_operators:
|
||||
return False
|
||||
# Every view op needs *some* handling in functionalization:
|
||||
# - non-composite view ops get a generated kernel
|
||||
# - composite view ops get a generated registration (that directly registers the composite kernel)
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
return True
|
||||
# For non-view ops, only inplace operators get functionalized
|
||||
assert isinstance(g, NativeFunction)
|
||||
return modifies_arguments(g)
|
||||
|
||||
|
||||
# See Note [Functionalization Pass: View Inverses].
|
||||
def gen_functionalization_view_inverse_declaration(selector: SelectiveBuilder, g: NativeFunctionsViewGroup) -> Optional[str]:
|
||||
# For every (non-composite) view op, we need a corresponding "inverse view" function.
|
||||
# This generates the declarations so we get a good compiler error when someone adds a new view.
|
||||
@with_native_function
|
||||
def emit_decl_helper(g: NativeFunctionsViewGroup) -> Optional[str]:
|
||||
if g.view.has_composite_implicit_autograd_kernel:
|
||||
return None
|
||||
view_inverse_sig = ViewInverseSignature(g.view)
|
||||
return view_inverse_sig.decl()
|
||||
|
||||
return emit_decl_helper(g)
|
||||
|
||||
|
||||
@with_native_function_and_selector_and_index
|
||||
def gen_functionalization_registration(
|
||||
selector: SelectiveBuilder,
|
||||
f: NativeFunction,
|
||||
g: Union[NativeFunction, NativeFunctionsViewGroup],
|
||||
composite_implicit_autograd_index: BackendIndex
|
||||
) -> Optional[str]:
|
||||
@with_native_function
|
||||
def emit_registration_helper(f: NativeFunction) -> Optional[str]:
|
||||
# Note: for now, this logic is meant to avoid registering functionalization kernels for mobile.
|
||||
# At some point, Vulkan we'll want to use functionalization and we'll need to change this.
|
||||
if not needs_functionalization(selector, f):
|
||||
return None
|
||||
if f.is_view_op and f.has_composite_implicit_autograd_kernel:
|
||||
) -> List[str]:
|
||||
def emit_registration_helper(f: NativeFunction, *, is_view: bool) -> str:
|
||||
if is_view and f.has_composite_implicit_autograd_kernel:
|
||||
metadata = composite_implicit_autograd_index.get_kernel(f)
|
||||
assert metadata is not None
|
||||
native_api_name = metadata.kernel
|
||||
@ -353,49 +443,42 @@ def gen_functionalization_registration(
|
||||
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
|
||||
registration_str = f'static_cast<{sig.ptr_type()}>(at::native::{native_api_name})'
|
||||
else:
|
||||
# non-composite view ops (and inplace ops) get a normal registration.
|
||||
registration_str = f'TORCH_FN(functionalization::{wrapper_name(f.func)})'
|
||||
|
||||
return f'm.impl("{f.func.name}", {registration_str});'
|
||||
|
||||
return emit_registration_helper(f)
|
||||
# Note: for now, this logic is meant to avoid registering functionalization kernels for mobile.
|
||||
# At some point, Vulkan we'll want to use functionalization and we'll need to change this.
|
||||
if not needs_functionalization(selector, g):
|
||||
return []
|
||||
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
view_str = [emit_registration_helper(g.view, is_view=True)]
|
||||
if g.view_inplace is not None:
|
||||
view_str.append(emit_registration_helper(g.view_inplace, is_view=True))
|
||||
return view_str
|
||||
else:
|
||||
f = g
|
||||
return [emit_registration_helper(f, is_view=False)]
|
||||
|
||||
def gen_functionalization_definition(
|
||||
selector: SelectiveBuilder,
|
||||
f: NativeFunction,
|
||||
g: Union[NativeFunction, NativeFunctionsViewGroup],
|
||||
functional_op: Optional[NativeFunction]
|
||||
) -> Optional[str]:
|
||||
@with_native_function
|
||||
def emit_definition_helper(f: NativeFunction) -> Optional[str]:
|
||||
if not needs_functionalization(selector, f):
|
||||
return None
|
||||
if f.is_view_op and f.has_composite_implicit_autograd_kernel:
|
||||
# See Note [Composite view ops in the functionalization pass]
|
||||
return None
|
||||
# order is important here, ops that are both views and mutations should hit the view path.
|
||||
if f.is_view_op:
|
||||
# Every view op is expected to have a functional counterpart (e.g. transpose_() -> transpose())
|
||||
assert functional_op is not None
|
||||
body_str = emit_view_functionalization_body(f, functional_op)
|
||||
) -> List[str]:
|
||||
if not needs_functionalization(selector, g):
|
||||
return []
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
view_defs = []
|
||||
if not g.view.has_composite_implicit_autograd_kernel:
|
||||
# For now I'm just asserting this - later, the view functionalization kernels
|
||||
# should be updated to redispatch to view_copy().
|
||||
assert g.view_copy is not None
|
||||
view_defs.append(emit_view_functionalization_body(g.view, g.view))
|
||||
if g.view_inplace is not None and not g.view_inplace.has_composite_implicit_autograd_kernel:
|
||||
view_defs.append(emit_view_functionalization_body(g.view_inplace, g.view))
|
||||
return view_defs
|
||||
else:
|
||||
# inplace op
|
||||
f = g
|
||||
assert modifies_arguments(f)
|
||||
body_str = emit_inplace_functionalization_body(f, functional_op)
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
return f"""
|
||||
{sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||
{body_str}
|
||||
}}
|
||||
"""
|
||||
|
||||
return emit_definition_helper(f)
|
||||
|
||||
# See Note [Functionalization Pass: View Inverses].
|
||||
@with_native_function
|
||||
def gen_functionalization_view_inverse_declaration(f: NativeFunction) -> Optional[str]:
|
||||
# We only need to generate view_inverse declarations for view ops that:
|
||||
# - aren't composite (since they'll decompose and we'll get them for free).
|
||||
# - aren't inplace (since they should have a corresponding functional version, which we call instead).
|
||||
if f.is_view_op and not f.has_composite_implicit_autograd_kernel and not modifies_arguments(f):
|
||||
output = emit_declaration_for_noncomposite_views(f)
|
||||
return output
|
||||
return None
|
||||
return [emit_inplace_functionalization_body(f, functional_op)]
|
||||
|
@ -274,6 +274,7 @@ class DeviceCheckType(Enum):
|
||||
|
||||
class Tag(Enum):
|
||||
inplace_view = 0
|
||||
view_copy = 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
@ -285,6 +286,8 @@ class Tag(Enum):
|
||||
return v
|
||||
raise AssertionError(f'unknown tag {value}')
|
||||
|
||||
ViewSchemaKind = Enum('ViewSchemaKind', ('aliasing', 'inplace', 'out', 'non_aliasing'))
|
||||
|
||||
# The basic input to the code generation is native_functions.yaml.
|
||||
# The name "native", BTW, comes from the distinction between native
|
||||
# functions and legacy TH functions. The legacy TH functions are gone,
|
||||
@ -653,6 +656,18 @@ class NativeFunction:
|
||||
inp.annotation.alias_set_after != "" for inp in self.func.schema_order_arguments())
|
||||
return is_non_mutating_view or is_inplace_view or is_wildcard_view
|
||||
|
||||
@property
|
||||
def view_schema_kind(self) -> ViewSchemaKind:
|
||||
# This covers both "ordinary" inplace ops, and inplace_views
|
||||
if self.func.name.name.inplace:
|
||||
return ViewSchemaKind.inplace
|
||||
elif self.func.is_out_fn():
|
||||
return ViewSchemaKind.out
|
||||
elif self.is_view_op:
|
||||
return ViewSchemaKind.aliasing
|
||||
else:
|
||||
return ViewSchemaKind.non_aliasing
|
||||
|
||||
@property
|
||||
def root_name(self) -> str:
|
||||
return self.func.name.name.base
|
||||
@ -1043,7 +1058,7 @@ class FunctionSchema:
|
||||
else:
|
||||
return SchemaKind.functional
|
||||
|
||||
def signature(self, *, strip_default: bool = False) -> 'FunctionSchema':
|
||||
def signature(self, *, strip_default: bool = False, strip_view_copy_name: bool = False) -> 'FunctionSchema':
|
||||
"""
|
||||
Certain schemas are 'related', in that they are simply
|
||||
inplace/out/functional versions of the same function. This method
|
||||
@ -1060,6 +1075,10 @@ class FunctionSchema:
|
||||
because you cannot overload on mutability annotation)
|
||||
- Return names are stripped since they are not overloadable and
|
||||
some variants have return names but some not
|
||||
|
||||
Finally, we want to be able to pair up related "view" and their
|
||||
corresponding "view_copy" operators. We do this by optionally
|
||||
stripping the trailing "_copy" from the base name.
|
||||
"""
|
||||
|
||||
def strip_ret_annotation(r: Return) -> Return:
|
||||
@ -1069,10 +1088,14 @@ class FunctionSchema:
|
||||
annotation=None,
|
||||
)
|
||||
|
||||
base_name = self.name.name.base
|
||||
if strip_view_copy_name and base_name.endswith('_copy'):
|
||||
base_name = base_name.replace('_copy', '')
|
||||
|
||||
return FunctionSchema(
|
||||
name=OperatorName(
|
||||
name=BaseOperatorName(
|
||||
base=self.name.name.base,
|
||||
base=base_name,
|
||||
inplace=False,
|
||||
dunder_method=self.name.name.dunder_method,
|
||||
),
|
||||
@ -1082,6 +1105,14 @@ class FunctionSchema:
|
||||
returns=tuple(map(strip_ret_annotation, self.returns)),
|
||||
)
|
||||
|
||||
def view_signature(self) -> 'FunctionSchema':
|
||||
return self.signature(strip_view_copy_name=True)
|
||||
|
||||
@property
|
||||
def modifies_arguments(self) -> bool:
|
||||
return self.kind() in [SchemaKind.inplace, SchemaKind.out]
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
all_arguments_str = str(self.arguments)
|
||||
if len(self.returns) == 1:
|
||||
@ -1751,6 +1782,95 @@ def gets_generated_out_inplace_wrapper(f: NativeFunction, g: NativeFunctionsGrou
|
||||
not b.has_kernel(f) and \
|
||||
b.has_kernel(g.functional)
|
||||
|
||||
# NativeFunction objects that are views (f.is_view_op returns True)
|
||||
# are added into a `NativeFunctionsViewGroup`, which we can use to
|
||||
# easily access the generated (optional) view_copy NativeFunction.
|
||||
# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup.
|
||||
# See Note [Codegen'd {view}_copy Operators]
|
||||
#
|
||||
# One property of this representation is that in order for a view-like op to be part of
|
||||
# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist.
|
||||
# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op,
|
||||
# but don't have corresponding aliasing `narrow.out` op.
|
||||
# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup.
|
||||
@dataclass(frozen=True)
|
||||
class NativeFunctionsViewGroup:
|
||||
view: NativeFunction
|
||||
# Note: the {view}_copy operator is optional because we currently don't generate copy variants
|
||||
# for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
|
||||
# (we already get them "for free" through decomposition)
|
||||
view_copy: Optional[NativeFunction]
|
||||
# view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
|
||||
view_inplace: Optional[NativeFunction]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.view.is_view_op
|
||||
if self.view_copy is None:
|
||||
assert not gets_generated_view_copy(self.view), \
|
||||
f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs." \
|
||||
" The codegen expects you to add a corresponding operator to native_functions.yaml:" \
|
||||
" {str(get_view_copy_name(self.view)}." \
|
||||
" See Note [view_copy NativeFunctions] for details."
|
||||
else:
|
||||
assert self.view_copy.func.name.name.base.endswith('_copy')
|
||||
assert self.view.func.signature() == self.view_copy.func.signature(strip_view_copy_name=True)
|
||||
assert self.view_copy.tag == Tag.view_copy, \
|
||||
f"{str(self.view_copy.func.name)} appears to be a view_copy operator. The codegen expects" \
|
||||
" view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml." \
|
||||
" See Note [view_copy NativeFunction] for details."
|
||||
if self.view_inplace is not None:
|
||||
assert self.view.func.signature() == self.view_inplace.func.signature()
|
||||
|
||||
def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
|
||||
yield self.view
|
||||
if self.view_inplace is not None:
|
||||
yield self.view_inplace
|
||||
if self.view_copy is not None and include_copy:
|
||||
yield self.view_copy
|
||||
|
||||
@property
|
||||
def root_name(self) -> str:
|
||||
return self.view.root_name
|
||||
|
||||
def gets_generated_view_copy(f: NativeFunction) -> bool:
|
||||
# Only aliasing (view) operators get a copy variant.
|
||||
if not f.is_view_op:
|
||||
return False
|
||||
# We don't need to bother generating copy variants for CompositeImplicitAutograd ops,
|
||||
# because we can let them decompose into base view ops.
|
||||
if f.has_composite_implicit_autograd_kernel:
|
||||
return False
|
||||
# We also don't need to generate copy variants for inplace views.
|
||||
if f.tag == Tag.inplace_view:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Given a NativeFunction that corresponds to a view op,
|
||||
# returns the OperatorName of the corresponding "copy" variant of the op.
|
||||
def get_view_copy_name(f: NativeFunction) -> 'OperatorName':
|
||||
# Right now, when asking for a view op's corresponding "view_copy" name
|
||||
# we assert for sanity that the op is allowed to have a generated view_copy variant.
|
||||
# (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
|
||||
# However, narrow_copy() already exists as an op directly in native_functions.yaml.
|
||||
# I'm hardcoding narrow_copy here for now to maintain the assert,
|
||||
# But we could also just get rid of the assert.
|
||||
list_of_ops_with_explicit_view_copy_operators = [
|
||||
'narrow'
|
||||
]
|
||||
if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
|
||||
assert gets_generated_view_copy(f)
|
||||
|
||||
base_name = f'{f.func.name.name.base}_copy'
|
||||
view_copy_name = OperatorName(
|
||||
name=BaseOperatorName(
|
||||
base=base_name,
|
||||
inplace=False,
|
||||
dunder_method=f.func.name.name.dunder_method),
|
||||
overload_name=f.func.name.overload_name
|
||||
)
|
||||
return view_copy_name
|
||||
|
||||
|
||||
# Helper functions for parsing argument lists (both inputs and returns)
|
||||
|
||||
def parse_returns(return_decl: str) -> Tuple[Return, ...]:
|
||||
|
@ -11963,3 +11963,141 @@ Example::
|
||||
tensor([[2, 3, 5],
|
||||
[2, 3, 5]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.view_as_real_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.view_as_real`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.view_as_complex_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.view_as_complex`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.as_strided_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.as_strided`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.diagonal_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.diagonal`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.expand_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.expand`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.permute_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.permute`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.select_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.select`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.detach_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.detach`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.slice_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.slice`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.split_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.split`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.split_with_sizes_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.split_with_sizes`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.squeeze_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.squeeze`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.t_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.t`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.transpose_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.transpose`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.unsqueeze_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.unsqueeze`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.indices_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.indices`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.values_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.values`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.crow_indices_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.crow_indices`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.col_indices_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.col_indices`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.unbind_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.unbind`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.view_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.view`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.unfold_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.unfold`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
||||
add_docstr(torch.alias_copy,
|
||||
r"""
|
||||
Performs the same operation as :func:`torch.alias`, but all output tensors
|
||||
are freshly created instead of aliasing the input.
|
||||
""")
|
||||
|
@ -1014,6 +1014,40 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.vstack: lambda tensors, out=None: -1,
|
||||
torch.where: lambda condition, x=None, y=None: -1,
|
||||
torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
||||
torch._fw_primal_copy: lambda self, level: -1,
|
||||
torch._make_dual_copy: lambda primal, tangent, level: -1,
|
||||
torch.view_as_real_copy: lambda self: -1,
|
||||
torch.view_as_complex_copy: lambda self: -1,
|
||||
torch._conj_copy: lambda self: -1,
|
||||
torch._neg_view_copy: lambda self: -1,
|
||||
torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1,
|
||||
torch._sparse_broadcast_to_copy: lambda self, size: -1,
|
||||
torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1,
|
||||
torch.expand_copy: lambda self, size, *, implicit=False: -1,
|
||||
torch.narrow_copy: lambda self, dim, start, length: -1,
|
||||
torch.permute_copy: lambda self, dims: -1,
|
||||
torch._reshape_alias_copy: lambda self, size, stride: -1,
|
||||
torch.select_copy: lambda self, dim, index: -1,
|
||||
torch.detach_copy: lambda self: -1,
|
||||
torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1,
|
||||
torch.split_copy: lambda self, split_size, dim=0: -1,
|
||||
torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1,
|
||||
torch.squeeze_copy: lambda self: -1,
|
||||
torch.squeeze_copy: lambda self, dim: -1,
|
||||
torch.t_copy: lambda self: -1,
|
||||
torch.transpose_copy: lambda self, dim0, dim1: -1,
|
||||
torch.unsqueeze_copy: lambda self, dim: -1,
|
||||
torch._indices_copy: lambda self: -1,
|
||||
torch._values_copy: lambda self: -1,
|
||||
torch.indices_copy: lambda self: -1,
|
||||
torch.values_copy: lambda self: -1,
|
||||
torch.crow_indices_copy: lambda self: -1,
|
||||
torch.col_indices_copy: lambda self: -1,
|
||||
torch.unbind_copy: lambda self, dim=0: -1,
|
||||
torch.view_copy: lambda self, size: -1,
|
||||
torch.view_copy: lambda self, dtype: -1,
|
||||
torch.unfold_copy: lambda self, dimension, size, step: -1,
|
||||
torch.alias_copy: lambda self: -1,
|
||||
Tensor.__floordiv__: lambda self, other: -1,
|
||||
Tensor.__rfloordiv__: lambda self, other: -1,
|
||||
Tensor.__ifloordiv__: lambda self, other: -1,
|
||||
|
Reference in New Issue
Block a user