mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
### Summary: NOTE: This is a re-export of https://github.com/pytorch/pytorch/pull/161994 ; the changes between these two PRs is exclusively to the buck/build files (Summary from #161994 ) Attempted rebase of https://github.com/pytorch/pytorch/pull/143712. This reverts commit 6c713ccb5e0df227dd5b630057cbccd373cbe7d6. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela imported-using-ghimport Test Plan: Imported from OSS Differential Revision: D81524507 Pulled By: Lucaskabela Pull Request resolved: https://github.com/pytorch/pytorch/pull/163769 Approved by: https://github.com/dolpm Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
29cbcbac42
commit
7d710403b0
@ -23,20 +23,13 @@ from torchgen.model import (
|
||||
|
||||
|
||||
# This file describes the translation of JIT schema to API's used
|
||||
# when creating view lambdas that are used by the functionalization pass.
|
||||
# There are two types of lambdas: forward lambdas and reverse lambdas.
|
||||
# These API's mostly follow the dispatcher API, with a few quirks:
|
||||
# - The lambda capture has to convert reference types to value types
|
||||
# - While the forward lambda just directly calls into the at::_ops API
|
||||
# (following the dispatcher convention), the logic here for the reverse lambda
|
||||
# when creating `ViewMeta` specializations that are used by the functionalization pass.
|
||||
# These API's mostly follow the dispatcher API, with one difference:
|
||||
# - While the forward function just directly calls into the at::_ops API
|
||||
# (following the dispatcher convention), the logic here for the reverse function
|
||||
# is responsible for generating both the call-site, and the declarations
|
||||
# (which are implemented manually in the at::functionalization::impl namespace).
|
||||
|
||||
# The lambdas generated for each view op in the functionalization pass are of the form
|
||||
# [capture_arguments](outer_arguments) -> returns_type {
|
||||
# return name(inner_arguments);
|
||||
# }
|
||||
|
||||
# Define some specific lambda input arguments.
|
||||
base_binding = Binding(
|
||||
name="base",
|
||||
@ -46,6 +39,18 @@ base_binding = Binding(
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
has_symbolic_inputs_binding = Binding(
|
||||
name="has_symbolic_inputs",
|
||||
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
|
||||
argument=Argument(
|
||||
name="has_symbolic_inputs",
|
||||
type=BaseType(BaseTy.bool),
|
||||
default=None,
|
||||
annotation=None,
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
mutated_view_binding = Binding(
|
||||
name="mutated_view",
|
||||
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
|
||||
@ -54,11 +59,11 @@ mutated_view_binding = Binding(
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
mutated_view_idx_binding = Binding(
|
||||
name="mutated_view_idx",
|
||||
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
|
||||
out_index_binding = Binding(
|
||||
name="out_index",
|
||||
nctype=NamedCType(name="out_index", type=BaseCType(longT)),
|
||||
argument=Argument(
|
||||
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
|
||||
name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
@ -86,8 +91,13 @@ inverse_return_mode_binding = Binding(
|
||||
)
|
||||
|
||||
|
||||
# The lambda capture itself doesn't have a name.
|
||||
# The name returned here corresponds to the name of the inner function called by the lambda.
|
||||
# Name of the `ViewMeta` specialization class created.
|
||||
def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
|
||||
namespace = "at::functionalization::" if with_namespace else ""
|
||||
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
|
||||
|
||||
|
||||
# Name of the operation called inside the `forward`/`reverse` implementations.
|
||||
def name(
|
||||
g: NativeFunctionsViewGroup,
|
||||
*,
|
||||
@ -124,24 +134,6 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
|
||||
return f"{api_name}_inverse"
|
||||
|
||||
|
||||
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
|
||||
# capture arguments include all arguments except `self`.
|
||||
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
|
||||
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
non_self_args = args[1:]
|
||||
non_self_value_bindings = [
|
||||
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
|
||||
]
|
||||
|
||||
all_bindings = [
|
||||
inverse_return_mode_binding if is_reverse else reapply_views_binding
|
||||
]
|
||||
all_bindings.extend(non_self_value_bindings)
|
||||
return all_bindings
|
||||
|
||||
|
||||
def returns_type(func: FunctionSchema) -> CType:
|
||||
# Assertion: all view ops return tensor-like outputs
|
||||
assert len(func.returns) >= 1
|
||||
@ -152,24 +144,49 @@ def returns_type(func: FunctionSchema) -> CType:
|
||||
return BaseCType(tensorT)
|
||||
|
||||
|
||||
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
|
||||
if is_reverse:
|
||||
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
|
||||
else:
|
||||
return [base_binding, mutated_view_idx_binding]
|
||||
# Checks whether `func` might return more than one value.
|
||||
def is_multi_output(func: FunctionSchema) -> bool:
|
||||
return len(func.returns) > 1 or (
|
||||
len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
|
||||
)
|
||||
|
||||
|
||||
def inner_call_index(func: FunctionSchema) -> Binding | None:
|
||||
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
|
||||
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
|
||||
if len(func.returns) > 1 or (
|
||||
len(func.returns) == 1 and func.returns[0].type.is_list_like()
|
||||
):
|
||||
return mutated_view_idx_binding
|
||||
return None
|
||||
# `ViewMeta` specialization constructor parameters.
|
||||
def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
|
||||
# All specializations are parematerized by `has_symbolic_inputs` flag.
|
||||
arguments = [has_symbolic_inputs_binding]
|
||||
|
||||
# If `func` might return more than 1 value, we also parameterize this specialization
|
||||
# with the output index.
|
||||
if is_multi_output(func):
|
||||
arguments.append(out_index_binding)
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
# `ViewMeta` specialized class' constructor arguments.
|
||||
#
|
||||
# Values needed specifically by this specialization, that the base class does not need.
|
||||
# Same as the class' attributes, but non-owning.
|
||||
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
|
||||
return attributes(func, owning=False)
|
||||
|
||||
|
||||
# `ViewMeta` specialized class' non-static member data.
|
||||
#
|
||||
# Essential data for calling the instance's `forward` and `reverse functions. You can
|
||||
# think of them as values that should be captured from the functionalization kernel.
|
||||
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
return [
|
||||
reapply_views_binding,
|
||||
inverse_return_mode_binding,
|
||||
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
|
||||
]
|
||||
|
||||
|
||||
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
non_self_args = args[1:]
|
||||
@ -183,13 +200,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
# the reverse lambda does the same, but with an additional "mutated_view" arg
|
||||
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
|
||||
# their corresponding view_inverse function takes in an additional index argument.
|
||||
index_binding = inner_call_index(func)
|
||||
if index_binding is not None:
|
||||
if is_multi_output(func):
|
||||
return [
|
||||
base_binding,
|
||||
mutated_view_binding,
|
||||
inverse_return_mode_binding,
|
||||
index_binding,
|
||||
out_index_binding,
|
||||
] + non_self_bindings
|
||||
else:
|
||||
return [
|
||||
|
@ -300,83 +300,11 @@ class ViewInverseSignature:
|
||||
return_type = functionalization.returns_type(self.g.view.func)
|
||||
decls = [
|
||||
a.decl()
|
||||
for a in functionalization.inner_arguments(
|
||||
self.g.view.func, is_reverse=True
|
||||
)
|
||||
for a in functionalization.op_arguments(self.g.view.func, is_reverse=True)
|
||||
]
|
||||
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionalizationLambda:
|
||||
g: NativeFunctionsViewGroup
|
||||
|
||||
# are we generating the forward lambda or the reverse lambda?
|
||||
is_reverse: bool
|
||||
|
||||
def captures(self) -> list[Expr]:
|
||||
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
|
||||
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
|
||||
# and plumb it into the lambda.
|
||||
outer_ctx = dispatcher.arguments(self.g.view.func) + [
|
||||
functionalization.reapply_views_binding,
|
||||
functionalization.inverse_return_mode_binding,
|
||||
]
|
||||
capture_bindings = functionalization.capture_arguments(
|
||||
self.g.view.func, is_reverse=self.is_reverse
|
||||
)
|
||||
# allow_expensive_conversions is set because we want to convert
|
||||
# some reference types (IntArrayRef) to value types (vector<int64_t>).
|
||||
capture_exprs = translate.translate(
|
||||
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
|
||||
)
|
||||
return capture_exprs
|
||||
|
||||
def decl(self) -> str:
|
||||
return_type = functionalization.returns_type(self.g.view.func)
|
||||
capture_str = ", ".join(
|
||||
f"{val.type.name} = {val.expr}" for val in self.captures()
|
||||
)
|
||||
decls = [
|
||||
a.decl()
|
||||
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
|
||||
]
|
||||
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
|
||||
|
||||
def inner_call(self, *, reapply_views: bool | None = None) -> str:
|
||||
inner_call_name = functionalization.name(
|
||||
self.g,
|
||||
is_reverse=self.is_reverse,
|
||||
include_namespace=True,
|
||||
reapply_views=reapply_views,
|
||||
)
|
||||
|
||||
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
|
||||
capture_ctx = functionalization.capture_arguments(
|
||||
self.g.view.func, is_reverse=self.is_reverse
|
||||
)
|
||||
full_ctx = arg_ctx + capture_ctx
|
||||
|
||||
assert self.g.view_copy is not None
|
||||
call_bindings = functionalization.inner_arguments(
|
||||
self.g.view_copy.func, is_reverse=self.is_reverse
|
||||
)
|
||||
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
|
||||
call_exprs = [
|
||||
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
|
||||
]
|
||||
if not self.is_reverse and maybe_index is not None:
|
||||
return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];"
|
||||
else:
|
||||
return f"{inner_call_name}({', '.join(call_exprs)});"
|
||||
|
||||
@staticmethod
|
||||
def from_func(
|
||||
g: NativeFunctionsViewGroup, *, is_reverse: bool
|
||||
) -> FunctionalizationLambda:
|
||||
return FunctionalizationLambda(g, is_reverse)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StructuredImplSignature:
|
||||
g: NativeFunctionsGroup
|
||||
|
Reference in New Issue
Block a user