mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD
342 lines
17 KiB
Python
342 lines
17 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
import torch
|
|
from torch._C import parse_schema
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class TestFunctionSchema(TestCase):
|
|
def test_serialize_and_deserialize(self):
|
|
schemas = torch._C._jit_get_all_schemas()
|
|
# so far we have around 1700 registered schemas
|
|
self.assertGreater(len(schemas), 1000)
|
|
for schema in schemas:
|
|
parsed_schema = parse_schema(str(schema))
|
|
self.assertEqual(parsed_schema, schema)
|
|
self.assertTrue(parsed_schema.is_backward_compatible_with(schema))
|
|
|
|
def test_out_schema(self):
|
|
schema_with_out = parse_schema(
|
|
"any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
self.assertTrue(schema_with_out.arguments[-1].is_out)
|
|
schema_without_out = parse_schema(
|
|
"any.not_out(Tensor self, Tensor b) -> Tensor"
|
|
)
|
|
self.assertFalse(schema_without_out.arguments[-1].is_out)
|
|
|
|
def test_hash_schema(self):
|
|
schema1 = parse_schema("any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
|
|
schema2 = parse_schema("any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
|
|
self.assertEqual(hash(schema1), hash(schema2))
|
|
|
|
schema3 = parse_schema(
|
|
"any.not_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
self.assertNotEqual(hash(schema2), hash(schema3))
|
|
|
|
schema4 = parse_schema(
|
|
"foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
self.assertNotEqual(hash(schema2), hash(schema4))
|
|
|
|
# schemas with different default value, or different kw-only arg, should have different hash
|
|
default_val_schema0 = parse_schema("foo(Tensor self, int a = 2) -> Tensor(a!)")
|
|
default_val_schema1 = parse_schema("foo(Tensor self, int a = 3) -> Tensor(a!)")
|
|
default_val_schema2 = parse_schema(
|
|
"foo(Tensor self, *, int a = 2) -> Tensor(a!)"
|
|
)
|
|
self.assertNotEqual(hash(default_val_schema0), hash(default_val_schema1))
|
|
self.assertNotEqual(hash(default_val_schema0), hash(default_val_schema2))
|
|
|
|
# schema with different alias annotation should have different hash
|
|
alias_schema = parse_schema("foo(Tensor(a!) self, int a = 2) -> Tensor(a!)")
|
|
self.assertNotEqual(hash(default_val_schema0), hash(alias_schema))
|
|
alias_schema2 = parse_schema("foo(Tensor(b!) self, int a = 2) -> Tensor(a!)")
|
|
self.assertNotEqual(hash(alias_schema), hash(alias_schema2))
|
|
|
|
# schema with different alias infos
|
|
alias_schema3 = parse_schema(
|
|
"foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)"
|
|
)
|
|
alias_schema4 = parse_schema(
|
|
"foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(b!)"
|
|
)
|
|
alias_schema5 = parse_schema(
|
|
"foo(Tensor self, *, int a, int b=1, Tensor(b!) out, Tensor(a!) b) -> Tensor(a!)"
|
|
)
|
|
self.assertNotEqual(hash(alias_schema3), hash(alias_schema4))
|
|
self.assertNotEqual(hash(alias_schema3), hash(alias_schema5))
|
|
|
|
def test_backward_compatible_structure(self):
|
|
old_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor")
|
|
# BC: A new schema without changes.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with different name.
|
|
new_schema = parse_schema("any_.over(Tensor self, *, Tensor b) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with different overload name.
|
|
new_schema = parse_schema("any.other(Tensor self, *, Tensor b) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema that adds vararg.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b, ...) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with different number of outputs.
|
|
new_schema = parse_schema(
|
|
"any.over(Tensor self, *, Tensor b) -> (Tensor, Tensor)"
|
|
)
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
|
|
def test_backward_compatible_outputs(self):
|
|
old_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor")
|
|
# No-BC: A new schema with output becoming of optional type.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor?")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: (the opposite case) An schema where the output is not of optional type anymore.
|
|
self.assertTrue(old_schema.is_backward_compatible_with(new_schema))
|
|
# No-BC: A new schema with a different output type.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> int")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with a different output type.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor out")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
|
|
def test_backward_compatible_arguments(self):
|
|
old_schema = parse_schema("any(Tensor self, *, Tensor b, int c) -> Tensor")
|
|
# No-BC: A new schema with less arguments.
|
|
new_schema = parse_schema("any(Tensor self, *, Tensor b) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with more arguments, appended, but no default value.
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, *, Tensor b, int c, int d) -> Tensor"
|
|
)
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: A new schema with more arguments, appended, that have a default value.
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, *, Tensor b, int c, int d=1) -> Tensor"
|
|
)
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with more arguments, not-appended, that have a default value.
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, int d=1, *, Tensor b, int c) -> Tensor"
|
|
)
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: A new schema where old kwargs becomes positional.
|
|
new_schema = parse_schema("any(Tensor self, Tensor b, *, int c) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: (the opposite case) A new schema where an old positional argument becomes kwarg.
|
|
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
|
|
# BC: A new schema where all old kwargs become positional.
|
|
new_schema = parse_schema("any(Tensor self, Tensor b, int c) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: (the opposite case) A new schema where all old positional arguments become kwarg.
|
|
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
|
|
# No-BC: A new schema where old kwargs appear in different order.
|
|
new_schema = parse_schema("any(Tensor self, *, int c, Tensor b) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: A new schema where argument becomes of type optional.
|
|
new_schema = parse_schema("any(Tensor self, *, Tensor b, int? c) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: A new schema where argument gains a default value.
|
|
new_schema = parse_schema("any(Tensor self, *, Tensor b, int c=1) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema where argument is "renamed".
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, *, Tensor b, int renamed) -> Tensor"
|
|
)
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema where argument type changes to an incompatible type.
|
|
new_schema = parse_schema("any(Tensor self, *, Tensor b, int[] c) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
|
|
def test_backward_compatible_with_smart_serialization(self):
|
|
# cases where out arg is provided
|
|
old_schema = parse_schema(
|
|
"foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
new_schema_same_out = parse_schema(
|
|
"foo(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
new_schema_wrong_default = parse_schema(
|
|
"foo(Tensor self, *, int b=1, int a, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
new_schema_more_out = parse_schema(
|
|
"foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)"
|
|
)
|
|
new_schema_wrong_pos = parse_schema(
|
|
"foo(Tensor self, *, int a, int b=1, Tensor(b!) b, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
self.assertTrue(new_schema_same_out.is_backward_compatible_with(old_schema))
|
|
self.assertTrue(new_schema_more_out.is_backward_compatible_with(old_schema))
|
|
self.assertFalse(
|
|
new_schema_wrong_default.is_backward_compatible_with(old_schema)
|
|
)
|
|
self.assertFalse(new_schema_wrong_pos.is_backward_compatible_with(old_schema))
|
|
|
|
# cases where out arg is not provided
|
|
old_schema_without_arg = parse_schema("foo(Tensor self, int a, int b=1) -> int")
|
|
new_schema_without_arg = parse_schema(
|
|
"foo(Tensor self, int a, int b=1, int c=2) -> int"
|
|
)
|
|
new_schema_without_arg_multiple_default = parse_schema(
|
|
"foo(Tensor self, int a, int b=1, int c=2, int d=3) -> int"
|
|
)
|
|
new_schema_without_arg_wrong_pos = parse_schema(
|
|
"foo(Tensor self, int a, int c=2, int b=1) -> int"
|
|
)
|
|
self.assertTrue(
|
|
new_schema_without_arg.is_backward_compatible_with(old_schema_without_arg)
|
|
)
|
|
self.assertTrue(
|
|
new_schema_without_arg_multiple_default.is_backward_compatible_with(
|
|
old_schema_without_arg
|
|
)
|
|
)
|
|
self.assertFalse(
|
|
new_schema_without_arg_wrong_pos.is_backward_compatible_with(
|
|
old_schema_without_arg
|
|
)
|
|
)
|
|
|
|
def test_string_optional_parameter_default_value(self):
|
|
schema_a = parse_schema('example::op(str? order="NCHW") -> (Tensor)')
|
|
schema_b = parse_schema(str(schema_a))
|
|
self.assertEqual(schema_a, schema_b)
|
|
|
|
def test_forward_compatible_arguments_without_out(self):
|
|
old_schema = parse_schema("any(Tensor self, int a, int b=1) -> Tensor")
|
|
# deleting default arg is FC compatible
|
|
new_schema = parse_schema("any(Tensor self, int a) -> Tensor")
|
|
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertTrue(is_fc)
|
|
# adding default arg is FC compatible
|
|
new_schema = parse_schema("any(Tensor self, int a, int b=1, int c=1) -> Tensor")
|
|
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertTrue(is_fc)
|
|
# adding default arg with container type is NOT FC compatible
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, int a, int b=1, int[2] c=1) -> Tensor"
|
|
)
|
|
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertFalse(is_fc)
|
|
self.assertEqual(
|
|
reason,
|
|
"Function schema is not forward compatible since the new argument"
|
|
" 'c' of type int[] has a container type as its default value.",
|
|
)
|
|
# updating the default value of a default arg is NOT FC compatible
|
|
new_schema = parse_schema("any(Tensor self, int a, int b=4) -> Tensor")
|
|
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertFalse(is_fc)
|
|
self.assertEqual(
|
|
reason, "'b' is not forward compatible with the older version of the schema"
|
|
)
|
|
# updating the arg name of a default arg is NOT FC compatible
|
|
new_schema = parse_schema("any(Tensor self, int a, int c=1) -> Tensor")
|
|
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertFalse(is_fc)
|
|
self.assertEqual(
|
|
reason, "'c' is not forward compatible with the older version of the schema"
|
|
)
|
|
# not adding default arg in the end is NOT FC compatible
|
|
new_schema = parse_schema("any(Tensor self, int a, int c=1, int b=1) -> Tensor")
|
|
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertFalse(is_fc)
|
|
self.assertEqual(
|
|
reason, "'c' is not forward compatible with the older version of the schema"
|
|
)
|
|
# making default arg into positional arg is NOT FC compatible
|
|
new_schema = parse_schema("any(Tensor self, int a, int b) -> Tensor")
|
|
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertFalse(is_fc)
|
|
self.assertEqual(
|
|
reason, "'b' is not forward compatible with the older version of the schema"
|
|
)
|
|
# making positional arg into default arg is NOT FC compatible
|
|
new_schema = parse_schema("any(Tensor self, int a=1, int b=1) -> Tensor")
|
|
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertFalse(is_fc)
|
|
self.assertEqual(
|
|
reason, "'a' is not forward compatible with the older version of the schema"
|
|
)
|
|
|
|
def test_forward_compatible_arguments_real_use_case(self):
|
|
# this change introduced forward incompatibility in the past
|
|
old_slice_schema = parse_schema(
|
|
"slice(Tensor(a) self, int dim=0, int start=0, int end=0, int step=1) -> Tensor(a)"
|
|
)
|
|
new_slice_schema = parse_schema(
|
|
"slice(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)"
|
|
)
|
|
is_fc, reason = new_slice_schema.check_forward_compatible_with(old_slice_schema)
|
|
self.assertFalse(is_fc)
|
|
self.assertEqual(
|
|
reason,
|
|
"'start' is not forward compatible with the older version of the schema",
|
|
)
|
|
|
|
def test_forward_compatible_arguments_with_out(self):
|
|
old_schema = parse_schema(
|
|
"any(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertTrue(is_fc)
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, *, int a, int b=1, int c=1, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertTrue(is_fc)
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, *, int a, Tensor(d!) d, int b=1, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
|
self.assertFalse(is_fc)
|
|
self.assertEqual(
|
|
reason, "Function schema should have the same number of out arguments"
|
|
)
|
|
|
|
def test_schema_error(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"schemas with vararg \(...\) can't have default value args"
|
|
):
|
|
parse_schema("any.foo(int arg1, int arg2=0, ...)")
|
|
|
|
def test_tensor_list_alias_annotation_properly_parsed(self):
|
|
schema_str = "foo(Tensor self, *, Tensor(a!)[] out) -> ()"
|
|
schema = parse_schema(schema_str)
|
|
self.assertTrue(schema.arguments[-1].alias_info.is_write)
|
|
self.assertEqual(str(schema), schema_str)
|
|
|
|
def test_tensor_option_arguments_properly_parsed(self):
|
|
schema_str = (
|
|
"_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
|
|
"bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor"
|
|
)
|
|
schema = parse_schema(schema_str)
|
|
# fake type of MemoryFormat? is int?
|
|
self.assertEqual(schema.arguments[-1].type.str(), "int?")
|
|
# fake type of Layout? is int?
|
|
self.assertEqual(schema.arguments[2].type.str(), "int?")
|
|
# fake type of Device? is Device?
|
|
self.assertEqual(schema.arguments[3].type.str(), "Device?")
|
|
# print real types in FunctionSchema
|
|
self.assertEqual(str(schema), schema_str)
|
|
|
|
def test_sym_int_argument_properly_parsed(self):
|
|
schema_str = "sym_size.int(Tensor self, int dim) -> SymInt"
|
|
schema = parse_schema(schema_str)
|
|
# fake type of SymInt is int
|
|
self.assertEqual(schema.returns[-1].type.str(), "int")
|
|
# real type of SymInt is SymInt
|
|
self.assertEqual(schema.returns[-1].real_type.str(), "SymInt")
|
|
# print real types in FunctionSchema
|
|
self.assertEqual(str(schema), schema_str)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|