Revert "Revert "[WIP] customize the C++ class for valueT"" (#77003)

This reverts commit ec841b0346ade6664d13d5d0263b8e6990bf4d95.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77003
Approved by: https://github.com/shunting314, https://github.com/JackCaoG
This commit is contained in:
Nikolay Korovaiko
2022-05-09 17:40:17 +00:00
committed by PyTorch MergeBot
parent a6341d2ce5
commit daf8c48a87
3 changed files with 39 additions and 11 deletions

View File

@ -31,7 +31,25 @@ from torchgen.api.types import (
SymIntT,
)
valueT = BaseCppType("torch::lazy", "Value")
_valueT = None
def getValueT() -> BaseCppType:
global _valueT
if not _valueT:
raise NotImplementedError(
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
)
return _valueT
def setValueT(val: BaseCppType) -> None:
global _valueT
_valueT = val
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
# making it easier to represent special properties of an arg.
tensorListValueT = BaseCppType("torch::lazy", "Value")
@ -57,17 +75,17 @@ def process_ir_type(
"""
if isinstance(typ, BaseType):
if typ.name == BaseTy.Tensor:
return BaseCType(valueT)
return BaseCType(getValueT())
elif typ.name == BaseTy.Scalar:
# at::scalar has special handling,
# and is wrapped in an lazy::Value just like at::tensor
return BaseCType(valueT)
return BaseCType(getValueT())
elif typ.name == BaseTy.ScalarType:
return BaseCType(scalarTypeT)
elif typ.name == BaseTy.int:
return BaseCType(longT)
elif typ.name == BaseTy.SymInt:
return BaseCType(valueT)
return BaseCType(getValueT())
elif typ.name == BaseTy.bool:
return BaseCType(boolT)
elif typ.name == BaseTy.float:
@ -87,7 +105,7 @@ def process_ir_type(
elif isinstance(typ, ListType):
if str(typ.elem) == "Tensor?":
# TODO(whc) is this actually correct? or should it use a Vector like above
return ListCType(OptionalCType(BaseCType(valueT)))
return ListCType(OptionalCType(BaseCType(getValueT())))
elif str(typ.elem) == "Tensor":
# this is a TensorList which comes in from GetTensorList as a Value
return BaseCType(tensorListValueT)
@ -105,7 +123,7 @@ def isValueType(typ: CType) -> bool:
if isinstance(typ, BaseCType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
return typ.type == valueT or typ.type == scalarT or typ.type == SymIntT
return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
return isValueType(typ.elem)
return False

View File

@ -13,6 +13,7 @@ import torchgen.api.dispatcher as dispatcher
from torchgen.api.lazy import (
LazyIrSchema,
LazyArgument,
getValueT,
isValueType,
tensorListValueT,
)
@ -33,7 +34,11 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
elif arg.lazy_type.type is tensorListValueT:
return f"lazy_{arg.name}_tensorlist"
elif arg.is_symint_or_list:
return f"Value(std::dynamic_pointer_cast<torch::lazy::SymbolicIntNode>({arg.name}.toSymbolicIntNode())->node_, 0)"
cpp_type = arg.lazy_type.cpp_type()
return (
f"{cpp_type}(std::dynamic_pointer_cast<torch::lazy::SymbolicIntNode>"
f"({arg.name}.toSymbolicIntNode())->node_, 0)"
)
return f"lazy_{arg.name}->GetIrValue()"
elif isinstance(arg.lazy_type, OptionalCType):
if arg.is_wrapped_scalar:
@ -263,7 +268,6 @@ class GenLazyNativeFuncDefinition:
create_from_first_tensor: bool
create_aten_from_ltc_tensor: str
tuple_aten_from_ltc_tensors: str
lazy_value_class: str
lazy_tensor_ptr: str
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
@ -381,7 +385,7 @@ class GenLazyNativeFuncDefinition:
if returns_length > 1:
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
for (int i = 0; i < {returns_length}; i++) {{
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({self.lazy_value_class}(node, i), *common_device));
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
}}
auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""

View File

@ -16,12 +16,16 @@ from typing import (
Tuple,
Type,
)
from torchgen.api.types import BaseCppType
from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR
from torchgen.gen import (
get_grouped_native_functions,
parse_native_yaml,
NamespaceHelper,
)
from torchgen.api.lazy import setValueT
from torchgen.model import (
FunctionSchema,
NativeFunction,
@ -281,7 +285,10 @@ def run_gen_lazy_tensor(
lazy_value_class: str = "torch::lazy::Value",
lazy_tensor_ptr: str = "LazyTensorPtr",
) -> None:
lv_tokens = lazy_value_class.split("::")
lv_class = lv_tokens[-1]
lv_ns = "::".join(lv_tokens[:-1])
setValueT(BaseCppType(lv_ns, lv_class))
template_dir = os.path.join(aten_path, "templates")
def make_file_manager(install_dir: str) -> FileManager:
@ -483,7 +490,6 @@ def run_gen_lazy_tensor(
create_from_first_tensor,
create_aten_from_ltc_tensor,
tuple_aten_from_ltc_tensors,
lazy_value_class,
lazy_tensor_ptr,
),
grouped_native_functions,