From ca2d424c6e5358f9fee8dc9ee7477de76b50f848 Mon Sep 17 00:00:00 2001 From: rzou Date: Sun, 14 Jul 2024 18:04:14 -0700 Subject: [PATCH] Tighten torch.library.infer_schema input types (#130705) Made the following changes: - mutates_args is now keyword-only and mandatory. This is to align with torch.library.custom_op (which makes it mandatory because it's easy to miss) - op_name is now keyword-only. This helps the readability of the API - updated all usages of infer_schema This change is not BC-breaking because we introduced torch.library.infer_schema a couple of days ago. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/130705 Approved by: https://github.com/yushangdi --- .../test_infer_schema_annotation.py | 59 +++++++++---------- test/test_custom_ops.py | 36 ++++++----- torch/_custom_op/impl.py | 2 +- torch/_custom_ops.py | 2 +- torch/_library/custom_ops.py | 4 +- torch/_library/infer_schema.py | 6 +- 6 files changed, 59 insertions(+), 50 deletions(-) diff --git a/test/custom_operator/test_infer_schema_annotation.py b/test/custom_operator/test_infer_schema_annotation.py index 9de44224f1c0..755a3364047a 100644 --- a/test/custom_operator/test_infer_schema_annotation.py +++ b/test/custom_operator/test_infer_schema_annotation.py @@ -5,7 +5,6 @@ import typing from typing import List, Optional, Sequence, Union # noqa: F401 import torch -import torch._custom_op.impl from torch import Tensor, types from torch.testing._internal.common_utils import run_tests, TestCase @@ -18,167 +17,167 @@ class TestInferSchemaWithAnnotation(TestCase): def foo_op(x: torch.Tensor) -> torch.Tensor: return x.clone() - result = torch._custom_op.impl.infer_schema(foo_op, mutates_args) + result = torch.library.infer_schema(foo_op, mutates_args=mutates_args) self.assertEqual(result, "(Tensor x) -> Tensor") def foo_op_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x.clone() + y - result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) self.assertEqual(result, "(Tensor x, Tensor y) -> Tensor") def test_native_types(self): def foo_op(x: int) -> int: return x - result = torch._custom_op.impl.infer_schema(foo_op, mutates_args) + result = torch.library.infer_schema(foo_op, mutates_args=mutates_args) self.assertEqual(result, "(SymInt x) -> SymInt") def foo_op_2(x: bool) -> bool: return x - result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) self.assertEqual(result, "(bool x) -> bool") def foo_op_3(x: str) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) self.assertEqual(result, "(str x) -> SymInt") def foo_op_4(x: float) -> float: return x - result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args) + result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args) self.assertEqual(result, "(float x) -> float") def test_torch_types(self): def foo_op_1(x: torch.types.Number) -> torch.types.Number: return x - result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args) + result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args) self.assertEqual(result, "(Scalar x) -> Scalar") def foo_op_2(x: torch.dtype) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) self.assertEqual(result, "(ScalarType x) -> SymInt") def foo_op_3(x: torch.device) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) self.assertEqual(result, "(Device x) -> SymInt") def test_type_variants(self): def foo_op_1(x: typing.Optional[int]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args) + result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args) self.assertEqual(result, "(SymInt? x) -> SymInt") def foo_op_2(x: typing.Sequence[int]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) self.assertEqual(result, "(SymInt[] x) -> SymInt") def foo_op_3(x: typing.List[int]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) self.assertEqual(result, "(SymInt[] x) -> SymInt") def foo_op_4(x: typing.Optional[typing.Sequence[int]]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args) + result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args) self.assertEqual(result, "(SymInt[]? x) -> SymInt") def foo_op_5(x: typing.Optional[typing.List[int]]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args) + result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args) self.assertEqual(result, "(SymInt[]? x) -> SymInt") def foo_op_6(x: typing.Union[int, float, bool]) -> types.Number: return x - result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args) + result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args) self.assertEqual(result, "(Scalar x) -> Scalar") def foo_op_7(x: typing.Union[int, bool, float]) -> types.Number: return x - result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args) + result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args) self.assertEqual(result, "(Scalar x) -> Scalar") def test_no_library_prefix(self): def foo_op(x: Tensor) -> Tensor: return x.clone() - result = torch._custom_op.impl.infer_schema(foo_op, mutates_args) + result = torch.library.infer_schema(foo_op, mutates_args=mutates_args) self.assertEqual(result, "(Tensor x) -> Tensor") def foo_op_2(x: Tensor) -> torch.Tensor: return x.clone() - result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) self.assertEqual(result, "(Tensor x) -> Tensor") def foo_op_3(x: torch.Tensor) -> Tensor: return x.clone() - result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) self.assertEqual(result, "(Tensor x) -> Tensor") def foo_op_4(x: List[int]) -> types.Number: return x[0] - result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args) + result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args) self.assertEqual(result, "(SymInt[] x) -> Scalar") def foo_op_5(x: Optional[int]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args) + result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args) self.assertEqual(result, "(SymInt? x) -> SymInt") def foo_op_6(x: Sequence[int]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args) + result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args) self.assertEqual(result, "(SymInt[] x) -> SymInt") def foo_op_7(x: List[int]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args) + result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args) self.assertEqual(result, "(SymInt[] x) -> SymInt") def foo_op_8(x: Optional[Sequence[int]]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_8, mutates_args) + result = torch.library.infer_schema(foo_op_8, mutates_args=mutates_args) self.assertEqual(result, "(SymInt[]? x) -> SymInt") def foo_op_9(x: Optional[List[int]]) -> int: return 1 - result = torch._custom_op.impl.infer_schema(foo_op_9, mutates_args) + result = torch.library.infer_schema(foo_op_9, mutates_args=mutates_args) self.assertEqual(result, "(SymInt[]? x) -> SymInt") def foo_op_10(x: Union[int, float, bool]) -> types.Number: return x - result = torch._custom_op.impl.infer_schema(foo_op_10, mutates_args) + result = torch.library.infer_schema(foo_op_10, mutates_args=mutates_args) self.assertEqual(result, "(Scalar x) -> Scalar") def foo_op_11(x: Union[int, bool, float]) -> types.Number: return x - result = torch._custom_op.impl.infer_schema(foo_op_11, mutates_args) + result = torch.library.infer_schema(foo_op_11, mutates_args=mutates_args) self.assertEqual(result, "(Scalar x) -> Scalar") def test_unsupported_annotation(self): @@ -190,7 +189,7 @@ class TestInferSchemaWithAnnotation(TestCase): def foo_op(x: D) -> Tensor: # noqa: F821 return torch.Tensor(x) - torch._custom_op.impl.infer_schema(foo_op, mutates_args) + torch.library.infer_schema(foo_op, mutates_args=mutates_args) with self.assertRaisesRegex( ValueError, @@ -200,7 +199,7 @@ class TestInferSchemaWithAnnotation(TestCase): def foo_op_2(x: Tensor) -> E: # noqa: F821 return x - torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) if __name__ == "__main__": diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 06483e8454cb..a5bad5ab82c8 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -608,19 +608,24 @@ class TestCustomOp(CustomOpTestCaseBase): def a(x: Tensor) -> Tensor: return torch.empty([]) - self.assertExpectedInline(infer_schema(a), """(Tensor x) -> Tensor""") + self.assertExpectedInline( + infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor""" + ) def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor: return torch.empty([]) self.assertExpectedInline( - infer_schema(kwonly1), """(Tensor x, *, SymInt y, float z) -> Tensor""" + infer_schema(kwonly1, mutates_args=()), + """(Tensor x, *, SymInt y, float z) -> Tensor""", ) def kwonly2(*, y: Tensor) -> Tensor: return torch.empty([]) - self.assertExpectedInline(infer_schema(kwonly2), """(*, Tensor y) -> Tensor""") + self.assertExpectedInline( + infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor""" + ) def b( x: Tensor, @@ -634,7 +639,7 @@ class TestCustomOp(CustomOpTestCaseBase): return torch.empty([]), 1, 0.1, True self.assertExpectedInline( - infer_schema(b), + infer_schema(b, mutates_args=()), """(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""", ) @@ -647,7 +652,7 @@ class TestCustomOp(CustomOpTestCaseBase): return [torch.empty([])] self.assertExpectedInline( - infer_schema(c), + infer_schema(c, mutates_args=()), """(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""", ) @@ -655,18 +660,20 @@ class TestCustomOp(CustomOpTestCaseBase): return [torch.empty([])], torch.empty([]) self.assertExpectedInline( - infer_schema(d), """(Tensor x) -> (Tensor[], Tensor)""" + infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)""" ) def e() -> Tensor: return torch.empty([]) - self.assertExpectedInline(infer_schema(e), """() -> Tensor""") + self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""") def f(x: Tensor) -> None: pass - self.assertExpectedInline(infer_schema(f), """(Tensor x) -> ()""") + self.assertExpectedInline( + infer_schema(f, mutates_args=()), """(Tensor x) -> ()""" + ) def g( x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]] @@ -674,7 +681,8 @@ class TestCustomOp(CustomOpTestCaseBase): pass self.assertExpectedInline( - infer_schema(g), """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""" + infer_schema(g, mutates_args=()), + """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""", ) self.assertExpectedInline( @@ -703,7 +711,7 @@ class TestCustomOp(CustomOpTestCaseBase): pass self.assertExpectedInline( - infer_schema(h), + infer_schema(h, mutates_args=()), ( """(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """ """ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()""" @@ -722,28 +730,28 @@ class TestCustomOp(CustomOpTestCaseBase): def foo(*args): raise NotImplementedError - infer_schema(foo) + infer_schema(foo, mutates_args=()) with self.assertRaisesRegex(ValueError, "varkwargs"): def foo(**kwargs): raise NotImplementedError - infer_schema(foo) + infer_schema(foo, mutates_args=()) with self.assertRaisesRegex(ValueError, "must have a type annotation"): def foo(x): raise NotImplementedError - infer_schema(foo) + infer_schema(foo, mutates_args=()) with self.assertRaisesRegex(ValueError, "unsupported"): def foo(x: Tensor) -> Tuple[Tensor, ...]: raise NotImplementedError - infer_schema(foo) + infer_schema(foo, mutates_args=()) with self.assertRaisesRegex(ValueError, "can be mutated"): diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index d9bde9a28008..c00e25ec7316 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -73,7 +73,7 @@ def custom_op( f"is passed to `custom_op`" ) - schema = infer_schema(func) if manual_schema is None else manual_schema + schema = infer_schema(func, mutates_args=()) if manual_schema is None else manual_schema schema_str = f"{name}{schema}" function_schema = FunctionSchema.parse(schema_str) validate_schema(function_schema) diff --git a/torch/_custom_ops.py b/torch/_custom_ops.py index b8231a186c0a..7f2f8f1bcd87 100644 --- a/torch/_custom_ops.py +++ b/torch/_custom_ops.py @@ -104,7 +104,7 @@ def custom_op(qualname, func_or_schema=None): f"is passed to `custom_op`" ) - schema = infer_schema(func) + schema = infer_schema(func, mutates_args=()) _custom_op_with_schema(qualname, schema) return func diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 153ce9071856..5cf05156f977 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -128,9 +128,7 @@ def custom_op( import torch if schema is None: - import torch._custom_op.impl - - schema_str = torch._custom_op.impl.infer_schema(fn, mutates_args) + schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args) else: schema_str = schema diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 756fc1505946..a0b241583e2b 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -11,7 +11,11 @@ from .. import device, dtype, Tensor, types @exposed_in("torch.library") def infer_schema( - prototype_function: typing.Callable, mutates_args=(), op_name: Optional[str] = None + prototype_function: typing.Callable, + /, + *, + mutates_args, + op_name: Optional[str] = None, ) -> str: r"""Parses the schema of a given function with type hints. The schema is inferred from the function's type hints, and can be used to define a new operator.