[Reland] Verify types in custom op schemas (#126861)

Summary:
co-dev reland of https://github.com/pytorch/pytorch/pull/124520, which requires
the removal of some executorch tests.

Before this PR, we didn't check that types in a schema were valid. This
is because TorchScript treats unknown types as type variables.

This PR checks types in a schema for the TORCH_LIBRARY APIs. To do this,
we add an `allow_typevars` flag to parseSchema so that TorchScript can
use allow_typevars=True. We also add some error messages for common
mistakes (e.g. using int64_t or double in schema).

Test Plan: Wait for tests

Differential Revision: D57666659

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126861
Approved by: https://github.com/albanD
This commit is contained in:
Richard Zou
2024-05-23 19:53:52 +00:00
committed by PyTorch MergeBot
parent c921c5cc77
commit f8857cef45
9 changed files with 75 additions and 19 deletions

View File

@ -23,14 +23,14 @@ namespace torch::jit {
namespace {
struct SchemaParser {
explicit SchemaParser(const std::string& str)
explicit SchemaParser(const std::string& str, bool allow_typevars)
: L(std::make_shared<Source>(
c10::string_view(str),
c10::nullopt,
0,
nullptr,
Source::DONT_COPY)),
type_parser(L, /*parse_complete_tensor_types*/ false) {}
type_parser(L, /*parse_complete_tensor_types*/ false, allow_typevars) {}
std::variant<OperatorName, FunctionSchema> parseDeclaration() {
OperatorName name = parseName();
@ -361,16 +361,24 @@ struct SchemaParser {
}
Lexer L;
SchemaTypeParser type_parser;
bool allow_typevars_;
};
} // namespace
std::variant<OperatorName, FunctionSchema> parseSchemaOrName(
const std::string& schemaOrName) {
return SchemaParser(schemaOrName).parseExactlyOneDeclaration();
const std::string& schemaOrName,
bool allow_typevars) {
// We're ignoring aten and prim for BC reasons
if (schemaOrName.rfind("aten::", 0) == 0 ||
schemaOrName.rfind("prim::", 0) == 0) {
allow_typevars = true;
}
return SchemaParser(schemaOrName, allow_typevars)
.parseExactlyOneDeclaration();
}
FunctionSchema parseSchema(const std::string& schema) {
auto parsed = parseSchemaOrName(schema);
FunctionSchema parseSchema(const std::string& schema, bool allow_typevars) {
auto parsed = parseSchemaOrName(schema, allow_typevars);
TORCH_CHECK(
std::holds_alternative<FunctionSchema>(parsed),
"Tried to parse a function schema but only the operator name was given");