Tighten torch.library.infer_schema input types (#130705)

Made the following changes:
- mutates_args is now keyword-only and mandatory. This is to align with
  torch.library.custom_op (which makes it mandatory because it's easy to
  miss)
- op_name is now keyword-only. This helps the readability of the API
- updated all usages of infer_schema

This change is not BC-breaking because we introduced
torch.library.infer_schema a couple of days ago.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130705
Approved by: https://github.com/yushangdi
This commit is contained in:
rzou
2024-07-14 18:04:14 -07:00
committed by PyTorch MergeBot
parent 9df4bc6a0d
commit ca2d424c6e
6 changed files with 59 additions and 50 deletions

View File

@ -5,7 +5,6 @@ import typing
from typing import List, Optional, Sequence, Union # noqa: F401
import torch
import torch._custom_op.impl
from torch import Tensor, types
from torch.testing._internal.common_utils import run_tests, TestCase
@ -18,167 +17,167 @@ class TestInferSchemaWithAnnotation(TestCase):
def foo_op(x: torch.Tensor) -> torch.Tensor:
return x.clone()
result = torch._custom_op.impl.infer_schema(foo_op, mutates_args)
result = torch.library.infer_schema(foo_op, mutates_args=mutates_args)
self.assertEqual(result, "(Tensor x) -> Tensor")
def foo_op_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x.clone() + y
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
self.assertEqual(result, "(Tensor x, Tensor y) -> Tensor")
def test_native_types(self):
def foo_op(x: int) -> int:
return x
result = torch._custom_op.impl.infer_schema(foo_op, mutates_args)
result = torch.library.infer_schema(foo_op, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt x) -> SymInt")
def foo_op_2(x: bool) -> bool:
return x
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
self.assertEqual(result, "(bool x) -> bool")
def foo_op_3(x: str) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
self.assertEqual(result, "(str x) -> SymInt")
def foo_op_4(x: float) -> float:
return x
result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args)
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
self.assertEqual(result, "(float x) -> float")
def test_torch_types(self):
def foo_op_1(x: torch.types.Number) -> torch.types.Number:
return x
result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args)
result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args)
self.assertEqual(result, "(Scalar x) -> Scalar")
def foo_op_2(x: torch.dtype) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
self.assertEqual(result, "(ScalarType x) -> SymInt")
def foo_op_3(x: torch.device) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
self.assertEqual(result, "(Device x) -> SymInt")
def test_type_variants(self):
def foo_op_1(x: typing.Optional[int]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args)
result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt? x) -> SymInt")
def foo_op_2(x: typing.Sequence[int]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[] x) -> SymInt")
def foo_op_3(x: typing.List[int]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[] x) -> SymInt")
def foo_op_4(x: typing.Optional[typing.Sequence[int]]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args)
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
def foo_op_5(x: typing.Optional[typing.List[int]]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args)
result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
def foo_op_6(x: typing.Union[int, float, bool]) -> types.Number:
return x
result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args)
result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args)
self.assertEqual(result, "(Scalar x) -> Scalar")
def foo_op_7(x: typing.Union[int, bool, float]) -> types.Number:
return x
result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args)
result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args)
self.assertEqual(result, "(Scalar x) -> Scalar")
def test_no_library_prefix(self):
def foo_op(x: Tensor) -> Tensor:
return x.clone()
result = torch._custom_op.impl.infer_schema(foo_op, mutates_args)
result = torch.library.infer_schema(foo_op, mutates_args=mutates_args)
self.assertEqual(result, "(Tensor x) -> Tensor")
def foo_op_2(x: Tensor) -> torch.Tensor:
return x.clone()
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
self.assertEqual(result, "(Tensor x) -> Tensor")
def foo_op_3(x: torch.Tensor) -> Tensor:
return x.clone()
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
self.assertEqual(result, "(Tensor x) -> Tensor")
def foo_op_4(x: List[int]) -> types.Number:
return x[0]
result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args)
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[] x) -> Scalar")
def foo_op_5(x: Optional[int]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args)
result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt? x) -> SymInt")
def foo_op_6(x: Sequence[int]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args)
result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[] x) -> SymInt")
def foo_op_7(x: List[int]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args)
result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[] x) -> SymInt")
def foo_op_8(x: Optional[Sequence[int]]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_8, mutates_args)
result = torch.library.infer_schema(foo_op_8, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
def foo_op_9(x: Optional[List[int]]) -> int:
return 1
result = torch._custom_op.impl.infer_schema(foo_op_9, mutates_args)
result = torch.library.infer_schema(foo_op_9, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
def foo_op_10(x: Union[int, float, bool]) -> types.Number:
return x
result = torch._custom_op.impl.infer_schema(foo_op_10, mutates_args)
result = torch.library.infer_schema(foo_op_10, mutates_args=mutates_args)
self.assertEqual(result, "(Scalar x) -> Scalar")
def foo_op_11(x: Union[int, bool, float]) -> types.Number:
return x
result = torch._custom_op.impl.infer_schema(foo_op_11, mutates_args)
result = torch.library.infer_schema(foo_op_11, mutates_args=mutates_args)
self.assertEqual(result, "(Scalar x) -> Scalar")
def test_unsupported_annotation(self):
@ -190,7 +189,7 @@ class TestInferSchemaWithAnnotation(TestCase):
def foo_op(x: D) -> Tensor: # noqa: F821
return torch.Tensor(x)
torch._custom_op.impl.infer_schema(foo_op, mutates_args)
torch.library.infer_schema(foo_op, mutates_args=mutates_args)
with self.assertRaisesRegex(
ValueError,
@ -200,7 +199,7 @@ class TestInferSchemaWithAnnotation(TestCase):
def foo_op_2(x: Tensor) -> E: # noqa: F821
return x
torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
if __name__ == "__main__":

View File

@ -608,19 +608,24 @@ class TestCustomOp(CustomOpTestCaseBase):
def a(x: Tensor) -> Tensor:
return torch.empty([])
self.assertExpectedInline(infer_schema(a), """(Tensor x) -> Tensor""")
self.assertExpectedInline(
infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor"""
)
def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor:
return torch.empty([])
self.assertExpectedInline(
infer_schema(kwonly1), """(Tensor x, *, SymInt y, float z) -> Tensor"""
infer_schema(kwonly1, mutates_args=()),
"""(Tensor x, *, SymInt y, float z) -> Tensor""",
)
def kwonly2(*, y: Tensor) -> Tensor:
return torch.empty([])
self.assertExpectedInline(infer_schema(kwonly2), """(*, Tensor y) -> Tensor""")
self.assertExpectedInline(
infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor"""
)
def b(
x: Tensor,
@ -634,7 +639,7 @@ class TestCustomOp(CustomOpTestCaseBase):
return torch.empty([]), 1, 0.1, True
self.assertExpectedInline(
infer_schema(b),
infer_schema(b, mutates_args=()),
"""(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""",
)
@ -647,7 +652,7 @@ class TestCustomOp(CustomOpTestCaseBase):
return [torch.empty([])]
self.assertExpectedInline(
infer_schema(c),
infer_schema(c, mutates_args=()),
"""(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""",
)
@ -655,18 +660,20 @@ class TestCustomOp(CustomOpTestCaseBase):
return [torch.empty([])], torch.empty([])
self.assertExpectedInline(
infer_schema(d), """(Tensor x) -> (Tensor[], Tensor)"""
infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)"""
)
def e() -> Tensor:
return torch.empty([])
self.assertExpectedInline(infer_schema(e), """() -> Tensor""")
self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""")
def f(x: Tensor) -> None:
pass
self.assertExpectedInline(infer_schema(f), """(Tensor x) -> ()""")
self.assertExpectedInline(
infer_schema(f, mutates_args=()), """(Tensor x) -> ()"""
)
def g(
x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
@ -674,7 +681,8 @@ class TestCustomOp(CustomOpTestCaseBase):
pass
self.assertExpectedInline(
infer_schema(g), """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()"""
infer_schema(g, mutates_args=()),
"""(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""",
)
self.assertExpectedInline(
@ -703,7 +711,7 @@ class TestCustomOp(CustomOpTestCaseBase):
pass
self.assertExpectedInline(
infer_schema(h),
infer_schema(h, mutates_args=()),
(
"""(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """
"""ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()"""
@ -722,28 +730,28 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo(*args):
raise NotImplementedError
infer_schema(foo)
infer_schema(foo, mutates_args=())
with self.assertRaisesRegex(ValueError, "varkwargs"):
def foo(**kwargs):
raise NotImplementedError
infer_schema(foo)
infer_schema(foo, mutates_args=())
with self.assertRaisesRegex(ValueError, "must have a type annotation"):
def foo(x):
raise NotImplementedError
infer_schema(foo)
infer_schema(foo, mutates_args=())
with self.assertRaisesRegex(ValueError, "unsupported"):
def foo(x: Tensor) -> Tuple[Tensor, ...]:
raise NotImplementedError
infer_schema(foo)
infer_schema(foo, mutates_args=())
with self.assertRaisesRegex(ValueError, "can be mutated"):

View File

@ -73,7 +73,7 @@ def custom_op(
f"is passed to `custom_op`"
)
schema = infer_schema(func) if manual_schema is None else manual_schema
schema = infer_schema(func, mutates_args=()) if manual_schema is None else manual_schema
schema_str = f"{name}{schema}"
function_schema = FunctionSchema.parse(schema_str)
validate_schema(function_schema)

View File

@ -104,7 +104,7 @@ def custom_op(qualname, func_or_schema=None):
f"is passed to `custom_op`"
)
schema = infer_schema(func)
schema = infer_schema(func, mutates_args=())
_custom_op_with_schema(qualname, schema)
return func

View File

@ -128,9 +128,7 @@ def custom_op(
import torch
if schema is None:
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(fn, mutates_args)
schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
else:
schema_str = schema

View File

@ -11,7 +11,11 @@ from .. import device, dtype, Tensor, types
@exposed_in("torch.library")
def infer_schema(
prototype_function: typing.Callable, mutates_args=(), op_name: Optional[str] = None
prototype_function: typing.Callable,
/,
*,
mutates_args,
op_name: Optional[str] = None,
) -> str:
r"""Parses the schema of a given function with type hints. The schema is inferred from the
function's type hints, and can be used to define a new operator.