[ONNX] Fix check_training_mode in symbolic_helper (#78376)

`check_training_mode` always warned that an op is set to training because it was comparing an int `op_train_mode` with an Enum `GLOBALS.training_mode`. This PR fixes the behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78376
Approved by: https://github.com/garymm
This commit is contained in:
Justin Chu
2022-05-27 00:38:16 +00:00
committed by PyTorch MergeBot
parent dfd78bf4ab
commit 299fbbccec
2 changed files with 94 additions and 21 deletions

View File

@ -0,0 +1,71 @@
# Owner(s): ["module: onnx"]
"""Unit tests on `torch.onnx.symbolic_helper`."""
import torch
from torch.onnx import symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.testing._internal import common_utils
class TestHelperFunctions(common_utils.TestCase):
def setUp(self):
super().setUp()
self._initial_training_mode = GLOBALS.training_mode
def tearDown(self):
GLOBALS.training_mode = self._initial_training_mode
@common_utils.parametrize(
"op_train_mode,export_mode",
[
common_utils.subtest(
[1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve"
),
common_utils.subtest(
[0, torch.onnx.TrainingMode.EVAL],
name="modes_match_op_train_mode_0_export_mode_eval",
),
common_utils.subtest(
[1, torch.onnx.TrainingMode.TRAINING],
name="modes_match_op_train_mode_1_export_mode_training",
),
],
)
def test_check_training_mode_does_not_warn_when(
self, op_train_mode: int, export_mode: torch.onnx.TrainingMode
):
GLOBALS.training_mode = export_mode
self.assertNotWarn(
lambda: symbolic_helper.check_training_mode(op_train_mode, "testop")
)
@common_utils.parametrize(
"op_train_mode,export_mode",
[
common_utils.subtest(
[0, torch.onnx.TrainingMode.TRAINING],
name="modes_do_not_match_op_train_mode_0_export_mode_training",
),
common_utils.subtest(
[1, torch.onnx.TrainingMode.EVAL],
name="modes_do_not_match_op_train_mode_1_export_mode_eval",
),
],
)
def test_check_training_mode_warns_when(
self,
op_train_mode: int,
export_mode: torch.onnx.TrainingMode,
):
with self.assertWarnsRegex(
UserWarning, f"ONNX export mode is set to {export_mode}"
):
GLOBALS.training_mode = export_mode
symbolic_helper.check_training_mode(op_train_mode, "testop")
common_utils.instantiate_parametrized_tests(TestHelperFunctions)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -1114,27 +1114,29 @@ def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, na
return padding
def check_training_mode(op_train_mode, op_name):
op_train_mode = True if op_train_mode == 1 else False
if GLOBALS.training_mode is not None and op_train_mode != GLOBALS.training_mode:
op_mode = "training " if op_train_mode else "inference"
training_mode = "training " if GLOBALS.training_mode else "inference"
# setting the model mode could result in op_mode != _flags.training_mode
# if the model is a FuncModule. In this case we warn the user of
# the state and export depending on op_mode
# This is to support use-cases of fixing certain layer weights
# in training.
warnings.warn(
"ONNX export mode is set to "
+ training_mode
+ " mode, but operator "
+ op_name
+ " is set to "
+ op_mode
+ " mode. The operators will be exported in "
+ op_mode
+ ", as specified by the functional operator."
)
def check_training_mode(op_train_mode: int, op_name: str) -> None:
"""Warns the user if the model's training mode and the export mode do not agree."""
if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE:
return
if op_train_mode:
op_mode_enum = _C_onnx.TrainingMode.TRAINING
else:
op_mode_enum = _C_onnx.TrainingMode.EVAL
if op_mode_enum == GLOBALS.training_mode:
# The modes agree. Do nothing
return
op_mode_text = f"train={bool(op_train_mode)}"
# Setting the model mode could result in op_mode != GLOBALS.training_mode
# if the model is a FuncModule. In this case we warn the user of
# the state and export depending on op_mode
# This is to support use-cases of fixing certain layer weights
# in training.
warnings.warn(
f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' "
f"is set to {op_mode_text}. Exporting with {op_mode_text}."
)
def _flatten_helper(g, input, start_dim, end_dim, dim):