Files
pytorch/test/test_ao_sparsity.py
Anthony Barbier b1b8e57cda Add __main__ guards to ao tests (#154612)
This is the first PR of a series in an attempt to get the content of #134592 merged as smaller PRs (Given that the original one was closed due to a lack of reviewers).

This specific PR contains:
- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Update ao tests.

There will be follow up PRs to update the other test suites but I don't have permissions to create branches directly on pytorch/pytorch so I can't create a stack and therefore will have to create them one at the time.

Cc @jerryzh168
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154612
Approved by: https://github.com/jcaip
2025-06-10 18:33:09 +00:00

66 lines
1.6 KiB
Python

# Owner(s): ["module: unknown"]
import logging
# Kernels
from ao.sparsity.test_kernels import ( # noqa: F401
TestQuantizedSparseKernels,
TestQuantizedSparseLayers,
)
# Parametrizations
from ao.sparsity.test_parametrization import TestFakeSparsity # noqa: F401
# Scheduler
from ao.sparsity.test_scheduler import TestCubicScheduler, TestScheduler # noqa: F401
# Sparsifier
from ao.sparsity.test_sparsifier import ( # noqa: F401
TestBaseSparsifier,
TestNearlyDiagonalSparsifier,
TestWeightNormSparsifier,
)
# Structured Pruning
from ao.sparsity.test_structured_sparsifier import ( # noqa: F401
TestBaseStructuredSparsifier,
TestFPGMPruner,
TestSaliencyPruner,
)
from torch.testing._internal.common_utils import IS_ARM64, run_tests
# Composability
if not IS_ARM64:
from ao.sparsity.test_composability import ( # noqa: F401
TestComposability,
TestFxComposability,
)
# Activation Sparsifier
from ao.sparsity.test_activation_sparsifier import ( # noqa: F401
TestActivationSparsifier,
)
# Data Scheduler
from ao.sparsity.test_data_scheduler import TestBaseDataScheduler # noqa: F401
# Data Sparsifier
from ao.sparsity.test_data_sparsifier import ( # noqa: F401
TestBaseDataSparsifier,
TestNormDataSparsifiers,
TestQuantizationUtils,
)
# Utilities
from ao.sparsity.test_sparsity_utils import TestSparsityUtilFunctions # noqa: F401
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
run_tests()