Reapply "Make functionalization ViewMeta serializable with pickle. (#143712)" (#163769)

### Summary:
NOTE: This is a re-export of https://github.com/pytorch/pytorch/pull/161994 ; the changes between these two PRs is exclusively to the buck/build files

(Summary from #161994 )
Attempted rebase of https://github.com/pytorch/pytorch/pull/143712.

This reverts commit 6c713ccb5e0df227dd5b630057cbccd373cbe7d6.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

imported-using-ghimport

Test Plan: Imported from OSS

Differential Revision: D81524507

Pulled By: Lucaskabela

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163769
Approved by: https://github.com/dolpm

Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
This commit is contained in:
Brian Hirsh
2025-09-25 10:27:37 +00:00
committed by PyTorch MergeBot
parent 29cbcbac42
commit 7d710403b0
38 changed files with 981 additions and 425 deletions

View File

@ -23,20 +23,13 @@ from torchgen.model import (
# This file describes the translation of JIT schema to API's used
# when creating view lambdas that are used by the functionalization pass.
# There are two types of lambdas: forward lambdas and reverse lambdas.
# These API's mostly follow the dispatcher API, with a few quirks:
# - The lambda capture has to convert reference types to value types
# - While the forward lambda just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse lambda
# when creating `ViewMeta` specializations that are used by the functionalization pass.
# These API's mostly follow the dispatcher API, with one difference:
# - While the forward function just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse function
# is responsible for generating both the call-site, and the declarations
# (which are implemented manually in the at::functionalization::impl namespace).
# The lambdas generated for each view op in the functionalization pass are of the form
# [capture_arguments](outer_arguments) -> returns_type {
# return name(inner_arguments);
# }
# Define some specific lambda input arguments.
base_binding = Binding(
name="base",
@ -46,6 +39,18 @@ base_binding = Binding(
),
default=None,
)
has_symbolic_inputs_binding = Binding(
name="has_symbolic_inputs",
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
argument=Argument(
name="has_symbolic_inputs",
type=BaseType(BaseTy.bool),
default=None,
annotation=None,
),
default=None,
)
mutated_view_binding = Binding(
name="mutated_view",
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
@ -54,11 +59,11 @@ mutated_view_binding = Binding(
),
default=None,
)
mutated_view_idx_binding = Binding(
name="mutated_view_idx",
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
out_index_binding = Binding(
name="out_index",
nctype=NamedCType(name="out_index", type=BaseCType(longT)),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
),
default=None,
)
@ -86,8 +91,13 @@ inverse_return_mode_binding = Binding(
)
# The lambda capture itself doesn't have a name.
# The name returned here corresponds to the name of the inner function called by the lambda.
# Name of the `ViewMeta` specialization class created.
def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
namespace = "at::functionalization::" if with_namespace else ""
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
# Name of the operation called inside the `forward`/`reverse` implementations.
def name(
g: NativeFunctionsViewGroup,
*,
@ -124,24 +134,6 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
return f"{api_name}_inverse"
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
# capture arguments include all arguments except `self`.
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
all_bindings = [
inverse_return_mode_binding if is_reverse else reapply_views_binding
]
all_bindings.extend(non_self_value_bindings)
return all_bindings
def returns_type(func: FunctionSchema) -> CType:
# Assertion: all view ops return tensor-like outputs
assert len(func.returns) >= 1
@ -152,24 +144,49 @@ def returns_type(func: FunctionSchema) -> CType:
return BaseCType(tensorT)
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
if is_reverse:
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
else:
return [base_binding, mutated_view_idx_binding]
# Checks whether `func` might return more than one value.
def is_multi_output(func: FunctionSchema) -> bool:
return len(func.returns) > 1 or (
len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
)
def inner_call_index(func: FunctionSchema) -> Binding | None:
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
if len(func.returns) > 1 or (
len(func.returns) == 1 and func.returns[0].type.is_list_like()
):
return mutated_view_idx_binding
return None
# `ViewMeta` specialization constructor parameters.
def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
# All specializations are parematerized by `has_symbolic_inputs` flag.
arguments = [has_symbolic_inputs_binding]
# If `func` might return more than 1 value, we also parameterize this specialization
# with the output index.
if is_multi_output(func):
arguments.append(out_index_binding)
return arguments
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
# `ViewMeta` specialized class' constructor arguments.
#
# Values needed specifically by this specialization, that the base class does not need.
# Same as the class' attributes, but non-owning.
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
return attributes(func, owning=False)
# `ViewMeta` specialized class' non-static member data.
#
# Essential data for calling the instance's `forward` and `reverse functions. You can
# think of them as values that should be captured from the functionalization kernel.
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
return [
reapply_views_binding,
inverse_return_mode_binding,
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
]
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
@ -183,13 +200,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
# the reverse lambda does the same, but with an additional "mutated_view" arg
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
# their corresponding view_inverse function takes in an additional index argument.
index_binding = inner_call_index(func)
if index_binding is not None:
if is_multi_output(func):
return [
base_binding,
mutated_view_binding,
inverse_return_mode_binding,
index_binding,
out_index_binding,
] + non_self_bindings
else:
return [

View File

@ -300,83 +300,11 @@ class ViewInverseSignature:
return_type = functionalization.returns_type(self.g.view.func)
decls = [
a.decl()
for a in functionalization.inner_arguments(
self.g.view.func, is_reverse=True
)
for a in functionalization.op_arguments(self.g.view.func, is_reverse=True)
]
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
@dataclass(frozen=True)
class FunctionalizationLambda:
g: NativeFunctionsViewGroup
# are we generating the forward lambda or the reverse lambda?
is_reverse: bool
def captures(self) -> list[Expr]:
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
# and plumb it into the lambda.
outer_ctx = dispatcher.arguments(self.g.view.func) + [
functionalization.reapply_views_binding,
functionalization.inverse_return_mode_binding,
]
capture_bindings = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
# allow_expensive_conversions is set because we want to convert
# some reference types (IntArrayRef) to value types (vector<int64_t>).
capture_exprs = translate.translate(
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
)
return capture_exprs
def decl(self) -> str:
return_type = functionalization.returns_type(self.g.view.func)
capture_str = ", ".join(
f"{val.type.name} = {val.expr}" for val in self.captures()
)
decls = [
a.decl()
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
]
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
def inner_call(self, *, reapply_views: bool | None = None) -> str:
inner_call_name = functionalization.name(
self.g,
is_reverse=self.is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
capture_ctx = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
full_ctx = arg_ctx + capture_ctx
assert self.g.view_copy is not None
call_bindings = functionalization.inner_arguments(
self.g.view_copy.func, is_reverse=self.is_reverse
)
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
call_exprs = [
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
]
if not self.is_reverse and maybe_index is not None:
return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];"
else:
return f"{inner_call_name}({', '.join(call_exprs)});"
@staticmethod
def from_func(
g: NativeFunctionsViewGroup, *, is_reverse: bool
) -> FunctionalizationLambda:
return FunctionalizationLambda(g, is_reverse)
@dataclass(frozen=True)
class StructuredImplSignature:
g: NativeFunctionsGroup

View File

@ -43,6 +43,8 @@ from torchgen.gen_functionalization_type import (
gen_functionalization_definition,
gen_functionalization_registration,
gen_functionalization_view_inverse_declaration,
gen_functionalization_view_meta_classes_decl,
gen_functionalization_view_meta_classes_impl,
GenCompositeViewCopyKernel,
)
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
@ -2493,48 +2495,48 @@ def gen_source_files(
},
)
def gen_op_headers(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> list[str]:
if isinstance(g, NativeFunctionsViewGroup):
# view ops always get a functionalization kernel
headers = [
f"#include <ATen/ops/{g.view.root_name}_native.h>",
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
]
if g.view_copy is not None:
headers += [
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
]
return headers
elif isinstance(g, NativeFunctionsGroup):
headers = [
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
f"#include <ATen/ops/{g.out.root_name}_native.h>",
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
]
if g.inplace is not None:
headers += [
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
]
if g.mutable is not None:
headers += [
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
]
return headers
else:
return [
f"#include <ATen/ops/{g.root_name}_native.h>",
f"#include <ATen/ops/{g.root_name}_ops.h>",
]
def functionalization_env_callable(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> dict[str, list[str]]:
def gen_op_headers(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> list[str]:
if isinstance(g, NativeFunctionsViewGroup):
# view ops always get a functionalization kernel
headers = [
f"#include <ATen/ops/{g.view.root_name}_native.h>",
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
]
if g.view_copy is not None:
headers += [
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
]
return headers
elif isinstance(g, NativeFunctionsGroup):
headers = [
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
f"#include <ATen/ops/{g.out.root_name}_native.h>",
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
]
if g.inplace is not None:
headers += [
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
]
if g.mutable is not None:
headers += [
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
]
return headers
else:
return [
f"#include <ATen/ops/{g.root_name}_native.h>",
f"#include <ATen/ops/{g.root_name}_ops.h>",
]
return {
"ops_headers": gen_op_headers(g),
"func_definitions": gen_functionalization_definition(
@ -2600,6 +2602,31 @@ def gen_source_files(
},
)
cpu_fm.write(
"ViewMetaClasses.h",
lambda: {
"view_meta_declarations": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
view_groups,
)
)
},
)
cpu_fm.write(
"ViewMetaClasses.cpp",
lambda: {
"view_meta_implementations": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
view_groups,
)
),
"op_headers": list(concatMap(gen_op_headers, view_groups)),
},
)
# Note [view_copy NativeFunctions]
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
# needs to have a corresponding non-aliasing {view}_copy variant.

View File

@ -1,16 +1,15 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, TYPE_CHECKING
from typing import Callable, Optional, TYPE_CHECKING
from torchgen.api import cpp, dispatcher
from torchgen.api import cpp, dispatcher, functionalization
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
CType,
DispatcherSignature,
FunctionalizationLambda,
iTensorListRefT,
NativeSignature,
OptionalCType,
@ -48,7 +47,7 @@ from torchgen.native_function_generation import (
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
)
from torchgen.utils import dataclass_repr
from torchgen.utils import concatMap, dataclass_repr, FileManager
if TYPE_CHECKING:
@ -365,6 +364,8 @@ def emit_view_functionalization_body(
with native_function_manager(f):
call_sig = DispatcherSignature.from_schema(g.view_copy.func)
spec = ViewMetaSpecialization(g, f=f)
# the "view_copy" op name that the functionalization kernels need to call
api_name = g.view_copy.func.name.unambiguous_name()
# Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
@ -385,9 +386,6 @@ def emit_view_functionalization_body(
for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
]
forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
# The meta API call should use the same arguments, but convert all tensors to meta tensors first.
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
meta_call_args = [
@ -415,19 +413,7 @@ def emit_view_functionalization_body(
: at::functionalization::InverseReturnMode::NeverView
);
{symbolic_inputs_check}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname}
);
auto view_meta = {spec.new()};
auto compute_reference_meta =
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
@ -455,7 +441,6 @@ def emit_view_functionalization_body(
"""
else:
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
{unwrap_tensor_args_str}
@ -489,21 +474,7 @@ def emit_view_functionalization_body(
}}
}}
{symbolic_inputs_check}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname},
/*is_multi_output=*/{str(is_multi_output_view).lower()},
/*is_as_strided=*/{str(str(f.func.name) == "as_strided").lower()}
);
auto view_meta = {spec.new()};
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
// See Note [Propagating strides in the functionalization pass]
if (compute_reference_meta && !disable_meta_reference()) {{
@ -771,6 +742,301 @@ def gen_functionalization_view_inverse_declaration(
return emit_decl_helper(g)
# Helper class for generating `ViewMeta` specializations.
@dataclass
class ViewMetaSpecialization:
g: NativeFunctionsViewGroup
f: NativeFunction
@property
def is_multi_output(self) -> bool:
return functionalization.is_multi_output(self.f.func)
@property
def is_as_strided(self) -> bool:
return str(self.f.func.name) == "as_strided"
@property
def out_index(self) -> str:
if self.is_multi_output:
return functionalization.out_index_binding.name
return "0"
@property
def classname(self) -> str:
return functionalization.classname(self.f.func)
def decl(self) -> list[str]:
base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func)
extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func)
attributes = functionalization.attributes(self.f.func)
# List of types for declaring the `SerializableTuple` type.
serializable_tuple_args = ",\n".join(
f" {binding.type} /* {binding.name} */"
for binding in (base_ctor_arguments + attributes)
)
# Arguments used for forwarding the tuple elements to the constructor.
destructure_tuple_args = ", ".join(
f"std::get<{i}>(tpl)"
for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments))
)
# List of constructor parameters
ctor_parameters = ", ".join(
binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments)
)
# Call the base class `ViewMeta` constructor.
#
# Both of `is_multi_output` and `is_as_strided` are known values, given the
# operation schema.
is_multi_output_str = str(self.is_multi_output).lower()
is_as_strided_str = str(self.is_as_strided).lower()
base_ctor_bindings = ", ".join(
[
# `has_symbolic_inputs` is always taken as parameter.
functionalization.has_symbolic_inputs_binding.name,
f"/*is_multi_output=*/{is_multi_output_str}",
f"/*is_as_strided=*/{is_as_strided_str}",
# `out_index` is know if the operation returns only one value. Otherwise,
# we also take it as parameter.
f"/*out_index=*/{self.out_index}",
]
)
# Assignments of `extra_ctor_arguments` to their corresponding fields.
# These are extra fields to-be-declared in this specialization.
#
# We need to set `allow_expensive_conversions`, since we are storing owned versions
# of the non-owning arguments.
ctor_assignments = ",\n".join(
f" {e.type.name}({e.expr})"
for e in translate(
extra_ctor_arguments,
attributes,
method=False,
allow_expensive_conversions=True,
)
)
# List of arguments for constructing the `SerializableTuple` from an instance.
tuple_arguments = ", ".join(
binding.name for binding in (base_ctor_arguments + attributes)
)
# List of field declarations.
attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes)
# Override `to_out_index` if this operation returns more than 1 value.
to_out_index_decl = ""
if self.is_multi_output:
to_out_index_decl = (
" std::shared_ptr<ViewMeta> to_out_index(int64_t out_idx) override;"
)
return [
f"""
struct TORCH_API {self.classname} : public ViewMeta {{
FUNCTIONALIZATION_VIEWMETA_NAME({self.classname})
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args});
{self.classname}(const SerializableTuple& tpl)
: {self.classname}({destructure_tuple_args}) {{}}
{self.classname}({ctor_parameters})
: at::functionalization::ViewMeta({base_ctor_bindings}),
{ctor_assignments} {{}}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
{to_out_index_decl}
SerializableTuple to_serializable_tuple() {{
return std::make_tuple({tuple_arguments});
}}
{attr_declarations}
}};
"""
]
# Generate a call to the actual operation.
def opcall(self, is_reverse: bool, reapply_views: bool) -> str:
opname = functionalization.name(
self.g,
is_reverse=is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
# Expected arguments for the operation.
assert self.g.view_copy is not None
op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse)
# The context is composed by the constructor arguments (which are also
# the field variables stored in the instance), and the `base` tensor.
context = [functionalization.base_binding]
context += functionalization.base_ctor_arguments(self.f.func)
context += functionalization.attributes(self.f.func)
# If we are generating the call for the reverse function, we also have
# access to `mutated_view` argument.
if is_reverse:
context.append(functionalization.mutated_view_binding)
arguments = ", ".join(
[e.expr for e in translate(context, op_arguments, method=False)]
)
# Index the result if this operation returns multiple values.
maybe_index = ""
if not is_reverse and self.is_multi_output:
maybe_index = f"[{self.out_index}]"
return f"{opname}({arguments}){maybe_index}"
def impl(self) -> list[str]:
functions = [
f"""
at::Tensor {self.classname}::forward(const at::Tensor& base) {{
if (reapply_views) {{
return {self.opcall(is_reverse=False, reapply_views=True)};
}} else {{
return {self.opcall(is_reverse=False, reapply_views=False)};
}}
}}""",
f"""
at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{
return {self.opcall(is_reverse=True, reapply_views=True)};
}}""",
]
# If this operation returns multiple values, also generate a `to_out_index`
# implementation.
if self.is_multi_output:
functions.append(f"""
std::shared_ptr<at::functionalization::ViewMeta> {self.classname}::to_out_index(int64_t out_index) {{
return {self.new("out_index")};
}}
""")
return functions
# Create the Python binding for this specialized class.
def binding(self) -> list[str]:
name = functionalization.classname(self.f.func, with_namespace=True)
return [f" create_binding_with_pickle<{name}>(functionalization);"]
# Generate an instantiation of this specialized class.
def new(self, out_index: str = "0") -> str:
name = functionalization.classname(self.f.func, with_namespace=True)
ctor_arguments = functionalization.base_ctor_arguments(
self.f.func
) + functionalization.extra_ctor_arguments(self.f.func)
# Replace the `out_index` parameter with the given `out_index`.
arguments = ", ".join(
binding.name if binding.name != "out_index" else out_index
for binding in ctor_arguments
)
return f"std::make_shared<{name}>({arguments})"
# Run the function `run` for both: `view` and `view_inplace` functions.
@staticmethod
def map(
g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]]
) -> list[str]:
def maybe_run(f: Optional[NativeFunction]) -> list[str]:
if f is None:
return []
with native_function_manager(f):
return run(ViewMetaSpecialization(g, f))
return list(concatMap(maybe_run, (g.view, g.view_inplace)))
def gen_functionalization_view_meta_classes_base(
selector: SelectiveBuilder,
g: NativeFunctionsViewGroup,
run: Callable[[ViewMetaSpecialization], list[str]],
) -> list[str]:
if not selector.include_all_operators:
return []
if g.composite:
return []
return ViewMetaSpecialization.map(g, run)
def gen_functionalization_view_meta_classes_decl(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.decl
)
def gen_functionalization_view_meta_classes_impl(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.impl
)
def gen_functionalization_view_meta_classes_binding(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.binding
)
# Generates the Python bindings for the `ViewMeta` specialized classes.
def gen_functionalization_view_meta_classes(
native_functions_path: str,
tags_path: str,
selector: SelectiveBuilder,
install_dir: str,
template_dir: str,
) -> None:
from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
# Parse the native_functions.yaml.
# Then, group them into `NativeFunctionsViewGroup`.
#
# This is the same steps we do in gen.py (ATen codegen).
native_functions = parse_native_yaml(
native_functions_path, tags_path
).native_functions
native_functions_with_view_groups = get_grouped_by_view_native_functions(
native_functions
)
view_groups = [
g
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
]
fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False)
fm.write(
"ViewMetaClassesPythonBinding.cpp",
lambda: {
"view_meta_bindings": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_binding(
selector, g
),
view_groups,
)
),
},
)
def gen_functionalization_registration(
selector: SelectiveBuilder,
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,