mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Fix check_training_mode
in symbolic_helper (#78376)
`check_training_mode` always warned that an op is set to training because it was comparing an int `op_train_mode` with an Enum `GLOBALS.training_mode`. This PR fixes the behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78376 Approved by: https://github.com/garymm
This commit is contained in:
committed by
PyTorch MergeBot
parent
dfd78bf4ab
commit
299fbbccec
71
test/onnx/test_symbolic_helper.py
Normal file
71
test/onnx/test_symbolic_helper.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
"""Unit tests on `torch.onnx.symbolic_helper`."""
|
||||
|
||||
import torch
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class TestHelperFunctions(common_utils.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._initial_training_mode = GLOBALS.training_mode
|
||||
|
||||
def tearDown(self):
|
||||
GLOBALS.training_mode = self._initial_training_mode
|
||||
|
||||
@common_utils.parametrize(
|
||||
"op_train_mode,export_mode",
|
||||
[
|
||||
common_utils.subtest(
|
||||
[1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve"
|
||||
),
|
||||
common_utils.subtest(
|
||||
[0, torch.onnx.TrainingMode.EVAL],
|
||||
name="modes_match_op_train_mode_0_export_mode_eval",
|
||||
),
|
||||
common_utils.subtest(
|
||||
[1, torch.onnx.TrainingMode.TRAINING],
|
||||
name="modes_match_op_train_mode_1_export_mode_training",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_check_training_mode_does_not_warn_when(
|
||||
self, op_train_mode: int, export_mode: torch.onnx.TrainingMode
|
||||
):
|
||||
GLOBALS.training_mode = export_mode
|
||||
self.assertNotWarn(
|
||||
lambda: symbolic_helper.check_training_mode(op_train_mode, "testop")
|
||||
)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"op_train_mode,export_mode",
|
||||
[
|
||||
common_utils.subtest(
|
||||
[0, torch.onnx.TrainingMode.TRAINING],
|
||||
name="modes_do_not_match_op_train_mode_0_export_mode_training",
|
||||
),
|
||||
common_utils.subtest(
|
||||
[1, torch.onnx.TrainingMode.EVAL],
|
||||
name="modes_do_not_match_op_train_mode_1_export_mode_eval",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_check_training_mode_warns_when(
|
||||
self,
|
||||
op_train_mode: int,
|
||||
export_mode: torch.onnx.TrainingMode,
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, f"ONNX export mode is set to {export_mode}"
|
||||
):
|
||||
GLOBALS.training_mode = export_mode
|
||||
symbolic_helper.check_training_mode(op_train_mode, "testop")
|
||||
|
||||
|
||||
common_utils.instantiate_parametrized_tests(TestHelperFunctions)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
@ -1114,27 +1114,29 @@ def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, na
|
||||
return padding
|
||||
|
||||
|
||||
def check_training_mode(op_train_mode, op_name):
|
||||
op_train_mode = True if op_train_mode == 1 else False
|
||||
if GLOBALS.training_mode is not None and op_train_mode != GLOBALS.training_mode:
|
||||
op_mode = "training " if op_train_mode else "inference"
|
||||
training_mode = "training " if GLOBALS.training_mode else "inference"
|
||||
# setting the model mode could result in op_mode != _flags.training_mode
|
||||
# if the model is a FuncModule. In this case we warn the user of
|
||||
# the state and export depending on op_mode
|
||||
# This is to support use-cases of fixing certain layer weights
|
||||
# in training.
|
||||
warnings.warn(
|
||||
"ONNX export mode is set to "
|
||||
+ training_mode
|
||||
+ " mode, but operator "
|
||||
+ op_name
|
||||
+ " is set to "
|
||||
+ op_mode
|
||||
+ " mode. The operators will be exported in "
|
||||
+ op_mode
|
||||
+ ", as specified by the functional operator."
|
||||
)
|
||||
def check_training_mode(op_train_mode: int, op_name: str) -> None:
|
||||
"""Warns the user if the model's training mode and the export mode do not agree."""
|
||||
if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE:
|
||||
return
|
||||
|
||||
if op_train_mode:
|
||||
op_mode_enum = _C_onnx.TrainingMode.TRAINING
|
||||
else:
|
||||
op_mode_enum = _C_onnx.TrainingMode.EVAL
|
||||
if op_mode_enum == GLOBALS.training_mode:
|
||||
# The modes agree. Do nothing
|
||||
return
|
||||
|
||||
op_mode_text = f"train={bool(op_train_mode)}"
|
||||
# Setting the model mode could result in op_mode != GLOBALS.training_mode
|
||||
# if the model is a FuncModule. In this case we warn the user of
|
||||
# the state and export depending on op_mode
|
||||
# This is to support use-cases of fixing certain layer weights
|
||||
# in training.
|
||||
warnings.warn(
|
||||
f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' "
|
||||
f"is set to {op_mode_text}. Exporting with {op_mode_text}."
|
||||
)
|
||||
|
||||
|
||||
def _flatten_helper(g, input, start_dim, end_dim, dim):
|
||||
|
Reference in New Issue
Block a user