[Reland] [Torchgen] Pass mutable to cpp.valuetype_type (#134549)

Reland of #121415

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134549
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-09-01 15:15:38 +00:00
committed by PyTorch MergeBot
parent b1a00b7b6d
commit 1595e755af
2 changed files with 5 additions and 3 deletions

View File

@ -94,6 +94,7 @@ def valuetype_type(
t: Type,
*,
binds: ArgName,
mutable: bool = True,
remove_non_owning_ref_types: bool = False,
symint: bool = False,
) -> NamedCType | None:
@ -113,7 +114,7 @@ def valuetype_type(
# All other BaseType currently map directly to BaseCppTypes.
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds, symint=symint)
elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint)
if elem is None:
return None
return NamedCType(binds, OptionalCType(elem.type))
@ -143,6 +144,7 @@ def argumenttype_type(
r = valuetype_type(
t,
binds=binds,
mutable=mutable,
symint=symint,
remove_non_owning_ref_types=remove_non_owning_ref_types,
)
@ -231,7 +233,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
# placeholder is ignored
# NB: symint is ALWAYS respected for return types. So symint argument
# here is IGNORED
r = valuetype_type(t, binds="__placeholder__", symint=True)
r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True)
if r is not None:
return r.type

View File

@ -48,7 +48,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
# CompositeExplicitAutograd and the meta function (which could
# hypothetically be SymInt), but for simplicity we plan for these to just
# be handled in Python
r = cpp.valuetype_type(t, symint=False, binds=binds)
r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable)
if r is not None:
return r