mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchgen] Add support for schema with namespace (#148038)
Fixes https://github.com/pytorch/executorch/issues/8711 In ExecuTorch when we try to parse the following schema: ``` aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor ``` Repro: ```python from torchgen.model import FunctionSchema native_schema = FunctionSchema.parse("aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor") ``` It's failing because `BaseOperatorName` categorizes it to be a inplace operator. I understand we are not supposed to pass in namespace "aten::" into `FunctionSchema.parse()` but unfortunately ExecuTorch requires this feature to work. This PR adds a new `namespace` attribute to `BaseOperatorName` and makes sure the rest of the stack works as before, if a schema without namespace is passed in Pull Request resolved: https://github.com/pytorch/pytorch/pull/148038 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
e593288859
commit
b5cd4ac950
@ -12,6 +12,7 @@ import torchgen.gen as gen
|
||||
from torchgen.gen import LineLoader, parse_native_yaml_struct
|
||||
from torchgen.model import (
|
||||
Annotation,
|
||||
BaseOperatorName,
|
||||
CustomClassType,
|
||||
DispatchKey,
|
||||
NativeFunctionsGroup,
|
||||
@ -202,5 +203,20 @@ class TestAnnotation(expecttest.TestCase):
|
||||
Annotation.parse("a|b -> c|d")
|
||||
|
||||
|
||||
class TestBaseOperatorName(expecttest.TestCase):
|
||||
def test_base_operator_name_with_ns_has_same_attributes_as_the_one_without_ns(
|
||||
self,
|
||||
) -> None:
|
||||
op = "aten::__lshift__"
|
||||
op_without_ns = "__lshift__"
|
||||
|
||||
op_name = BaseOperatorName.parse(op)
|
||||
op_name_without_ns = BaseOperatorName.parse(op_without_ns)
|
||||
|
||||
self.assertEqual(op_name.base, op_name_without_ns.base)
|
||||
self.assertEqual(op_name.inplace, op_name_without_ns.inplace)
|
||||
self.assertEqual(op_name.dunder_method, op_name_without_ns.dunder_method)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -5,7 +5,7 @@ import itertools
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import Callable, Optional, TYPE_CHECKING
|
||||
|
||||
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
|
||||
|
||||
@ -1708,9 +1708,11 @@ class FunctionSchema:
|
||||
for a in itertools.chain(
|
||||
# Order is important here (otherwise e.g. inplace with mutable args
|
||||
# and out= with mutable args won't have the same signature)
|
||||
[self.arguments.self_arg.argument]
|
||||
if self.arguments.self_arg is not None
|
||||
else [],
|
||||
(
|
||||
[self.arguments.self_arg.argument]
|
||||
if self.arguments.self_arg is not None
|
||||
else []
|
||||
),
|
||||
self.arguments.out,
|
||||
self.arguments.post_self_positional,
|
||||
)
|
||||
@ -2302,9 +2304,11 @@ class Arguments:
|
||||
pre_self_positional=tuple(
|
||||
map(strip_arg_annotation, self.pre_self_positional)
|
||||
),
|
||||
self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument))
|
||||
if self.self_arg is not None
|
||||
else None,
|
||||
self_arg=(
|
||||
SelfArgument(strip_arg_annotation(self.self_arg.argument))
|
||||
if self.self_arg is not None
|
||||
else None
|
||||
),
|
||||
post_self_positional=tuple(
|
||||
map(strip_arg_annotation, self.post_self_positional)
|
||||
),
|
||||
@ -2533,6 +2537,12 @@ class BaseOperatorName:
|
||||
# Doing that is BC-breaking though, so we're stuck with the above modeling.
|
||||
functional_overload: bool = False
|
||||
|
||||
# NB: We don't officially support namespace in FunctionSchema, we treat this prefix
|
||||
# as part of the base operator name, for __str__() to consume.
|
||||
# The canonical input (from the rest of the infra) will not contain namespace, but
|
||||
# we have a usecase in ExecuTorch where we want to support BaseOperatorName with namespace.
|
||||
namespace: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def parse(op: str) -> BaseOperatorName:
|
||||
assert op != ""
|
||||
@ -2540,7 +2550,13 @@ class BaseOperatorName:
|
||||
"_out suffix is reserved and not permitted for operator names; "
|
||||
"did you mean to specify an out overload name instead?"
|
||||
)
|
||||
m = re.match(r"^__([^_]+)__$", op)
|
||||
# Extract namespace out. Base operator name may or may not contain namespace.
|
||||
# E.g., aten::__lshift__ is a valid base operator name, __lshift__ is also valid.
|
||||
# We want to split the namespace out from the base operator name.
|
||||
match = re.match(r"^(?:(.*)::)?(.*)$", op)
|
||||
namespace = match.group(1) if match else ""
|
||||
op_without_ns = match.group(2) if match else op
|
||||
m = re.match(r"^__([^_]+)__$", op_without_ns)
|
||||
if m is not None:
|
||||
dunder_method = True
|
||||
base = m.group(1)
|
||||
@ -2556,7 +2572,7 @@ class BaseOperatorName:
|
||||
assert base[0] != "i"
|
||||
else:
|
||||
dunder_method = False
|
||||
base = op
|
||||
base = op_without_ns
|
||||
if base[-1] == "_":
|
||||
inplace = True
|
||||
base = base[:-1]
|
||||
@ -2579,14 +2595,16 @@ class BaseOperatorName:
|
||||
inplace=inplace,
|
||||
dunder_method=dunder_method,
|
||||
functional_overload=functional_overload,
|
||||
namespace=namespace,
|
||||
)
|
||||
assert str(r) == op, f"{str(r)} != {op}"
|
||||
return r
|
||||
|
||||
def __str__(self) -> str:
|
||||
namespace_prefix = f"{self.namespace}::" if self.namespace else ""
|
||||
if self.dunder_method:
|
||||
i = "i" if self.inplace else ""
|
||||
return f"__{i}{self.base}__"
|
||||
return f"{namespace_prefix}__{i}{self.base}__"
|
||||
else:
|
||||
i = (
|
||||
"_"
|
||||
@ -2595,7 +2613,7 @@ class BaseOperatorName:
|
||||
if self.functional_overload
|
||||
else ""
|
||||
)
|
||||
return f"{self.base}{i}"
|
||||
return f"{namespace_prefix}{self.base}{i}"
|
||||
|
||||
|
||||
# Operator name is the base operator name along with the (typically not
|
||||
|
Reference in New Issue
Block a user