Add __main__ guards to ao tests (#154612)

This is the first PR of a series in an attempt to get the content of #134592 merged as smaller PRs (Given that the original one was closed due to a lack of reviewers).

This specific PR contains:
- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Update ao tests.

There will be follow up PRs to update the other test suites but I don't have permissions to create branches directly on pytorch/pytorch so I can't create a stack and therefore will have to create them one at the time.

Cc @jerryzh168
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154612
Approved by: https://github.com/jcaip
This commit is contained in:
Anthony Barbier
2025-06-10 18:33:05 +00:00
committed by PyTorch MergeBot
parent 0b677560e6
commit b1b8e57cda
11 changed files with 68 additions and 60 deletions

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import copy import copy
import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,11 +9,10 @@ from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier
ActivationSparsifier, ActivationSparsifier,
) )
from torch.ao.pruning.sparsifier.utils import module_to_fqn from torch.ao.pruning.sparsifier.utils import module_to_fqn
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
logging.basicConfig( TestCase,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
) )
@ -405,3 +403,7 @@ class TestActivationSparsifier(TestCase):
# check state_dict() after squash_mask() # check state_dict() after squash_mask()
self._check_state_dict(activation_sparsifier) self._check_state_dict(activation_sparsifier)
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,8 +1,6 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import logging
import torch import torch
import torch.ao.quantization as tq import torch.ao.quantization as tq
from torch import nn from torch import nn
@ -15,13 +13,13 @@ from torch.ao.quantization.quantize_fx import (
prepare_qat_fx, prepare_qat_fx,
) )
from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_utils import TestCase, xfailIfS390X from torch.testing._internal.common_utils import (
raise_on_run_directly,
TestCase,
logging.basicConfig( xfailIfS390X,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
) )
sparse_defaults = { sparse_defaults = {
"sparsity_level": 0.8, "sparsity_level": 0.8,
"sparse_block_shape": (1, 4), "sparse_block_shape": (1, 4),
@ -642,3 +640,7 @@ class TestFxComposability(TestCase):
sparsity_level, sparse_config[0]["sparsity_level"] sparsity_level, sparse_config[0]["sparsity_level"]
) )
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,19 +1,13 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import copy import copy
import logging
import warnings import warnings
import torch import torch
from torch import nn from torch import nn
from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class ImplementedDataScheduler(BaseDataScheduler): class ImplementedDataScheduler(BaseDataScheduler):
@ -180,3 +174,7 @@ class TestBaseDataScheduler(TestCase):
name, _, _ = self._get_name_data_config(some_data, defaults) name, _, _ = self._get_name_data_config(some_data, defaults)
assert scheduler1.base_param[name] == scheduler2.base_param[name] assert scheduler1.base_param[name] == scheduler2.base_param[name]
assert scheduler1._last_param[name] == scheduler2._last_param[name] assert scheduler1._last_param[name] == scheduler2._last_param[name]
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -2,7 +2,6 @@
import copy import copy
import itertools import itertools
import logging
import math import math
import torch import torch
@ -15,12 +14,7 @@ from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import (
post_training_sparse_quantize, post_training_sparse_quantize,
) )
from torch.nn.utils.parametrize import is_parametrized from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class ImplementedSparsifier(BaseDataSparsifier): class ImplementedSparsifier(BaseDataSparsifier):
@ -792,3 +786,7 @@ class TestQuantizationUtils(TestCase):
assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway
assert abs(sl_emb_seq_0 - 0.80) <= 0.05 # +- 5% leeway assert abs(sl_emb_seq_0 - 0.80) <= 0.05 # +- 5% leeway
assert abs(sl_emb_seq_1 - 0.80) <= 0.05 # +- 5% leeway assert abs(sl_emb_seq_1 - 0.80) <= 0.05 # +- 5% leeway
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -19,20 +19,16 @@ from torch.testing._internal.common_quantized import (
qengine_is_qnnpack, qengine_is_qnnpack,
qengine_is_x86, qengine_is_x86,
) )
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
TestCase,
)
# TODO: Once more test files are created, move the contents to a ao folder. # TODO: Once more test files are created, move the contents to a ao folder.
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False # Prevent duplicate logs if root logger also has handlers
class TestQuantizedSparseKernels(TestCase): class TestQuantizedSparseKernels(TestCase):
@ -331,4 +327,4 @@ class TestQuantizedSparseLayers(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
run_tests() raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,18 +1,11 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import logging
import torch import torch
from torch import nn from torch import nn
from torch.ao.pruning.sparsifier import utils from torch.ao.pruning.sparsifier import utils
from torch.nn.utils import parametrize from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class ModelUnderTest(nn.Module): class ModelUnderTest(nn.Module):
@ -173,3 +166,7 @@ class TestFakeSparsity(TestCase):
y = model(x) y = model(x)
y_hat = model_trace(x) y_hat = model_trace(x)
self.assertEqual(y_hat, y) self.assertEqual(y_hat, y)
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -4,7 +4,7 @@ import warnings
from torch import nn from torch import nn
from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
class ImplementedScheduler(BaseScheduler): class ImplementedScheduler(BaseScheduler):
@ -190,3 +190,7 @@ class TestCubicScheduler(TestCase):
self.sorted_sparse_levels, self.sorted_sparse_levels,
msg="Sparsity level is not reaching the target level afer delta_t * n steps ", msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
) )
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import itertools import itertools
import logging
import re import re
import torch import torch
@ -18,12 +17,7 @@ from torch.testing._internal.common_pruning import (
MockSparseLinear, MockSparseLinear,
SimpleLinear, SimpleLinear,
) )
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class TestBaseSparsifier(TestCase): class TestBaseSparsifier(TestCase):
@ -484,3 +478,7 @@ class TestNearlyDiagonalSparsifier(TestCase):
assert mask[row, col] == 1 assert mask[row, col] == 1
else: else:
assert mask[row, col] == 0 assert mask[row, col] == 0
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -18,7 +18,7 @@ from torch.testing._internal.common_quantization import (
SingleLayerLinearModel, SingleLayerLinearModel,
TwoLayerLinearModel, TwoLayerLinearModel,
) )
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
logging.basicConfig( logging.basicConfig(
@ -147,3 +147,7 @@ class TestSparsityUtilFunctions(TestCase):
self.assertEqual(arg_info["module_fqn"], "foo.bar") self.assertEqual(arg_info["module_fqn"], "foo.bar")
self.assertEqual(arg_info["tensor_name"], "baz") self.assertEqual(arg_info["tensor_name"], "baz")
self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz") self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import copy import copy
import logging
import random import random
import torch import torch
@ -29,13 +28,13 @@ from torch.testing._internal.common_pruning import (
SimpleConv2d, SimpleConv2d,
SimpleLinear, SimpleLinear,
) )
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
logging.basicConfig( TestCase,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
) )
DEVICES = { DEVICES = {
torch.device("cpu"), torch.device("cpu"),
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
@ -1089,3 +1088,7 @@ class TestFPGMPruner(TestCase):
self._test_update_mask_on_multiple_layer( self._test_update_mask_on_multiple_layer(
expected_conv1, expected_conv2, device expected_conv1, expected_conv2, device
) )
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")

View File

@ -1,4 +1,5 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import logging
# Kernels # Kernels
from ao.sparsity.test_kernels import ( # noqa: F401 from ao.sparsity.test_kernels import ( # noqa: F401
@ -56,4 +57,9 @@ from ao.sparsity.test_sparsity_utils import TestSparsityUtilFunctions # noqa: F
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
run_tests() run_tests()