mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Tighten torch.library.infer_schema input types (#130705)"
This reverts commit ca2d424c6e5358f9fee8dc9ee7477de76b50f848. Reverted https://github.com/pytorch/pytorch/pull/130705 on behalf of https://github.com/atalman due to Failing internal CI ([comment](https://github.com/pytorch/pytorch/pull/130705#issuecomment-2230821876))
This commit is contained in:
@ -5,6 +5,7 @@ import typing
|
||||
from typing import List, Optional, Sequence, Union # noqa: F401
|
||||
|
||||
import torch
|
||||
import torch._custom_op.impl
|
||||
from torch import Tensor, types
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
@ -17,167 +18,167 @@ class TestInferSchemaWithAnnotation(TestCase):
|
||||
def foo_op(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.clone()
|
||||
|
||||
result = torch.library.infer_schema(foo_op, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x) -> Tensor")
|
||||
|
||||
def foo_op_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x.clone() + y
|
||||
|
||||
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x, Tensor y) -> Tensor")
|
||||
|
||||
def test_native_types(self):
|
||||
def foo_op(x: int) -> int:
|
||||
return x
|
||||
|
||||
result = torch.library.infer_schema(foo_op, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op, mutates_args)
|
||||
self.assertEqual(result, "(SymInt x) -> SymInt")
|
||||
|
||||
def foo_op_2(x: bool) -> bool:
|
||||
return x
|
||||
|
||||
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(bool x) -> bool")
|
||||
|
||||
def foo_op_3(x: str) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
|
||||
self.assertEqual(result, "(str x) -> SymInt")
|
||||
|
||||
def foo_op_4(x: float) -> float:
|
||||
return x
|
||||
|
||||
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args)
|
||||
self.assertEqual(result, "(float x) -> float")
|
||||
|
||||
def test_torch_types(self):
|
||||
def foo_op_1(x: torch.types.Number) -> torch.types.Number:
|
||||
return x
|
||||
|
||||
result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def foo_op_2(x: torch.dtype) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(ScalarType x) -> SymInt")
|
||||
|
||||
def foo_op_3(x: torch.device) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
|
||||
self.assertEqual(result, "(Device x) -> SymInt")
|
||||
|
||||
def test_type_variants(self):
|
||||
def foo_op_1(x: typing.Optional[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_1, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args)
|
||||
self.assertEqual(result, "(SymInt? x) -> SymInt")
|
||||
|
||||
def foo_op_2(x: typing.Sequence[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_3(x: typing.List[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_4(x: typing.Optional[typing.Sequence[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_5(x: typing.Optional[typing.List[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_6(x: typing.Union[int, float, bool]) -> types.Number:
|
||||
return x
|
||||
|
||||
result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def foo_op_7(x: typing.Union[int, bool, float]) -> types.Number:
|
||||
return x
|
||||
|
||||
result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def test_no_library_prefix(self):
|
||||
def foo_op(x: Tensor) -> Tensor:
|
||||
return x.clone()
|
||||
|
||||
result = torch.library.infer_schema(foo_op, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x) -> Tensor")
|
||||
|
||||
def foo_op_2(x: Tensor) -> torch.Tensor:
|
||||
return x.clone()
|
||||
|
||||
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x) -> Tensor")
|
||||
|
||||
def foo_op_3(x: torch.Tensor) -> Tensor:
|
||||
return x.clone()
|
||||
|
||||
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args)
|
||||
self.assertEqual(result, "(Tensor x) -> Tensor")
|
||||
|
||||
def foo_op_4(x: List[int]) -> types.Number:
|
||||
return x[0]
|
||||
|
||||
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> Scalar")
|
||||
|
||||
def foo_op_5(x: Optional[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args)
|
||||
self.assertEqual(result, "(SymInt? x) -> SymInt")
|
||||
|
||||
def foo_op_6(x: Sequence[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_7(x: List[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_8(x: Optional[Sequence[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_8, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_8, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_9(x: Optional[List[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_9, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_9, mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_10(x: Union[int, float, bool]) -> types.Number:
|
||||
return x
|
||||
|
||||
result = torch.library.infer_schema(foo_op_10, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_10, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def foo_op_11(x: Union[int, bool, float]) -> types.Number:
|
||||
return x
|
||||
|
||||
result = torch.library.infer_schema(foo_op_11, mutates_args=mutates_args)
|
||||
result = torch._custom_op.impl.infer_schema(foo_op_11, mutates_args)
|
||||
self.assertEqual(result, "(Scalar x) -> Scalar")
|
||||
|
||||
def test_unsupported_annotation(self):
|
||||
@ -189,7 +190,7 @@ class TestInferSchemaWithAnnotation(TestCase):
|
||||
def foo_op(x: D) -> Tensor: # noqa: F821
|
||||
return torch.Tensor(x)
|
||||
|
||||
torch.library.infer_schema(foo_op, mutates_args=mutates_args)
|
||||
torch._custom_op.impl.infer_schema(foo_op, mutates_args)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -199,7 +200,7 @@ class TestInferSchemaWithAnnotation(TestCase):
|
||||
def foo_op_2(x: Tensor) -> E: # noqa: F821
|
||||
return x
|
||||
|
||||
torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
|
||||
torch._custom_op.impl.infer_schema(foo_op_2, mutates_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -608,24 +608,19 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
def a(x: Tensor) -> Tensor:
|
||||
return torch.empty([])
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor"""
|
||||
)
|
||||
self.assertExpectedInline(infer_schema(a), """(Tensor x) -> Tensor""")
|
||||
|
||||
def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor:
|
||||
return torch.empty([])
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(kwonly1, mutates_args=()),
|
||||
"""(Tensor x, *, SymInt y, float z) -> Tensor""",
|
||||
infer_schema(kwonly1), """(Tensor x, *, SymInt y, float z) -> Tensor"""
|
||||
)
|
||||
|
||||
def kwonly2(*, y: Tensor) -> Tensor:
|
||||
return torch.empty([])
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor"""
|
||||
)
|
||||
self.assertExpectedInline(infer_schema(kwonly2), """(*, Tensor y) -> Tensor""")
|
||||
|
||||
def b(
|
||||
x: Tensor,
|
||||
@ -639,7 +634,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
return torch.empty([]), 1, 0.1, True
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(b, mutates_args=()),
|
||||
infer_schema(b),
|
||||
"""(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""",
|
||||
)
|
||||
|
||||
@ -652,7 +647,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
return [torch.empty([])]
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(c, mutates_args=()),
|
||||
infer_schema(c),
|
||||
"""(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""",
|
||||
)
|
||||
|
||||
@ -660,20 +655,18 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
return [torch.empty([])], torch.empty([])
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)"""
|
||||
infer_schema(d), """(Tensor x) -> (Tensor[], Tensor)"""
|
||||
)
|
||||
|
||||
def e() -> Tensor:
|
||||
return torch.empty([])
|
||||
|
||||
self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""")
|
||||
self.assertExpectedInline(infer_schema(e), """() -> Tensor""")
|
||||
|
||||
def f(x: Tensor) -> None:
|
||||
pass
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(f, mutates_args=()), """(Tensor x) -> ()"""
|
||||
)
|
||||
self.assertExpectedInline(infer_schema(f), """(Tensor x) -> ()""")
|
||||
|
||||
def g(
|
||||
x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
|
||||
@ -681,8 +674,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
pass
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(g, mutates_args=()),
|
||||
"""(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""",
|
||||
infer_schema(g), """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()"""
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
@ -711,7 +703,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
pass
|
||||
|
||||
self.assertExpectedInline(
|
||||
infer_schema(h, mutates_args=()),
|
||||
infer_schema(h),
|
||||
(
|
||||
"""(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """
|
||||
"""ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()"""
|
||||
@ -730,28 +722,28 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
def foo(*args):
|
||||
raise NotImplementedError
|
||||
|
||||
infer_schema(foo, mutates_args=())
|
||||
infer_schema(foo)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "varkwargs"):
|
||||
|
||||
def foo(**kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
infer_schema(foo, mutates_args=())
|
||||
infer_schema(foo)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "must have a type annotation"):
|
||||
|
||||
def foo(x):
|
||||
raise NotImplementedError
|
||||
|
||||
infer_schema(foo, mutates_args=())
|
||||
infer_schema(foo)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "unsupported"):
|
||||
|
||||
def foo(x: Tensor) -> Tuple[Tensor, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
infer_schema(foo, mutates_args=())
|
||||
infer_schema(foo)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "can be mutated"):
|
||||
|
||||
|
@ -73,7 +73,7 @@ def custom_op(
|
||||
f"is passed to `custom_op`"
|
||||
)
|
||||
|
||||
schema = infer_schema(func, mutates_args=()) if manual_schema is None else manual_schema
|
||||
schema = infer_schema(func) if manual_schema is None else manual_schema
|
||||
schema_str = f"{name}{schema}"
|
||||
function_schema = FunctionSchema.parse(schema_str)
|
||||
validate_schema(function_schema)
|
||||
|
@ -104,7 +104,7 @@ def custom_op(qualname, func_or_schema=None):
|
||||
f"is passed to `custom_op`"
|
||||
)
|
||||
|
||||
schema = infer_schema(func, mutates_args=())
|
||||
schema = infer_schema(func)
|
||||
_custom_op_with_schema(qualname, schema)
|
||||
return func
|
||||
|
||||
|
@ -128,7 +128,9 @@ def custom_op(
|
||||
import torch
|
||||
|
||||
if schema is None:
|
||||
schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
|
||||
import torch._custom_op.impl
|
||||
|
||||
schema_str = torch._custom_op.impl.infer_schema(fn, mutates_args)
|
||||
else:
|
||||
schema_str = schema
|
||||
|
||||
|
@ -11,11 +11,7 @@ from .. import device, dtype, Tensor, types
|
||||
|
||||
@exposed_in("torch.library")
|
||||
def infer_schema(
|
||||
prototype_function: typing.Callable,
|
||||
/,
|
||||
*,
|
||||
mutates_args,
|
||||
op_name: Optional[str] = None,
|
||||
prototype_function: typing.Callable, mutates_args=(), op_name: Optional[str] = None
|
||||
) -> str:
|
||||
r"""Parses the schema of a given function with type hints. The schema is inferred from the
|
||||
function's type hints, and can be used to define a new operator.
|
||||
|
Reference in New Issue
Block a user