Revert "Revert "[WIP] customize the C++ class for valueT"" (#77003)

This reverts commit ec841b0346ade6664d13d5d0263b8e6990bf4d95.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77003
Approved by: https://github.com/shunting314, https://github.com/JackCaoG
This commit is contained in:
Nikolay Korovaiko
2022-05-09 17:40:17 +00:00
committed by PyTorch MergeBot
parent a6341d2ce5
commit daf8c48a87
3 changed files with 39 additions and 11 deletions

View File

@ -16,12 +16,16 @@ from typing import (
Tuple,
Type,
)
from torchgen.api.types import BaseCppType
from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR
from torchgen.gen import (
get_grouped_native_functions,
parse_native_yaml,
NamespaceHelper,
)
from torchgen.api.lazy import setValueT
from torchgen.model import (
FunctionSchema,
NativeFunction,
@ -281,7 +285,10 @@ def run_gen_lazy_tensor(
lazy_value_class: str = "torch::lazy::Value",
lazy_tensor_ptr: str = "LazyTensorPtr",
) -> None:
lv_tokens = lazy_value_class.split("::")
lv_class = lv_tokens[-1]
lv_ns = "::".join(lv_tokens[:-1])
setValueT(BaseCppType(lv_ns, lv_class))
template_dir = os.path.join(aten_path, "templates")
def make_file_manager(install_dir: str) -> FileManager:
@ -483,7 +490,6 @@ def run_gen_lazy_tensor(
create_from_first_tensor,
create_aten_from_ltc_tensor,
tuple_aten_from_ltc_tensors,
lazy_value_class,
lazy_tensor_ptr,
),
grouped_native_functions,