[torchgen] Add support for schema with namespace (#148038)

Fixes https://github.com/pytorch/executorch/issues/8711

In ExecuTorch when we try to parse the following schema:

```
aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor
```
Repro:

```python
from torchgen.model import FunctionSchema
native_schema = FunctionSchema.parse("aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor")
```
It's failing because `BaseOperatorName` categorizes it to be a
inplace operator.

I understand we are not supposed to pass in namespace "aten::" into
`FunctionSchema.parse()` but unfortunately ExecuTorch requires this
feature to work.

This PR adds a new `namespace` attribute to `BaseOperatorName` and makes
sure the rest of the stack works as before, if a schema without
namespace  is passed in
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148038
Approved by: https://github.com/bdhirsh
This commit is contained in:
Mengwei Liu
2025-02-28 16:41:50 +00:00
committed by PyTorch MergeBot
parent e593288859
commit b5cd4ac950
2 changed files with 45 additions and 11 deletions

View File

@ -12,6 +12,7 @@ import torchgen.gen as gen
from torchgen.gen import LineLoader, parse_native_yaml_struct
from torchgen.model import (
Annotation,
BaseOperatorName,
CustomClassType,
DispatchKey,
NativeFunctionsGroup,
@ -202,5 +203,20 @@ class TestAnnotation(expecttest.TestCase):
Annotation.parse("a|b -> c|d")
class TestBaseOperatorName(expecttest.TestCase):
def test_base_operator_name_with_ns_has_same_attributes_as_the_one_without_ns(
self,
) -> None:
op = "aten::__lshift__"
op_without_ns = "__lshift__"
op_name = BaseOperatorName.parse(op)
op_name_without_ns = BaseOperatorName.parse(op_without_ns)
self.assertEqual(op_name.base, op_name_without_ns.base)
self.assertEqual(op_name.inplace, op_name_without_ns.inplace)
self.assertEqual(op_name.dunder_method, op_name_without_ns.dunder_method)
if __name__ == "__main__":
unittest.main()

View File

@ -5,7 +5,7 @@ import itertools
import re
from dataclasses import dataclass
from enum import auto, Enum
from typing import Callable, TYPE_CHECKING
from typing import Callable, Optional, TYPE_CHECKING
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
@ -1708,9 +1708,11 @@ class FunctionSchema:
for a in itertools.chain(
# Order is important here (otherwise e.g. inplace with mutable args
# and out= with mutable args won't have the same signature)
[self.arguments.self_arg.argument]
if self.arguments.self_arg is not None
else [],
(
[self.arguments.self_arg.argument]
if self.arguments.self_arg is not None
else []
),
self.arguments.out,
self.arguments.post_self_positional,
)
@ -2302,9 +2304,11 @@ class Arguments:
pre_self_positional=tuple(
map(strip_arg_annotation, self.pre_self_positional)
),
self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument))
if self.self_arg is not None
else None,
self_arg=(
SelfArgument(strip_arg_annotation(self.self_arg.argument))
if self.self_arg is not None
else None
),
post_self_positional=tuple(
map(strip_arg_annotation, self.post_self_positional)
),
@ -2533,6 +2537,12 @@ class BaseOperatorName:
# Doing that is BC-breaking though, so we're stuck with the above modeling.
functional_overload: bool = False
# NB: We don't officially support namespace in FunctionSchema, we treat this prefix
# as part of the base operator name, for __str__() to consume.
# The canonical input (from the rest of the infra) will not contain namespace, but
# we have a usecase in ExecuTorch where we want to support BaseOperatorName with namespace.
namespace: Optional[str] = None
@staticmethod
def parse(op: str) -> BaseOperatorName:
assert op != ""
@ -2540,7 +2550,13 @@ class BaseOperatorName:
"_out suffix is reserved and not permitted for operator names; "
"did you mean to specify an out overload name instead?"
)
m = re.match(r"^__([^_]+)__$", op)
# Extract namespace out. Base operator name may or may not contain namespace.
# E.g., aten::__lshift__ is a valid base operator name, __lshift__ is also valid.
# We want to split the namespace out from the base operator name.
match = re.match(r"^(?:(.*)::)?(.*)$", op)
namespace = match.group(1) if match else ""
op_without_ns = match.group(2) if match else op
m = re.match(r"^__([^_]+)__$", op_without_ns)
if m is not None:
dunder_method = True
base = m.group(1)
@ -2556,7 +2572,7 @@ class BaseOperatorName:
assert base[0] != "i"
else:
dunder_method = False
base = op
base = op_without_ns
if base[-1] == "_":
inplace = True
base = base[:-1]
@ -2579,14 +2595,16 @@ class BaseOperatorName:
inplace=inplace,
dunder_method=dunder_method,
functional_overload=functional_overload,
namespace=namespace,
)
assert str(r) == op, f"{str(r)} != {op}"
return r
def __str__(self) -> str:
namespace_prefix = f"{self.namespace}::" if self.namespace else ""
if self.dunder_method:
i = "i" if self.inplace else ""
return f"__{i}{self.base}__"
return f"{namespace_prefix}__{i}{self.base}__"
else:
i = (
"_"
@ -2595,7 +2613,7 @@ class BaseOperatorName:
if self.functional_overload
else ""
)
return f"{self.base}{i}"
return f"{namespace_prefix}{self.base}{i}"
# Operator name is the base operator name along with the (typically not