Files
pytorch/torchgen/api/functionalization.py
Edward Yang 36420b5e8c Rename tools/codegen to torchgen (#76275)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275

In preparation for addressing
https://github.com/pytorch/pytorch/issues/73212

Diff was generated with:

```
git mv tools/codegen torchgen
git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g'
sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt
```

and a manual edits to:

* tools/test/test_gen_backend_stubs.py
* torchgen/build.bzl
* torchgen/gen_backend_stubs.py

aka this diff:

```
 diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py
index 3dc26c6d2d..104054575e 100644
 --- a/tools/test/test_gen_backend_stubs.py
+++ b/tools/test/test_gen_backend_stubs.py
@@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run
 from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE  # noqa: F401

 path = os.path.dirname(os.path.realpath(__file__))
-gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py')
+gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py')

 # gen_backend_stubs.py is an integration point that is called directly by external backends.
 # The tests here are to confirm that badly formed inputs result in reasonable error messages.
 diff --git a/torchgen/build.bzl b/torchgen/build.bzl
index ed04e35a43..d00078a3cf 100644
 --- a/torchgen/build.bzl
+++ b/torchgen/build.bzl
@@ -1,6 +1,6 @@
 def define_targets(rules):
     rules.py_library(
-        name = "codegen",
+        name = "torchgen",
         srcs = rules.glob(["**/*.py"]),
         deps = [
             rules.requirement("PyYAML"),
@@ -11,6 +11,6 @@ def define_targets(rules):

     rules.py_binary(
         name = "gen",
-        srcs = [":codegen"],
+        srcs = [":torchgen"],
         visibility = ["//visibility:public"],
     )
 diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py
index c1a672a655..beee7a15e0 100644
 --- a/torchgen/gen_backend_stubs.py
+++ b/torchgen/gen_backend_stubs.py
@@ -474,7 +474,7 @@ def run(
 ) -> None:

     # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
-    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
+    pytorch_root = pathlib.Path(__file__).parent.parent.absolute()
     template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

     def make_file_manager(install_dir: str) -> FileManager:
```

run_all_fbandroid_tests

Test Plan: sandcastle

Reviewed By: albanD, ngimel

Differential Revision: D35770317

fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf
(cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
2022-04-25 01:38:06 +00:00

175 lines
6.8 KiB
Python

from torchgen.model import (
FunctionSchema,
BaseTy,
BaseType,
NativeFunctionsViewGroup,
Argument,
)
from torchgen.api.types import (
Binding,
NamedCType,
ConstRefCType,
BaseCType,
CType,
tensorT,
longT,
boolT,
)
from torchgen.api import dispatcher
from typing import List, Optional
# This file describes the translation of JIT schema to API's used
# when creating view lambdas that are used by the functionalization pass.
# There are two types of lambdas: forward lambdas and reverse lambdas.
# These API's mostly follow the dispatcher API, with a few quirks:
# - The lambda capture has to convert reference types to value types
# - While the forward lambda just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse lambda
# is responsible for generating both the call-site, and the declarations
# (which are implemented manually in the at::functionalization::impl namespace).
# The lambdas generated for each view op in the functionalization pass are of the form
# [capture_arguments](outer_arguments) -> returns_type {
# return name(inner_arguments);
# }
# Define some specific lambda input arguments.
base_binding = Binding(
name="base",
nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
mutated_view_binding = Binding(
name="mutated_view",
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
mutated_view_idx_binding = Binding(
name="mutated_view_idx",
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
reapply_views_binding = Binding(
name="reapply_views",
nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
argument=Argument(
name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
),
default=None,
)
# The lambda capture itself doesn't have a name.
# The name returned here corresponds to the name of the inner function called by the lambda.
def name(
g: NativeFunctionsViewGroup,
*,
is_reverse: bool,
include_namespace: bool,
reapply_views: Optional[bool] = None,
) -> str:
if reapply_views is None:
# reapply_views is only important for the fwd lambda,
# since we always plumb the runtime "reapply_views" argument into the reverse function.
assert is_reverse
if is_reverse:
# for the reverse: the name of the inverse function always involves "view_copy",
# and we plumb the "reapply_views" flag into that function.
# (We could avoid doing that, but that would require writing out twice as many view inverse functions).
assert g.view_copy is not None
api_name = g.view_copy.func.name.unambiguous_name()
# in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
if include_namespace:
return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
else:
return f"{api_name}_inverse"
# in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
assert include_namespace
assert g.view_copy is not None
api_name = (
g.view.func.name.unambiguous_name()
if reapply_views
else g.view_copy.func.name.unambiguous_name()
)
return f"at::_ops::{api_name}::call"
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]:
# capture arguments include all arguments except `self`.
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
all_bindings = [reapply_views_binding] + non_self_value_bindings
return all_bindings
def returns_type(func: FunctionSchema) -> CType:
# Assertion: all view ops return tensor-like outputs
assert len(func.returns) >= 1
for ret in func.returns:
assert ret.type.is_tensor_like()
# However, the return type of the lambda is always an individual tensor.
# For multi-tensor outputs, each tensor needs to be tracked individually.
return BaseCType(tensorT)
def outer_arguments(*, is_reverse: bool) -> List[Binding]:
if is_reverse:
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
else:
return [base_binding, mutated_view_idx_binding]
def inner_call_index(func: FunctionSchema) -> Optional[Binding]:
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
if len(func.returns) > 1 or (
len(func.returns) == 1 and func.returns[0].type.is_list_like()
):
return mutated_view_idx_binding
return None
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
# The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
# Both of these follow the dispatcher API.
non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
if not is_reverse:
# the forward lambda swaps out the original tensor argument with the lambd arg "base"
return [base_binding] + non_self_bindings
else:
# the reverse lambda does the same, but with an additional "mutated_view" arg
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
# their corresponding view_inverse function takes in an additional index argument.
index_binding = inner_call_index(func)
if index_binding is not None:
return [
base_binding,
mutated_view_binding,
reapply_views_binding,
index_binding,
] + non_self_bindings
else:
return [
base_binding,
mutated_view_binding,
reapply_views_binding,
] + non_self_bindings