mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland] Use std::string_view in torchgen (#158625)
Reland of #157050, which is incidentally closed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158625 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
52af91e4c1
commit
972e409829
@ -940,8 +940,6 @@ def generate_tensor_like_override_tests(cls):
|
||||
return None
|
||||
elif arg_type == "ScalarType":
|
||||
return torch.float32
|
||||
elif arg_type == "c10::string_view":
|
||||
return ""
|
||||
elif arg_type in ("std::string_view", "::std::string_view"):
|
||||
return ""
|
||||
elif arg_type == "SymInt":
|
||||
|
@ -969,7 +969,7 @@ def saved_variables(
|
||||
if nctype.type == OptionalCType(BaseCType(stringT)):
|
||||
formula = re.sub(
|
||||
rf"\b{name}\b",
|
||||
f"{name}.has_value() ? std::optional<std::string_view>({name}.value()) : std::nullopt",
|
||||
f"{name}.has_value() ? std::optional<::std::string_view>({name}.value()) : std::nullopt",
|
||||
formula,
|
||||
)
|
||||
|
||||
|
@ -46,7 +46,6 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
||||
{"DeviceIndex", ParameterType::INT64},
|
||||
{"Stream", ParameterType::STREAM},
|
||||
{"std::string", ParameterType::STRING},
|
||||
{"c10::string_view", ParameterType::STRING},
|
||||
{"std::string_view", ParameterType::STRING},
|
||||
{"::std::string_view", ParameterType::STRING},
|
||||
{"Dimname", ParameterType::DIMNAME},
|
||||
|
@ -683,7 +683,7 @@ def argument_type_str(
|
||||
elif t.name == BaseTy.float:
|
||||
return "double"
|
||||
elif t.name == BaseTy.str:
|
||||
return "c10::string_view"
|
||||
return "std::string_view"
|
||||
elif t.name in [
|
||||
BaseTy.Tensor,
|
||||
BaseTy.bool,
|
||||
|
@ -52,7 +52,7 @@ float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz")
|
||||
float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
|
||||
float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz")
|
||||
float8_e8m0fnuT = BaseCppType("at", "Float8_e8m0fnu")
|
||||
stringT = BaseCppType("c10", "string_view")
|
||||
stringT = BaseCppType("std", "string_view")
|
||||
generatorT = BaseCppType("at", "Generator")
|
||||
scalarTypeT = BaseCppType("at", "ScalarType")
|
||||
tensorT = BaseCppType("at", "Tensor")
|
||||
|
@ -81,6 +81,8 @@ class BaseCType(CType):
|
||||
type: BaseCppType
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
if self.type.ns == "std":
|
||||
return "::" + str(self.type)
|
||||
return str(self.type)
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
|
@ -256,7 +256,7 @@ class GenLazyIR(ABC):
|
||||
[
|
||||
# This code is just special casing the mapping from string_view -> strings
|
||||
f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
|
||||
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
|
||||
if a.lazy_type.cpp_type() == "::std::optional<::std::string_view>"
|
||||
else f"{a.name}({a.name})"
|
||||
for a in scalar_args
|
||||
]
|
||||
@ -266,9 +266,9 @@ class GenLazyIR(ABC):
|
||||
scalar_decls = "\n ".join(
|
||||
[
|
||||
f"std::string {a.name};"
|
||||
if a.lazy_type.cpp_type() == "c10::string_view"
|
||||
if a.lazy_type.cpp_type() == "::std::string_view"
|
||||
else f"::std::optional<std::string> {a.name};"
|
||||
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
|
||||
if a.lazy_type.cpp_type() == "::std::optional<::std::string_view>"
|
||||
else f"{a.lazy_type.cpp_type()} {a.name};"
|
||||
for a in scalar_args
|
||||
]
|
||||
|
@ -323,8 +323,7 @@ def ivalue_type_conversion_method(
|
||||
),
|
||||
BaseTy.str: (
|
||||
(False, "toStringView()"),
|
||||
(False, "toOptional<c10::string_view>()"),
|
||||
(False, "toOptional<::std::string_view>()"),
|
||||
(False, "toOptional<std::string_view>()"),
|
||||
),
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user