Add __main__ guards to ao tests (#154612)

This is the first PR of a series in an attempt to get the content of #134592 merged as smaller PRs (Given that the original one was closed due to a lack of reviewers).

This specific PR contains:
- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Update ao tests.

There will be follow up PRs to update the other test suites but I don't have permissions to create branches directly on pytorch/pytorch so I can't create a stack and therefore will have to create them one at the time.

Cc @jerryzh168
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154612
Approved by: https://github.com/jcaip
This commit is contained in:
Anthony Barbier
2025-06-10 18:33:05 +00:00
committed by PyTorch MergeBot
parent 0b677560e6
commit b1b8e57cda
11 changed files with 68 additions and 60 deletions

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: unknown"]
import copy
import logging
import torch
import torch.nn as nn
@ -10,11 +9,10 @@ from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier
ActivationSparsifier,
)
from torch.ao.pruning.sparsifier.utils import module_to_fqn
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
TestCase,
)
@ -405,3 +403,7 @@ class TestActivationSparsifier(TestCase):
# check state_dict() after squash_mask()
self._check_state_dict(activation_sparsifier)
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,8 +1,6 @@
# Owner(s): ["module: unknown"]
import logging
import torch
import torch.ao.quantization as tq
from torch import nn
@ -15,13 +13,13 @@ from torch.ao.quantization.quantize_fx import (
prepare_qat_fx,
)
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_utils import TestCase, xfailIfS390X
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
from torch.testing._internal.common_utils import (
raise_on_run_directly,
TestCase,
xfailIfS390X,
)
sparse_defaults = {
"sparsity_level": 0.8,
"sparse_block_shape": (1, 4),
@ -642,3 +640,7 @@ class TestFxComposability(TestCase):
sparsity_level, sparse_config[0]["sparsity_level"]
)
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,19 +1,13 @@
# Owner(s): ["module: unknown"]
import copy
import logging
import warnings
import torch
from torch import nn
from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
class ImplementedDataScheduler(BaseDataScheduler):
@ -180,3 +174,7 @@ class TestBaseDataScheduler(TestCase):
name, _, _ = self._get_name_data_config(some_data, defaults)
assert scheduler1.base_param[name] == scheduler2.base_param[name]
assert scheduler1._last_param[name] == scheduler2._last_param[name]
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -2,7 +2,6 @@
import copy
import itertools
import logging
import math
import torch
@ -15,12 +14,7 @@ from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import (
post_training_sparse_quantize,
)
from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
class ImplementedSparsifier(BaseDataSparsifier):
@ -792,3 +786,7 @@ class TestQuantizationUtils(TestCase):
assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway
assert abs(sl_emb_seq_0 - 0.80) <= 0.05 # +- 5% leeway
assert abs(sl_emb_seq_1 - 0.80) <= 0.05 # +- 5% leeway
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -19,20 +19,16 @@ from torch.testing._internal.common_quantized import (
qengine_is_qnnpack,
qengine_is_x86,
)
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
TestCase,
)
# TODO: Once more test files are created, move the contents to a ao folder.
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False # Prevent duplicate logs if root logger also has handlers
class TestQuantizedSparseKernels(TestCase):
@ -331,4 +327,4 @@ class TestQuantizedSparseLayers(TestCase):
if __name__ == "__main__":
run_tests()
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,18 +1,11 @@
# Owner(s): ["module: unknown"]
import logging
import torch
from torch import nn
from torch.ao.pruning.sparsifier import utils
from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
class ModelUnderTest(nn.Module):
@ -173,3 +166,7 @@ class TestFakeSparsity(TestCase):
y = model(x)
y_hat = model_trace(x)
self.assertEqual(y_hat, y)
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -4,7 +4,7 @@ import warnings
from torch import nn
from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
class ImplementedScheduler(BaseScheduler):
@ -190,3 +190,7 @@ class TestCubicScheduler(TestCase):
self.sorted_sparse_levels,
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
)
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: unknown"]
import itertools
import logging
import re
import torch
@ -18,12 +17,7 @@ from torch.testing._internal.common_pruning import (
MockSparseLinear,
SimpleLinear,
)
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
class TestBaseSparsifier(TestCase):
@ -484,3 +478,7 @@ class TestNearlyDiagonalSparsifier(TestCase):
assert mask[row, col] == 1
else:
assert mask[row, col] == 0
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -18,7 +18,7 @@ from torch.testing._internal.common_quantization import (
SingleLayerLinearModel,
TwoLayerLinearModel,
)
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
logging.basicConfig(
@ -147,3 +147,7 @@ class TestSparsityUtilFunctions(TestCase):
self.assertEqual(arg_info["module_fqn"], "foo.bar")
self.assertEqual(arg_info["tensor_name"], "baz")
self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: unknown"]
import copy
import logging
import random
import torch
@ -29,13 +28,13 @@ from torch.testing._internal.common_pruning import (
SimpleConv2d,
SimpleLinear,
)
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
TestCase,
)
DEVICES = {
torch.device("cpu"),
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
@ -1089,3 +1088,7 @@ class TestFPGMPruner(TestCase):
self._test_update_mask_on_multiple_layer(
expected_conv1, expected_conv2, device
)
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,4 +1,5 @@
# Owner(s): ["module: unknown"]
import logging
# Kernels
from ao.sparsity.test_kernels import ( # noqa: F401
@ -56,4 +57,9 @@ from ao.sparsity.test_sparsity_utils import TestSparsityUtilFunctions # noqa: F
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
run_tests()