mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
### Summary: NOTE: This is a re-export of https://github.com/pytorch/pytorch/pull/161994 ; the changes between these two PRs is exclusively to the buck/build files (Summary from #161994 ) Attempted rebase of https://github.com/pytorch/pytorch/pull/143712. This reverts commit 6c713ccb5e0df227dd5b630057cbccd373cbe7d6. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela imported-using-ghimport Test Plan: Imported from OSS Differential Revision: D81524507 Pulled By: Lucaskabela Pull Request resolved: https://github.com/pytorch/pytorch/pull/163769 Approved by: https://github.com/dolpm Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
29cbcbac42
commit
7d710403b0
@ -23,20 +23,13 @@ from torchgen.model import (
|
||||
|
||||
|
||||
# This file describes the translation of JIT schema to API's used
|
||||
# when creating view lambdas that are used by the functionalization pass.
|
||||
# There are two types of lambdas: forward lambdas and reverse lambdas.
|
||||
# These API's mostly follow the dispatcher API, with a few quirks:
|
||||
# - The lambda capture has to convert reference types to value types
|
||||
# - While the forward lambda just directly calls into the at::_ops API
|
||||
# (following the dispatcher convention), the logic here for the reverse lambda
|
||||
# when creating `ViewMeta` specializations that are used by the functionalization pass.
|
||||
# These API's mostly follow the dispatcher API, with one difference:
|
||||
# - While the forward function just directly calls into the at::_ops API
|
||||
# (following the dispatcher convention), the logic here for the reverse function
|
||||
# is responsible for generating both the call-site, and the declarations
|
||||
# (which are implemented manually in the at::functionalization::impl namespace).
|
||||
|
||||
# The lambdas generated for each view op in the functionalization pass are of the form
|
||||
# [capture_arguments](outer_arguments) -> returns_type {
|
||||
# return name(inner_arguments);
|
||||
# }
|
||||
|
||||
# Define some specific lambda input arguments.
|
||||
base_binding = Binding(
|
||||
name="base",
|
||||
@ -46,6 +39,18 @@ base_binding = Binding(
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
has_symbolic_inputs_binding = Binding(
|
||||
name="has_symbolic_inputs",
|
||||
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
|
||||
argument=Argument(
|
||||
name="has_symbolic_inputs",
|
||||
type=BaseType(BaseTy.bool),
|
||||
default=None,
|
||||
annotation=None,
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
mutated_view_binding = Binding(
|
||||
name="mutated_view",
|
||||
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
|
||||
@ -54,11 +59,11 @@ mutated_view_binding = Binding(
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
mutated_view_idx_binding = Binding(
|
||||
name="mutated_view_idx",
|
||||
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
|
||||
out_index_binding = Binding(
|
||||
name="out_index",
|
||||
nctype=NamedCType(name="out_index", type=BaseCType(longT)),
|
||||
argument=Argument(
|
||||
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
|
||||
name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
@ -86,8 +91,13 @@ inverse_return_mode_binding = Binding(
|
||||
)
|
||||
|
||||
|
||||
# The lambda capture itself doesn't have a name.
|
||||
# The name returned here corresponds to the name of the inner function called by the lambda.
|
||||
# Name of the `ViewMeta` specialization class created.
|
||||
def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
|
||||
namespace = "at::functionalization::" if with_namespace else ""
|
||||
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
|
||||
|
||||
|
||||
# Name of the operation called inside the `forward`/`reverse` implementations.
|
||||
def name(
|
||||
g: NativeFunctionsViewGroup,
|
||||
*,
|
||||
@ -124,24 +134,6 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
|
||||
return f"{api_name}_inverse"
|
||||
|
||||
|
||||
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
|
||||
# capture arguments include all arguments except `self`.
|
||||
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
|
||||
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
non_self_args = args[1:]
|
||||
non_self_value_bindings = [
|
||||
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
|
||||
]
|
||||
|
||||
all_bindings = [
|
||||
inverse_return_mode_binding if is_reverse else reapply_views_binding
|
||||
]
|
||||
all_bindings.extend(non_self_value_bindings)
|
||||
return all_bindings
|
||||
|
||||
|
||||
def returns_type(func: FunctionSchema) -> CType:
|
||||
# Assertion: all view ops return tensor-like outputs
|
||||
assert len(func.returns) >= 1
|
||||
@ -152,24 +144,49 @@ def returns_type(func: FunctionSchema) -> CType:
|
||||
return BaseCType(tensorT)
|
||||
|
||||
|
||||
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
|
||||
if is_reverse:
|
||||
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
|
||||
else:
|
||||
return [base_binding, mutated_view_idx_binding]
|
||||
# Checks whether `func` might return more than one value.
|
||||
def is_multi_output(func: FunctionSchema) -> bool:
|
||||
return len(func.returns) > 1 or (
|
||||
len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
|
||||
)
|
||||
|
||||
|
||||
def inner_call_index(func: FunctionSchema) -> Binding | None:
|
||||
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
|
||||
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
|
||||
if len(func.returns) > 1 or (
|
||||
len(func.returns) == 1 and func.returns[0].type.is_list_like()
|
||||
):
|
||||
return mutated_view_idx_binding
|
||||
return None
|
||||
# `ViewMeta` specialization constructor parameters.
|
||||
def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
|
||||
# All specializations are parematerized by `has_symbolic_inputs` flag.
|
||||
arguments = [has_symbolic_inputs_binding]
|
||||
|
||||
# If `func` might return more than 1 value, we also parameterize this specialization
|
||||
# with the output index.
|
||||
if is_multi_output(func):
|
||||
arguments.append(out_index_binding)
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
# `ViewMeta` specialized class' constructor arguments.
|
||||
#
|
||||
# Values needed specifically by this specialization, that the base class does not need.
|
||||
# Same as the class' attributes, but non-owning.
|
||||
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
|
||||
return attributes(func, owning=False)
|
||||
|
||||
|
||||
# `ViewMeta` specialized class' non-static member data.
|
||||
#
|
||||
# Essential data for calling the instance's `forward` and `reverse functions. You can
|
||||
# think of them as values that should be captured from the functionalization kernel.
|
||||
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
return [
|
||||
reapply_views_binding,
|
||||
inverse_return_mode_binding,
|
||||
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
|
||||
]
|
||||
|
||||
|
||||
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
non_self_args = args[1:]
|
||||
@ -183,13 +200,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
# the reverse lambda does the same, but with an additional "mutated_view" arg
|
||||
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
|
||||
# their corresponding view_inverse function takes in an additional index argument.
|
||||
index_binding = inner_call_index(func)
|
||||
if index_binding is not None:
|
||||
if is_multi_output(func):
|
||||
return [
|
||||
base_binding,
|
||||
mutated_view_binding,
|
||||
inverse_return_mode_binding,
|
||||
index_binding,
|
||||
out_index_binding,
|
||||
] + non_self_bindings
|
||||
else:
|
||||
return [
|
||||
|
@ -300,83 +300,11 @@ class ViewInverseSignature:
|
||||
return_type = functionalization.returns_type(self.g.view.func)
|
||||
decls = [
|
||||
a.decl()
|
||||
for a in functionalization.inner_arguments(
|
||||
self.g.view.func, is_reverse=True
|
||||
)
|
||||
for a in functionalization.op_arguments(self.g.view.func, is_reverse=True)
|
||||
]
|
||||
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionalizationLambda:
|
||||
g: NativeFunctionsViewGroup
|
||||
|
||||
# are we generating the forward lambda or the reverse lambda?
|
||||
is_reverse: bool
|
||||
|
||||
def captures(self) -> list[Expr]:
|
||||
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
|
||||
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
|
||||
# and plumb it into the lambda.
|
||||
outer_ctx = dispatcher.arguments(self.g.view.func) + [
|
||||
functionalization.reapply_views_binding,
|
||||
functionalization.inverse_return_mode_binding,
|
||||
]
|
||||
capture_bindings = functionalization.capture_arguments(
|
||||
self.g.view.func, is_reverse=self.is_reverse
|
||||
)
|
||||
# allow_expensive_conversions is set because we want to convert
|
||||
# some reference types (IntArrayRef) to value types (vector<int64_t>).
|
||||
capture_exprs = translate.translate(
|
||||
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
|
||||
)
|
||||
return capture_exprs
|
||||
|
||||
def decl(self) -> str:
|
||||
return_type = functionalization.returns_type(self.g.view.func)
|
||||
capture_str = ", ".join(
|
||||
f"{val.type.name} = {val.expr}" for val in self.captures()
|
||||
)
|
||||
decls = [
|
||||
a.decl()
|
||||
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
|
||||
]
|
||||
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
|
||||
|
||||
def inner_call(self, *, reapply_views: bool | None = None) -> str:
|
||||
inner_call_name = functionalization.name(
|
||||
self.g,
|
||||
is_reverse=self.is_reverse,
|
||||
include_namespace=True,
|
||||
reapply_views=reapply_views,
|
||||
)
|
||||
|
||||
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
|
||||
capture_ctx = functionalization.capture_arguments(
|
||||
self.g.view.func, is_reverse=self.is_reverse
|
||||
)
|
||||
full_ctx = arg_ctx + capture_ctx
|
||||
|
||||
assert self.g.view_copy is not None
|
||||
call_bindings = functionalization.inner_arguments(
|
||||
self.g.view_copy.func, is_reverse=self.is_reverse
|
||||
)
|
||||
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
|
||||
call_exprs = [
|
||||
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
|
||||
]
|
||||
if not self.is_reverse and maybe_index is not None:
|
||||
return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];"
|
||||
else:
|
||||
return f"{inner_call_name}({', '.join(call_exprs)});"
|
||||
|
||||
@staticmethod
|
||||
def from_func(
|
||||
g: NativeFunctionsViewGroup, *, is_reverse: bool
|
||||
) -> FunctionalizationLambda:
|
||||
return FunctionalizationLambda(g, is_reverse)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StructuredImplSignature:
|
||||
g: NativeFunctionsGroup
|
||||
|
105
torchgen/gen.py
105
torchgen/gen.py
@ -43,6 +43,8 @@ from torchgen.gen_functionalization_type import (
|
||||
gen_functionalization_definition,
|
||||
gen_functionalization_registration,
|
||||
gen_functionalization_view_inverse_declaration,
|
||||
gen_functionalization_view_meta_classes_decl,
|
||||
gen_functionalization_view_meta_classes_impl,
|
||||
GenCompositeViewCopyKernel,
|
||||
)
|
||||
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
|
||||
@ -2493,48 +2495,48 @@ def gen_source_files(
|
||||
},
|
||||
)
|
||||
|
||||
def gen_op_headers(
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
) -> list[str]:
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
# view ops always get a functionalization kernel
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.view.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
|
||||
]
|
||||
if g.view_copy is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
|
||||
]
|
||||
if g.inplace is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
|
||||
]
|
||||
if g.mutable is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
else:
|
||||
return [
|
||||
f"#include <ATen/ops/{g.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.root_name}_ops.h>",
|
||||
]
|
||||
|
||||
def functionalization_env_callable(
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
) -> dict[str, list[str]]:
|
||||
def gen_op_headers(
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
) -> list[str]:
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
# view ops always get a functionalization kernel
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.view.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
|
||||
]
|
||||
if g.view_copy is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
|
||||
]
|
||||
if g.inplace is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
|
||||
]
|
||||
if g.mutable is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
else:
|
||||
return [
|
||||
f"#include <ATen/ops/{g.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.root_name}_ops.h>",
|
||||
]
|
||||
|
||||
return {
|
||||
"ops_headers": gen_op_headers(g),
|
||||
"func_definitions": gen_functionalization_definition(
|
||||
@ -2600,6 +2602,31 @@ def gen_source_files(
|
||||
},
|
||||
)
|
||||
|
||||
cpu_fm.write(
|
||||
"ViewMetaClasses.h",
|
||||
lambda: {
|
||||
"view_meta_declarations": list(
|
||||
concatMap(
|
||||
lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
|
||||
view_groups,
|
||||
)
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
cpu_fm.write(
|
||||
"ViewMetaClasses.cpp",
|
||||
lambda: {
|
||||
"view_meta_implementations": list(
|
||||
concatMap(
|
||||
lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
|
||||
view_groups,
|
||||
)
|
||||
),
|
||||
"op_headers": list(concatMap(gen_op_headers, view_groups)),
|
||||
},
|
||||
)
|
||||
|
||||
# Note [view_copy NativeFunctions]
|
||||
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
|
||||
# needs to have a corresponding non-aliasing {view}_copy variant.
|
||||
|
@ -1,16 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import Callable, Optional, TYPE_CHECKING
|
||||
|
||||
from torchgen.api import cpp, dispatcher
|
||||
from torchgen.api import cpp, dispatcher, functionalization
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
Binding,
|
||||
CType,
|
||||
DispatcherSignature,
|
||||
FunctionalizationLambda,
|
||||
iTensorListRefT,
|
||||
NativeSignature,
|
||||
OptionalCType,
|
||||
@ -48,7 +47,7 @@ from torchgen.native_function_generation import (
|
||||
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||
)
|
||||
from torchgen.utils import dataclass_repr
|
||||
from torchgen.utils import concatMap, dataclass_repr, FileManager
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -365,6 +364,8 @@ def emit_view_functionalization_body(
|
||||
with native_function_manager(f):
|
||||
call_sig = DispatcherSignature.from_schema(g.view_copy.func)
|
||||
|
||||
spec = ViewMetaSpecialization(g, f=f)
|
||||
|
||||
# the "view_copy" op name that the functionalization kernels need to call
|
||||
api_name = g.view_copy.func.name.unambiguous_name()
|
||||
# Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
|
||||
@ -385,9 +386,6 @@ def emit_view_functionalization_body(
|
||||
for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
|
||||
]
|
||||
|
||||
forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
|
||||
reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
|
||||
|
||||
# The meta API call should use the same arguments, but convert all tensors to meta tensors first.
|
||||
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
|
||||
meta_call_args = [
|
||||
@ -415,19 +413,7 @@ def emit_view_functionalization_body(
|
||||
: at::functionalization::InverseReturnMode::NeverView
|
||||
);
|
||||
{symbolic_inputs_check}
|
||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
||||
{forward_lambda.decl()} {{
|
||||
if (reapply_views) {{
|
||||
return {forward_lambda.inner_call(reapply_views=True)}
|
||||
}} else {{
|
||||
return {forward_lambda.inner_call(reapply_views=False)}
|
||||
}}
|
||||
}},
|
||||
{reverse_lambda.decl()} {{
|
||||
return {reverse_lambda.inner_call()}
|
||||
}},
|
||||
/*has_symbolic_inputs=*/{symbolic_inputs_varname}
|
||||
);
|
||||
auto view_meta = {spec.new()};
|
||||
auto compute_reference_meta =
|
||||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
|
||||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
|
||||
@ -455,7 +441,6 @@ def emit_view_functionalization_body(
|
||||
"""
|
||||
|
||||
else:
|
||||
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
|
||||
return f"""
|
||||
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||
{unwrap_tensor_args_str}
|
||||
@ -489,21 +474,7 @@ def emit_view_functionalization_body(
|
||||
}}
|
||||
}}
|
||||
{symbolic_inputs_check}
|
||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
||||
{forward_lambda.decl()} {{
|
||||
if (reapply_views) {{
|
||||
return {forward_lambda.inner_call(reapply_views=True)}
|
||||
}} else {{
|
||||
return {forward_lambda.inner_call(reapply_views=False)}
|
||||
}}
|
||||
}},
|
||||
{reverse_lambda.decl()} {{
|
||||
return {reverse_lambda.inner_call()}
|
||||
}},
|
||||
/*has_symbolic_inputs=*/{symbolic_inputs_varname},
|
||||
/*is_multi_output=*/{str(is_multi_output_view).lower()},
|
||||
/*is_as_strided=*/{str(str(f.func.name) == "as_strided").lower()}
|
||||
);
|
||||
auto view_meta = {spec.new()};
|
||||
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
|
||||
// See Note [Propagating strides in the functionalization pass]
|
||||
if (compute_reference_meta && !disable_meta_reference()) {{
|
||||
@ -771,6 +742,301 @@ def gen_functionalization_view_inverse_declaration(
|
||||
return emit_decl_helper(g)
|
||||
|
||||
|
||||
# Helper class for generating `ViewMeta` specializations.
|
||||
@dataclass
|
||||
class ViewMetaSpecialization:
|
||||
g: NativeFunctionsViewGroup
|
||||
f: NativeFunction
|
||||
|
||||
@property
|
||||
def is_multi_output(self) -> bool:
|
||||
return functionalization.is_multi_output(self.f.func)
|
||||
|
||||
@property
|
||||
def is_as_strided(self) -> bool:
|
||||
return str(self.f.func.name) == "as_strided"
|
||||
|
||||
@property
|
||||
def out_index(self) -> str:
|
||||
if self.is_multi_output:
|
||||
return functionalization.out_index_binding.name
|
||||
return "0"
|
||||
|
||||
@property
|
||||
def classname(self) -> str:
|
||||
return functionalization.classname(self.f.func)
|
||||
|
||||
def decl(self) -> list[str]:
|
||||
base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func)
|
||||
extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func)
|
||||
attributes = functionalization.attributes(self.f.func)
|
||||
|
||||
# List of types for declaring the `SerializableTuple` type.
|
||||
serializable_tuple_args = ",\n".join(
|
||||
f" {binding.type} /* {binding.name} */"
|
||||
for binding in (base_ctor_arguments + attributes)
|
||||
)
|
||||
|
||||
# Arguments used for forwarding the tuple elements to the constructor.
|
||||
destructure_tuple_args = ", ".join(
|
||||
f"std::get<{i}>(tpl)"
|
||||
for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments))
|
||||
)
|
||||
|
||||
# List of constructor parameters
|
||||
ctor_parameters = ", ".join(
|
||||
binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments)
|
||||
)
|
||||
|
||||
# Call the base class `ViewMeta` constructor.
|
||||
#
|
||||
# Both of `is_multi_output` and `is_as_strided` are known values, given the
|
||||
# operation schema.
|
||||
is_multi_output_str = str(self.is_multi_output).lower()
|
||||
is_as_strided_str = str(self.is_as_strided).lower()
|
||||
|
||||
base_ctor_bindings = ", ".join(
|
||||
[
|
||||
# `has_symbolic_inputs` is always taken as parameter.
|
||||
functionalization.has_symbolic_inputs_binding.name,
|
||||
f"/*is_multi_output=*/{is_multi_output_str}",
|
||||
f"/*is_as_strided=*/{is_as_strided_str}",
|
||||
# `out_index` is know if the operation returns only one value. Otherwise,
|
||||
# we also take it as parameter.
|
||||
f"/*out_index=*/{self.out_index}",
|
||||
]
|
||||
)
|
||||
|
||||
# Assignments of `extra_ctor_arguments` to their corresponding fields.
|
||||
# These are extra fields to-be-declared in this specialization.
|
||||
#
|
||||
# We need to set `allow_expensive_conversions`, since we are storing owned versions
|
||||
# of the non-owning arguments.
|
||||
ctor_assignments = ",\n".join(
|
||||
f" {e.type.name}({e.expr})"
|
||||
for e in translate(
|
||||
extra_ctor_arguments,
|
||||
attributes,
|
||||
method=False,
|
||||
allow_expensive_conversions=True,
|
||||
)
|
||||
)
|
||||
|
||||
# List of arguments for constructing the `SerializableTuple` from an instance.
|
||||
tuple_arguments = ", ".join(
|
||||
binding.name for binding in (base_ctor_arguments + attributes)
|
||||
)
|
||||
|
||||
# List of field declarations.
|
||||
attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes)
|
||||
|
||||
# Override `to_out_index` if this operation returns more than 1 value.
|
||||
to_out_index_decl = ""
|
||||
if self.is_multi_output:
|
||||
to_out_index_decl = (
|
||||
" std::shared_ptr<ViewMeta> to_out_index(int64_t out_idx) override;"
|
||||
)
|
||||
|
||||
return [
|
||||
f"""
|
||||
struct TORCH_API {self.classname} : public ViewMeta {{
|
||||
FUNCTIONALIZATION_VIEWMETA_NAME({self.classname})
|
||||
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args});
|
||||
|
||||
{self.classname}(const SerializableTuple& tpl)
|
||||
: {self.classname}({destructure_tuple_args}) {{}}
|
||||
|
||||
{self.classname}({ctor_parameters})
|
||||
: at::functionalization::ViewMeta({base_ctor_bindings}),
|
||||
{ctor_assignments} {{}}
|
||||
|
||||
Tensor forward(const Tensor& base) override;
|
||||
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
|
||||
{to_out_index_decl}
|
||||
|
||||
SerializableTuple to_serializable_tuple() {{
|
||||
return std::make_tuple({tuple_arguments});
|
||||
}}
|
||||
|
||||
{attr_declarations}
|
||||
}};
|
||||
"""
|
||||
]
|
||||
|
||||
# Generate a call to the actual operation.
|
||||
def opcall(self, is_reverse: bool, reapply_views: bool) -> str:
|
||||
opname = functionalization.name(
|
||||
self.g,
|
||||
is_reverse=is_reverse,
|
||||
include_namespace=True,
|
||||
reapply_views=reapply_views,
|
||||
)
|
||||
|
||||
# Expected arguments for the operation.
|
||||
assert self.g.view_copy is not None
|
||||
op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse)
|
||||
|
||||
# The context is composed by the constructor arguments (which are also
|
||||
# the field variables stored in the instance), and the `base` tensor.
|
||||
context = [functionalization.base_binding]
|
||||
context += functionalization.base_ctor_arguments(self.f.func)
|
||||
context += functionalization.attributes(self.f.func)
|
||||
|
||||
# If we are generating the call for the reverse function, we also have
|
||||
# access to `mutated_view` argument.
|
||||
if is_reverse:
|
||||
context.append(functionalization.mutated_view_binding)
|
||||
|
||||
arguments = ", ".join(
|
||||
[e.expr for e in translate(context, op_arguments, method=False)]
|
||||
)
|
||||
|
||||
# Index the result if this operation returns multiple values.
|
||||
maybe_index = ""
|
||||
if not is_reverse and self.is_multi_output:
|
||||
maybe_index = f"[{self.out_index}]"
|
||||
|
||||
return f"{opname}({arguments}){maybe_index}"
|
||||
|
||||
def impl(self) -> list[str]:
|
||||
functions = [
|
||||
f"""
|
||||
at::Tensor {self.classname}::forward(const at::Tensor& base) {{
|
||||
if (reapply_views) {{
|
||||
return {self.opcall(is_reverse=False, reapply_views=True)};
|
||||
}} else {{
|
||||
return {self.opcall(is_reverse=False, reapply_views=False)};
|
||||
}}
|
||||
}}""",
|
||||
f"""
|
||||
at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{
|
||||
return {self.opcall(is_reverse=True, reapply_views=True)};
|
||||
}}""",
|
||||
]
|
||||
|
||||
# If this operation returns multiple values, also generate a `to_out_index`
|
||||
# implementation.
|
||||
if self.is_multi_output:
|
||||
functions.append(f"""
|
||||
std::shared_ptr<at::functionalization::ViewMeta> {self.classname}::to_out_index(int64_t out_index) {{
|
||||
return {self.new("out_index")};
|
||||
}}
|
||||
""")
|
||||
|
||||
return functions
|
||||
|
||||
# Create the Python binding for this specialized class.
|
||||
def binding(self) -> list[str]:
|
||||
name = functionalization.classname(self.f.func, with_namespace=True)
|
||||
return [f" create_binding_with_pickle<{name}>(functionalization);"]
|
||||
|
||||
# Generate an instantiation of this specialized class.
|
||||
def new(self, out_index: str = "0") -> str:
|
||||
name = functionalization.classname(self.f.func, with_namespace=True)
|
||||
ctor_arguments = functionalization.base_ctor_arguments(
|
||||
self.f.func
|
||||
) + functionalization.extra_ctor_arguments(self.f.func)
|
||||
# Replace the `out_index` parameter with the given `out_index`.
|
||||
arguments = ", ".join(
|
||||
binding.name if binding.name != "out_index" else out_index
|
||||
for binding in ctor_arguments
|
||||
)
|
||||
return f"std::make_shared<{name}>({arguments})"
|
||||
|
||||
# Run the function `run` for both: `view` and `view_inplace` functions.
|
||||
@staticmethod
|
||||
def map(
|
||||
g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]]
|
||||
) -> list[str]:
|
||||
def maybe_run(f: Optional[NativeFunction]) -> list[str]:
|
||||
if f is None:
|
||||
return []
|
||||
with native_function_manager(f):
|
||||
return run(ViewMetaSpecialization(g, f))
|
||||
|
||||
return list(concatMap(maybe_run, (g.view, g.view_inplace)))
|
||||
|
||||
|
||||
def gen_functionalization_view_meta_classes_base(
|
||||
selector: SelectiveBuilder,
|
||||
g: NativeFunctionsViewGroup,
|
||||
run: Callable[[ViewMetaSpecialization], list[str]],
|
||||
) -> list[str]:
|
||||
if not selector.include_all_operators:
|
||||
return []
|
||||
|
||||
if g.composite:
|
||||
return []
|
||||
|
||||
return ViewMetaSpecialization.map(g, run)
|
||||
|
||||
|
||||
def gen_functionalization_view_meta_classes_decl(
|
||||
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||
) -> list[str]:
|
||||
return gen_functionalization_view_meta_classes_base(
|
||||
selector, g, ViewMetaSpecialization.decl
|
||||
)
|
||||
|
||||
|
||||
def gen_functionalization_view_meta_classes_impl(
|
||||
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||
) -> list[str]:
|
||||
return gen_functionalization_view_meta_classes_base(
|
||||
selector, g, ViewMetaSpecialization.impl
|
||||
)
|
||||
|
||||
|
||||
def gen_functionalization_view_meta_classes_binding(
|
||||
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||
) -> list[str]:
|
||||
return gen_functionalization_view_meta_classes_base(
|
||||
selector, g, ViewMetaSpecialization.binding
|
||||
)
|
||||
|
||||
|
||||
# Generates the Python bindings for the `ViewMeta` specialized classes.
|
||||
def gen_functionalization_view_meta_classes(
|
||||
native_functions_path: str,
|
||||
tags_path: str,
|
||||
selector: SelectiveBuilder,
|
||||
install_dir: str,
|
||||
template_dir: str,
|
||||
) -> None:
|
||||
from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
|
||||
|
||||
# Parse the native_functions.yaml.
|
||||
# Then, group them into `NativeFunctionsViewGroup`.
|
||||
#
|
||||
# This is the same steps we do in gen.py (ATen codegen).
|
||||
native_functions = parse_native_yaml(
|
||||
native_functions_path, tags_path
|
||||
).native_functions
|
||||
native_functions_with_view_groups = get_grouped_by_view_native_functions(
|
||||
native_functions
|
||||
)
|
||||
view_groups = [
|
||||
g
|
||||
for g in native_functions_with_view_groups
|
||||
if isinstance(g, NativeFunctionsViewGroup)
|
||||
]
|
||||
|
||||
fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False)
|
||||
fm.write(
|
||||
"ViewMetaClassesPythonBinding.cpp",
|
||||
lambda: {
|
||||
"view_meta_bindings": list(
|
||||
concatMap(
|
||||
lambda g: gen_functionalization_view_meta_classes_binding(
|
||||
selector, g
|
||||
),
|
||||
view_groups,
|
||||
)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def gen_functionalization_registration(
|
||||
selector: SelectiveBuilder,
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
|
Reference in New Issue
Block a user