mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[quant][pt2e] Move batch norm op between eval/train for cuda (#123957)"
This reverts commit 4efb28c90025ea3d979b720942cd97a274fac6da. Reverted https://github.com/pytorch/pytorch/pull/123957 on behalf of https://github.com/jeanschmidt due to reverting to check if it will fix rocm jobs on main ([comment](https://github.com/pytorch/pytorch/pull/123957#issuecomment-2075158146))
This commit is contained in:
@ -1826,18 +1826,6 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||||||
def test_move_exported_model_dropout_inplace(self):
|
def test_move_exported_model_dropout_inplace(self):
|
||||||
self._test_move_exported_model_dropout(inplace=True)
|
self._test_move_exported_model_dropout(inplace=True)
|
||||||
|
|
||||||
def _get_bn_train_eval_ops(self, is_cuda: bool):
|
|
||||||
if is_cuda:
|
|
||||||
return (
|
|
||||||
torch.ops.aten.cudnn_batch_norm.default,
|
|
||||||
torch.ops.aten.cudnn_batch_norm.default,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return (
|
|
||||||
torch.ops.aten._native_batch_norm_legit.default,
|
|
||||||
torch.ops.aten._native_batch_norm_legit_no_training.default,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_move_exported_model_bn(self):
|
def test_move_exported_model_bn(self):
|
||||||
"""
|
"""
|
||||||
Test switching batch_norm behavior between train and eval modes using
|
Test switching batch_norm behavior between train and eval modes using
|
||||||
@ -1852,18 +1840,12 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bn(x)
|
return self.bn(x)
|
||||||
|
|
||||||
is_cuda = torch.cuda.is_available()
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||||
if is_cuda:
|
m = M().train()
|
||||||
m = M().train().cuda()
|
|
||||||
example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
|
|
||||||
else:
|
|
||||||
m = M().train()
|
|
||||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
||||||
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops(is_cuda)
|
|
||||||
m = capture_pre_autograd_graph(m, example_inputs)
|
m = capture_pre_autograd_graph(m, example_inputs)
|
||||||
|
|
||||||
# Assert that batch norm op exists and is in train mode
|
# Assert that batch norm op exists and is in train mode
|
||||||
bn_node = self._get_node(m, bn_train_op)
|
bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
|
||||||
self.assertTrue(bn_node is not None)
|
self.assertTrue(bn_node is not None)
|
||||||
self.assertTrue(bn_node.args[5])
|
self.assertTrue(bn_node.args[5])
|
||||||
|
|
||||||
@ -1871,14 +1853,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||||||
torch.ao.quantization.move_exported_model_to_eval(m)
|
torch.ao.quantization.move_exported_model_to_eval(m)
|
||||||
|
|
||||||
# Assert that batch norm op is now in eval mode
|
# Assert that batch norm op is now in eval mode
|
||||||
bn_node = self._get_node(m, bn_eval_op)
|
bn_node = self._get_node(
|
||||||
|
m, torch.ops.aten._native_batch_norm_legit_no_training.default
|
||||||
|
)
|
||||||
self.assertTrue(bn_node is not None)
|
self.assertTrue(bn_node is not None)
|
||||||
|
|
||||||
# Move to train
|
# Move to train
|
||||||
torch.ao.quantization.move_exported_model_to_train(m)
|
torch.ao.quantization.move_exported_model_to_train(m)
|
||||||
|
|
||||||
# Assert that batch norm op is now in train mode again
|
# Assert that batch norm op is now in train mode again
|
||||||
bn_node = self._get_node(m, bn_train_op)
|
bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
|
||||||
self.assertTrue(bn_node is not None)
|
self.assertTrue(bn_node is not None)
|
||||||
self.assertTrue(bn_node.args[5])
|
self.assertTrue(bn_node.args[5])
|
||||||
|
|
||||||
@ -1924,25 +1908,22 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
is_cuda = torch.cuda.is_available()
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||||
if is_cuda:
|
m = M().train()
|
||||||
m = M().train().cuda()
|
|
||||||
example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
|
|
||||||
else:
|
|
||||||
m = M().train()
|
|
||||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
||||||
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops(is_cuda)
|
|
||||||
m = capture_pre_autograd_graph(m, example_inputs)
|
m = capture_pre_autograd_graph(m, example_inputs)
|
||||||
|
|
||||||
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
|
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
|
||||||
targets = [n.target for n in m.graph.nodes]
|
targets = [n.target for n in m.graph.nodes]
|
||||||
bn_op = bn_train_op if train else bn_eval_op
|
bn_train_target = torch.ops.aten._native_batch_norm_legit.default
|
||||||
bn_node = self._get_node(m, bn_op)
|
bn_eval_target = torch.ops.aten._native_batch_norm_legit_no_training.default
|
||||||
self.assertTrue(bn_node is not None)
|
if train:
|
||||||
if is_cuda:
|
self.assertTrue(bn_train_target in targets)
|
||||||
self.assertEqual(bn_node.args[5], train)
|
self.assertTrue(bn_eval_target not in targets)
|
||||||
|
else:
|
||||||
|
self.assertTrue(bn_eval_target in targets)
|
||||||
|
self.assertTrue(bn_train_target not in targets)
|
||||||
dropout_node = self._get_node(m, torch.ops.aten.dropout.default)
|
dropout_node = self._get_node(m, torch.ops.aten.dropout.default)
|
||||||
self.assertEqual(dropout_node.args[2], train)
|
self.assertTrue(dropout_node.args[2] == train)
|
||||||
|
|
||||||
# Before wrapping: this is not OK
|
# Before wrapping: this is not OK
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
|
@ -23,7 +23,6 @@ from torch.ao.quantization.qconfig import (
|
|||||||
)
|
)
|
||||||
from torch.ao.quantization.stubs import DeQuantStub
|
from torch.ao.quantization.stubs import DeQuantStub
|
||||||
from torch.ao.quantization.utils import (
|
from torch.ao.quantization.utils import (
|
||||||
_assert_and_get_unique_device,
|
|
||||||
activation_is_statically_quantized,
|
activation_is_statically_quantized,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization.observer import _is_activation_post_process
|
from torch.ao.quantization.observer import _is_activation_post_process
|
||||||
@ -223,13 +222,26 @@ def graph_module_from_producer_nodes(
|
|||||||
graph_module = GraphModule(root, graph)
|
graph_module = GraphModule(root, graph)
|
||||||
return graph_module
|
return graph_module
|
||||||
|
|
||||||
# TODO: delete
|
|
||||||
def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
|
def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
|
||||||
"""
|
"""
|
||||||
Returns the unique device for a module, or None if no device is found.
|
Returns the unique device for a module, or None if no device is found.
|
||||||
Throws an error if multiple devices are detected.
|
Throws an error if multiple devices are detected.
|
||||||
"""
|
"""
|
||||||
return _assert_and_get_unique_device(module)
|
devices = {p.device for p in module.parameters()} | \
|
||||||
|
{p.device for p in module.buffers()}
|
||||||
|
"""
|
||||||
|
As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
|
||||||
|
"""
|
||||||
|
if {torch.device("cpu"), torch.device("meta")} == devices:
|
||||||
|
warnings.warn("Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.")
|
||||||
|
devices = {torch.device("cpu")}
|
||||||
|
""
|
||||||
|
assert len(devices) <= 1, (
|
||||||
|
"prepare only works with cpu or single-device CUDA modules, "
|
||||||
|
f"but got devices {devices}"
|
||||||
|
)
|
||||||
|
device = next(iter(devices)) if len(devices) > 0 else None
|
||||||
|
return device
|
||||||
|
|
||||||
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
|
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
|
||||||
"""
|
"""
|
||||||
|
@ -3,8 +3,6 @@ import types
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from torch.ao.quantization.utils import _assert_and_get_unique_device
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"model_is_exported",
|
"model_is_exported",
|
||||||
@ -138,26 +136,20 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
|
|||||||
torch.randn(1), # bn_running_mean
|
torch.randn(1), # bn_running_mean
|
||||||
torch.randn(1), # bn_running_var
|
torch.randn(1), # bn_running_var
|
||||||
)
|
)
|
||||||
|
|
||||||
device = _assert_and_get_unique_device(m)
|
|
||||||
is_cuda = device is not None and device.type == "cuda"
|
|
||||||
bn_train_aten = _get_aten_graph_module_for_pattern(
|
|
||||||
_WrapperModule(bn_train),
|
|
||||||
example_inputs,
|
|
||||||
is_cuda,
|
|
||||||
)
|
|
||||||
bn_eval_aten = _get_aten_graph_module_for_pattern(
|
|
||||||
_WrapperModule(bn_eval),
|
|
||||||
example_inputs,
|
|
||||||
is_cuda,
|
|
||||||
)
|
|
||||||
|
|
||||||
if train_to_eval:
|
if train_to_eval:
|
||||||
match_pattern = bn_train_aten
|
match_pattern = _get_aten_graph_module_for_pattern(
|
||||||
replacement_pattern = bn_eval_aten
|
_WrapperModule(bn_train), example_inputs
|
||||||
|
)
|
||||||
|
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||||
|
_WrapperModule(bn_eval), example_inputs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
match_pattern = bn_eval_aten
|
match_pattern = _get_aten_graph_module_for_pattern(
|
||||||
replacement_pattern = bn_train_aten
|
_WrapperModule(bn_eval), example_inputs
|
||||||
|
)
|
||||||
|
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||||
|
_WrapperModule(bn_train), example_inputs
|
||||||
|
)
|
||||||
|
|
||||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||||
|
|
||||||
|
@ -688,27 +688,6 @@ def get_fqn_to_example_inputs(
|
|||||||
torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign]
|
torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign]
|
||||||
return fqn_to_example_inputs
|
return fqn_to_example_inputs
|
||||||
|
|
||||||
def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
|
|
||||||
"""
|
|
||||||
Returns the unique device for a module, or None if no device is found.
|
|
||||||
Throws an error if multiple devices are detected.
|
|
||||||
"""
|
|
||||||
devices = {p.device for p in module.parameters()} | \
|
|
||||||
{p.device for p in module.buffers()}
|
|
||||||
"""
|
|
||||||
As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
|
|
||||||
"""
|
|
||||||
if {torch.device("cpu"), torch.device("meta")} == devices:
|
|
||||||
warnings.warn("Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.")
|
|
||||||
devices = {torch.device("cpu")}
|
|
||||||
""
|
|
||||||
assert len(devices) <= 1, (
|
|
||||||
"prepare only works with cpu or single-device CUDA modules, "
|
|
||||||
f"but got devices {devices}"
|
|
||||||
)
|
|
||||||
device = next(iter(devices)) if len(devices) > 0 else None
|
|
||||||
return device
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"NodePattern",
|
"NodePattern",
|
||||||
"Pattern",
|
"Pattern",
|
||||||
|
Reference in New Issue
Block a user