Add __main__ guards to quantization tests (#154728)

This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs.

In quantization tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154728
Approved by: https://github.com/ezyang
This commit is contained in:
Anthony Barbier
2025-06-10 19:46:02 +00:00
committed by PyTorch MergeBot
parent 07eb374e7e
commit 954ce94950
28 changed files with 171 additions and 14 deletions

View File

@ -1,5 +1,7 @@
# Owner(s): ["oncall: quantization"]
from torch.testing._internal.common_utils import raise_on_run_directly
from .common import AOMigrationTestCase
@ -359,3 +361,7 @@ class TestAOMigrationNNIntrinsic(AOMigrationTestCase):
_ = torch.ao.nn.intrinsic.quantized.dynamic
_ = torch.nn.intrinsic.quantized.dynamic
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -1,5 +1,7 @@
# Owner(s): ["oncall: quantization"]
from torch.testing._internal.common_utils import raise_on_run_directly
from .common import AOMigrationTestCase
@ -219,3 +221,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
"weight_is_statically_quantized",
]
self._test_function_import("utils", function_list)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -1,5 +1,7 @@
# Owner(s): ["oncall: quantization"]
from torch.testing._internal.common_utils import raise_on_run_directly
from .common import AOMigrationTestCase
@ -150,3 +152,7 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
"maybe_get_next_module",
]
self._test_function_import("fx.utils", function_list)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -20,7 +20,11 @@ from torch.testing._internal.common_quantized import (
)
# Testing utils
from torch.testing._internal.common_utils import IS_AVX512_VNNI_SUPPORTED, TestCase
from torch.testing._internal.common_utils import (
IS_AVX512_VNNI_SUPPORTED,
raise_on_run_directly,
TestCase,
)
from torch.testing._internal.quantization_torch_package_models import (
LinearReluFunctional,
)
@ -565,3 +569,7 @@ class TestSerialization(TestCase):
def test_linear_relu_package_quantization_transforms(self):
m = LinearReluFunctional(4).eval()
self._test_package(m, input_size=(1, 1, 4, 4), generate=False)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -134,3 +134,10 @@ class TestAdaround(QuantizationTestCase):
ada_loss = F.mse_loss(ada_out, float_out)
fq_loss = F.mse_loss(fq_out, float_out)
self.assertTrue(ada_loss.item() < fq_loss.item())
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -11,7 +11,7 @@ from torch.testing._internal.common_quantization import (
SingleLayerLinearModel,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import IS_ARM64, IS_FBCODE
from torch.testing._internal.common_utils import raise_on_run_directly, IS_ARM64, IS_FBCODE
import unittest
@ -141,3 +141,6 @@ class TestQuantizationDocs(QuantizationTestCase):
code = self._get_code(path_from_pytorch, unique_identifier)
self._test_code(code, global_inputs)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -16,7 +16,7 @@ from torch.testing._internal.common_quantization import (
_make_conv_test_input,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import IS_PPC
from torch.testing._internal.common_utils import raise_on_run_directly, IS_PPC
class TestQuantizedFunctionalOps(QuantizationTestCase):
def test_relu_api(self):
@ -235,3 +235,6 @@ class TestQuantizedFunctionalOps(QuantizationTestCase):
out_exp = torch.quantize_per_tensor(F.grid_sample(X, grid), scale=scale, zero_point=zero_point, dtype=torch.quint8)
np.testing.assert_array_almost_equal(
out.int_repr().numpy(), out_exp.int_repr().numpy(), decimal=0)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -31,6 +31,7 @@ from torch.testing._internal.common_quantized import (
qengine_is_qnnpack,
qengine_is_onednn,
)
from torch.testing._internal.common_utils import raise_on_run_directly
import torch.fx
from hypothesis import assume, given
from hypothesis import strategies as st
@ -2095,3 +2096,6 @@ class TestReferenceQuantizedModule(QuantizationTestCase):
self.assertTrue(qmax == 127)
found += 1
self.assertTrue(found == 2)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -24,8 +24,15 @@ import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import IS_PPC, IS_MACOS, IS_SANDCASTLE, IS_FBCODE, IS_ARM64
from torch.testing._internal.common_utils import (
raise_on_run_directly,
TestCase,
IS_PPC,
IS_MACOS,
IS_SANDCASTLE,
IS_FBCODE,
IS_ARM64
)
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
override_quantized_engine, supported_qengines, override_qengines, _snr
@ -8265,3 +8272,6 @@ class TestComparatorOps(TestCase):
note(f"result 3: {result}")
self.assertEqual(result_ref, result,
msg=f"'tensor.{op}(scalar)'' failed")
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -91,3 +91,9 @@ class TestQConfig(TestCase):
fake_quantize_weight = qconfig.weight()
self.assertEqual(fake_quantize_weight.reduce_range, reduce_ranges[1])
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: quantization"]
import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
from torch.ao.quantization.utils import get_fqn_to_example_inputs
from torch.ao.nn.quantized.modules.utils import _quantize_weight
from torch.ao.quantization import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
@ -220,3 +220,6 @@ class TestUtils(TestCase):
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
assert x.dtype == dtype
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -18,6 +18,7 @@ from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skipIfNoFBGEMM,
)
from torch.testing._internal.common_utils import raise_on_run_directly
class TestBiasCorrectionEager(QuantizationTestCase):
@ -119,3 +120,7 @@ class TestBiasCorrectionEager(QuantizationTestCase):
for _ in range(50)
]
self.correct_artificial_bias_quantize(float_model, img_data)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -7,6 +7,7 @@ import torch.ao.quantization._equalize as _equalize
import torch.nn as nn
from torch.ao.quantization.fuse_modules import fuse_modules
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import raise_on_run_directly
class TestEqualizeEager(QuantizationTestCase):
@ -203,3 +204,7 @@ class TestEqualizeEager(QuantizationTestCase):
input = torch.randn(20, 3)
self.assertEqual(fused_model1(input), fused_model2(input))
self.assertEqual(fused_model1(input), model(input))
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -38,7 +38,7 @@ from torch.testing._internal.common_quantization import (
test_only_eval_fn,
)
from torch.testing._internal.common_quantized import override_qengines
from torch.testing._internal.common_utils import IS_ARM64
from torch.testing._internal.common_utils import IS_ARM64, raise_on_run_directly
class SubModule(torch.nn.Module):
@ -612,3 +612,7 @@ class TestNumericSuiteEager(QuantizationTestCase):
from torchvision.models.quantization import mobilenet_v3_large
self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False))
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -43,6 +43,7 @@ from torch.testing._internal.common_quantization import (
FunctionalConvReluModel,
FunctionalConvReluConvModel,
)
from torch.testing._internal.common_utils import raise_on_run_directly
# Standard Libraries
import copy
@ -894,3 +895,6 @@ class TestEqualizeFx(QuantizationTestCase):
# Check the order of nodes in the graph
self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -30,6 +30,7 @@ from torch.testing._internal.common_quantization import (
skipIfNoQNNPACK,
override_quantized_engine,
)
from torch.testing._internal.common_utils import raise_on_run_directly
"""
@ -1956,3 +1957,6 @@ def _get_prepped_for_calibration_model_helper(model, detector_set, example_input
prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
return (prepared_for_callibrate_model, model_report)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -37,7 +37,7 @@ from torch.testing._internal.common_quantization import (
skip_if_no_torchvision,
TwoLayerLinearModel
)
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.common_utils import raise_on_run_directly, skipIfTorchDynamo
from torch.ao.quantization.quantization_mappings import (
get_default_static_quant_module_mappings,
get_default_dynamic_quant_module_mappings,
@ -2915,3 +2915,6 @@ class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
m, (torch.randn(1, 3, 224, 224),),
qconfig_dict=qconfig_dict,
should_log_inputs=False)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -4,6 +4,7 @@
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import raise_on_run_directly
class TestFusionPasses(QuantizationTestCase):
@ -104,3 +105,7 @@ class TestFusionPasses(QuantizationTestCase):
).check("quantized::add_scalar_relu_out").run(scripted_m.graph)
output = scripted_m(qA, 3.0, qC)
self.assertEqual(ref_output, output)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -528,3 +528,10 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
def test_device_side_api(self):
model = MyConvLinearModule()
self._check_device_side_api(model)
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -71,7 +71,10 @@ from torch.testing._internal.common_quantized import (
qengine_is_fbgemm,
qengine_is_qnnpack,
)
from torch.testing._internal.common_utils import set_default_dtype
from torch.testing._internal.common_utils import (
raise_on_run_directly,
set_default_dtype,
)
from torch.testing._internal.jit_utils import (
attrs_with_prefix,
get_forward,
@ -3880,3 +3883,7 @@ class TestQuantizeJit(QuantizationTestCase):
)
# compare result with eager mode
self.assertEqual(quantized_model(self.calib_data[0][0]), result_eager)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -26,7 +26,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
)
from torch.export import export_for_training
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS
from torch.testing._internal.common_utils import IS_WINDOWS, raise_on_run_directly
class TestHelperModules:
@ -307,3 +307,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
example_inputs,
BackendAQuantizer(),
)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -9,7 +9,11 @@ from torch.ao.quantization.pt2e.graph_utils import (
get_equivalent_types,
update_equivalent_types_dict,
)
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
from torch.testing._internal.common_utils import (
IS_WINDOWS,
raise_on_run_directly,
TestCase,
)
class TestGraphUtils(TestCase):
@ -121,3 +125,7 @@ class TestGraphUtils(TestCase):
[torch.nn.Conv2d, torch.nn.ReLU6],
)
self.assertEqual(len(fused_partitions), 1)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -12,7 +12,11 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
from torch.fx import Node
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef
from torch.testing._internal.common_utils import (
IS_WINDOWS,
raise_on_run_directly,
skipIfCrossRef,
)
class TestHelperModules:
@ -513,3 +517,7 @@ class TestMetaDataPorting(QuantizationTestCase):
BackendAQuantizer(),
node_tags,
)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -21,7 +21,12 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
)
from torch.export import export_for_training
from torch.testing._internal.common_quantization import TestHelperModules
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase
from torch.testing._internal.common_utils import (
IS_WINDOWS,
raise_on_run_directly,
skipIfCrossRef,
TestCase,
)
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
@ -346,3 +351,7 @@ class TestNumericDebugger(TestCase):
# may change with future node ordering changes.
self.assertNotEqual(handles_after_modification["relu_default"], 0)
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -43,6 +43,7 @@ from torch.testing._internal.common_quantization import (
skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import raise_on_run_directly
class PT2EQATTestCase(QuantizationTestCase):
@ -1177,3 +1178,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
self.checkGraphModuleNodes(
exported_model.graph_module, expected_node_occurrence=node_occurrence
)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -17,6 +17,7 @@ from torch.testing._internal.common_quantization import (
skipIfNoQNNPACK,
TestHelperModules,
)
from torch.testing._internal.common_utils import raise_on_run_directly
@skipIfNoQNNPACK
@ -306,3 +307,7 @@ class TestPT2ERepresentation(QuantizationTestCase):
ref_node_occurrence,
non_ref_node_occurrence,
)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -25,7 +25,10 @@ from torch.testing._internal.common_quantization import (
skipIfNoX86,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
)
class NodePosType(Enum):
@ -2858,3 +2861,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_list,
lower=True,
)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")

View File

@ -38,6 +38,7 @@ from torch.testing._internal.common_quantization import (
TestHelperModules,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import raise_on_run_directly
@skipIfNoQNNPACK
@ -1080,3 +1081,7 @@ class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase):
self.assertTrue(
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")