mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275 In preparation for addressing https://github.com/pytorch/pytorch/issues/73212 Diff was generated with: ``` git mv tools/codegen torchgen git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g' sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt ``` and a manual edits to: * tools/test/test_gen_backend_stubs.py * torchgen/build.bzl * torchgen/gen_backend_stubs.py aka this diff: ``` diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 3dc26c6d2d..104054575e 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401 path = os.path.dirname(os.path.realpath(__file__)) -gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py') +gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py') # gen_backend_stubs.py is an integration point that is called directly by external backends. # The tests here are to confirm that badly formed inputs result in reasonable error messages. diff --git a/torchgen/build.bzl b/torchgen/build.bzl index ed04e35a43..d00078a3cf 100644 --- a/torchgen/build.bzl +++ b/torchgen/build.bzl @@ -1,6 +1,6 @@ def define_targets(rules): rules.py_library( - name = "codegen", + name = "torchgen", srcs = rules.glob(["**/*.py"]), deps = [ rules.requirement("PyYAML"), @@ -11,6 +11,6 @@ def define_targets(rules): rules.py_binary( name = "gen", - srcs = [":codegen"], + srcs = [":torchgen"], visibility = ["//visibility:public"], ) diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index c1a672a655..beee7a15e0 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -474,7 +474,7 @@ def run( ) -> None: # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py - pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute() + pytorch_root = pathlib.Path(__file__).parent.parent.absolute() template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates") def make_file_manager(install_dir: str) -> FileManager: ``` run_all_fbandroid_tests Test Plan: sandcastle Reviewed By: albanD, ngimel Differential Revision: D35770317 fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf (cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
102 lines
4.0 KiB
Python
102 lines
4.0 KiB
Python
# Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables.
|
|
#
|
|
# This writes one file: variable_factories.h
|
|
|
|
import re
|
|
from typing import Optional, List
|
|
|
|
from torchgen.api.types import CppSignatureGroup
|
|
from torchgen.api import cpp
|
|
import torchgen.api.python as python
|
|
from torchgen.gen import parse_native_yaml
|
|
from torchgen.context import with_native_function
|
|
from torchgen.utils import mapMaybe, FileManager
|
|
from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
|
|
|
|
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
|
|
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
|
|
|
|
# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc.
|
|
# TODO: maybe update the cpp argument API to take optional namespace argument?
|
|
def fully_qualified_type(argument_type: str) -> str:
|
|
def maybe_optional_type(type: str, is_opt: bool) -> str:
|
|
return f"c10::optional<{type}>" if is_opt else type
|
|
|
|
opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type)
|
|
is_opt = opt_match is not None
|
|
if opt_match:
|
|
argument_type = argument_type[opt_match.start(1) : opt_match.end(1)]
|
|
match = TYPE_PATTERN.match(argument_type)
|
|
if match is None:
|
|
return maybe_optional_type(argument_type, is_opt)
|
|
index = match.start(1)
|
|
qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}"
|
|
return maybe_optional_type(qualified_type, is_opt)
|
|
|
|
|
|
def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None:
|
|
native_functions = parse_native_yaml(native_yaml_path).native_functions
|
|
factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
|
|
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
|
fm.write_with_template(
|
|
"variable_factories.h",
|
|
"variable_factories.h",
|
|
lambda: {
|
|
"generated_comment": "@"
|
|
+ f"generated from {fm.template_dir}/variable_factories.h",
|
|
"ops_headers": [
|
|
f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions
|
|
],
|
|
"function_definitions": list(mapMaybe(process_function, factory_functions)),
|
|
},
|
|
)
|
|
|
|
|
|
@with_native_function
|
|
def is_factory_function(f: NativeFunction) -> bool:
|
|
if Variant.function not in f.variants:
|
|
return False
|
|
|
|
name = cpp.name(f.func)
|
|
has_tensor_options = python.has_tensor_options(f)
|
|
return has_tensor_options or name.endswith("_like")
|
|
|
|
|
|
@with_native_function
|
|
def process_function(f: NativeFunction) -> Optional[str]:
|
|
name = cpp.name(f.func)
|
|
has_tensor_options = python.has_tensor_options(f)
|
|
is_factory = has_tensor_options or name.endswith("_like")
|
|
|
|
if Variant.function not in f.variants or not is_factory:
|
|
return None
|
|
|
|
sig = CppSignatureGroup.from_native_function(f, method=False).signature
|
|
formals: List[str] = []
|
|
exprs: List[str] = []
|
|
requires_grad = "false"
|
|
for arg in sig.arguments():
|
|
qualified_type = fully_qualified_type(arg.type)
|
|
if arg.default:
|
|
formals.append(f"{qualified_type} {arg.name} = {arg.default}")
|
|
else:
|
|
formals.append(f"{qualified_type} {arg.name}")
|
|
|
|
if isinstance(arg.argument, TensorOptionsArguments):
|
|
# note: we remove the requires_grad setting from the TensorOptions because
|
|
# it is ignored anyways (and we actually have an assertion that it isn't set
|
|
# which would fail otherwise). We handle requires_grad explicitly here
|
|
# instead of passing it through to the kernel.
|
|
exprs.append(f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)")
|
|
# Manually set the requires_grad bit on the result tensor.
|
|
requires_grad = f"{arg.name}.requires_grad()"
|
|
else:
|
|
exprs.append(arg.name)
|
|
|
|
return f"""\
|
|
inline at::Tensor {name}({', '.join(formals)}) {{
|
|
at::AutoDispatchBelowADInplaceOrView guard;
|
|
return autograd::make_variable(at::{name}({', '.join(exprs)}), /*requires_grad=*/{requires_grad});
|
|
}}
|
|
"""
|