Revert "Verify types in custom op schemas (#124520)"

This reverts commit 5b98d43488bed0836b4da5996a50bafd0dd2c11c.

Reverted https://github.com/pytorch/pytorch/pull/124520 on behalf of https://github.com/zou3519 due to broke static runtime tests ([comment](https://github.com/pytorch/pytorch/pull/124520#issuecomment-2075111935))
This commit is contained in:
PyTorch MergeBot
2024-04-24 14:41:26 +00:00
parent 7d94f52a8a
commit 92295fbacd
9 changed files with 19 additions and 68 deletions

View File

@ -1740,17 +1740,6 @@ dynamic shape operator: _torch_testing.numpy_nonzero.default
res = torch._library.utils.is_functional_schema(schema)
self.assertEqual(res, expected)
def test_incorrect_schema_types(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
lib.define("foo12(Tensor a) -> asdfasdf")
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
lib.define("foo12(asdf a) -> Tensor")
with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"):
lib.define("foo12(int64_t a) -> Tensor")
with self.assertRaisesRegex(RuntimeError, "Use `float`"):
lib.define("foo12(double a) -> Tensor")
def test_is_tensorlist_like_type(self):
tensorlists = [
# Tensor[]

View File

@ -23,14 +23,14 @@ namespace torch::jit {
namespace {
struct SchemaParser {
explicit SchemaParser(const std::string& str, bool allow_typevars)
explicit SchemaParser(const std::string& str)
: L(std::make_shared<Source>(
c10::string_view(str),
c10::nullopt,
0,
nullptr,
Source::DONT_COPY)),
type_parser(L, /*parse_complete_tensor_types*/ false, allow_typevars) {}
type_parser(L, /*parse_complete_tensor_types*/ false) {}
std::variant<OperatorName, FunctionSchema> parseDeclaration() {
OperatorName name = parseName();
@ -361,19 +361,16 @@ struct SchemaParser {
}
Lexer L;
SchemaTypeParser type_parser;
bool allow_typevars_;
};
} // namespace
std::variant<OperatorName, FunctionSchema> parseSchemaOrName(
const std::string& schemaOrName,
bool allow_typevars) {
return SchemaParser(schemaOrName, allow_typevars)
.parseExactlyOneDeclaration();
const std::string& schemaOrName) {
return SchemaParser(schemaOrName).parseExactlyOneDeclaration();
}
FunctionSchema parseSchema(const std::string& schema, bool allow_typevars) {
auto parsed = parseSchemaOrName(schema, allow_typevars);
FunctionSchema parseSchema(const std::string& schema) {
auto parsed = parseSchemaOrName(schema);
TORCH_CHECK(
std::holds_alternative<FunctionSchema>(parsed),
"Tried to parse a function schema but only the operator name was given");

View File

@ -8,15 +8,9 @@
namespace torch {
namespace jit {
// allow_typevars: If true, we assume that lowercase types that we don't
// understand are type variables. This is only needed for TorchScript (and not
// not needed for custom ops).
TORCH_API std::variant<c10::OperatorName, c10::FunctionSchema> parseSchemaOrName(
const std::string& schemaOrName,
bool allow_typevars = true);
TORCH_API c10::FunctionSchema parseSchema(
const std::string& schema,
bool allow_typevars = true);
const std::string& schemaOrName);
TORCH_API c10::FunctionSchema parseSchema(const std::string& schema);
TORCH_API c10::OperatorName parseName(const std::string& name);
} // namespace jit

View File

@ -82,27 +82,12 @@ TypePtr SchemaTypeParser::parseBaseType() {
auto it = type_map.find(text);
if (it == type_map.end()) {
if (allow_typevars_ && !text.empty() && islower(text[0])) {
if (!text.empty() && islower(text[0])) {
// lower case identifiers that are not otherwise valid types
// are treated as type variables
return c10::TypeFactory::createNamed<VarType>(text);
}
if (text == "double") {
throw ErrorReport(tok.range)
<< "Use `float` instead of `double` in an operator's schema string. "
"`float` in schema corresponds to the double type in C++";
}
if (text == "int64_t") {
throw ErrorReport(tok.range)
<< "Use `SymInt` or `int` instead of `int64_t` in an operator's schema string. "
"`SymInt` corresponds to c10::SymInt in C++ while `int` in schema corresponds "
"to the int64_t type in C++.";
}
throw ErrorReport(tok.range)
<< "unknown type specifier. Common valid schema types include "
"Tensor, SymInt, int, float, bool, Scalar; "
"for a full list, please see "
"https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func ";
throw ErrorReport(tok.range) << "unknown type specifier";
}
return it->second;
}

View File

@ -20,13 +20,8 @@ struct TORCH_API SchemaTypeParser {
c10::optional<at::ScalarType> parseTensorDType(const std::string& dtype);
TypePtr parseRefinedTensor();
SchemaTypeParser(
Lexer& L,
bool parse_complete_tensor_types,
bool allow_typevars)
: complete_tensor_types(parse_complete_tensor_types),
L(L),
allow_typevars_(allow_typevars) {}
SchemaTypeParser(Lexer& L, bool parse_complete_tensor_types)
: complete_tensor_types(parse_complete_tensor_types), L(L) {}
private:
c10::optional<bool> tryToParseRequiresGrad();
@ -40,7 +35,6 @@ struct TORCH_API SchemaTypeParser {
bool complete_tensor_types;
Lexer& L;
size_t next_id = 0;
bool allow_typevars_;
};
} // namespace jit
} // namespace torch

View File

@ -35,10 +35,7 @@ class IRParser {
: L(std::make_shared<Source>(str)),
g(graph),
vmap(vmap),
type_parser(
L,
/*parse_complete_tensor_types*/ true,
/*allow_type_vars*/ true),
type_parser(L, /*parse_complete_tensor_types*/ true),
parse_tensor_constants_(parse_tensor_constants) {}
std::string parseVar();

View File

@ -1765,11 +1765,7 @@ void initJITBindings(PyObject* module) {
},
py::arg("input"),
py::arg("parse_tensor_constants") = false);
m.def(
"parse_schema",
&parseSchema,
py::arg("schema"),
py::arg("allow_typevars") = true);
m.def("parse_schema", parseSchema);
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
std::ostringstream s;
auto type = unifyTypeList(types, s);

View File

@ -1347,8 +1347,7 @@ bool isNoOpSlice(Node* node) {
void EliminateNoOpSlice(std::shared_ptr<Graph>& graph) {
DepthFirstGraphNodeIterator it(graph);
auto schema = torch::schema(
"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]",
/*allow_typevars*/ true);
"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]");
Node* node = nullptr;
std::vector<Node*> to_delete;
while ((node = it.next()) != nullptr) {

View File

@ -406,8 +406,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
/// ```
///
/// \ingroup torch-schema-overloads
inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k, bool allow_typevars=false) {
c10::FunctionSchema s = torch::jit::parseSchema(str, /*allow_typevars*/allow_typevars);
inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) {
c10::FunctionSchema s = torch::jit::parseSchema(str);
s.setAliasAnalysis(k);
return s;
}
@ -415,8 +415,8 @@ inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k, boo
/// Function schemas can be directly constructed from string literals.
///
/// \ingroup torch-schema-overloads
inline c10::FunctionSchema schema(const char* s, bool allow_typevars=false) {
return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA, allow_typevars);
inline c10::FunctionSchema schema(const char* s) {
return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA);
}
/// \private