Make lazy tensor ptr class customizable (#76476)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76476

Test Plan: Imported from OSS

Reviewed By: Krovatkin, bdhirsh

Differential Revision: D35980433

Pulled By: wconstab

fbshipit-source-id: 1d4d00a494bf8aea86278b007f7f353cd7a822f8
(cherry picked from commit a78655bef23b5fa8487ced13443ca0bfdec65e5c)
This commit is contained in:
Will Constable
2022-04-27 20:16:46 -07:00
committed by PyTorch MergeBot
parent 4cae57080a
commit d0cb31d5bc
2 changed files with 6 additions and 3 deletions

View File

@ -253,6 +253,7 @@ class GenLazyNativeFuncDefinition:
create_aten_from_ltc_tensor: str
tuple_aten_from_ltc_tensors: str
lazy_value_class: str
lazy_tensor_ptr: str
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
@ -272,14 +273,14 @@ class GenLazyNativeFuncDefinition:
)
else:
lazy_tensor_decls.append(
f"{self.tensor_class}Ptr lazy_{arg.name} = "
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
)
elif isinstance(arg.lazy_type, OptionalCType):
# TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
# until we encounter a real world example.
lazy_tensor_decls.append(
f" {self.tensor_class}Ptr lazy_{arg.name} = "
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
)
else:
@ -367,7 +368,7 @@ class GenLazyNativeFuncDefinition:
{self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
if returns_length > 1:
bridge_str = f"""std::vector<{self.tensor_class}Ptr> lazy_tensors;
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
for (int i = 0; i < {returns_length}; i++) {{
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({self.lazy_value_class}(node, i), *common_device));
}}

View File

@ -279,6 +279,7 @@ def run_gen_lazy_tensor(
create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
lazy_value_class: str = "torch::lazy::Value",
lazy_tensor_ptr: str = "LazyTensorPtr",
) -> None:
template_dir = os.path.join(aten_path, "templates")
@ -482,6 +483,7 @@ def run_gen_lazy_tensor(
create_aten_from_ltc_tensor,
tuple_aten_from_ltc_tensors,
lazy_value_class,
lazy_tensor_ptr,
),
grouped_native_functions,
codegenInplaceVariant=True,