mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add option to load historic operators in IR when the operator is deprecated (#71148)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71148 Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D33521300 Pulled By: tugsbayasgalan fbshipit-source-id: a0607dba5e7233590384326537017eb0b18da419
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8f4cec2231
commit
70951884d4
@ -4,6 +4,8 @@
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
@ -58,5 +60,40 @@ TEST(UpgraderUtils, FindIfOpIsCurrent) {
|
||||
test_only_remove_entry("foo");
|
||||
}
|
||||
|
||||
TEST(UpgraderUtils, CanLoadHistoricOp) {
|
||||
std::vector<UpgraderEntry> dummy_entry = {
|
||||
{4, "foo__0_3", "foo.bar()"},
|
||||
{8, "foo__4_7", "foo.foo()"},
|
||||
};
|
||||
|
||||
std::vector<std::string> schemas = {"foo.bar()", "foo.foo()"};
|
||||
|
||||
// symbol based look up
|
||||
test_only_add_entry("old_op_not_exist.first", dummy_entry[0]);
|
||||
test_only_add_entry("old_op_not_exist.second", dummy_entry[1]);
|
||||
|
||||
auto oldSchemas = loadPossibleHistoricOps("old_op_not_exist", 2);
|
||||
EXPECT_EQ(oldSchemas.size(), 2);
|
||||
for (const auto& entry : oldSchemas) {
|
||||
EXPECT_TRUE(
|
||||
std::find(schemas.begin(), schemas.end(), entry) != schemas.end());
|
||||
}
|
||||
|
||||
auto oldSchemasWithCurrentVersion =
|
||||
loadPossibleHistoricOps("old_op_not_exist", 9);
|
||||
EXPECT_EQ(oldSchemasWithCurrentVersion.size(), 0);
|
||||
|
||||
test_only_remove_entry("old_op_not_exist.first");
|
||||
test_only_remove_entry("old_op_not_exist.first");
|
||||
|
||||
// it is ok to have old schemas without overload
|
||||
test_only_add_entry("old_op_not_exist_no_overload", dummy_entry[0]);
|
||||
auto oldSchemasNoOverload =
|
||||
loadPossibleHistoricOps("old_op_not_exist_no_overload", 2);
|
||||
EXPECT_EQ(oldSchemasNoOverload.size(), 1);
|
||||
EXPECT_EQ(oldSchemasNoOverload[0], "foo.bar()");
|
||||
test_only_remove_entry("old_op_not_exist_no_overload");
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -615,6 +615,8 @@ static Value* emitBuiltinNode(
|
||||
if (!version.has_value() ||
|
||||
isOpSymbolCurrent(matched_schema.schema_name, version.value())) {
|
||||
n->getOperation();
|
||||
} else {
|
||||
n->setHistoricSchemaName(matched_schema.schema_name);
|
||||
}
|
||||
|
||||
return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
|
||||
@ -678,6 +680,18 @@ Value* emitBuiltinCall(
|
||||
schemas.push_back(&op->schema());
|
||||
}
|
||||
|
||||
// we might have seen old historic
|
||||
// ops that are deprecated
|
||||
if (variants.empty()) {
|
||||
auto oldSchemas =
|
||||
loadPossibleHistoricOps(name.toQualString(), graph_version);
|
||||
upgrader_schemas.reserve(oldSchemas.size());
|
||||
for (const auto& old_schema_entry : oldSchemas) {
|
||||
FunctionSchema old_schema = parseSchema(old_schema_entry);
|
||||
upgrader_schemas.emplace_back(old_schema);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO (tugsuu): make sure this is optimized later
|
||||
for (const auto& schema : upgrader_schemas) {
|
||||
schemas.push_back(&schema);
|
||||
@ -710,7 +724,7 @@ Value* emitBuiltinCall(
|
||||
|
||||
auto matched = matchSchemas(schemas, loc, graph, args, kwargs, self);
|
||||
|
||||
if (matched.first < variants.size()) {
|
||||
if (matched.first < variants.size() + upgrader_schemas.size()) {
|
||||
return emitBuiltinNode(matched.second, loc, graph, name, graph_version);
|
||||
} else {
|
||||
auto& fn = *builtin_functions[matched.first - variants.size()];
|
||||
|
@ -338,6 +338,12 @@ struct TORCH_API Node {
|
||||
topo_position_t topo_position_ = 0;
|
||||
// a managing wrapper for Python to allow invalidation
|
||||
std::shared_ptr<Wrap<Node>> wrap_;
|
||||
// Stores the full schema name, if the operator is historic
|
||||
// When the operator is deprecated or the name of the operator
|
||||
// is changed, we need to rely on this name
|
||||
// to retrieve old schemas to successfully apply upgraders
|
||||
// for this operator.
|
||||
c10::optional<std::string> historic_schema_name_ = c10::nullopt;
|
||||
|
||||
protected:
|
||||
Node(Graph* graph_, NodeKind kind_); // defined after graph
|
||||
@ -362,6 +368,14 @@ struct TORCH_API Node {
|
||||
return wrap_;
|
||||
}
|
||||
|
||||
const c10::optional<std::string> getHistoricSchemaName() {
|
||||
return historic_schema_name_;
|
||||
}
|
||||
|
||||
void setHistoricSchemaName(const std::string& name) {
|
||||
historic_schema_name_ = name;
|
||||
}
|
||||
|
||||
Node*& next() {
|
||||
return next_in_graph[kNextDirection];
|
||||
}
|
||||
|
@ -49,5 +49,29 @@ bool isOpSymbolCurrent(const std::string& name, size_t current_version) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::string> loadPossibleHistoricOps(
|
||||
const std::string& name,
|
||||
c10::optional<size_t> version) {
|
||||
std::vector<std::string> possibleSchemas;
|
||||
|
||||
if (!version.has_value()) {
|
||||
return possibleSchemas;
|
||||
}
|
||||
|
||||
for (const auto& entry : get_operator_version_map()) {
|
||||
auto old_symbol_name = entry.first;
|
||||
// strip off the overload name, if exist
|
||||
auto base_name = old_symbol_name.substr(0, old_symbol_name.find('.'));
|
||||
if (base_name == name) {
|
||||
auto possibleUpgrader = findUpgrader(entry.second, version.value());
|
||||
if (possibleUpgrader.has_value()) {
|
||||
possibleSchemas.push_back(possibleUpgrader.value().old_schema);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return possibleSchemas;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -30,5 +30,13 @@ TORCH_API bool isOpSymbolCurrent(
|
||||
const std::string& name,
|
||||
size_t current_version);
|
||||
|
||||
// Returns the possible old schemas for the operator that
|
||||
// doesn't exist anymore. This can be true for deprecated
|
||||
// operators. Since name is always a symbol name, there
|
||||
// can be multiple schemas for different overloads.
|
||||
TORCH_API std::vector<std::string> loadPossibleHistoricOps(
|
||||
const std::string& name,
|
||||
c10::optional<size_t> version);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -28,20 +28,28 @@ struct OldOpsReplacerWithUpgraders {
|
||||
DepthFirstGraphNodeIterator graph_it(graph_);
|
||||
Node* node = graph_it.next();
|
||||
while (node) {
|
||||
if (auto schema = node->maybeSchema()) {
|
||||
auto schema_name = getFullSchemaName(*schema);
|
||||
// load the schema name for this op
|
||||
c10::optional<std::string> schema_name = c10::nullopt;
|
||||
if (auto op_schema = node->maybeSchema()) {
|
||||
schema_name = getFullSchemaName(*op_schema);
|
||||
} else {
|
||||
schema_name = node->getHistoricSchemaName();
|
||||
}
|
||||
|
||||
if (schema_name.has_value()) {
|
||||
// this implies there was a version bump because of this operator
|
||||
auto version_entry = get_operator_version_map().find(schema_name);
|
||||
auto version_entry =
|
||||
get_operator_version_map().find(schema_name.value());
|
||||
if (version_entry != get_operator_version_map().end()) {
|
||||
const auto& entry = version_entry->second;
|
||||
auto upgrader_entry =
|
||||
findUpgrader(version_entry->second, current_version);
|
||||
if (!upgrader_entry.has_value()) {
|
||||
if (!isOpSymbolCurrent(schema_name, current_version)) {
|
||||
if (!isOpSymbolCurrent(schema_name.value(), current_version)) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
"Upgrader must be present for ",
|
||||
schema_name,
|
||||
schema_name.value(),
|
||||
". The upgrader might have deprecated");
|
||||
}
|
||||
node = graph_it.next();
|
||||
|
Reference in New Issue
Block a user