mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add __main__ guards to ao tests (#154612)
This is the first PR of a series in an attempt to get the content of #134592 merged as smaller PRs (Given that the original one was closed due to a lack of reviewers). This specific PR contains: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Update ao tests. There will be follow up PRs to update the other test suites but I don't have permissions to create branches directly on pytorch/pytorch so I can't create a stack and therefore will have to create them one at the time. Cc @jerryzh168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154612 Approved by: https://github.com/jcaip
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b677560e6
commit
b1b8e57cda
@ -1,7 +1,6 @@
|
|||||||
# Owner(s): ["module: unknown"]
|
# Owner(s): ["module: unknown"]
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -10,11 +9,10 @@ from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier
|
|||||||
ActivationSparsifier,
|
ActivationSparsifier,
|
||||||
)
|
)
|
||||||
from torch.ao.pruning.sparsifier.utils import module_to_fqn
|
from torch.ao.pruning.sparsifier.utils import module_to_fqn
|
||||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
|
from torch.testing._internal.common_utils import (
|
||||||
|
raise_on_run_directly,
|
||||||
|
skipIfTorchDynamo,
|
||||||
logging.basicConfig(
|
TestCase,
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -405,3 +403,7 @@ class TestActivationSparsifier(TestCase):
|
|||||||
|
|
||||||
# check state_dict() after squash_mask()
|
# check state_dict() after squash_mask()
|
||||||
self._check_state_dict(activation_sparsifier)
|
self._check_state_dict(activation_sparsifier)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
# Owner(s): ["module: unknown"]
|
# Owner(s): ["module: unknown"]
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.ao.quantization as tq
|
import torch.ao.quantization as tq
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -15,13 +13,13 @@ from torch.ao.quantization.quantize_fx import (
|
|||||||
prepare_qat_fx,
|
prepare_qat_fx,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||||
from torch.testing._internal.common_utils import TestCase, xfailIfS390X
|
from torch.testing._internal.common_utils import (
|
||||||
|
raise_on_run_directly,
|
||||||
|
TestCase,
|
||||||
logging.basicConfig(
|
xfailIfS390X,
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
sparse_defaults = {
|
sparse_defaults = {
|
||||||
"sparsity_level": 0.8,
|
"sparsity_level": 0.8,
|
||||||
"sparse_block_shape": (1, 4),
|
"sparse_block_shape": (1, 4),
|
||||||
@ -642,3 +640,7 @@ class TestFxComposability(TestCase):
|
|||||||
sparsity_level, sparse_config[0]["sparsity_level"]
|
sparsity_level, sparse_config[0]["sparsity_level"]
|
||||||
)
|
)
|
||||||
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
|
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -1,19 +1,13 @@
|
|||||||
# Owner(s): ["module: unknown"]
|
# Owner(s): ["module: unknown"]
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
|
from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
|
||||||
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
|
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImplementedDataScheduler(BaseDataScheduler):
|
class ImplementedDataScheduler(BaseDataScheduler):
|
||||||
@ -180,3 +174,7 @@ class TestBaseDataScheduler(TestCase):
|
|||||||
name, _, _ = self._get_name_data_config(some_data, defaults)
|
name, _, _ = self._get_name_data_config(some_data, defaults)
|
||||||
assert scheduler1.base_param[name] == scheduler2.base_param[name]
|
assert scheduler1.base_param[name] == scheduler2.base_param[name]
|
||||||
assert scheduler1._last_param[name] == scheduler2._last_param[name]
|
assert scheduler1._last_param[name] == scheduler2._last_param[name]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -15,12 +14,7 @@ from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import (
|
|||||||
post_training_sparse_quantize,
|
post_training_sparse_quantize,
|
||||||
)
|
)
|
||||||
from torch.nn.utils.parametrize import is_parametrized
|
from torch.nn.utils.parametrize import is_parametrized
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImplementedSparsifier(BaseDataSparsifier):
|
class ImplementedSparsifier(BaseDataSparsifier):
|
||||||
@ -792,3 +786,7 @@ class TestQuantizationUtils(TestCase):
|
|||||||
assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway
|
assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway
|
||||||
assert abs(sl_emb_seq_0 - 0.80) <= 0.05 # +- 5% leeway
|
assert abs(sl_emb_seq_0 - 0.80) <= 0.05 # +- 5% leeway
|
||||||
assert abs(sl_emb_seq_1 - 0.80) <= 0.05 # +- 5% leeway
|
assert abs(sl_emb_seq_1 - 0.80) <= 0.05 # +- 5% leeway
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -19,20 +19,16 @@ from torch.testing._internal.common_quantized import (
|
|||||||
qengine_is_qnnpack,
|
qengine_is_qnnpack,
|
||||||
qengine_is_x86,
|
qengine_is_x86,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
from torch.testing._internal.common_utils import (
|
||||||
|
raise_on_run_directly,
|
||||||
|
skipIfTorchDynamo,
|
||||||
|
TestCase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Once more test files are created, move the contents to a ao folder.
|
# TODO: Once more test files are created, move the contents to a ao folder.
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
logger.addHandler(handler)
|
|
||||||
logger.propagate = False # Prevent duplicate logs if root logger also has handlers
|
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizedSparseKernels(TestCase):
|
class TestQuantizedSparseKernels(TestCase):
|
||||||
@ -331,4 +327,4 @@ class TestQuantizedSparseLayers(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -1,18 +1,11 @@
|
|||||||
# Owner(s): ["module: unknown"]
|
# Owner(s): ["module: unknown"]
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.ao.pruning.sparsifier import utils
|
from torch.ao.pruning.sparsifier import utils
|
||||||
from torch.nn.utils import parametrize
|
from torch.nn.utils import parametrize
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelUnderTest(nn.Module):
|
class ModelUnderTest(nn.Module):
|
||||||
@ -173,3 +166,7 @@ class TestFakeSparsity(TestCase):
|
|||||||
y = model(x)
|
y = model(x)
|
||||||
y_hat = model_trace(x)
|
y_hat = model_trace(x)
|
||||||
self.assertEqual(y_hat, y)
|
self.assertEqual(y_hat, y)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -4,7 +4,7 @@ import warnings
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
|
from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
class ImplementedScheduler(BaseScheduler):
|
class ImplementedScheduler(BaseScheduler):
|
||||||
@ -190,3 +190,7 @@ class TestCubicScheduler(TestCase):
|
|||||||
self.sorted_sparse_levels,
|
self.sorted_sparse_levels,
|
||||||
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
|
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Owner(s): ["module: unknown"]
|
# Owner(s): ["module: unknown"]
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -18,12 +17,7 @@ from torch.testing._internal.common_pruning import (
|
|||||||
MockSparseLinear,
|
MockSparseLinear,
|
||||||
SimpleLinear,
|
SimpleLinear,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBaseSparsifier(TestCase):
|
class TestBaseSparsifier(TestCase):
|
||||||
@ -484,3 +478,7 @@ class TestNearlyDiagonalSparsifier(TestCase):
|
|||||||
assert mask[row, col] == 1
|
assert mask[row, col] == 1
|
||||||
else:
|
else:
|
||||||
assert mask[row, col] == 0
|
assert mask[row, col] == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -18,7 +18,7 @@ from torch.testing._internal.common_quantization import (
|
|||||||
SingleLayerLinearModel,
|
SingleLayerLinearModel,
|
||||||
TwoLayerLinearModel,
|
TwoLayerLinearModel,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -147,3 +147,7 @@ class TestSparsityUtilFunctions(TestCase):
|
|||||||
self.assertEqual(arg_info["module_fqn"], "foo.bar")
|
self.assertEqual(arg_info["module_fqn"], "foo.bar")
|
||||||
self.assertEqual(arg_info["tensor_name"], "baz")
|
self.assertEqual(arg_info["tensor_name"], "baz")
|
||||||
self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")
|
self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Owner(s): ["module: unknown"]
|
# Owner(s): ["module: unknown"]
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -29,13 +28,13 @@ from torch.testing._internal.common_pruning import (
|
|||||||
SimpleConv2d,
|
SimpleConv2d,
|
||||||
SimpleLinear,
|
SimpleLinear,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
|
from torch.testing._internal.common_utils import (
|
||||||
|
raise_on_run_directly,
|
||||||
|
skipIfTorchDynamo,
|
||||||
logging.basicConfig(
|
TestCase,
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
DEVICES = {
|
DEVICES = {
|
||||||
torch.device("cpu"),
|
torch.device("cpu"),
|
||||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
|
||||||
@ -1089,3 +1088,7 @@ class TestFPGMPruner(TestCase):
|
|||||||
self._test_update_mask_on_multiple_layer(
|
self._test_update_mask_on_multiple_layer(
|
||||||
expected_conv1, expected_conv2, device
|
expected_conv1, expected_conv2, device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_ao_sparsity.py")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Owner(s): ["module: unknown"]
|
# Owner(s): ["module: unknown"]
|
||||||
|
import logging
|
||||||
|
|
||||||
# Kernels
|
# Kernels
|
||||||
from ao.sparsity.test_kernels import ( # noqa: F401
|
from ao.sparsity.test_kernels import ( # noqa: F401
|
||||||
@ -56,4 +57,9 @@ from ao.sparsity.test_sparsity_utils import TestSparsityUtilFunctions # noqa: F
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
Reference in New Issue
Block a user