mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Verify types in custom op schemas (#124520)
Before this PR, we didn't check that types in a schema were valid. This is because TorchScript treats unknown types as type variables. This PR checks types in a schema for the TORCH_LIBRARY APIs. To do this, we add an `allow_typevars` flag to parseSchema so that TorchScript can use allow_typevars=True. We also add some error messages for common mistakes (e.g. using int64_t or double in schema). Test Plan: - new tests Differential Revision: [D56432690](https://our.internmc.facebook.com/intern/diff/D56432690) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124520 Approved by: https://github.com/albanD
This commit is contained in:
@ -1740,6 +1740,17 @@ dynamic shape operator: _torch_testing.numpy_nonzero.default
|
||||
res = torch._library.utils.is_functional_schema(schema)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
def test_incorrect_schema_types(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
|
||||
lib.define("foo12(Tensor a) -> asdfasdf")
|
||||
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
|
||||
lib.define("foo12(asdf a) -> Tensor")
|
||||
with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"):
|
||||
lib.define("foo12(int64_t a) -> Tensor")
|
||||
with self.assertRaisesRegex(RuntimeError, "Use `float`"):
|
||||
lib.define("foo12(double a) -> Tensor")
|
||||
|
||||
def test_is_tensorlist_like_type(self):
|
||||
tensorlists = [
|
||||
# Tensor[]
|
||||
|
Reference in New Issue
Block a user