[TorchGen] Use std::optional in generated code (#121454)

This PR changes TorchGen to generate std::optional.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121454
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-03-29 14:11:09 +00:00
committed by PyTorch MergeBot
parent 375a8041ed
commit fb90b4d4b2
21 changed files with 202 additions and 192 deletions

View File

@ -62,9 +62,9 @@ from torchgen.model import (
# Note: the scattered TensorOptions fields are packed into 'options'.
#
# auto dispatch_empty =
# [](IntArrayRef size, c10::optional<DimnameList> names,
# [](IntArrayRef size, std::optional<DimnameList> names,
# const TensorOptions & options,
# c10::optional<MemoryFormat> memory_format) -> Tensor {
# std::optional<MemoryFormat> memory_format) -> Tensor {
# pybind11::gil_scoped_release no_gil;
# return torch::empty(size, names, options, memory_format);
# };
@ -93,9 +93,9 @@ from torchgen.model import (
# Where does 'names' come from? It involves special local init:
#
# auto __names = _r.toDimnameListOptional(1);
# c10::optional<DimnameList> names =
# __names ? c10::make_optional(DimnameList(__names.value()))
# : c10::nullopt;
# std::optional<DimnameList> names =
# __names ? std::make_optional(DimnameList(__names.value()))
# : std::nullopt;
#
# Where does 'options' come from? It involves special local init
# for TensorOptions. Note that Python side has the additional
@ -235,6 +235,8 @@ class PythonArgument:
default = {
"nullptr": "None",
"c10::nullopt": "None",
"::std::nullopt": "None",
"std::nullopt": "None",
"{}": "None",
}.get(self.default, self.default)
return f"{type_str} {name}={default}"
@ -280,6 +282,8 @@ class PythonArgument:
default = {
"nullptr": "None",
"c10::nullopt": "None",
"::std::nullopt": "None",
"std::nullopt": "None",
"{}": "None",
"MemoryFormat::Contiguous": "contiguous_format",
"QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
@ -697,9 +701,9 @@ def argument_type_str(
return f"ScalarList[{size}]" if size is not None else "ScalarList"
elif str(t.elem) == "Tensor?":
if simple_type:
return "c10::List<c10::optional<Tensor>>"
return "c10::List<::std::optional<Tensor>>"
else:
return "const c10::List<c10::optional<Tensor>> &"
return "const c10::List<::std::optional<Tensor>> &"
elif str(t.elem) == "Dimname":
return f"DimnameList[{size}]" if size is not None else "DimnameList"
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
@ -1308,7 +1312,13 @@ def arg_parser_unpack_method(
return "generator"
elif str(t.elem) == "Dimname[]":
return "toDimnameListOptional"
elif not has_default_init and default in (None, "None", "c10::nullopt"):
elif not has_default_init and default in (
None,
"None",
"c10::nullopt",
"::std::nullopt",
"std::nullopt",
):
# If default is None: append 'Optional' to elem's unpacking method
return (
arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
@ -1430,7 +1440,7 @@ def dispatch_lambda_exprs(
inits.extend(
[
f"auto __{name} = {arg_parser_expr};",
f"c10::optional<DimnameList> {name} = __{name} ? c10::make_optional(DimnameList(__{name}.value())) : c10::nullopt;", # noqa: B950
f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950
]
)
lambda_args_exprs[name] = name