mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Revert "[WIP] customize the C++ class for valueT"" (#77003)
This reverts commit ec841b0346ade6664d13d5d0263b8e6990bf4d95. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/77003 Approved by: https://github.com/shunting314, https://github.com/JackCaoG
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6341d2ce5
commit
daf8c48a87
@ -16,12 +16,16 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
from torchgen.api.types import BaseCppType
|
||||
from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR
|
||||
from torchgen.gen import (
|
||||
get_grouped_native_functions,
|
||||
parse_native_yaml,
|
||||
NamespaceHelper,
|
||||
)
|
||||
|
||||
from torchgen.api.lazy import setValueT
|
||||
|
||||
from torchgen.model import (
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
@ -281,7 +285,10 @@ def run_gen_lazy_tensor(
|
||||
lazy_value_class: str = "torch::lazy::Value",
|
||||
lazy_tensor_ptr: str = "LazyTensorPtr",
|
||||
) -> None:
|
||||
|
||||
lv_tokens = lazy_value_class.split("::")
|
||||
lv_class = lv_tokens[-1]
|
||||
lv_ns = "::".join(lv_tokens[:-1])
|
||||
setValueT(BaseCppType(lv_ns, lv_class))
|
||||
template_dir = os.path.join(aten_path, "templates")
|
||||
|
||||
def make_file_manager(install_dir: str) -> FileManager:
|
||||
@ -483,7 +490,6 @@ def run_gen_lazy_tensor(
|
||||
create_from_first_tensor,
|
||||
create_aten_from_ltc_tensor,
|
||||
tuple_aten_from_ltc_tensors,
|
||||
lazy_value_class,
|
||||
lazy_tensor_ptr,
|
||||
),
|
||||
grouped_native_functions,
|
||||
|
Reference in New Issue
Block a user