Verify types in custom op schemas (#124520)

Before this PR, we didn't check that types in a schema were valid. This
is because TorchScript treats unknown types as type variables.

This PR checks types in a schema for the TORCH_LIBRARY APIs. To do this,
we add an `allow_typevars` flag to parseSchema so that TorchScript can
use allow_typevars=True. We also add some error messages for common
mistakes (e.g. using int64_t or double in schema).

Test Plan:
- new tests

Differential Revision: [D56432690](https://our.internmc.facebook.com/intern/diff/D56432690)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124520
Approved by: https://github.com/albanD
This commit is contained in:
rzou
2024-04-22 12:14:03 -07:00
committed by PyTorch MergeBot
parent 107f944f22
commit 5b98d43488
9 changed files with 68 additions and 19 deletions

View File

@ -1740,6 +1740,17 @@ dynamic shape operator: _torch_testing.numpy_nonzero.default
res = torch._library.utils.is_functional_schema(schema)
self.assertEqual(res, expected)
def test_incorrect_schema_types(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
lib.define("foo12(Tensor a) -> asdfasdf")
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
lib.define("foo12(asdf a) -> Tensor")
with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"):
lib.define("foo12(int64_t a) -> Tensor")
with self.assertRaisesRegex(RuntimeError, "Use `float`"):
lib.define("foo12(double a) -> Tensor")
def test_is_tensorlist_like_type(self):
tensorlists = [
# Tensor[]