mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[TorchGen] Use std::optional in generated code (#121454)
This PR changes TorchGen to generate std::optional. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121454 Approved by: https://github.com/ezyang
This commit is contained in:
@ -62,9 +62,9 @@ from torchgen.model import (
|
||||
# Note: the scattered TensorOptions fields are packed into 'options'.
|
||||
#
|
||||
# auto dispatch_empty =
|
||||
# [](IntArrayRef size, c10::optional<DimnameList> names,
|
||||
# [](IntArrayRef size, std::optional<DimnameList> names,
|
||||
# const TensorOptions & options,
|
||||
# c10::optional<MemoryFormat> memory_format) -> Tensor {
|
||||
# std::optional<MemoryFormat> memory_format) -> Tensor {
|
||||
# pybind11::gil_scoped_release no_gil;
|
||||
# return torch::empty(size, names, options, memory_format);
|
||||
# };
|
||||
@ -93,9 +93,9 @@ from torchgen.model import (
|
||||
# Where does 'names' come from? It involves special local init:
|
||||
#
|
||||
# auto __names = _r.toDimnameListOptional(1);
|
||||
# c10::optional<DimnameList> names =
|
||||
# __names ? c10::make_optional(DimnameList(__names.value()))
|
||||
# : c10::nullopt;
|
||||
# std::optional<DimnameList> names =
|
||||
# __names ? std::make_optional(DimnameList(__names.value()))
|
||||
# : std::nullopt;
|
||||
#
|
||||
# Where does 'options' come from? It involves special local init
|
||||
# for TensorOptions. Note that Python side has the additional
|
||||
@ -235,6 +235,8 @@ class PythonArgument:
|
||||
default = {
|
||||
"nullptr": "None",
|
||||
"c10::nullopt": "None",
|
||||
"::std::nullopt": "None",
|
||||
"std::nullopt": "None",
|
||||
"{}": "None",
|
||||
}.get(self.default, self.default)
|
||||
return f"{type_str} {name}={default}"
|
||||
@ -280,6 +282,8 @@ class PythonArgument:
|
||||
default = {
|
||||
"nullptr": "None",
|
||||
"c10::nullopt": "None",
|
||||
"::std::nullopt": "None",
|
||||
"std::nullopt": "None",
|
||||
"{}": "None",
|
||||
"MemoryFormat::Contiguous": "contiguous_format",
|
||||
"QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
|
||||
@ -697,9 +701,9 @@ def argument_type_str(
|
||||
return f"ScalarList[{size}]" if size is not None else "ScalarList"
|
||||
elif str(t.elem) == "Tensor?":
|
||||
if simple_type:
|
||||
return "c10::List<c10::optional<Tensor>>"
|
||||
return "c10::List<::std::optional<Tensor>>"
|
||||
else:
|
||||
return "const c10::List<c10::optional<Tensor>> &"
|
||||
return "const c10::List<::std::optional<Tensor>> &"
|
||||
elif str(t.elem) == "Dimname":
|
||||
return f"DimnameList[{size}]" if size is not None else "DimnameList"
|
||||
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
|
||||
@ -1308,7 +1312,13 @@ def arg_parser_unpack_method(
|
||||
return "generator"
|
||||
elif str(t.elem) == "Dimname[]":
|
||||
return "toDimnameListOptional"
|
||||
elif not has_default_init and default in (None, "None", "c10::nullopt"):
|
||||
elif not has_default_init and default in (
|
||||
None,
|
||||
"None",
|
||||
"c10::nullopt",
|
||||
"::std::nullopt",
|
||||
"std::nullopt",
|
||||
):
|
||||
# If default is None: append 'Optional' to elem's unpacking method
|
||||
return (
|
||||
arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
|
||||
@ -1430,7 +1440,7 @@ def dispatch_lambda_exprs(
|
||||
inits.extend(
|
||||
[
|
||||
f"auto __{name} = {arg_parser_expr};",
|
||||
f"c10::optional<DimnameList> {name} = __{name} ? c10::make_optional(DimnameList(__{name}.value())) : c10::nullopt;", # noqa: B950
|
||||
f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950
|
||||
]
|
||||
)
|
||||
lambda_args_exprs[name] = name
|
||||
|
Reference in New Issue
Block a user