mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[custom ops] convert string type annotation to real type (#128809)
Fixes #105157 Bug source: `from __future__ import annotations` converts type annotation to strings to make forwards references easier. However, existing custom ops do not consider strings to be valid types. Fix: We check if the argument and return type annotation is string type. If so, we try to use `eval` to convert it to a type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128809 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
c35ffaf954
commit
fbc7559ceb
207
test/custom_operator/test_infer_schema_annotation.py
Normal file
207
test/custom_operator/test_infer_schema_annotation.py
Normal file
@ -0,0 +1,207 @@
|
||||
# Owner(s): ["module: pt2-dispatcher"]
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import List, Optional, Sequence, Union # noqa: F401
|
||||
|
||||
import torch
|
||||
import torch._custom_op.impl
|
||||
from torch import Tensor, types
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
mutates_args = {}
|
||||
|
||||
|
||||
class TestInferSchemaWithAnnotation(TestCase):
|
||||
def test_tensor(self):
|
||||
def foo_op(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.clone()
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x) -> Tensor")
|
||||
|
||||
def foo_op_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x.clone() + y
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x, Tensor y) -> Tensor")
|
||||
|
||||
def test_native_types(self):
|
||||
def foo_op(x: int) -> int:
|
||||
return x
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op, mutates_args)
|
||||
self.assertEqual(result, "(SymInt x) -> SymInt")
|
||||
|
||||
def foo_op_2(x: bool) -> bool:
|
||||
return x
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(bool x) -> bool")
|
||||
|
||||
def foo_op_3(x: str) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
|
||||
self.assertEqual(result, "(str x) -> SymInt")
|
||||
|
||||
def foo_op_4(x: float) -> float:
|
||||
return x
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args)
|
||||
self.assertEqual(result, "(float x) -> float")
|
||||
|
||||
def test_torch_types(self):
|
||||
def foo_op_1(x: torch.types.Number) -> torch.types.Number:
|
||||
return x
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def foo_op_2(x: torch.dtype) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(ScalarType x) -> SymInt")
|
||||
|
||||
def foo_op_3(x: torch.device) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
|
||||
self.assertEqual(result, "(Device x) -> SymInt")
|
||||
|
||||
def test_type_variants(self):
|
||||
def foo_op_1(x: typing.Optional[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args)
|
||||
self.assertEqual(result, "(SymInt? x) -> SymInt")
|
||||
|
||||
def foo_op_2(x: typing.Sequence[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_3(x: typing.List[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_4(x: typing.Optional[typing.Sequence[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_5(x: typing.Optional[typing.List[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_6(x: typing.Union[int, float, bool]) -> types.Number:
|
||||
return x
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def foo_op_7(x: typing.Union[int, bool, float]) -> types.Number:
|
||||
return x
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def test_no_library_prefix(self):
|
||||
def foo_op(x: Tensor) -> Tensor:
|
||||
return x.clone()
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x) -> Tensor")
|
||||
|
||||
def foo_op_2(x: Tensor) -> torch.Tensor:
|
||||
return x.clone()
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x) -> Tensor")
|
||||
|
||||
def foo_op_3(x: torch.Tensor) -> Tensor:
|
||||
return x.clone()
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x) -> Tensor")
|
||||
|
||||
def foo_op_4(x: List[int]) -> types.Number:
|
||||
return x[0]
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> Scalar")
|
||||
|
||||
def foo_op_5(x: Optional[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args)
|
||||
self.assertEqual(result, "(SymInt? x) -> SymInt")
|
||||
|
||||
def foo_op_6(x: Sequence[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_7(x: List[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_8(x: Optional[Sequence[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_8, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_9(x: Optional[List[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_9, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_10(x: Union[int, float, bool]) -> types.Number:
|
||||
return x
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_10, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def foo_op_11(x: Union[int, bool, float]) -> types.Number:
|
||||
return x
|
||||
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_11, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def test_unsupported_annotation(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Unsupported type annotation D. It is not a type.",
|
||||
):
|
||||
|
||||
def foo_op(x: D) -> Tensor: # noqa: F821
|
||||
return torch.Tensor(x)
|
||||
|
||||
torch._custom_op.impl.infer_schema(foo_op, mutates_args)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Unsupported type annotation E. It is not a type.",
|
||||
):
|
||||
|
||||
def foo_op_2(x: Tensor) -> E: # noqa: F821
|
||||
return x
|
||||
|
||||
torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1,7 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
import typing
|
||||
from typing import List, Optional, Sequence, Union # noqa: F401
|
||||
|
||||
import torch # noqa: F401
|
||||
from .. import device, dtype, Tensor, types
|
||||
|
||||
|
||||
@ -12,6 +14,9 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
|
||||
write custom ops in real life:
|
||||
- none of the outputs alias any of the inputs or each other.
|
||||
- only the args listed in mutates_args are being mutated.
|
||||
- string type annotations "device, dtype, Tensor, types" without library specification
|
||||
are assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
|
||||
without library specification are assumed to be typing.*.
|
||||
|
||||
Callers (e.g. the custom ops API) are responsible for checking these assumptions.
|
||||
"""
|
||||
@ -22,6 +27,14 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
|
||||
f"infer_schema(func): {what} " f"Got func with signature {sig})"
|
||||
)
|
||||
|
||||
def convert_type_string(annotation_type: str):
|
||||
try:
|
||||
return eval(annotation_type)
|
||||
except Exception as e:
|
||||
error_fn(
|
||||
f"Unsupported type annotation {annotation_type}. It is not a type."
|
||||
)
|
||||
|
||||
params = []
|
||||
seen_args = set()
|
||||
saw_kwarg_only_arg = False
|
||||
@ -38,13 +51,19 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
error_fn(f"Parameter {name} must have a type annotation.")
|
||||
|
||||
if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
|
||||
# The annotation might be converted to a string by annotation,
|
||||
# we convert it to the actual type.
|
||||
annotation_type = param.annotation
|
||||
if type(annotation_type) == str:
|
||||
annotation_type = convert_type_string(annotation_type)
|
||||
|
||||
if annotation_type not in SUPPORTED_PARAM_TYPES.keys():
|
||||
error_fn(
|
||||
f"Parameter {name} has unsupported type {param.annotation}. "
|
||||
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
|
||||
)
|
||||
|
||||
schema_type = SUPPORTED_PARAM_TYPES[param.annotation]
|
||||
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
|
||||
if name in mutates_args:
|
||||
if not schema_type.startswith("Tensor"):
|
||||
error_fn(
|
||||
@ -72,7 +91,10 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
|
||||
f"mutates_args should contain the names of all args that the "
|
||||
f"custom op mutates."
|
||||
)
|
||||
ret = parse_return(sig.return_annotation, error_fn)
|
||||
return_annotation = sig.return_annotation
|
||||
if type(return_annotation) == str:
|
||||
return_annotation = convert_type_string(return_annotation)
|
||||
ret = parse_return(return_annotation, error_fn)
|
||||
return f"({', '.join(params)}) -> {ret}"
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user