mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland] Verify types in custom op schemas (#126861)
Summary: co-dev reland of https://github.com/pytorch/pytorch/pull/124520, which requires the removal of some executorch tests. Before this PR, we didn't check that types in a schema were valid. This is because TorchScript treats unknown types as type variables. This PR checks types in a schema for the TORCH_LIBRARY APIs. To do this, we add an `allow_typevars` flag to parseSchema so that TorchScript can use allow_typevars=True. We also add some error messages for common mistakes (e.g. using int64_t or double in schema). Test Plan: Wait for tests Differential Revision: D57666659 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126861 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c921c5cc77
commit
f8857cef45
@ -23,14 +23,14 @@ namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
struct SchemaParser {
|
||||
explicit SchemaParser(const std::string& str)
|
||||
explicit SchemaParser(const std::string& str, bool allow_typevars)
|
||||
: L(std::make_shared<Source>(
|
||||
c10::string_view(str),
|
||||
c10::nullopt,
|
||||
0,
|
||||
nullptr,
|
||||
Source::DONT_COPY)),
|
||||
type_parser(L, /*parse_complete_tensor_types*/ false) {}
|
||||
type_parser(L, /*parse_complete_tensor_types*/ false, allow_typevars) {}
|
||||
|
||||
std::variant<OperatorName, FunctionSchema> parseDeclaration() {
|
||||
OperatorName name = parseName();
|
||||
@ -361,16 +361,24 @@ struct SchemaParser {
|
||||
}
|
||||
Lexer L;
|
||||
SchemaTypeParser type_parser;
|
||||
bool allow_typevars_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::variant<OperatorName, FunctionSchema> parseSchemaOrName(
|
||||
const std::string& schemaOrName) {
|
||||
return SchemaParser(schemaOrName).parseExactlyOneDeclaration();
|
||||
const std::string& schemaOrName,
|
||||
bool allow_typevars) {
|
||||
// We're ignoring aten and prim for BC reasons
|
||||
if (schemaOrName.rfind("aten::", 0) == 0 ||
|
||||
schemaOrName.rfind("prim::", 0) == 0) {
|
||||
allow_typevars = true;
|
||||
}
|
||||
return SchemaParser(schemaOrName, allow_typevars)
|
||||
.parseExactlyOneDeclaration();
|
||||
}
|
||||
|
||||
FunctionSchema parseSchema(const std::string& schema) {
|
||||
auto parsed = parseSchemaOrName(schema);
|
||||
FunctionSchema parseSchema(const std::string& schema, bool allow_typevars) {
|
||||
auto parsed = parseSchemaOrName(schema, allow_typevars);
|
||||
TORCH_CHECK(
|
||||
std::holds_alternative<FunctionSchema>(parsed),
|
||||
"Tried to parse a function schema but only the operator name was given");
|
||||
|
Reference in New Issue
Block a user