Add option to load historic operators in IR when the operator is deprecated (#71148)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71148

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D33521300

Pulled By: tugsbayasgalan

fbshipit-source-id: a0607dba5e7233590384326537017eb0b18da419
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar
2022-01-12 11:05:35 -08:00
committed by Facebook GitHub Bot
parent 8f4cec2231
commit 70951884d4
6 changed files with 111 additions and 6 deletions

View File

@ -4,6 +4,8 @@
#include <test/cpp/jit/test_utils.h>
#include <vector>
namespace torch {
namespace jit {
@ -58,5 +60,40 @@ TEST(UpgraderUtils, FindIfOpIsCurrent) {
test_only_remove_entry("foo");
}
TEST(UpgraderUtils, CanLoadHistoricOp) {
std::vector<UpgraderEntry> dummy_entry = {
{4, "foo__0_3", "foo.bar()"},
{8, "foo__4_7", "foo.foo()"},
};
std::vector<std::string> schemas = {"foo.bar()", "foo.foo()"};
// symbol based look up
test_only_add_entry("old_op_not_exist.first", dummy_entry[0]);
test_only_add_entry("old_op_not_exist.second", dummy_entry[1]);
auto oldSchemas = loadPossibleHistoricOps("old_op_not_exist", 2);
EXPECT_EQ(oldSchemas.size(), 2);
for (const auto& entry : oldSchemas) {
EXPECT_TRUE(
std::find(schemas.begin(), schemas.end(), entry) != schemas.end());
}
auto oldSchemasWithCurrentVersion =
loadPossibleHistoricOps("old_op_not_exist", 9);
EXPECT_EQ(oldSchemasWithCurrentVersion.size(), 0);
test_only_remove_entry("old_op_not_exist.first");
test_only_remove_entry("old_op_not_exist.first");
// it is ok to have old schemas without overload
test_only_add_entry("old_op_not_exist_no_overload", dummy_entry[0]);
auto oldSchemasNoOverload =
loadPossibleHistoricOps("old_op_not_exist_no_overload", 2);
EXPECT_EQ(oldSchemasNoOverload.size(), 1);
EXPECT_EQ(oldSchemasNoOverload[0], "foo.bar()");
test_only_remove_entry("old_op_not_exist_no_overload");
}
} // namespace jit
} // namespace torch

View File

@ -615,6 +615,8 @@ static Value* emitBuiltinNode(
if (!version.has_value() ||
isOpSymbolCurrent(matched_schema.schema_name, version.value())) {
n->getOperation();
} else {
n->setHistoricSchemaName(matched_schema.schema_name);
}
return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
@ -678,6 +680,18 @@ Value* emitBuiltinCall(
schemas.push_back(&op->schema());
}
// we might have seen old historic
// ops that are deprecated
if (variants.empty()) {
auto oldSchemas =
loadPossibleHistoricOps(name.toQualString(), graph_version);
upgrader_schemas.reserve(oldSchemas.size());
for (const auto& old_schema_entry : oldSchemas) {
FunctionSchema old_schema = parseSchema(old_schema_entry);
upgrader_schemas.emplace_back(old_schema);
}
}
// TODO (tugsuu): make sure this is optimized later
for (const auto& schema : upgrader_schemas) {
schemas.push_back(&schema);
@ -710,7 +724,7 @@ Value* emitBuiltinCall(
auto matched = matchSchemas(schemas, loc, graph, args, kwargs, self);
if (matched.first < variants.size()) {
if (matched.first < variants.size() + upgrader_schemas.size()) {
return emitBuiltinNode(matched.second, loc, graph, name, graph_version);
} else {
auto& fn = *builtin_functions[matched.first - variants.size()];

View File

@ -338,6 +338,12 @@ struct TORCH_API Node {
topo_position_t topo_position_ = 0;
// a managing wrapper for Python to allow invalidation
std::shared_ptr<Wrap<Node>> wrap_;
// Stores the full schema name, if the operator is historic
// When the operator is deprecated or the name of the operator
// is changed, we need to rely on this name
// to retrieve old schemas to successfully apply upgraders
// for this operator.
c10::optional<std::string> historic_schema_name_ = c10::nullopt;
protected:
Node(Graph* graph_, NodeKind kind_); // defined after graph
@ -362,6 +368,14 @@ struct TORCH_API Node {
return wrap_;
}
const c10::optional<std::string> getHistoricSchemaName() {
return historic_schema_name_;
}
void setHistoricSchemaName(const std::string& name) {
historic_schema_name_ = name;
}
Node*& next() {
return next_in_graph[kNextDirection];
}

View File

@ -49,5 +49,29 @@ bool isOpSymbolCurrent(const std::string& name, size_t current_version) {
return true;
}
std::vector<std::string> loadPossibleHistoricOps(
const std::string& name,
c10::optional<size_t> version) {
std::vector<std::string> possibleSchemas;
if (!version.has_value()) {
return possibleSchemas;
}
for (const auto& entry : get_operator_version_map()) {
auto old_symbol_name = entry.first;
// strip off the overload name, if exist
auto base_name = old_symbol_name.substr(0, old_symbol_name.find('.'));
if (base_name == name) {
auto possibleUpgrader = findUpgrader(entry.second, version.value());
if (possibleUpgrader.has_value()) {
possibleSchemas.push_back(possibleUpgrader.value().old_schema);
}
}
}
return possibleSchemas;
}
} // namespace jit
} // namespace torch

View File

@ -30,5 +30,13 @@ TORCH_API bool isOpSymbolCurrent(
const std::string& name,
size_t current_version);
// Returns the possible old schemas for the operator that
// doesn't exist anymore. This can be true for deprecated
// operators. Since name is always a symbol name, there
// can be multiple schemas for different overloads.
TORCH_API std::vector<std::string> loadPossibleHistoricOps(
const std::string& name,
c10::optional<size_t> version);
} // namespace jit
} // namespace torch

View File

@ -28,20 +28,28 @@ struct OldOpsReplacerWithUpgraders {
DepthFirstGraphNodeIterator graph_it(graph_);
Node* node = graph_it.next();
while (node) {
if (auto schema = node->maybeSchema()) {
auto schema_name = getFullSchemaName(*schema);
// load the schema name for this op
c10::optional<std::string> schema_name = c10::nullopt;
if (auto op_schema = node->maybeSchema()) {
schema_name = getFullSchemaName(*op_schema);
} else {
schema_name = node->getHistoricSchemaName();
}
if (schema_name.has_value()) {
// this implies there was a version bump because of this operator
auto version_entry = get_operator_version_map().find(schema_name);
auto version_entry =
get_operator_version_map().find(schema_name.value());
if (version_entry != get_operator_version_map().end()) {
const auto& entry = version_entry->second;
auto upgrader_entry =
findUpgrader(version_entry->second, current_version);
if (!upgrader_entry.has_value()) {
if (!isOpSymbolCurrent(schema_name, current_version)) {
if (!isOpSymbolCurrent(schema_name.value(), current_version)) {
TORCH_INTERNAL_ASSERT(
false,
"Upgrader must be present for ",
schema_name,
schema_name.value(),
". The upgrader might have deprecated");
}
node = graph_it.next();