mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update the operator version check logic when generating models for testing upgraders (#71894)
Summary: The model generation script will check the model version, to ensure the developer run the script before they change operator Previously, the version use the old model version. However, it's hard for developer to know the old version number. In this change, it use the current max operator version to check. It's less strict, but more developer friendly Pull Request resolved: https://github.com/pytorch/pytorch/pull/71894 ghstack-source-id: 147769215 Test Plan: first time run: ``` chenlai@devvm5615:~/fbsource/fbcode(b82243650)$ buck run mode/opt //caffe2/torch/fb/mobile/upgrader_codegen:upgrader_test_models_gen Parsing buck files: finished in 0.7 sec Downloaded 0/2 artifacts, 0.00 bytes, 100.0% cache miss (for updated rules) Building: finished in 21.6 sec (100%) 11547/11547 jobs, 2/11547 updated Total time: 22.4 sec BUILD SUCCEEDED TestVersionedDivTensorExampleV7() aten::div.Tensor INFO:test.jit.fixtures_srcs.generate_models:Processing TestVersionedDivTensorExampleV7 INFO:test.jit.fixtures_srcs.generate_models:Generating model test_versioned_div_tensor_example_v7 and it's save to /data/users/chenlai/fbsource/fbcode/caffe2/test/jit/fixtures/test_versioned_div_tensor_example_v7.ptl chenlai@devvm5615:~/fbsource/fbcode(b82243650)$ ``` second time run: ``` chenlai@devvm5615:~/fbsource/fbcode(b82243650)$ rm caffe2/test/jit/fixtures/test_versioned_div_tensor_example_v4.ptl chenlai@devvm5615:~/fbsource/fbcode(b82243650)$ buck run mode/opt //caffe2/torch/fb/mobile/upgrader_codegen:upgrader_test_models_gen Action graph will be rebuilt because files have been added or removed. Parsing buck files: finished in 2.0 sec Building... 17.4 sec (99%) 9289/9290 jobs, 0/9290 updated TestVersionedDivTensorExampleV7() aten::div.Tensor INFO:test.jit.fixtures_srcs.generate_models:Processing TestVersionedDivTensorExampleV7 INFO:test.jit.fixtures_srcs.generate_models:Model test_versioned_div_tensor_example_v7 already exists, skipping chenlai@devvm5615:~/fbsource/fbcode(b82243650)$ jf s ``` Reviewed By: tugsbayasgalan Differential Revision: D33804737 fbshipit-source-id: 7424b81a700703bdf896ec606c2dac8df6dbf8a6 (cherry picked from commit 44b4e37d30077a3160b8a92209af339a6f2fc885)
This commit is contained in:
committed by
PyTorch MergeBot
parent
0cae3c0481
commit
e755a4f124
@ -64,7 +64,7 @@ is working as expected. In `test/jit/fixtures_srcs/generate_models.py`, add the
|
||||
it's corresponding changed operator like following
|
||||
```
|
||||
ALL_MODULES = {
|
||||
TestVersionedDivTensorExampleV4(): "aten::div.Tensor",
|
||||
TestVersionedDivTensorExampleV7(): "aten::div.Tensor",
|
||||
}
|
||||
```
|
||||
This module should includes the changed operator. If the operator isn't covered in the model,
|
||||
@ -89,7 +89,7 @@ key: test module
|
||||
value: changed operator
|
||||
"""
|
||||
ALL_MODULES = {
|
||||
TestVersionedDivTensorExampleV4(): "aten::div.Tensor",
|
||||
TestVersionedDivTensorExampleV7(): "aten::div.Tensor",
|
||||
}
|
||||
|
||||
"""
|
||||
@ -157,9 +157,9 @@ def generate_models(model_directory_path: Path):
|
||||
for a_module, expect_operator in ALL_MODULES.items():
|
||||
print(a_module, expect_operator)
|
||||
script_module = torch.jit.script(a_module)
|
||||
model_version = get_output_model_version(script_module)
|
||||
actual_model_version = get_output_model_version(script_module)
|
||||
|
||||
# For example: TestVersionedDivTensorExampleV4
|
||||
# For example: TestVersionedDivTensorExampleV7
|
||||
torch_module_name = type(a_module).__name__
|
||||
|
||||
# The corresponding model name is: test_versioned_div_tensor_example_v4
|
||||
@ -172,12 +172,11 @@ def generate_models(model_directory_path: Path):
|
||||
logger.info(f"Model {model_name} already exists, skipping")
|
||||
continue
|
||||
|
||||
actual_model_version = "v" + str(model_version)
|
||||
expect_model_version = model_name.split("_")[-1]
|
||||
if actual_model_version != expect_model_version:
|
||||
current_operator_version = torch._C._get_max_operator_version()
|
||||
if actual_model_version >= current_operator_version + 1:
|
||||
logger.error(
|
||||
f"Actual model version {actual_model_version} "
|
||||
f"doesn't match the expect model version {expect_model_version}. "
|
||||
f"is equal or larger than {current_operator_version} + 1. "
|
||||
f"Please run the script before the commit to change operator.")
|
||||
continue
|
||||
|
||||
|
Reference in New Issue
Block a user