mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is part of a series attempting to re-submit #134592 as smaller PRs. In fx tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) - Remove any remaining uses of "unittest.main()"" Pull Request resolved: https://github.com/pytorch/pytorch/pull/154715 Approved by: https://github.com/Skylion007
99 lines
2.5 KiB
Python
99 lines
2.5 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
#
|
|
# Tests the graph pickler by using pickling on all the inductor tests.
|
|
#
|
|
|
|
import contextlib
|
|
import importlib
|
|
import os
|
|
import sys
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch.library
|
|
from torch._dynamo.testing import make_test_cls_with_patches
|
|
from torch._inductor.test_case import TestCase
|
|
from torch.testing._internal.inductor_utils import HAS_CPU
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
|
check_model,
|
|
CommonTemplate,
|
|
copy_tests,
|
|
TestFailure,
|
|
)
|
|
|
|
|
|
importlib.import_module("filelock")
|
|
|
|
# xfail by default, set is_skip=True to skip
|
|
test_failures = {
|
|
# TypeError: cannot pickle 'generator' object
|
|
"test_layer_norm_graph_pickler": TestFailure(("cpu"), is_skip=True),
|
|
}
|
|
|
|
|
|
def make_test_cls(cls, xfail_prop="_expected_failure_graph_pickler"):
|
|
return make_test_cls_with_patches(
|
|
cls,
|
|
"GraphPickler",
|
|
"_graph_pickler",
|
|
(
|
|
torch._inductor.compile_fx,
|
|
"fx_compile_mode",
|
|
torch._inductor.compile_fx.FxCompileMode.SERIALIZE,
|
|
),
|
|
xfail_prop=xfail_prop,
|
|
)
|
|
|
|
|
|
GraphPicklerCommonTemplate = make_test_cls(CommonTemplate)
|
|
|
|
|
|
if HAS_CPU:
|
|
|
|
class GraphPicklerCpuTests(TestCase):
|
|
common = check_model
|
|
device = "cpu"
|
|
|
|
copy_tests(GraphPicklerCommonTemplate, GraphPicklerCpuTests, "cpu", test_failures)
|
|
|
|
|
|
class TestGraphPickler(TestCase):
|
|
def setUp(self):
|
|
torch._dynamo.reset()
|
|
TestCase.setUp(self)
|
|
|
|
self._stack = contextlib.ExitStack()
|
|
self._stack.enter_context(
|
|
patch(
|
|
"torch._inductor.compile_fx.fx_compile_mode",
|
|
torch._inductor.compile_fx.FxCompileMode.SERIALIZE,
|
|
)
|
|
)
|
|
|
|
def tearDown(self):
|
|
self._stack.close()
|
|
TestCase.tearDown(self)
|
|
torch._dynamo.reset()
|
|
|
|
def test_simple(self):
|
|
# Make sure that compiling works when we pass the input + output from
|
|
# fx_codegen_and_compile() through serde.
|
|
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
check_model(self, fn, (torch.tensor([False, True]), torch.tensor([True, True])))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test is not currently used and should be "
|
|
"enabled in discover_tests.py if required."
|
|
)
|