Files
pytorch/torchgen/shape_functions/gen_jit_shape_functions.py
Edward Yang 36420b5e8c Rename tools/codegen to torchgen (#76275)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275

In preparation for addressing
https://github.com/pytorch/pytorch/issues/73212

Diff was generated with:

```
git mv tools/codegen torchgen
git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g'
sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt
```

and a manual edits to:

* tools/test/test_gen_backend_stubs.py
* torchgen/build.bzl
* torchgen/gen_backend_stubs.py

aka this diff:

```
 diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py
index 3dc26c6d2d..104054575e 100644
 --- a/tools/test/test_gen_backend_stubs.py
+++ b/tools/test/test_gen_backend_stubs.py
@@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run
 from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE  # noqa: F401

 path = os.path.dirname(os.path.realpath(__file__))
-gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py')
+gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py')

 # gen_backend_stubs.py is an integration point that is called directly by external backends.
 # The tests here are to confirm that badly formed inputs result in reasonable error messages.
 diff --git a/torchgen/build.bzl b/torchgen/build.bzl
index ed04e35a43..d00078a3cf 100644
 --- a/torchgen/build.bzl
+++ b/torchgen/build.bzl
@@ -1,6 +1,6 @@
 def define_targets(rules):
     rules.py_library(
-        name = "codegen",
+        name = "torchgen",
         srcs = rules.glob(["**/*.py"]),
         deps = [
             rules.requirement("PyYAML"),
@@ -11,6 +11,6 @@ def define_targets(rules):

     rules.py_binary(
         name = "gen",
-        srcs = [":codegen"],
+        srcs = [":torchgen"],
         visibility = ["//visibility:public"],
     )
 diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py
index c1a672a655..beee7a15e0 100644
 --- a/torchgen/gen_backend_stubs.py
+++ b/torchgen/gen_backend_stubs.py
@@ -474,7 +474,7 @@ def run(
 ) -> None:

     # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
-    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
+    pytorch_root = pathlib.Path(__file__).parent.parent.absolute()
     template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

     def make_file_manager(install_dir: str) -> FileManager:
```

run_all_fbandroid_tests

Test Plan: sandcastle

Reviewed By: albanD, ngimel

Differential Revision: D35770317

fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf
(cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
2022-04-25 01:38:06 +00:00

124 lines
3.3 KiB
Python

#!/usr/bin/env python3
import os
from pathlib import Path
from torch.jit._shape_functions import shape_compute_graph_mapping
SHAPE_HEADER = r"""
/**
* @generated
* This is an auto-generated file. Please do not modify it by hand.
* To re-generate, please run:
* cd ~/pytorch && python
* torchgen/shape_functions/gen_jit_shape_functions.py
*/
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
#include <torch/csrc/jit/runtime/operator.h>
// clang-format off
namespace torch {
namespace jit {
std::string shape_funcs = ""
"""
DECOMP_CENTER = r"""
const std::string& GetSerializedShapeFunctions() {
return shape_funcs;
}
const OperatorMap<std::string>& GetShapeFunctionMappings() {
static const OperatorMap<std::string> shape_mappings {
"""
DECOMP_END = r"""
};
return shape_mappings;
}
// clang-format on
} // namespace jit
} // namespace torch
"""
SERIALIZED_SHAPE_UTIL_FILE_NAME = "serialized_shape_function_registry.cpp"
def gen_serialized_decompisitions() -> str:
already_serialized_names = set()
unique_funcs = []
for scripted_func in shape_compute_graph_mapping.values():
if scripted_func.name in already_serialized_names:
continue
already_serialized_names.add(scripted_func.name)
unique_funcs.append(scripted_func)
output_strs = []
curr_str = ""
for scripted_func in unique_funcs:
serialized_code = scripted_func.code
# technically its higher but give a buffer bc there are weird rules
# around some characters
# TODO: this was the limit I found by googling but it seems way
# too short ?
MAX_MSFT_STR_LEN = 2000
if len(curr_str) + len(serialized_code) <= MAX_MSFT_STR_LEN:
curr_str += "\n" + serialized_code
else:
output_strs.append(curr_str)
curr_str = scripted_func.code
output_strs.append(curr_str)
final_output = ""
# Windows compiler doesnt correctly handle adjacent
# string literals
for output_str in output_strs:
start = '+ std::string(R"=====('
end = '\n)=====")\n'
final_output += start + output_str + end
final_output += ";"
return final_output
def gen_shape_mappings() -> str:
shape_mappings = []
for schema, scripted_func in shape_compute_graph_mapping.items():
shape_mappings.append(' {"' + schema + '", "' + scripted_func.name + '"},')
return "\n".join(shape_mappings)
def write_decomposition_util_file(path: str) -> None:
decomposition_str = gen_serialized_decompisitions()
shape_mappings = gen_shape_mappings()
file_components = [
SHAPE_HEADER,
decomposition_str,
DECOMP_CENTER,
shape_mappings,
DECOMP_END,
]
print("writing file to : ", path + "/" + SERIALIZED_SHAPE_UTIL_FILE_NAME)
with open(os.path.join(path, SERIALIZED_SHAPE_UTIL_FILE_NAME), "wb") as out_file:
final_output = "".join(file_components)
out_file.write(final_output.encode("utf-8"))
def main() -> None:
pytorch_dir = Path(__file__).resolve().parents[3]
upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
write_decomposition_util_file(str(upgrader_path))
if __name__ == "__main__":
main()