mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move other div variants to upgraders map
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73586 Approved by: https://github.com/gmagogsfm
This commit is contained in:
committed by
PyTorch MergeBot
parent
20ba6e6935
commit
31d9f7c303
@ -133,11 +133,53 @@ class TestUpgraders(JitTestCase):
|
||||
traced_func = torch.jit.trace(test_func, ())
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(traced_func, buffer)
|
||||
|
||||
current_flag_value = torch._C._get_version_calculator_flag()
|
||||
# calculate based on old version
|
||||
torch._C._calculate_package_version_based_on_upgraders(False)
|
||||
buffer.seek(0)
|
||||
loaded_func = torch.jit.load(buffer)
|
||||
version = self._load_model_version(loaded_func)
|
||||
self.assertTrue(version == 4)
|
||||
|
||||
# calculate based on new version
|
||||
torch._C._calculate_package_version_based_on_upgraders(True)
|
||||
buffer.seek(0)
|
||||
loaded_func = torch.jit.load(buffer)
|
||||
version = self._load_model_version(loaded_func)
|
||||
self.assertTrue(version == 4)
|
||||
|
||||
# make sure we preserve old behaviou
|
||||
torch._C._calculate_package_version_based_on_upgraders(current_flag_value)
|
||||
|
||||
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||
def test_aten_full_other_variants(self):
|
||||
def test_func():
|
||||
a = torch.full([4, 5, 6], 4, names=["a", "b", "c"], dtype=torch.int64)
|
||||
return a
|
||||
|
||||
scripted_func = torch.jit.script(test_func)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_func, buffer)
|
||||
|
||||
current_flag_value = torch._C._get_version_calculator_flag()
|
||||
# calculate based on old version
|
||||
torch._C._calculate_package_version_based_on_upgraders(False)
|
||||
buffer.seek(0)
|
||||
loaded_func = torch.jit.load(buffer)
|
||||
version = self._load_model_version(loaded_func)
|
||||
self.assertTrue(version == 5)
|
||||
|
||||
# calculate based on new version
|
||||
torch._C._calculate_package_version_based_on_upgraders(True)
|
||||
buffer.seek(0)
|
||||
loaded_func = torch.jit.load(buffer)
|
||||
version = self._load_model_version(loaded_func)
|
||||
self.assertTrue(version == 5)
|
||||
|
||||
# make sure we preserve old behaviou
|
||||
torch._C._calculate_package_version_based_on_upgraders(current_flag_value)
|
||||
|
||||
@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
|
||||
def test_aten_linspace(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
|
||||
|
@ -381,6 +381,10 @@ def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ...
|
||||
|
||||
def _generate_upgraders_graph() -> Dict[str, Graph]: ...
|
||||
|
||||
def _calculate_package_version_based_on_upgraders(val: _bool): ...
|
||||
|
||||
def _get_version_calculator_flag() -> _bool: ...
|
||||
|
||||
def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ...
|
||||
def _jit_script_compile_overload(
|
||||
qualname: str,
|
||||
|
@ -46,30 +46,56 @@ def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
|
||||
if (self.is_floating_point() or other.is_floating_point()):
|
||||
return self.true_divide(other)
|
||||
return self.divide(other, rounding_mode='trunc')
|
||||
)SCRIPT"},
|
||||
{"div_Tensor_mode_0_3", R"SCRIPT(
|
||||
def div_Tensor_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None) -> Tensor:
|
||||
return self.divide(other, rounding_mode=rounding_mode)
|
||||
)SCRIPT"},
|
||||
{"div_Scalar_0_3", R"SCRIPT(
|
||||
def div_Scalar_0_3(self: Tensor, other: number) -> Tensor:
|
||||
if (self.is_floating_point() or isinstance(other, float)):
|
||||
return self.true_divide(other)
|
||||
return self.divide(other, rounding_mode='trunc')
|
||||
)SCRIPT"},
|
||||
{"div_Scalar_mode_0_3", R"SCRIPT(
|
||||
def div_Scalar_mode_0_3(self: Tensor, other: number, *, rounding_mode: Optional[str]=None) -> Tensor:
|
||||
return self.divide(other, rounding_mode=rounding_mode)
|
||||
)SCRIPT"},
|
||||
{"div_out_0_3", R"SCRIPT(
|
||||
def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor:
|
||||
if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()):
|
||||
return self.true_divide(other, out=out)
|
||||
return self.divide(other, rounding_mode='trunc', out=out)
|
||||
)SCRIPT"},
|
||||
{"div_out_mode_0_3", R"SCRIPT(
|
||||
def div_out_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None, out: Tensor) -> Tensor:
|
||||
return self.divide(other, rounding_mode=rounding_mode, out=out)
|
||||
)SCRIPT"},
|
||||
{"div__Tensor_0_3", R"SCRIPT(
|
||||
def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
|
||||
if (self.is_floating_point() or other.is_floating_point()):
|
||||
return self.true_divide_(other)
|
||||
return self.divide_(other, rounding_mode='trunc')
|
||||
)SCRIPT"},
|
||||
{"div__Tensor_mode_0_3", R"SCRIPT(
|
||||
def div__Tensor_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None) -> Tensor:
|
||||
return self.divide_(other, rounding_mode=rounding_mode)
|
||||
)SCRIPT"},
|
||||
{"div__Scalar_0_3", R"SCRIPT(
|
||||
def div__Scalar_0_3(self: Tensor, other: number) -> Tensor:
|
||||
if (self.is_floating_point() or isinstance(other, float)):
|
||||
return self.true_divide_(other)
|
||||
return self.divide_(other, rounding_mode='trunc')
|
||||
)SCRIPT"},
|
||||
{"div__Scalar_mode_0_3", R"SCRIPT(
|
||||
def div__Scalar_mode_0_3(self: Tensor, other: number, *, rounding_mode: Optional[str]=None) -> Tensor:
|
||||
return self.divide_(other, rounding_mode=rounding_mode)
|
||||
)SCRIPT"},
|
||||
{"full_names_0_4", R"SCRIPT(
|
||||
def full_names_0_4(size:List[int], fill_value:number, *, names:Optional[List[str]]=None,
|
||||
dtype:Optional[int]=None, layout:Optional[int]=None, device:Optional[Device]=None,
|
||||
pin_memory:Optional[bool]=None) -> Tensor:
|
||||
return torch.full(size, fill_value, names=names, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
|
||||
)SCRIPT"},
|
||||
{"full_0_4", R"SCRIPT(
|
||||
def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
|
||||
|
@ -36,26 +36,50 @@ static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersi
|
||||
{{4,
|
||||
"div_Tensor_0_3",
|
||||
"aten::div.Tensor(Tensor self, Tensor other) -> Tensor"}}},
|
||||
{"aten::div.Tensor_mode",
|
||||
{{4,
|
||||
"div_Tensor_mode_0_3",
|
||||
"aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"}}},
|
||||
{"aten::div.Scalar",
|
||||
{{4,
|
||||
"div_Scalar_0_3",
|
||||
"aten::div.Scalar(Tensor self, Scalar other) -> Tensor"}}},
|
||||
{"aten::div.Scalar_mode",
|
||||
{{4,
|
||||
"div_Scalar_mode_0_3",
|
||||
"aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor"}}},
|
||||
{"aten::div.out",
|
||||
{{4,
|
||||
"div_out_0_3",
|
||||
"aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"}}},
|
||||
{"aten::div.out_mode",
|
||||
{{4,
|
||||
"div_out_mode_0_3",
|
||||
"aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)"}}},
|
||||
{"aten::div_.Tensor",
|
||||
{{4,
|
||||
"div__Tensor_0_3",
|
||||
"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
|
||||
{"aten::div_.Tensor_mode",
|
||||
{{4,
|
||||
"div__Tensor_mode_0_3",
|
||||
"aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)"}}},
|
||||
{"aten::div_.Scalar",
|
||||
{{4,
|
||||
"div__Scalar_0_3",
|
||||
"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
|
||||
{"aten::div_.Scalar_mode",
|
||||
{{4,
|
||||
"div__Scalar_mode_0_3",
|
||||
"aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)"}}},
|
||||
{"aten::full",
|
||||
{{5,
|
||||
"full_0_4",
|
||||
"aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
|
||||
{"aten::full.names",
|
||||
{{5,
|
||||
"full_names_0_4",
|
||||
"aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
|
||||
{"aten::full.out",
|
||||
{{5,
|
||||
"full_out_0_4",
|
||||
@ -96,5 +120,15 @@ void test_only_reset_flag() {
|
||||
isVersionMapSorted = false;
|
||||
}
|
||||
|
||||
static bool calculatePackageVersionBasedOnUpgraders = false;
|
||||
|
||||
void calculate_package_version_based_on_upgraders(bool val) {
|
||||
calculatePackageVersionBasedOnUpgraders = val;
|
||||
}
|
||||
|
||||
bool get_version_calculator_flag() {
|
||||
return calculatePackageVersionBasedOnUpgraders;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -13,6 +13,13 @@ struct UpgraderEntry {
|
||||
std::string old_schema;
|
||||
};
|
||||
|
||||
// Toggle the behaviour of calculating version for the module.
|
||||
// If this is true, we calculate solely based on upgraders
|
||||
// If this is false, we calculate it based on historic per op version map
|
||||
TORCH_API void calculate_package_version_based_on_upgraders(bool val);
|
||||
|
||||
TORCH_API bool get_version_calculator_flag();
|
||||
|
||||
TORCH_API const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
|
||||
get_operator_version_map();
|
||||
|
||||
|
@ -1570,6 +1570,10 @@ void initJitScriptBindings(PyObject* module) {
|
||||
})
|
||||
.def_property_readonly("owner", &Method::owner);
|
||||
m.def("_generate_upgraders_graph", &generate_upgraders_graph);
|
||||
m.def(
|
||||
"_calculate_package_version_based_on_upgraders",
|
||||
&calculate_package_version_based_on_upgraders);
|
||||
m.def("_get_version_calculator_flag", &get_version_calculator_flag);
|
||||
m.def(
|
||||
"_compile_graph_to_code_table",
|
||||
[](const std::string& name, const std::shared_ptr<Graph>& graph) {
|
||||
|
@ -753,15 +753,22 @@ struct PythonPrintImpl {
|
||||
if (version_entry != get_operator_version_map().end()) {
|
||||
const auto& entry = version_entry->second;
|
||||
// TODO (tugsuu) move this calculation into a seperate step.
|
||||
min_version_ = std::max(
|
||||
min_version_, uint64_t(entry[entry.size() - 1].bumped_at_version));
|
||||
uint64_t current_version = entry[entry.size() - 1].bumped_at_version;
|
||||
uint64_t legacy_version_map_version =
|
||||
get_min_version_for_kind(node->kind());
|
||||
|
||||
// True means we solely calculate based on upgrader version
|
||||
if (get_version_calculator_flag()) {
|
||||
min_version_ = std::max(min_version_, current_version);
|
||||
} else {
|
||||
if (legacy_version_map_version != 0) {
|
||||
min_version_ = std::max(min_version_, legacy_version_map_version);
|
||||
} else {
|
||||
min_version_ = std::max(min_version_, current_version);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// We want to manually bump the minimum versions for
|
||||
// other variants of aten::div and aten::full which
|
||||
// are not covered by the new upgraders
|
||||
min_version_ =
|
||||
std::max(min_version_, get_min_version_for_kind(node->kind()));
|
||||
#else
|
||||
min_version_ =
|
||||
std::max(min_version_, get_min_version_for_kind(node->kind()));
|
||||
|
Reference in New Issue
Block a user