mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 23:15:01 +08:00
Add __main__ guards to fx tests (#154715)
This PR is part of a series attempting to re-submit #134592 as smaller PRs. In fx tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) - Remove any remaining uses of "unittest.main()"" Pull Request resolved: https://github.com/pytorch/pytorch/pull/154715 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf9cad31df
commit
c8d44a2296
@ -9,7 +9,7 @@ from torch.fx.passes.dialect.common.cse_pass import CSEPass
|
|||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
raise_on_run_directly,
|
||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,4 +128,4 @@ class TestCommonPass(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
raise_on_run_directly("test/test_fx.py")
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch
|
|||||||
from torch.fx import symbolic_trace
|
from torch.fx import symbolic_trace
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops
|
from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
banned_ops = get_CSE_banned_ops()
|
banned_ops = get_CSE_banned_ops()
|
||||||
@ -259,4 +259,4 @@ class TestCSEPass(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
raise_on_run_directly("test/test_fx.py")
|
||||||
|
|||||||
@ -5,7 +5,11 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.testing._internal.common_utils import IS_MACOS, TestCase
|
from torch.testing._internal.common_utils import (
|
||||||
|
IS_MACOS,
|
||||||
|
raise_on_run_directly,
|
||||||
|
TestCase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestDCE(TestCase):
|
class TestDCE(TestCase):
|
||||||
@ -328,3 +332,7 @@ class TestDCE(TestCase):
|
|||||||
# collective nodes should not be removed because they have side effects.
|
# collective nodes should not be removed because they have side effects.
|
||||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_fx.py")
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.experimental._dynamism import track_dynamism_across_examples
|
from torch.fx.experimental._dynamism import track_dynamism_across_examples
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestDynamism(TestCase):
|
class TestDynamism(TestCase):
|
||||||
@ -148,4 +148,7 @@ class TestDynamism(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch
|
|||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx.experimental import const_fold
|
from torch.fx.experimental import const_fold
|
||||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestConstFold(TestCase):
|
class TestConstFold(TestCase):
|
||||||
@ -706,3 +706,7 @@ class TestConstFold(TestCase):
|
|||||||
base_result = mod(in_x, in_y)
|
base_result = mod(in_x, in_y)
|
||||||
fold_result = mod_folded(in_x, in_y)
|
fold_result = mod_folded(in_x, in_y)
|
||||||
self.assertTrue(torch.equal(fold_result, base_result))
|
self.assertTrue(torch.equal(fold_result, base_result))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_fx.py")
|
||||||
|
|||||||
@ -89,3 +89,10 @@ class TestFXNodeHook(TestCase):
|
|||||||
assert gm._create_node_hooks == []
|
assert gm._create_node_hooks == []
|
||||||
assert gm._erase_node_hooks == []
|
assert gm._erase_node_hooks == []
|
||||||
assert gm._replace_hooks == []
|
assert gm._replace_hooks == []
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -1,10 +1,8 @@
|
|||||||
# Owner(s): ["module: fx"]
|
# Owner(s): ["module: fx"]
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
class MyModuleBase(torch.nn.Module):
|
class MyModuleBase(torch.nn.Module):
|
||||||
@ -158,4 +156,4 @@ class TestConstParamShapeInControlFlow(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
raise_on_run_directly("test/test_fx.py")
|
||||||
|
|||||||
@ -223,3 +223,10 @@ class TestSplitOutputType(TestCase):
|
|||||||
|
|
||||||
self.assertTrue(type(gm_output) == type(split_gm_output))
|
self.assertTrue(type(gm_output) == type(split_gm_output))
|
||||||
self.assertTrue(torch.equal(gm_output, split_gm_output))
|
self.assertTrue(torch.equal(gm_output, split_gm_output))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -167,3 +167,10 @@ class TestFXNodeSource(TestCase):
|
|||||||
"Interpreter_FlattenInputOutputSignature",
|
"Interpreter_FlattenInputOutputSignature",
|
||||||
CREATE_STR,
|
CREATE_STR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -11,14 +11,6 @@ from torch.fx.traceback import NodeSourceAction
|
|||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import TestCase
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
raise RuntimeError(
|
|
||||||
"This test file is not meant to be run directly, use:\n\n"
|
|
||||||
"\tpython test/test_fx.py TESTNAME\n\n"
|
|
||||||
"instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphTransformObserver(TestCase):
|
class TestGraphTransformObserver(TestCase):
|
||||||
def test_graph_transform_observer(self):
|
def test_graph_transform_observer(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
@ -186,3 +178,10 @@ class TestGraphTransformObserver(TestCase):
|
|||||||
self.assertEqual(len(gm2._create_node_hooks), 0)
|
self.assertEqual(len(gm2._create_node_hooks), 0)
|
||||||
self.assertEqual(len(gm2._erase_node_hooks), 0)
|
self.assertEqual(len(gm2._erase_node_hooks), 0)
|
||||||
self.assertEqual(len(gm2._deepcopy_hooks), 0)
|
self.assertEqual(len(gm2._deepcopy_hooks), 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from torch.fx.experimental.rewriter import RewritingTracer
|
|||||||
from torch.fx.experimental.unify_refinements import infer_symbolic_types
|
from torch.fx.experimental.unify_refinements import infer_symbolic_types
|
||||||
from torch.fx.passes.shape_prop import ShapeProp
|
from torch.fx.passes.shape_prop import ShapeProp
|
||||||
from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType
|
from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1168,4 +1168,4 @@ class TypeCheckerTest(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
raise_on_run_directly("test/test_fx.py")
|
||||||
|
|||||||
@ -14,8 +14,7 @@ import torch
|
|||||||
import torch.library
|
import torch.library
|
||||||
from torch._dynamo.testing import make_test_cls_with_patches
|
from torch._dynamo.testing import make_test_cls_with_patches
|
||||||
from torch._inductor.test_case import TestCase
|
from torch._inductor.test_case import TestCase
|
||||||
from torch.testing._internal.common_utils import TEST_WITH_ASAN
|
from torch.testing._internal.inductor_utils import HAS_CPU
|
||||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU
|
|
||||||
|
|
||||||
|
|
||||||
# Make the helper files in test/ importable
|
# Make the helper files in test/ importable
|
||||||
@ -93,8 +92,7 @@ class TestGraphPickler(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._inductor.test_case import run_tests
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
# Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068
|
"enabled in discover_tests.py if required."
|
||||||
if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN:
|
)
|
||||||
run_tests(needs="filelock")
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from torch.fx._lazy_graph_module import (
|
|||||||
)
|
)
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.package import PackageExporter, PackageImporter
|
from torch.package import PackageExporter, PackageImporter
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestLazyGraphModule(TestCase):
|
class TestLazyGraphModule(TestCase):
|
||||||
@ -276,4 +276,7 @@ class TestLazyGraphModule(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
|
|||||||
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
|
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
|
||||||
SubgraphMatcherWithNameNodeMap,
|
SubgraphMatcherWithNameNodeMap,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
|
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
|
||||||
|
|
||||||
@ -269,4 +269,7 @@ class TestMatcher(JitTestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -100,3 +100,10 @@ class TestNetMinBaseBlock(TestCase):
|
|||||||
|
|
||||||
def test_continugous_partial_discrepancy_beginning(self) -> None:
|
def test_continugous_partial_discrepancy_beginning(self) -> None:
|
||||||
self.assert_problematic_nodes(["linear", "linear2"])
|
self.assert_problematic_nodes(["linear", "linear2"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# Owner(s): ["module: fx"]
|
# Owner(s): ["module: fx"]
|
||||||
|
|
||||||
import unittest
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -49,4 +48,7 @@ class TestPartitionerOrder(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from torch.fx.passes.infra.pass_manager import (
|
|||||||
PassManager,
|
PassManager,
|
||||||
this_before_that_pass_constraint,
|
this_before_that_pass_constraint,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||||
|
|
||||||
|
|
||||||
# Pass that uses PassBase and returns a PassResult (best scenario)
|
# Pass that uses PassBase and returns a PassResult (best scenario)
|
||||||
@ -228,3 +228,7 @@ class TestPassManager(TestCase):
|
|||||||
error_msg = "pass_fail.*ReplaceAddWithMulPass.*replace_mul_with_div_pass.*ReplaceDivWithSubPass.*replace_sub_with_add_pass"
|
error_msg = "pass_fail.*ReplaceAddWithMulPass.*replace_mul_with_div_pass.*ReplaceDivWithSubPass.*replace_sub_with_add_pass"
|
||||||
with self.assertRaisesRegex(Exception, error_msg):
|
with self.assertRaisesRegex(Exception, error_msg):
|
||||||
pm(traced_m)
|
pm(traced_m)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_fx.py")
|
||||||
|
|||||||
@ -108,3 +108,10 @@ class TestShapeInference(unittest.TestCase):
|
|||||||
gm = generate_graph_module(m)
|
gm = generate_graph_module(m)
|
||||||
input_tensors = [torch.randn(1, 1)]
|
input_tensors = [torch.randn(1, 1)]
|
||||||
infer_shape(gm, input_tensors)
|
infer_shape(gm, input_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise RuntimeError(
|
||||||
|
"This test is not currently used and should be "
|
||||||
|
"enabled in discover_tests.py if required."
|
||||||
|
)
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from torch.fx.passes.utils.source_matcher_utils import (
|
|||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
|
raise_on_run_directly,
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
@ -481,3 +482,6 @@ class TestSourceMatcher(JitTestCase):
|
|||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestSourceMatcher)
|
instantiate_parametrized_tests(TestSourceMatcher)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise_on_run_directly("test/test_fx.py")
|
||||||
|
|||||||
@ -902,6 +902,11 @@ def prof_callable(callable, *args, **kwargs):
|
|||||||
|
|
||||||
return callable(*args, **kwargs)
|
return callable(*args, **kwargs)
|
||||||
|
|
||||||
|
def raise_on_run_directly(file_to_call):
|
||||||
|
raise RuntimeError("This test file is not meant to be run directly, "
|
||||||
|
f"use:\n\n\tpython {file_to_call} TESTNAME\n\n"
|
||||||
|
"instead.")
|
||||||
|
|
||||||
def prof_func_call(*args, **kwargs):
|
def prof_func_call(*args, **kwargs):
|
||||||
return prof_callable(func_call, *args, **kwargs)
|
return prof_callable(func_call, *args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user