mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make lazy tensor ptr class customizable (#76476)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76476 Test Plan: Imported from OSS Reviewed By: Krovatkin, bdhirsh Differential Revision: D35980433 Pulled By: wconstab fbshipit-source-id: 1d4d00a494bf8aea86278b007f7f353cd7a822f8 (cherry picked from commit a78655bef23b5fa8487ced13443ca0bfdec65e5c)
This commit is contained in:
committed by
PyTorch MergeBot
parent
4cae57080a
commit
d0cb31d5bc
@ -253,6 +253,7 @@ class GenLazyNativeFuncDefinition:
|
||||
create_aten_from_ltc_tensor: str
|
||||
tuple_aten_from_ltc_tensors: str
|
||||
lazy_value_class: str
|
||||
lazy_tensor_ptr: str
|
||||
|
||||
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
||||
value_args = schema.filtered_args(values=True, scalars=False)
|
||||
@ -272,14 +273,14 @@ class GenLazyNativeFuncDefinition:
|
||||
)
|
||||
else:
|
||||
lazy_tensor_decls.append(
|
||||
f"{self.tensor_class}Ptr lazy_{arg.name} = "
|
||||
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
|
||||
f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
|
||||
)
|
||||
elif isinstance(arg.lazy_type, OptionalCType):
|
||||
# TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
|
||||
# until we encounter a real world example.
|
||||
lazy_tensor_decls.append(
|
||||
f" {self.tensor_class}Ptr lazy_{arg.name} = "
|
||||
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
|
||||
f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
|
||||
)
|
||||
else:
|
||||
@ -367,7 +368,7 @@ class GenLazyNativeFuncDefinition:
|
||||
{self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
|
||||
|
||||
if returns_length > 1:
|
||||
bridge_str = f"""std::vector<{self.tensor_class}Ptr> lazy_tensors;
|
||||
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
|
||||
for (int i = 0; i < {returns_length}; i++) {{
|
||||
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({self.lazy_value_class}(node, i), *common_device));
|
||||
}}
|
||||
|
@ -279,6 +279,7 @@ def run_gen_lazy_tensor(
|
||||
create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
|
||||
tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
|
||||
lazy_value_class: str = "torch::lazy::Value",
|
||||
lazy_tensor_ptr: str = "LazyTensorPtr",
|
||||
) -> None:
|
||||
|
||||
template_dir = os.path.join(aten_path, "templates")
|
||||
@ -482,6 +483,7 @@ def run_gen_lazy_tensor(
|
||||
create_aten_from_ltc_tensor,
|
||||
tuple_aten_from_ltc_tensors,
|
||||
lazy_value_class,
|
||||
lazy_tensor_ptr,
|
||||
),
|
||||
grouped_native_functions,
|
||||
codegenInplaceVariant=True,
|
||||
|
Reference in New Issue
Block a user