mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add __main__ guards to quantization tests (#154728)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In quantization tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154728 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
07eb374e7e
commit
954ce94950
@ -1,5 +1,7 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
from .common import AOMigrationTestCase
|
||||
|
||||
|
||||
@ -359,3 +361,7 @@ class TestAOMigrationNNIntrinsic(AOMigrationTestCase):
|
||||
|
||||
_ = torch.ao.nn.intrinsic.quantized.dynamic
|
||||
_ = torch.nn.intrinsic.quantized.dynamic
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
from .common import AOMigrationTestCase
|
||||
|
||||
|
||||
@ -219,3 +221,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
"weight_is_statically_quantized",
|
||||
]
|
||||
self._test_function_import("utils", function_list)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
from .common import AOMigrationTestCase
|
||||
|
||||
|
||||
@ -150,3 +152,7 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
"maybe_get_next_module",
|
||||
]
|
||||
self._test_function_import("fx.utils", function_list)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -20,7 +20,11 @@ from torch.testing._internal.common_quantized import (
|
||||
)
|
||||
|
||||
# Testing utils
|
||||
from torch.testing._internal.common_utils import IS_AVX512_VNNI_SUPPORTED, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_AVX512_VNNI_SUPPORTED,
|
||||
raise_on_run_directly,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.quantization_torch_package_models import (
|
||||
LinearReluFunctional,
|
||||
)
|
||||
@ -565,3 +569,7 @@ class TestSerialization(TestCase):
|
||||
def test_linear_relu_package_quantization_transforms(self):
|
||||
m = LinearReluFunctional(4).eval()
|
||||
self._test_package(m, input_size=(1, 1, 4, 4), generate=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -134,3 +134,10 @@ class TestAdaround(QuantizationTestCase):
|
||||
ada_loss = F.mse_loss(ada_out, float_out)
|
||||
fq_loss = F.mse_loss(fq_out, float_out)
|
||||
self.assertTrue(ada_loss.item() < fq_loss.item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ from torch.testing._internal.common_quantization import (
|
||||
SingleLayerLinearModel,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import IS_ARM64, IS_FBCODE
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, IS_ARM64, IS_FBCODE
|
||||
import unittest
|
||||
|
||||
|
||||
@ -141,3 +141,6 @@ class TestQuantizationDocs(QuantizationTestCase):
|
||||
|
||||
code = self._get_code(path_from_pytorch, unique_identifier)
|
||||
self._test_code(code, global_inputs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -16,7 +16,7 @@ from torch.testing._internal.common_quantization import (
|
||||
_make_conv_test_input,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import IS_PPC
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, IS_PPC
|
||||
|
||||
class TestQuantizedFunctionalOps(QuantizationTestCase):
|
||||
def test_relu_api(self):
|
||||
@ -235,3 +235,6 @@ class TestQuantizedFunctionalOps(QuantizationTestCase):
|
||||
out_exp = torch.quantize_per_tensor(F.grid_sample(X, grid), scale=scale, zero_point=zero_point, dtype=torch.quint8)
|
||||
np.testing.assert_array_almost_equal(
|
||||
out.int_repr().numpy(), out_exp.int_repr().numpy(), decimal=0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -31,6 +31,7 @@ from torch.testing._internal.common_quantized import (
|
||||
qengine_is_qnnpack,
|
||||
qengine_is_onednn,
|
||||
)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
import torch.fx
|
||||
from hypothesis import assume, given
|
||||
from hypothesis import strategies as st
|
||||
@ -2095,3 +2096,6 @@ class TestReferenceQuantizedModule(QuantizationTestCase):
|
||||
self.assertTrue(qmax == 127)
|
||||
found += 1
|
||||
self.assertTrue(found == 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -24,8 +24,15 @@ import torch.testing._internal.hypothesis_utils as hu
|
||||
hu.assert_deadline_disabled()
|
||||
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.common_utils import IS_PPC, IS_MACOS, IS_SANDCASTLE, IS_FBCODE, IS_ARM64
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
TestCase,
|
||||
IS_PPC,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_FBCODE,
|
||||
IS_ARM64
|
||||
)
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
|
||||
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
|
||||
override_quantized_engine, supported_qengines, override_qengines, _snr
|
||||
@ -8265,3 +8272,6 @@ class TestComparatorOps(TestCase):
|
||||
note(f"result 3: {result}")
|
||||
self.assertEqual(result_ref, result,
|
||||
msg=f"'tensor.{op}(scalar)'' failed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -91,3 +91,9 @@ class TestQConfig(TestCase):
|
||||
|
||||
fake_quantize_weight = qconfig.weight()
|
||||
self.assertEqual(fake_quantize_weight.reduce_range, reduce_ranges[1])
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||
from torch.ao.quantization.utils import get_fqn_to_example_inputs
|
||||
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
||||
from torch.ao.quantization import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
|
||||
@ -220,3 +220,6 @@ class TestUtils(TestCase):
|
||||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
|
||||
], dtype=torch.uint8))
|
||||
assert x.dtype == dtype
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -18,6 +18,7 @@ from torch.testing._internal.common_quantization import (
|
||||
QuantizationTestCase,
|
||||
skipIfNoFBGEMM,
|
||||
)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
|
||||
class TestBiasCorrectionEager(QuantizationTestCase):
|
||||
@ -119,3 +120,7 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||
for _ in range(50)
|
||||
]
|
||||
self.correct_artificial_bias_quantize(float_model, img_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -7,6 +7,7 @@ import torch.ao.quantization._equalize as _equalize
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization.fuse_modules import fuse_modules
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
|
||||
class TestEqualizeEager(QuantizationTestCase):
|
||||
@ -203,3 +204,7 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
input = torch.randn(20, 3)
|
||||
self.assertEqual(fused_model1(input), fused_model2(input))
|
||||
self.assertEqual(fused_model1(input), model(input))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -38,7 +38,7 @@ from torch.testing._internal.common_quantization import (
|
||||
test_only_eval_fn,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_qengines
|
||||
from torch.testing._internal.common_utils import IS_ARM64
|
||||
from torch.testing._internal.common_utils import IS_ARM64, raise_on_run_directly
|
||||
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
@ -612,3 +612,7 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||
from torchvision.models.quantization import mobilenet_v3_large
|
||||
|
||||
self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -43,6 +43,7 @@ from torch.testing._internal.common_quantization import (
|
||||
FunctionalConvReluModel,
|
||||
FunctionalConvReluConvModel,
|
||||
)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
# Standard Libraries
|
||||
import copy
|
||||
@ -894,3 +895,6 @@ class TestEqualizeFx(QuantizationTestCase):
|
||||
|
||||
# Check the order of nodes in the graph
|
||||
self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -30,6 +30,7 @@ from torch.testing._internal.common_quantization import (
|
||||
skipIfNoQNNPACK,
|
||||
override_quantized_engine,
|
||||
)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
|
||||
"""
|
||||
@ -1956,3 +1957,6 @@ def _get_prepped_for_calibration_model_helper(model, detector_set, example_input
|
||||
prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
|
||||
|
||||
return (prepared_for_callibrate_model, model_report)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -37,7 +37,7 @@ from torch.testing._internal.common_quantization import (
|
||||
skip_if_no_torchvision,
|
||||
TwoLayerLinearModel
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, skipIfTorchDynamo
|
||||
from torch.ao.quantization.quantization_mappings import (
|
||||
get_default_static_quant_module_mappings,
|
||||
get_default_dynamic_quant_module_mappings,
|
||||
@ -2915,3 +2915,6 @@ class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
|
||||
m, (torch.randn(1, 3, 224, 224),),
|
||||
qconfig_dict=qconfig_dict,
|
||||
should_log_inputs=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -4,6 +4,7 @@
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
|
||||
class TestFusionPasses(QuantizationTestCase):
|
||||
@ -104,3 +105,7 @@ class TestFusionPasses(QuantizationTestCase):
|
||||
).check("quantized::add_scalar_relu_out").run(scripted_m.graph)
|
||||
output = scripted_m(qA, 3.0, qC)
|
||||
self.assertEqual(ref_output, output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -528,3 +528,10 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
|
||||
def test_device_side_api(self):
|
||||
model = MyConvLinearModule()
|
||||
self._check_device_side_api(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -71,7 +71,10 @@ from torch.testing._internal.common_quantized import (
|
||||
qengine_is_fbgemm,
|
||||
qengine_is_qnnpack,
|
||||
)
|
||||
from torch.testing._internal.common_utils import set_default_dtype
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
set_default_dtype,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import (
|
||||
attrs_with_prefix,
|
||||
get_forward,
|
||||
@ -3880,3 +3883,7 @@ class TestQuantizeJit(QuantizationTestCase):
|
||||
)
|
||||
# compare result with eager mode
|
||||
self.assertEqual(quantized_model(self.calib_data[0][0]), result_eager)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -26,7 +26,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, raise_on_run_directly
|
||||
|
||||
|
||||
class TestHelperModules:
|
||||
@ -307,3 +307,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -9,7 +9,11 @@ from torch.ao.quantization.pt2e.graph_utils import (
|
||||
get_equivalent_types,
|
||||
update_equivalent_types_dict,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
raise_on_run_directly,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphUtils(TestCase):
|
||||
@ -121,3 +125,7 @@ class TestGraphUtils(TestCase):
|
||||
[torch.nn.Conv2d, torch.nn.ReLU6],
|
||||
)
|
||||
self.assertEqual(len(fused_partitions), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -12,7 +12,11 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
|
||||
from torch.fx import Node
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
raise_on_run_directly,
|
||||
skipIfCrossRef,
|
||||
)
|
||||
|
||||
|
||||
class TestHelperModules:
|
||||
@ -513,3 +517,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
BackendAQuantizer(),
|
||||
node_tags,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -21,7 +21,12 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.testing._internal.common_quantization import TestHelperModules
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
raise_on_run_directly,
|
||||
skipIfCrossRef,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
||||
@ -346,3 +351,7 @@ class TestNumericDebugger(TestCase):
|
||||
# may change with future node ordering changes.
|
||||
self.assertNotEqual(handles_after_modification["relu_default"], 0)
|
||||
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -43,6 +43,7 @@ from torch.testing._internal.common_quantization import (
|
||||
skipIfNoQNNPACK,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
|
||||
class PT2EQATTestCase(QuantizationTestCase):
|
||||
@ -1177,3 +1178,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
|
||||
self.checkGraphModuleNodes(
|
||||
exported_model.graph_module, expected_node_occurrence=node_occurrence
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -17,6 +17,7 @@ from torch.testing._internal.common_quantization import (
|
||||
skipIfNoQNNPACK,
|
||||
TestHelperModules,
|
||||
)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
@ -306,3 +307,7 @@ class TestPT2ERepresentation(QuantizationTestCase):
|
||||
ref_node_occurrence,
|
||||
non_ref_node_occurrence,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -25,7 +25,10 @@ from torch.testing._internal.common_quantization import (
|
||||
skipIfNoX86,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
|
||||
|
||||
class NodePosType(Enum):
|
||||
@ -2858,3 +2861,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_list,
|
||||
lower=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
@ -38,6 +38,7 @@ from torch.testing._internal.common_quantization import (
|
||||
TestHelperModules,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
@ -1080,3 +1081,7 @@ class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase):
|
||||
self.assertTrue(
|
||||
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_quantization.py")
|
||||
|
Reference in New Issue
Block a user