mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Verify types in custom op schemas (#124520)"
This reverts commit 5b98d43488bed0836b4da5996a50bafd0dd2c11c. Reverted https://github.com/pytorch/pytorch/pull/124520 on behalf of https://github.com/zou3519 due to broke static runtime tests ([comment](https://github.com/pytorch/pytorch/pull/124520#issuecomment-2075111935))
This commit is contained in:
@ -1740,17 +1740,6 @@ dynamic shape operator: _torch_testing.numpy_nonzero.default
|
||||
res = torch._library.utils.is_functional_schema(schema)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
def test_incorrect_schema_types(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
|
||||
lib.define("foo12(Tensor a) -> asdfasdf")
|
||||
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
|
||||
lib.define("foo12(asdf a) -> Tensor")
|
||||
with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"):
|
||||
lib.define("foo12(int64_t a) -> Tensor")
|
||||
with self.assertRaisesRegex(RuntimeError, "Use `float`"):
|
||||
lib.define("foo12(double a) -> Tensor")
|
||||
|
||||
def test_is_tensorlist_like_type(self):
|
||||
tensorlists = [
|
||||
# Tensor[]
|
||||
|
@ -23,14 +23,14 @@ namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
struct SchemaParser {
|
||||
explicit SchemaParser(const std::string& str, bool allow_typevars)
|
||||
explicit SchemaParser(const std::string& str)
|
||||
: L(std::make_shared<Source>(
|
||||
c10::string_view(str),
|
||||
c10::nullopt,
|
||||
0,
|
||||
nullptr,
|
||||
Source::DONT_COPY)),
|
||||
type_parser(L, /*parse_complete_tensor_types*/ false, allow_typevars) {}
|
||||
type_parser(L, /*parse_complete_tensor_types*/ false) {}
|
||||
|
||||
std::variant<OperatorName, FunctionSchema> parseDeclaration() {
|
||||
OperatorName name = parseName();
|
||||
@ -361,19 +361,16 @@ struct SchemaParser {
|
||||
}
|
||||
Lexer L;
|
||||
SchemaTypeParser type_parser;
|
||||
bool allow_typevars_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::variant<OperatorName, FunctionSchema> parseSchemaOrName(
|
||||
const std::string& schemaOrName,
|
||||
bool allow_typevars) {
|
||||
return SchemaParser(schemaOrName, allow_typevars)
|
||||
.parseExactlyOneDeclaration();
|
||||
const std::string& schemaOrName) {
|
||||
return SchemaParser(schemaOrName).parseExactlyOneDeclaration();
|
||||
}
|
||||
|
||||
FunctionSchema parseSchema(const std::string& schema, bool allow_typevars) {
|
||||
auto parsed = parseSchemaOrName(schema, allow_typevars);
|
||||
FunctionSchema parseSchema(const std::string& schema) {
|
||||
auto parsed = parseSchemaOrName(schema);
|
||||
TORCH_CHECK(
|
||||
std::holds_alternative<FunctionSchema>(parsed),
|
||||
"Tried to parse a function schema but only the operator name was given");
|
||||
|
@ -8,15 +8,9 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// allow_typevars: If true, we assume that lowercase types that we don't
|
||||
// understand are type variables. This is only needed for TorchScript (and not
|
||||
// not needed for custom ops).
|
||||
TORCH_API std::variant<c10::OperatorName, c10::FunctionSchema> parseSchemaOrName(
|
||||
const std::string& schemaOrName,
|
||||
bool allow_typevars = true);
|
||||
TORCH_API c10::FunctionSchema parseSchema(
|
||||
const std::string& schema,
|
||||
bool allow_typevars = true);
|
||||
const std::string& schemaOrName);
|
||||
TORCH_API c10::FunctionSchema parseSchema(const std::string& schema);
|
||||
TORCH_API c10::OperatorName parseName(const std::string& name);
|
||||
|
||||
} // namespace jit
|
||||
|
@ -82,27 +82,12 @@ TypePtr SchemaTypeParser::parseBaseType() {
|
||||
|
||||
auto it = type_map.find(text);
|
||||
if (it == type_map.end()) {
|
||||
if (allow_typevars_ && !text.empty() && islower(text[0])) {
|
||||
if (!text.empty() && islower(text[0])) {
|
||||
// lower case identifiers that are not otherwise valid types
|
||||
// are treated as type variables
|
||||
return c10::TypeFactory::createNamed<VarType>(text);
|
||||
}
|
||||
if (text == "double") {
|
||||
throw ErrorReport(tok.range)
|
||||
<< "Use `float` instead of `double` in an operator's schema string. "
|
||||
"`float` in schema corresponds to the double type in C++";
|
||||
}
|
||||
if (text == "int64_t") {
|
||||
throw ErrorReport(tok.range)
|
||||
<< "Use `SymInt` or `int` instead of `int64_t` in an operator's schema string. "
|
||||
"`SymInt` corresponds to c10::SymInt in C++ while `int` in schema corresponds "
|
||||
"to the int64_t type in C++.";
|
||||
}
|
||||
throw ErrorReport(tok.range)
|
||||
<< "unknown type specifier. Common valid schema types include "
|
||||
"Tensor, SymInt, int, float, bool, Scalar; "
|
||||
"for a full list, please see "
|
||||
"https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func ";
|
||||
throw ErrorReport(tok.range) << "unknown type specifier";
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
@ -20,13 +20,8 @@ struct TORCH_API SchemaTypeParser {
|
||||
c10::optional<at::ScalarType> parseTensorDType(const std::string& dtype);
|
||||
TypePtr parseRefinedTensor();
|
||||
|
||||
SchemaTypeParser(
|
||||
Lexer& L,
|
||||
bool parse_complete_tensor_types,
|
||||
bool allow_typevars)
|
||||
: complete_tensor_types(parse_complete_tensor_types),
|
||||
L(L),
|
||||
allow_typevars_(allow_typevars) {}
|
||||
SchemaTypeParser(Lexer& L, bool parse_complete_tensor_types)
|
||||
: complete_tensor_types(parse_complete_tensor_types), L(L) {}
|
||||
|
||||
private:
|
||||
c10::optional<bool> tryToParseRequiresGrad();
|
||||
@ -40,7 +35,6 @@ struct TORCH_API SchemaTypeParser {
|
||||
bool complete_tensor_types;
|
||||
Lexer& L;
|
||||
size_t next_id = 0;
|
||||
bool allow_typevars_;
|
||||
};
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -35,10 +35,7 @@ class IRParser {
|
||||
: L(std::make_shared<Source>(str)),
|
||||
g(graph),
|
||||
vmap(vmap),
|
||||
type_parser(
|
||||
L,
|
||||
/*parse_complete_tensor_types*/ true,
|
||||
/*allow_type_vars*/ true),
|
||||
type_parser(L, /*parse_complete_tensor_types*/ true),
|
||||
parse_tensor_constants_(parse_tensor_constants) {}
|
||||
|
||||
std::string parseVar();
|
||||
|
@ -1765,11 +1765,7 @@ void initJITBindings(PyObject* module) {
|
||||
},
|
||||
py::arg("input"),
|
||||
py::arg("parse_tensor_constants") = false);
|
||||
m.def(
|
||||
"parse_schema",
|
||||
&parseSchema,
|
||||
py::arg("schema"),
|
||||
py::arg("allow_typevars") = true);
|
||||
m.def("parse_schema", parseSchema);
|
||||
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
|
||||
std::ostringstream s;
|
||||
auto type = unifyTypeList(types, s);
|
||||
|
@ -1347,8 +1347,7 @@ bool isNoOpSlice(Node* node) {
|
||||
void EliminateNoOpSlice(std::shared_ptr<Graph>& graph) {
|
||||
DepthFirstGraphNodeIterator it(graph);
|
||||
auto schema = torch::schema(
|
||||
"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]",
|
||||
/*allow_typevars*/ true);
|
||||
"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]");
|
||||
Node* node = nullptr;
|
||||
std::vector<Node*> to_delete;
|
||||
while ((node = it.next()) != nullptr) {
|
||||
|
@ -406,8 +406,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
|
||||
/// ```
|
||||
///
|
||||
/// \ingroup torch-schema-overloads
|
||||
inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k, bool allow_typevars=false) {
|
||||
c10::FunctionSchema s = torch::jit::parseSchema(str, /*allow_typevars*/allow_typevars);
|
||||
inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) {
|
||||
c10::FunctionSchema s = torch::jit::parseSchema(str);
|
||||
s.setAliasAnalysis(k);
|
||||
return s;
|
||||
}
|
||||
@ -415,8 +415,8 @@ inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k, boo
|
||||
/// Function schemas can be directly constructed from string literals.
|
||||
///
|
||||
/// \ingroup torch-schema-overloads
|
||||
inline c10::FunctionSchema schema(const char* s, bool allow_typevars=false) {
|
||||
return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA, allow_typevars);
|
||||
inline c10::FunctionSchema schema(const char* s) {
|
||||
return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA);
|
||||
}
|
||||
|
||||
/// \private
|
||||
|
Reference in New Issue
Block a user