mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
Revert "Revert "[WIP] customize the C++ class for valueT"" (#77003)
This reverts commit ec841b0346ade6664d13d5d0263b8e6990bf4d95. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/77003 Approved by: https://github.com/shunting314, https://github.com/JackCaoG
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6341d2ce5
commit
daf8c48a87
@ -31,7 +31,25 @@ from torchgen.api.types import (
|
||||
SymIntT,
|
||||
)
|
||||
|
||||
valueT = BaseCppType("torch::lazy", "Value")
|
||||
|
||||
_valueT = None
|
||||
|
||||
|
||||
def getValueT() -> BaseCppType:
|
||||
global _valueT
|
||||
if not _valueT:
|
||||
raise NotImplementedError(
|
||||
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
|
||||
)
|
||||
|
||||
return _valueT
|
||||
|
||||
|
||||
def setValueT(val: BaseCppType) -> None:
|
||||
global _valueT
|
||||
_valueT = val
|
||||
|
||||
|
||||
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
|
||||
# making it easier to represent special properties of an arg.
|
||||
tensorListValueT = BaseCppType("torch::lazy", "Value")
|
||||
@ -57,17 +75,17 @@ def process_ir_type(
|
||||
"""
|
||||
if isinstance(typ, BaseType):
|
||||
if typ.name == BaseTy.Tensor:
|
||||
return BaseCType(valueT)
|
||||
return BaseCType(getValueT())
|
||||
elif typ.name == BaseTy.Scalar:
|
||||
# at::scalar has special handling,
|
||||
# and is wrapped in an lazy::Value just like at::tensor
|
||||
return BaseCType(valueT)
|
||||
return BaseCType(getValueT())
|
||||
elif typ.name == BaseTy.ScalarType:
|
||||
return BaseCType(scalarTypeT)
|
||||
elif typ.name == BaseTy.int:
|
||||
return BaseCType(longT)
|
||||
elif typ.name == BaseTy.SymInt:
|
||||
return BaseCType(valueT)
|
||||
return BaseCType(getValueT())
|
||||
elif typ.name == BaseTy.bool:
|
||||
return BaseCType(boolT)
|
||||
elif typ.name == BaseTy.float:
|
||||
@ -87,7 +105,7 @@ def process_ir_type(
|
||||
elif isinstance(typ, ListType):
|
||||
if str(typ.elem) == "Tensor?":
|
||||
# TODO(whc) is this actually correct? or should it use a Vector like above
|
||||
return ListCType(OptionalCType(BaseCType(valueT)))
|
||||
return ListCType(OptionalCType(BaseCType(getValueT())))
|
||||
elif str(typ.elem) == "Tensor":
|
||||
# this is a TensorList which comes in from GetTensorList as a Value
|
||||
return BaseCType(tensorListValueT)
|
||||
@ -105,7 +123,7 @@ def isValueType(typ: CType) -> bool:
|
||||
if isinstance(typ, BaseCType):
|
||||
# I am regretting my naming conventions, but now we are wrapping at::scalar in
|
||||
# lazy value, while preserving other 'scalar' types as scalars in the IR
|
||||
return typ.type == valueT or typ.type == scalarT or typ.type == SymIntT
|
||||
return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
|
||||
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
|
||||
return isValueType(typ.elem)
|
||||
return False
|
||||
|
||||
@ -13,6 +13,7 @@ import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.lazy import (
|
||||
LazyIrSchema,
|
||||
LazyArgument,
|
||||
getValueT,
|
||||
isValueType,
|
||||
tensorListValueT,
|
||||
)
|
||||
@ -33,7 +34,11 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
|
||||
elif arg.lazy_type.type is tensorListValueT:
|
||||
return f"lazy_{arg.name}_tensorlist"
|
||||
elif arg.is_symint_or_list:
|
||||
return f"Value(std::dynamic_pointer_cast<torch::lazy::SymbolicIntNode>({arg.name}.toSymbolicIntNode())->node_, 0)"
|
||||
cpp_type = arg.lazy_type.cpp_type()
|
||||
return (
|
||||
f"{cpp_type}(std::dynamic_pointer_cast<torch::lazy::SymbolicIntNode>"
|
||||
f"({arg.name}.toSymbolicIntNode())->node_, 0)"
|
||||
)
|
||||
return f"lazy_{arg.name}->GetIrValue()"
|
||||
elif isinstance(arg.lazy_type, OptionalCType):
|
||||
if arg.is_wrapped_scalar:
|
||||
@ -263,7 +268,6 @@ class GenLazyNativeFuncDefinition:
|
||||
create_from_first_tensor: bool
|
||||
create_aten_from_ltc_tensor: str
|
||||
tuple_aten_from_ltc_tensors: str
|
||||
lazy_value_class: str
|
||||
lazy_tensor_ptr: str
|
||||
|
||||
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
||||
@ -381,7 +385,7 @@ class GenLazyNativeFuncDefinition:
|
||||
if returns_length > 1:
|
||||
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
|
||||
for (int i = 0; i < {returns_length}; i++) {{
|
||||
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({self.lazy_value_class}(node, i), *common_device));
|
||||
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
|
||||
}}
|
||||
auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
|
||||
|
||||
|
||||
@ -16,12 +16,16 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
from torchgen.api.types import BaseCppType
|
||||
from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR
|
||||
from torchgen.gen import (
|
||||
get_grouped_native_functions,
|
||||
parse_native_yaml,
|
||||
NamespaceHelper,
|
||||
)
|
||||
|
||||
from torchgen.api.lazy import setValueT
|
||||
|
||||
from torchgen.model import (
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
@ -281,7 +285,10 @@ def run_gen_lazy_tensor(
|
||||
lazy_value_class: str = "torch::lazy::Value",
|
||||
lazy_tensor_ptr: str = "LazyTensorPtr",
|
||||
) -> None:
|
||||
|
||||
lv_tokens = lazy_value_class.split("::")
|
||||
lv_class = lv_tokens[-1]
|
||||
lv_ns = "::".join(lv_tokens[:-1])
|
||||
setValueT(BaseCppType(lv_ns, lv_class))
|
||||
template_dir = os.path.join(aten_path, "templates")
|
||||
|
||||
def make_file_manager(install_dir: str) -> FileManager:
|
||||
@ -483,7 +490,6 @@ def run_gen_lazy_tensor(
|
||||
create_from_first_tensor,
|
||||
create_aten_from_ltc_tensor,
|
||||
tuple_aten_from_ltc_tensors,
|
||||
lazy_value_class,
|
||||
lazy_tensor_ptr,
|
||||
),
|
||||
grouped_native_functions,
|
||||
|
||||
Reference in New Issue
Block a user