Move other div variants to upgraders map

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73586

Approved by: https://github.com/gmagogsfm
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2022-05-16 10:44:45 -07:00
committed by PyTorch MergeBot
parent 20ba6e6935
commit 31d9f7c303
7 changed files with 131 additions and 7 deletions

View File

@ -133,11 +133,53 @@ class TestUpgraders(JitTestCase):
traced_func = torch.jit.trace(test_func, ())
buffer = io.BytesIO()
torch.jit.save(traced_func, buffer)
current_flag_value = torch._C._get_version_calculator_flag()
# calculate based on old version
torch._C._calculate_package_version_based_on_upgraders(False)
buffer.seek(0)
loaded_func = torch.jit.load(buffer)
version = self._load_model_version(loaded_func)
self.assertTrue(version == 4)
# calculate based on new version
torch._C._calculate_package_version_based_on_upgraders(True)
buffer.seek(0)
loaded_func = torch.jit.load(buffer)
version = self._load_model_version(loaded_func)
self.assertTrue(version == 4)
# make sure we preserve old behaviou
torch._C._calculate_package_version_based_on_upgraders(current_flag_value)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_full_other_variants(self):
def test_func():
a = torch.full([4, 5, 6], 4, names=["a", "b", "c"], dtype=torch.int64)
return a
scripted_func = torch.jit.script(test_func)
buffer = io.BytesIO()
torch.jit.save(scripted_func, buffer)
current_flag_value = torch._C._get_version_calculator_flag()
# calculate based on old version
torch._C._calculate_package_version_based_on_upgraders(False)
buffer.seek(0)
loaded_func = torch.jit.load(buffer)
version = self._load_model_version(loaded_func)
self.assertTrue(version == 5)
# calculate based on new version
torch._C._calculate_package_version_based_on_upgraders(True)
buffer.seek(0)
loaded_func = torch.jit.load(buffer)
version = self._load_model_version(loaded_func)
self.assertTrue(version == 5)
# make sure we preserve old behaviou
torch._C._calculate_package_version_based_on_upgraders(current_flag_value)
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_linspace(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"

View File

@ -381,6 +381,10 @@ def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ...
def _generate_upgraders_graph() -> Dict[str, Graph]: ...
def _calculate_package_version_based_on_upgraders(val: _bool): ...
def _get_version_calculator_flag() -> _bool: ...
def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ...
def _jit_script_compile_overload(
qualname: str,

View File

@ -46,30 +46,56 @@ def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point()):
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
)SCRIPT"},
{"div_Tensor_mode_0_3", R"SCRIPT(
def div_Tensor_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None) -> Tensor:
return self.divide(other, rounding_mode=rounding_mode)
)SCRIPT"},
{"div_Scalar_0_3", R"SCRIPT(
def div_Scalar_0_3(self: Tensor, other: number) -> Tensor:
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
)SCRIPT"},
{"div_Scalar_mode_0_3", R"SCRIPT(
def div_Scalar_mode_0_3(self: Tensor, other: number, *, rounding_mode: Optional[str]=None) -> Tensor:
return self.divide(other, rounding_mode=rounding_mode)
)SCRIPT"},
{"div_out_0_3", R"SCRIPT(
def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()):
return self.true_divide(other, out=out)
return self.divide(other, rounding_mode='trunc', out=out)
)SCRIPT"},
{"div_out_mode_0_3", R"SCRIPT(
def div_out_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None, out: Tensor) -> Tensor:
return self.divide(other, rounding_mode=rounding_mode, out=out)
)SCRIPT"},
{"div__Tensor_0_3", R"SCRIPT(
def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
if (self.is_floating_point() or other.is_floating_point()):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
)SCRIPT"},
{"div__Tensor_mode_0_3", R"SCRIPT(
def div__Tensor_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None) -> Tensor:
return self.divide_(other, rounding_mode=rounding_mode)
)SCRIPT"},
{"div__Scalar_0_3", R"SCRIPT(
def div__Scalar_0_3(self: Tensor, other: number) -> Tensor:
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
)SCRIPT"},
{"div__Scalar_mode_0_3", R"SCRIPT(
def div__Scalar_mode_0_3(self: Tensor, other: number, *, rounding_mode: Optional[str]=None) -> Tensor:
return self.divide_(other, rounding_mode=rounding_mode)
)SCRIPT"},
{"full_names_0_4", R"SCRIPT(
def full_names_0_4(size:List[int], fill_value:number, *, names:Optional[List[str]]=None,
dtype:Optional[int]=None, layout:Optional[int]=None, device:Optional[Device]=None,
pin_memory:Optional[bool]=None) -> Tensor:
return torch.full(size, fill_value, names=names, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
)SCRIPT"},
{"full_0_4", R"SCRIPT(
def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,

View File

@ -36,26 +36,50 @@ static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersi
{{4,
"div_Tensor_0_3",
"aten::div.Tensor(Tensor self, Tensor other) -> Tensor"}}},
{"aten::div.Tensor_mode",
{{4,
"div_Tensor_mode_0_3",
"aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"}}},
{"aten::div.Scalar",
{{4,
"div_Scalar_0_3",
"aten::div.Scalar(Tensor self, Scalar other) -> Tensor"}}},
{"aten::div.Scalar_mode",
{{4,
"div_Scalar_mode_0_3",
"aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor"}}},
{"aten::div.out",
{{4,
"div_out_0_3",
"aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"}}},
{"aten::div.out_mode",
{{4,
"div_out_mode_0_3",
"aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)"}}},
{"aten::div_.Tensor",
{{4,
"div__Tensor_0_3",
"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
{"aten::div_.Tensor_mode",
{{4,
"div__Tensor_mode_0_3",
"aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)"}}},
{"aten::div_.Scalar",
{{4,
"div__Scalar_0_3",
"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
{"aten::div_.Scalar_mode",
{{4,
"div__Scalar_mode_0_3",
"aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)"}}},
{"aten::full",
{{5,
"full_0_4",
"aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
{"aten::full.names",
{{5,
"full_names_0_4",
"aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
{"aten::full.out",
{{5,
"full_out_0_4",
@ -96,5 +120,15 @@ void test_only_reset_flag() {
isVersionMapSorted = false;
}
static bool calculatePackageVersionBasedOnUpgraders = false;
void calculate_package_version_based_on_upgraders(bool val) {
calculatePackageVersionBasedOnUpgraders = val;
}
bool get_version_calculator_flag() {
return calculatePackageVersionBasedOnUpgraders;
}
} // namespace jit
} // namespace torch

View File

@ -13,6 +13,13 @@ struct UpgraderEntry {
std::string old_schema;
};
// Toggle the behaviour of calculating version for the module.
// If this is true, we calculate solely based on upgraders
// If this is false, we calculate it based on historic per op version map
TORCH_API void calculate_package_version_based_on_upgraders(bool val);
TORCH_API bool get_version_calculator_flag();
TORCH_API const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
get_operator_version_map();

View File

@ -1570,6 +1570,10 @@ void initJitScriptBindings(PyObject* module) {
})
.def_property_readonly("owner", &Method::owner);
m.def("_generate_upgraders_graph", &generate_upgraders_graph);
m.def(
"_calculate_package_version_based_on_upgraders",
&calculate_package_version_based_on_upgraders);
m.def("_get_version_calculator_flag", &get_version_calculator_flag);
m.def(
"_compile_graph_to_code_table",
[](const std::string& name, const std::shared_ptr<Graph>& graph) {

View File

@ -753,15 +753,22 @@ struct PythonPrintImpl {
if (version_entry != get_operator_version_map().end()) {
const auto& entry = version_entry->second;
// TODO (tugsuu) move this calculation into a seperate step.
min_version_ = std::max(
min_version_, uint64_t(entry[entry.size() - 1].bumped_at_version));
uint64_t current_version = entry[entry.size() - 1].bumped_at_version;
uint64_t legacy_version_map_version =
get_min_version_for_kind(node->kind());
// True means we solely calculate based on upgrader version
if (get_version_calculator_flag()) {
min_version_ = std::max(min_version_, current_version);
} else {
if (legacy_version_map_version != 0) {
min_version_ = std::max(min_version_, legacy_version_map_version);
} else {
min_version_ = std::max(min_version_, current_version);
}
}
}
}
// We want to manually bump the minimum versions for
// other variants of aten::div and aten::full which
// are not covered by the new upgraders
min_version_ =
std::max(min_version_, get_min_version_for_kind(node->kind()));
#else
min_version_ =
std::max(min_version_, get_min_version_for_kind(node->kind()));