#include #include #include #include #include namespace torch { namespace jit { TEST(UpgraderUtils, FindCorrectUpgrader) { std::vector dummy_entry = { {4, "foo__0_3", "foo.bar()"}, {8, "foo__4_7", "foo.bar()"}, }; auto upgrader_at_6 = findUpgrader(dummy_entry, 6); EXPECT_TRUE(upgrader_at_6.has_value()); EXPECT_EQ(upgrader_at_6.value().upgrader_name, "foo__4_7"); auto upgrader_at_1 = findUpgrader(dummy_entry, 1); EXPECT_TRUE(upgrader_at_1.has_value()); EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3"); auto upgrader_at_10 = findUpgrader(dummy_entry, 10); EXPECT_TRUE(upgrader_at_1.has_value()); EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3"); } TEST(UpgraderUtils, IsVersionMapSorted) { auto map = get_operator_version_map(); // tests if the each list of UpgraderEntry in the map is sorted by // their bumped_at_version field. for (const auto& entry : map) { std::vector versions; for (const auto& el : entry.second) { versions.push_back(el.bumped_at_version); } EXPECT_TRUE(std::is_sorted(versions.begin(), versions.end())); } } TEST(UpgraderUtils, FindIfOpIsCurrent) { std::vector dummy_entry = { {4, "foo__0_3", "foo.bar()"}, {8, "foo__4_7", "foo.bar()"}, }; auto isCurrent = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 6); auto isCurrentV2 = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 8); EXPECT_FALSE(isCurrent); EXPECT_TRUE(isCurrentV2); // symbol based look up test_only_add_entry("foo", dummy_entry[0]); test_only_add_entry("foo", dummy_entry[1]); EXPECT_FALSE(isOpSymbolCurrent("foo", 6)); EXPECT_TRUE(isOpSymbolCurrent("foo", 8)); test_only_remove_entry("foo"); } TEST(UpgraderUtils, CanLoadHistoricOp) { std::vector dummy_entry = { {4, "foo__0_3", "foo.bar()"}, {8, "foo__4_7", "foo.foo()"}, }; std::vector schemas = {"foo.bar()", "foo.foo()"}; // symbol based look up test_only_add_entry("old_op_not_exist.first", dummy_entry[0]); test_only_add_entry("old_op_not_exist.second", dummy_entry[1]); auto oldSchemas = loadPossibleHistoricOps("old_op_not_exist", 2); EXPECT_EQ(oldSchemas.size(), 2); for (const auto& entry : oldSchemas) { EXPECT_TRUE( std::find(schemas.begin(), schemas.end(), entry) != schemas.end()); } auto oldSchemasWithCurrentVersion = loadPossibleHistoricOps("old_op_not_exist", 9); EXPECT_EQ(oldSchemasWithCurrentVersion.size(), 0); test_only_remove_entry("old_op_not_exist.first"); test_only_remove_entry("old_op_not_exist.first"); // it is ok to have old schemas without overload test_only_add_entry("old_op_not_exist_no_overload", dummy_entry[0]); auto oldSchemasNoOverload = loadPossibleHistoricOps("old_op_not_exist_no_overload", 2); EXPECT_EQ(oldSchemasNoOverload.size(), 1); EXPECT_EQ(oldSchemasNoOverload[0], "foo.bar()"); test_only_remove_entry("old_op_not_exist_no_overload"); } } // namespace jit } // namespace torch