mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add __main__ guards to ao tests (#154612)
This is the first PR of a series in an attempt to get the content of #134592 merged as smaller PRs (Given that the original one was closed due to a lack of reviewers). This specific PR contains: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Update ao tests. There will be follow up PRs to update the other test suites but I don't have permissions to create branches directly on pytorch/pytorch so I can't create a stack and therefore will have to create them one at the time. Cc @jerryzh168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154612 Approved by: https://github.com/jcaip
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b677560e6
commit
b1b8e57cda
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,11 +9,10 @@ from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier
|
||||
ActivationSparsifier,
|
||||
)
|
||||
from torch.ao.pruning.sparsifier.utils import module_to_fqn
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
@ -405,3 +403,7 @@ class TestActivationSparsifier(TestCase):
|
||||
|
||||
# check state_dict() after squash_mask()
|
||||
self._check_state_dict(activation_sparsifier)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -1,8 +1,6 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.ao.quantization as tq
|
||||
from torch import nn
|
||||
@ -15,13 +13,13 @@ from torch.ao.quantization.quantize_fx import (
|
||||
prepare_qat_fx,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||
from torch.testing._internal.common_utils import TestCase, xfailIfS390X
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
TestCase,
|
||||
xfailIfS390X,
|
||||
)
|
||||
|
||||
|
||||
sparse_defaults = {
|
||||
"sparsity_level": 0.8,
|
||||
"sparse_block_shape": (1, 4),
|
||||
@ -642,3 +640,7 @@ class TestFxComposability(TestCase):
|
||||
sparsity_level, sparse_config[0]["sparsity_level"]
|
||||
)
|
||||
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -1,19 +1,13 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
|
||||
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||
|
||||
|
||||
class ImplementedDataScheduler(BaseDataScheduler):
|
||||
@ -180,3 +174,7 @@ class TestBaseDataScheduler(TestCase):
|
||||
name, _, _ = self._get_name_data_config(some_data, defaults)
|
||||
assert scheduler1.base_param[name] == scheduler2.base_param[name]
|
||||
assert scheduler1._last_param[name] == scheduler2._last_param[name]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
@ -15,12 +14,7 @@ from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import (
|
||||
post_training_sparse_quantize,
|
||||
)
|
||||
from torch.nn.utils.parametrize import is_parametrized
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||
|
||||
|
||||
class ImplementedSparsifier(BaseDataSparsifier):
|
||||
@ -792,3 +786,7 @@ class TestQuantizationUtils(TestCase):
|
||||
assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway
|
||||
assert abs(sl_emb_seq_0 - 0.80) <= 0.05 # +- 5% leeway
|
||||
assert abs(sl_emb_seq_1 - 0.80) <= 0.05 # +- 5% leeway
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -19,20 +19,16 @@ from torch.testing._internal.common_quantized import (
|
||||
qengine_is_qnnpack,
|
||||
qengine_is_x86,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Once more test files are created, move the contents to a ao folder.
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False # Prevent duplicate logs if root logger also has handlers
|
||||
|
||||
|
||||
class TestQuantizedSparseKernels(TestCase):
|
||||
@ -331,4 +327,4 @@ class TestQuantizedSparseLayers(TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -1,18 +1,11 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.ao.pruning.sparsifier import utils
|
||||
from torch.nn.utils import parametrize
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||
|
||||
|
||||
class ModelUnderTest(nn.Module):
|
||||
@ -173,3 +166,7 @@ class TestFakeSparsity(TestCase):
|
||||
y = model(x)
|
||||
y_hat = model_trace(x)
|
||||
self.assertEqual(y_hat, y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -4,7 +4,7 @@ import warnings
|
||||
|
||||
from torch import nn
|
||||
from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||
|
||||
|
||||
class ImplementedScheduler(BaseScheduler):
|
||||
@ -190,3 +190,7 @@ class TestCubicScheduler(TestCase):
|
||||
self.sorted_sparse_levels,
|
||||
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
|
||||
import torch
|
||||
@ -18,12 +17,7 @@ from torch.testing._internal.common_pruning import (
|
||||
MockSparseLinear,
|
||||
SimpleLinear,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||
|
||||
|
||||
class TestBaseSparsifier(TestCase):
|
||||
@ -484,3 +478,7 @@ class TestNearlyDiagonalSparsifier(TestCase):
|
||||
assert mask[row, col] == 1
|
||||
else:
|
||||
assert mask[row, col] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -18,7 +18,7 @@ from torch.testing._internal.common_quantization import (
|
||||
SingleLayerLinearModel,
|
||||
TwoLayerLinearModel,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@ -147,3 +147,7 @@ class TestSparsityUtilFunctions(TestCase):
|
||||
self.assertEqual(arg_info["module_fqn"], "foo.bar")
|
||||
self.assertEqual(arg_info["tensor_name"], "baz")
|
||||
self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
@ -29,13 +28,13 @@ from torch.testing._internal.common_pruning import (
|
||||
SimpleConv2d,
|
||||
SimpleLinear,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
DEVICES = {
|
||||
torch.device("cpu"),
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
|
||||
@ -1089,3 +1088,7 @@ class TestFPGMPruner(TestCase):
|
||||
self._test_update_mask_on_multiple_layer(
|
||||
expected_conv1, expected_conv2, device
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
import logging
|
||||
|
||||
# Kernels
|
||||
from ao.sparsity.test_kernels import ( # noqa: F401
|
||||
@ -56,4 +57,9 @@ from ao.sparsity.test_sparsity_utils import TestSparsityUtilFunctions # noqa: F
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user