Add __main__ guards to fx tests (#154715)

This PR is part of a series attempting to re-submit #134592 as smaller PRs.

In fx tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)
- Remove any remaining uses of "unittest.main()""

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154715
Approved by: https://github.com/Skylion007
This commit is contained in:
Anthony Barbier
2025-06-04 14:38:45 +00:00
committed by PyTorch MergeBot
parent cf9cad31df
commit c8d44a2296
20 changed files with 102 additions and 36 deletions

View File

@ -9,7 +9,7 @@ from torch.fx.passes.dialect.common.cse_pass import CSEPass
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,
run_tests, raise_on_run_directly,
TestCase, TestCase,
) )
@ -128,4 +128,4 @@ class TestCommonPass(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
run_tests() raise_on_run_directly("test/test_fx.py")

View File

@ -6,7 +6,7 @@ import torch
from torch.fx import symbolic_trace from torch.fx import symbolic_trace
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
banned_ops = get_CSE_banned_ops() banned_ops = get_CSE_banned_ops()
@ -259,4 +259,4 @@ class TestCSEPass(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
run_tests() raise_on_run_directly("test/test_fx.py")

View File

@ -5,7 +5,11 @@ from typing import Optional
import torch import torch
import torch.fx import torch.fx
from torch.testing._internal.common_utils import IS_MACOS, TestCase from torch.testing._internal.common_utils import (
IS_MACOS,
raise_on_run_directly,
TestCase,
)
class TestDCE(TestCase): class TestDCE(TestCase):
@ -328,3 +332,7 @@ class TestDCE(TestCase):
# collective nodes should not be removed because they have side effects. # collective nodes should not be removed because they have side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False) self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
if __name__ == "__main__":
raise_on_run_directly("test/test_fx.py")

View File

@ -2,7 +2,7 @@
import torch import torch
from torch.fx.experimental._dynamism import track_dynamism_across_examples from torch.fx.experimental._dynamism import track_dynamism_across_examples
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import TestCase
class TestDynamism(TestCase): class TestDynamism(TestCase):
@ -148,4 +148,7 @@ class TestDynamism(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
run_tests() raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -6,7 +6,7 @@ import torch
import torch.fx import torch.fx
from torch.fx.experimental import const_fold from torch.fx.experimental import const_fold
from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
class TestConstFold(TestCase): class TestConstFold(TestCase):
@ -706,3 +706,7 @@ class TestConstFold(TestCase):
base_result = mod(in_x, in_y) base_result = mod(in_x, in_y)
fold_result = mod_folded(in_x, in_y) fold_result = mod_folded(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result)) self.assertTrue(torch.equal(fold_result, base_result))
if __name__ == "__main__":
raise_on_run_directly("test/test_fx.py")

View File

@ -89,3 +89,10 @@ class TestFXNodeHook(TestCase):
assert gm._create_node_hooks == [] assert gm._create_node_hooks == []
assert gm._erase_node_hooks == [] assert gm._erase_node_hooks == []
assert gm._replace_hooks == [] assert gm._replace_hooks == []
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -1,10 +1,8 @@
# Owner(s): ["module: fx"] # Owner(s): ["module: fx"]
import unittest
import torch import torch
import torch.fx import torch.fx
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
class MyModuleBase(torch.nn.Module): class MyModuleBase(torch.nn.Module):
@ -158,4 +156,4 @@ class TestConstParamShapeInControlFlow(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() raise_on_run_directly("test/test_fx.py")

View File

@ -223,3 +223,10 @@ class TestSplitOutputType(TestCase):
self.assertTrue(type(gm_output) == type(split_gm_output)) self.assertTrue(type(gm_output) == type(split_gm_output))
self.assertTrue(torch.equal(gm_output, split_gm_output)) self.assertTrue(torch.equal(gm_output, split_gm_output))
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -167,3 +167,10 @@ class TestFXNodeSource(TestCase):
"Interpreter_FlattenInputOutputSignature", "Interpreter_FlattenInputOutputSignature",
CREATE_STR, CREATE_STR,
) )
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -11,14 +11,6 @@ from torch.fx.traceback import NodeSourceAction
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import TestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_fx.py TESTNAME\n\n"
"instead."
)
class TestGraphTransformObserver(TestCase): class TestGraphTransformObserver(TestCase):
def test_graph_transform_observer(self): def test_graph_transform_observer(self):
class M(torch.nn.Module): class M(torch.nn.Module):
@ -186,3 +178,10 @@ class TestGraphTransformObserver(TestCase):
self.assertEqual(len(gm2._create_node_hooks), 0) self.assertEqual(len(gm2._create_node_hooks), 0)
self.assertEqual(len(gm2._erase_node_hooks), 0) self.assertEqual(len(gm2._erase_node_hooks), 0)
self.assertEqual(len(gm2._deepcopy_hooks), 0) self.assertEqual(len(gm2._deepcopy_hooks), 0)
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -17,7 +17,7 @@ from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx.experimental.unify_refinements import infer_symbolic_types from torch.fx.experimental.unify_refinements import infer_symbolic_types
from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
try: try:
@ -1168,4 +1168,4 @@ class TypeCheckerTest(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() raise_on_run_directly("test/test_fx.py")

View File

@ -14,8 +14,7 @@ import torch
import torch.library import torch.library
from torch._dynamo.testing import make_test_cls_with_patches from torch._dynamo.testing import make_test_cls_with_patches
from torch._inductor.test_case import TestCase from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import TEST_WITH_ASAN from torch.testing._internal.inductor_utils import HAS_CPU
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU
# Make the helper files in test/ importable # Make the helper files in test/ importable
@ -93,8 +92,7 @@ class TestGraphPickler(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
from torch._inductor.test_case import run_tests raise RuntimeError(
"This test is not currently used and should be "
# Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 "enabled in discover_tests.py if required."
if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: )
run_tests(needs="filelock")

View File

@ -15,7 +15,7 @@ from torch.fx._lazy_graph_module import (
) )
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.package import PackageExporter, PackageImporter from torch.package import PackageExporter, PackageImporter
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import TestCase
class TestLazyGraphModule(TestCase): class TestLazyGraphModule(TestCase):
@ -276,4 +276,7 @@ class TestLazyGraphModule(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
run_tests() raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -19,7 +19,7 @@ from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
SubgraphMatcherWithNameNodeMap, SubgraphMatcherWithNameNodeMap,
) )
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests from torch.testing._internal.common_utils import IS_WINDOWS
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
@ -269,4 +269,7 @@ class TestMatcher(JitTestCase):
if __name__ == "__main__": if __name__ == "__main__":
run_tests() raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -100,3 +100,10 @@ class TestNetMinBaseBlock(TestCase):
def test_continugous_partial_discrepancy_beginning(self) -> None: def test_continugous_partial_discrepancy_beginning(self) -> None:
self.assert_problematic_nodes(["linear", "linear2"]) self.assert_problematic_nodes(["linear", "linear2"])
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: fx"] # Owner(s): ["module: fx"]
import unittest
from collections.abc import Mapping from collections.abc import Mapping
import torch import torch
@ -49,4 +48,7 @@ class TestPartitionerOrder(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -9,7 +9,7 @@ from torch.fx.passes.infra.pass_manager import (
PassManager, PassManager,
this_before_that_pass_constraint, this_before_that_pass_constraint,
) )
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
# Pass that uses PassBase and returns a PassResult (best scenario) # Pass that uses PassBase and returns a PassResult (best scenario)
@ -228,3 +228,7 @@ class TestPassManager(TestCase):
error_msg = "pass_fail.*ReplaceAddWithMulPass.*replace_mul_with_div_pass.*ReplaceDivWithSubPass.*replace_sub_with_add_pass" error_msg = "pass_fail.*ReplaceAddWithMulPass.*replace_mul_with_div_pass.*ReplaceDivWithSubPass.*replace_sub_with_add_pass"
with self.assertRaisesRegex(Exception, error_msg): with self.assertRaisesRegex(Exception, error_msg):
pm(traced_m) pm(traced_m)
if __name__ == "__main__":
raise_on_run_directly("test/test_fx.py")

