mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: woooo Test Plan: arc lint --apply-patches --take BLACK --paths-cmd 'hg files -I "caffe2/**/*.py"' Reviewed By: SplitInfinity Differential Revision: D28608934 fbshipit-source-id: 7768fed50a87883a95319376c0a6d73a9492bdcc
42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
import os
|
|
import sys
|
|
from tempfile import NamedTemporaryFile
|
|
|
|
import torch.package.package_exporter
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
|
|
|
|
|
|
class PackageTestCase(TestCase):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._temporary_files = []
|
|
|
|
def temp(self):
|
|
t = NamedTemporaryFile()
|
|
name = t.name
|
|
if IS_WINDOWS:
|
|
t.close() # can't read an open file in windows
|
|
else:
|
|
self._temporary_files.append(t)
|
|
return name
|
|
|
|
def setUp(self):
|
|
"""Add test/package/ to module search path. This ensures that
|
|
importing our fake packages via, e.g. `import package_a` will always
|
|
work regardless of how we invoke the test.
|
|
"""
|
|
super().setUp()
|
|
self.package_test_dir = os.path.dirname(os.path.realpath(__file__))
|
|
self.orig_sys_path = sys.path.copy()
|
|
sys.path.append(self.package_test_dir)
|
|
torch.package.package_exporter._gate_torchscript_serialization = False
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
sys.path = self.orig_sys_path
|
|
|
|
# remove any temporary files
|
|
for t in self._temporary_files:
|
|
t.close()
|
|
self._temporary_files = []
|