mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[quant][pt2e] Move batch norm op between eval/train for cuda (#123957)"
This reverts commit 4efb28c90025ea3d979b720942cd97a274fac6da. Reverted https://github.com/pytorch/pytorch/pull/123957 on behalf of https://github.com/jeanschmidt due to reverting to check if it will fix rocm jobs on main ([comment](https://github.com/pytorch/pytorch/pull/123957#issuecomment-2075158146))
This commit is contained in:
@ -1826,18 +1826,6 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
def test_move_exported_model_dropout_inplace(self):
|
||||
self._test_move_exported_model_dropout(inplace=True)
|
||||
|
||||
def _get_bn_train_eval_ops(self, is_cuda: bool):
|
||||
if is_cuda:
|
||||
return (
|
||||
torch.ops.aten.cudnn_batch_norm.default,
|
||||
torch.ops.aten.cudnn_batch_norm.default,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch.ops.aten._native_batch_norm_legit.default,
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default,
|
||||
)
|
||||
|
||||
def test_move_exported_model_bn(self):
|
||||
"""
|
||||
Test switching batch_norm behavior between train and eval modes using
|
||||
@ -1852,18 +1840,12 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
def forward(self, x):
|
||||
return self.bn(x)
|
||||
|
||||
is_cuda = torch.cuda.is_available()
|
||||
if is_cuda:
|
||||
m = M().train().cuda()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
|
||||
else:
|
||||
m = M().train()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops(is_cuda)
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
m = M().train()
|
||||
m = capture_pre_autograd_graph(m, example_inputs)
|
||||
|
||||
# Assert that batch norm op exists and is in train mode
|
||||
bn_node = self._get_node(m, bn_train_op)
|
||||
bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
|
||||
self.assertTrue(bn_node is not None)
|
||||
self.assertTrue(bn_node.args[5])
|
||||
|
||||
@ -1871,14 +1853,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
torch.ao.quantization.move_exported_model_to_eval(m)
|
||||
|
||||
# Assert that batch norm op is now in eval mode
|
||||
bn_node = self._get_node(m, bn_eval_op)
|
||||
bn_node = self._get_node(
|
||||
m, torch.ops.aten._native_batch_norm_legit_no_training.default
|
||||
)
|
||||
self.assertTrue(bn_node is not None)
|
||||
|
||||
# Move to train
|
||||
torch.ao.quantization.move_exported_model_to_train(m)
|
||||
|
||||
# Assert that batch norm op is now in train mode again
|
||||
bn_node = self._get_node(m, bn_train_op)
|
||||
bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
|
||||
self.assertTrue(bn_node is not None)
|
||||
self.assertTrue(bn_node.args[5])
|
||||
|
||||
@ -1924,25 +1908,22 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
is_cuda = torch.cuda.is_available()
|
||||
if is_cuda:
|
||||
m = M().train().cuda()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
|
||||
else:
|
||||
m = M().train()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops(is_cuda)
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
m = M().train()
|
||||
m = capture_pre_autograd_graph(m, example_inputs)
|
||||
|
||||
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
|
||||
targets = [n.target for n in m.graph.nodes]
|
||||
bn_op = bn_train_op if train else bn_eval_op
|
||||
bn_node = self._get_node(m, bn_op)
|
||||
self.assertTrue(bn_node is not None)
|
||||
if is_cuda:
|
||||
self.assertEqual(bn_node.args[5], train)
|
||||
bn_train_target = torch.ops.aten._native_batch_norm_legit.default
|
||||
bn_eval_target = torch.ops.aten._native_batch_norm_legit_no_training.default
|
||||
if train:
|
||||
self.assertTrue(bn_train_target in targets)
|
||||
self.assertTrue(bn_eval_target not in targets)
|
||||
else:
|
||||
self.assertTrue(bn_eval_target in targets)
|
||||
self.assertTrue(bn_train_target not in targets)
|
||||
dropout_node = self._get_node(m, torch.ops.aten.dropout.default)
|
||||
self.assertEqual(dropout_node.args[2], train)
|
||||
self.assertTrue(dropout_node.args[2] == train)
|
||||
|
||||
# Before wrapping: this is not OK
|
||||
with self.assertRaises(NotImplementedError):
|
||||
|
Reference in New Issue
Block a user