View File

@ -108,3 +108,10 @@ class TestShapeInference(unittest.TestCase):
gm = generate_graph_module(m) gm = generate_graph_module(m)
input_tensors = [torch.randn(1, 1)] input_tensors = [torch.randn(1, 1)]
infer_shape(gm, input_tensors) infer_shape(gm, input_tensors)
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -18,6 +18,7 @@ from torch.fx.passes.utils.source_matcher_utils import (
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,
raise_on_run_directly,
skipIfTorchDynamo, skipIfTorchDynamo,
) )
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
@ -481,3 +482,6 @@ class TestSourceMatcher(JitTestCase):
instantiate_parametrized_tests(TestSourceMatcher) instantiate_parametrized_tests(TestSourceMatcher)
if __name__ == "__main__":
raise_on_run_directly("test/test_fx.py")

View File

@ -902,6 +902,11 @@ def prof_callable(callable, *args, **kwargs):
return callable(*args, **kwargs) return callable(*args, **kwargs)
def raise_on_run_directly(file_to_call):
raise RuntimeError("This test file is not meant to be run directly, "
f"use:\n\n\tpython {file_to_call} TESTNAME\n\n"
"instead.")
def prof_func_call(*args, **kwargs): def prof_func_call(*args, **kwargs):
return prof_callable(func_call, *args, **kwargs) return prof_callable(func_call, *args, **kwargs)