mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275 In preparation for addressing https://github.com/pytorch/pytorch/issues/73212 Diff was generated with: ``` git mv tools/codegen torchgen git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g' sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt ``` and a manual edits to: * tools/test/test_gen_backend_stubs.py * torchgen/build.bzl * torchgen/gen_backend_stubs.py aka this diff: ``` diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 3dc26c6d2d..104054575e 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401 path = os.path.dirname(os.path.realpath(__file__)) -gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py') +gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py') # gen_backend_stubs.py is an integration point that is called directly by external backends. # The tests here are to confirm that badly formed inputs result in reasonable error messages. diff --git a/torchgen/build.bzl b/torchgen/build.bzl index ed04e35a43..d00078a3cf 100644 --- a/torchgen/build.bzl +++ b/torchgen/build.bzl @@ -1,6 +1,6 @@ def define_targets(rules): rules.py_library( - name = "codegen", + name = "torchgen", srcs = rules.glob(["**/*.py"]), deps = [ rules.requirement("PyYAML"), @@ -11,6 +11,6 @@ def define_targets(rules): rules.py_binary( name = "gen", - srcs = [":codegen"], + srcs = [":torchgen"], visibility = ["//visibility:public"], ) diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index c1a672a655..beee7a15e0 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -474,7 +474,7 @@ def run( ) -> None: # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py - pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute() + pytorch_root = pathlib.Path(__file__).parent.parent.absolute() template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates") def make_file_manager(install_dir: str) -> FileManager: ``` run_all_fbandroid_tests Test Plan: sandcastle Reviewed By: albanD, ngimel Differential Revision: D35770317 fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf (cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
132 lines
4.1 KiB
Python
132 lines
4.1 KiB
Python
"""
|
|
To run this file by hand from the root of the PyTorch
|
|
repository, run:
|
|
|
|
python -m tools.autograd.gen_autograd \
|
|
build/aten/src/ATen/Declarations.yaml \
|
|
aten/src/ATen/native/native_functions.yaml \
|
|
$OUTPUT_DIR \
|
|
tools/autograd
|
|
|
|
Where $OUTPUT_DIR is where you would like the files to be
|
|
generated. In the full build system, OUTPUT_DIR is
|
|
torch/csrc/autograd/generated/
|
|
"""
|
|
|
|
# gen_autograd.py generates C++ autograd functions and Python bindings.
|
|
#
|
|
# It delegates to the following scripts:
|
|
#
|
|
# gen_autograd_functions.py: generates subclasses of torch::autograd::Node
|
|
# gen_variable_type.py: generates VariableType.h which contains all tensor methods
|
|
# gen_python_functions.py: generates Python bindings to THPVariable
|
|
#
|
|
|
|
import argparse
|
|
import os
|
|
from torchgen.api import cpp
|
|
from torchgen.api.autograd import (
|
|
match_differentiability_info,
|
|
NativeFunctionWithDifferentiabilityInfo,
|
|
)
|
|
from torchgen.gen import parse_native_yaml
|
|
from torchgen.selective_build.selector import SelectiveBuilder
|
|
from typing import List
|
|
from . import gen_python_functions
|
|
from .gen_autograd_functions import (
|
|
gen_autograd_functions_lib,
|
|
gen_autograd_functions_python,
|
|
)
|
|
from .gen_trace_type import gen_trace_type
|
|
from .gen_variable_type import gen_variable_type
|
|
from .gen_inplace_or_view_type import gen_inplace_or_view_type
|
|
from .gen_variable_factories import gen_variable_factories
|
|
from .load_derivatives import load_derivatives
|
|
|
|
|
|
def gen_autograd(
|
|
native_functions_path: str,
|
|
out: str,
|
|
autograd_dir: str,
|
|
operator_selector: SelectiveBuilder,
|
|
disable_autograd: bool = False,
|
|
) -> None:
|
|
# Parse and load derivatives.yaml
|
|
differentiability_infos = load_derivatives(
|
|
os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path
|
|
)
|
|
|
|
template_path = os.path.join(autograd_dir, "templates")
|
|
|
|
native_funcs = parse_native_yaml(native_functions_path).native_functions
|
|
fns = list(
|
|
sorted(
|
|
filter(
|
|
operator_selector.is_native_function_selected_for_training, native_funcs
|
|
),
|
|
key=lambda f: cpp.name(f.func),
|
|
)
|
|
)
|
|
fns_with_diff_infos: List[
|
|
NativeFunctionWithDifferentiabilityInfo
|
|
] = match_differentiability_info(fns, differentiability_infos)
|
|
|
|
# Generate VariableType.h/cpp
|
|
if not disable_autograd:
|
|
gen_variable_type(
|
|
out, native_functions_path, fns_with_diff_infos, template_path
|
|
)
|
|
|
|
gen_inplace_or_view_type(
|
|
out, native_functions_path, fns_with_diff_infos, template_path
|
|
)
|
|
|
|
# operator filter not applied as tracing sources are excluded in selective build
|
|
gen_trace_type(out, native_funcs, template_path)
|
|
# Generate Functions.h/cpp
|
|
gen_autograd_functions_lib(out, differentiability_infos, template_path)
|
|
|
|
# Generate variable_factories.h
|
|
gen_variable_factories(out, native_functions_path, template_path)
|
|
|
|
|
|
def gen_autograd_python(
|
|
native_functions_path: str,
|
|
out: str,
|
|
autograd_dir: str,
|
|
) -> None:
|
|
differentiability_infos = load_derivatives(
|
|
os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path
|
|
)
|
|
|
|
template_path = os.path.join(autograd_dir, "templates")
|
|
|
|
# Generate Functions.h/cpp
|
|
gen_autograd_functions_python(out, differentiability_infos, template_path)
|
|
|
|
# Generate Python bindings
|
|
deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
|
|
gen_python_functions.gen(out, native_functions_path, deprecated_path, template_path)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
|
|
parser.add_argument(
|
|
"native_functions", metavar="NATIVE", help="path to native_functions.yaml"
|
|
)
|
|
parser.add_argument("out", metavar="OUT", help="path to output directory")
|
|
parser.add_argument(
|
|
"autograd", metavar="AUTOGRAD", help="path to autograd directory"
|
|
)
|
|
args = parser.parse_args()
|
|
gen_autograd(
|
|
args.native_functions,
|
|
args.out,
|
|
args.autograd,
|
|
SelectiveBuilder.get_nop_selector(),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|