From 1595e755af83226321554188fb7992878d129cc4 Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 1 Sep 2024 15:15:38 +0000 Subject: [PATCH] [Reland] [Torchgen] Pass mutable to cpp.valuetype_type (#134549) Reland of #121415 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134549 Approved by: https://github.com/ezyang --- torchgen/api/cpp.py | 6 ++++-- torchgen/api/structured.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 19e5a11b8cdf..c657570ee3e2 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -94,6 +94,7 @@ def valuetype_type( t: Type, *, binds: ArgName, + mutable: bool = True, remove_non_owning_ref_types: bool = False, symint: bool = False, ) -> NamedCType | None: @@ -113,7 +114,7 @@ def valuetype_type( # All other BaseType currently map directly to BaseCppTypes. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) elif isinstance(t, OptionalType): - elem = valuetype_type(t.elem, binds=binds, symint=symint) + elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint) if elem is None: return None return NamedCType(binds, OptionalCType(elem.type)) @@ -143,6 +144,7 @@ def argumenttype_type( r = valuetype_type( t, binds=binds, + mutable=mutable, symint=symint, remove_non_owning_ref_types=remove_non_owning_ref_types, ) @@ -231,7 +233,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: # placeholder is ignored # NB: symint is ALWAYS respected for return types. So symint argument # here is IGNORED - r = valuetype_type(t, binds="__placeholder__", symint=True) + r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True) if r is not None: return r.type diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index a93d666114de..93a72eb2b4a5 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -48,7 +48,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # CompositeExplicitAutograd and the meta function (which could # hypothetically be SymInt), but for simplicity we plan for these to just # be handled in Python - r = cpp.valuetype_type(t, symint=False, binds=binds) + r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable) if r is not None: return r