[quant][pt2e] Move batch norm op between eval/train for cuda (#123957)

Summary: Before in `move_exported_model_to_train/eval`, we only
switched the CPU versions of the batch norm op. This commit adds
support for the cuda versions of the op too. Note that this fix
is temporary; we won't have to differentiate between these two
cases once we have batch norm consolidation.

Test Plan:
python test/test_quantization.py -k test_move_exported_model_bn

Reviewers: jerryzh168

Subscribers: jerryzh168, leslie-fang-intel, supriyar

Differential Revision: [D56070054](https://our.internmc.facebook.com/intern/diff/D56070054)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123957
Approved by: https://github.com/jerryzh168
This commit is contained in:
andrewor14
2024-04-23 14:13:01 -07:00
committed by PyTorch MergeBot
parent 64af899fdf
commit 4efb28c900
4 changed files with 81 additions and 45 deletions

View File

@ -1826,6 +1826,18 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def test_move_exported_model_dropout_inplace(self):
self._test_move_exported_model_dropout(inplace=True)
def _get_bn_train_eval_ops(self, is_cuda: bool):
if is_cuda:
return (
torch.ops.aten.cudnn_batch_norm.default,
torch.ops.aten.cudnn_batch_norm.default,
)
else:
return (
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten._native_batch_norm_legit_no_training.default,
)
def test_move_exported_model_bn(self):
"""
Test switching batch_norm behavior between train and eval modes using
@ -1840,12 +1852,18 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def forward(self, x):
return self.bn(x)
example_inputs = (torch.randn(1, 3, 3, 3),)
m = M().train()
is_cuda = torch.cuda.is_available()
if is_cuda:
m = M().train().cuda()
example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
else:
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops(is_cuda)
m = capture_pre_autograd_graph(m, example_inputs)
# Assert that batch norm op exists and is in train mode
bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
bn_node = self._get_node(m, bn_train_op)
self.assertTrue(bn_node is not None)
self.assertTrue(bn_node.args[5])
@ -1853,16 +1871,14 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
torch.ao.quantization.move_exported_model_to_eval(m)
# Assert that batch norm op is now in eval mode
bn_node = self._get_node(
m, torch.ops.aten._native_batch_norm_legit_no_training.default
)
bn_node = self._get_node(m, bn_eval_op)
self.assertTrue(bn_node is not None)
# Move to train
torch.ao.quantization.move_exported_model_to_train(m)
# Assert that batch norm op is now in train mode again
bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
bn_node = self._get_node(m, bn_train_op)
self.assertTrue(bn_node is not None)
self.assertTrue(bn_node.args[5])
@ -1908,22 +1924,25 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
x = self.dropout(x)
return x
example_inputs = (torch.randn(1, 3, 3, 3),)
m = M().train()
is_cuda = torch.cuda.is_available()
if is_cuda:
m = M().train().cuda()
example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
else:
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops(is_cuda)
m = capture_pre_autograd_graph(m, example_inputs)
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
targets = [n.target for n in m.graph.nodes]
bn_train_target = torch.ops.aten._native_batch_norm_legit.default
bn_eval_target = torch.ops.aten._native_batch_norm_legit_no_training.default
if train:
self.assertTrue(bn_train_target in targets)
self.assertTrue(bn_eval_target not in targets)
else:
self.assertTrue(bn_eval_target in targets)
self.assertTrue(bn_train_target not in targets)
bn_op = bn_train_op if train else bn_eval_op
bn_node = self._get_node(m, bn_op)
self.assertTrue(bn_node is not None)
if is_cuda:
self.assertEqual(bn_node.args[5], train)
dropout_node = self._get_node(m, torch.ops.aten.dropout.default)
self.assertTrue(dropout_node.args[2] == train)
self.assertEqual(dropout_node.args[2], train)
# Before wrapping: this is not OK
with self.assertRaises(NotImplementedError):

View File

@ -23,6 +23,7 @@ from torch.ao.quantization.qconfig import (
)
from torch.ao.quantization.stubs import DeQuantStub
from torch.ao.quantization.utils import (
_assert_and_get_unique_device,
activation_is_statically_quantized,
)
from torch.ao.quantization.observer import _is_activation_post_process
@ -222,26 +223,13 @@ def graph_module_from_producer_nodes(
graph_module = GraphModule(root, graph)
return graph_module
# TODO: delete
def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
"""
Returns the unique device for a module, or None if no device is found.
Throws an error if multiple devices are detected.
"""
devices = {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
"""
As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
"""
if {torch.device("cpu"), torch.device("meta")} == devices:
warnings.warn("Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.")
devices = {torch.device("cpu")}
""
assert len(devices) <= 1, (
"prepare only works with cpu or single-device CUDA modules, "
f"but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
return device
return _assert_and_get_unique_device(module)
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
"""

View File

@ -3,6 +3,8 @@ import types
import torch
import torch.nn.functional as F
from torch.ao.quantization.utils import _assert_and_get_unique_device
__all__ = [
"model_is_exported",
@ -136,20 +138,26 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
device = _assert_and_get_unique_device(m)
is_cuda = device is not None and device.type == "cuda"
bn_train_aten = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_train),
example_inputs,
is_cuda,
)
bn_eval_aten = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_eval),
example_inputs,
is_cuda,
)
if train_to_eval:
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_train), example_inputs
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_eval), example_inputs
)
match_pattern = bn_train_aten
replacement_pattern = bn_eval_aten
else:
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_eval), example_inputs
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_train), example_inputs
)
match_pattern = bn_eval_aten
replacement_pattern = bn_train_aten
from torch.fx.subgraph_rewriter import replace_pattern_with_filters

View File

@ -688,6 +688,27 @@ def get_fqn_to_example_inputs(
torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign]
return fqn_to_example_inputs
def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
"""
Returns the unique device for a module, or None if no device is found.
Throws an error if multiple devices are detected.
"""
devices = {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
"""
As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
"""
if {torch.device("cpu"), torch.device("meta")} == devices:
warnings.warn("Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.")
devices = {torch.device("cpu")}
""
assert len(devices) <= 1, (
"prepare only works with cpu or single-device CUDA modules, "
f"but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
return device
__all__ = [
"NodePattern",
"Pattern",