mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 23:04:52 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51762 Update test_util.py to add a `make_tempdir()` function to the `TestCase` class. The main advantage of this function is that the temporary directory will be automatically cleaned up when the test case finishes, so that test case does not need to worry about manually cleaning up this directory. This also prefixes the directory name with `caffe2_test.` so that it is more obvious where the temporary directories came from if they are ever left behind after a crashed or killed test process. This updates the tests in `operator_test/load_save_test.py` to use this new function, so they no longer have to perform their own manual cleanup in each test. Test Plan: python caffe2/python/operator_test/load_save_test.py Reviewed By: mraway Differential Revision: D26271178 Pulled By: simpkins fbshipit-source-id: 51175eefed39d65c03484482e84923e5f39a4768
116 lines
3.4 KiB
Python
116 lines
3.4 KiB
Python
## @package test_util
|
|
# Module caffe2.python.test_util
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
from caffe2.python import core, workspace
|
|
|
|
import os
|
|
import pathlib
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
from typing import Any, Callable, Tuple, Type
|
|
from types import TracebackType
|
|
|
|
|
|
def rand_array(*dims):
|
|
# np.random.rand() returns float instead of 0-dim array, that's why need to
|
|
# do some tricks
|
|
return np.array(np.random.rand(*dims) - 0.5).astype(np.float32)
|
|
|
|
|
|
def randBlob(name, type, *dims, **kwargs):
|
|
offset = kwargs['offset'] if 'offset' in kwargs else 0.0
|
|
workspace.FeedBlob(name, np.random.rand(*dims).astype(type) + offset)
|
|
|
|
|
|
def randBlobFloat32(name, *dims, **kwargs):
|
|
randBlob(name, np.float32, *dims, **kwargs)
|
|
|
|
|
|
def randBlobsFloat32(names, *dims, **kwargs):
|
|
for name in names:
|
|
randBlobFloat32(name, *dims, **kwargs)
|
|
|
|
|
|
def numOps(net):
|
|
return len(net.Proto().op)
|
|
|
|
|
|
def str_compare(a, b, encoding="utf8"):
|
|
if isinstance(a, bytes):
|
|
a = a.decode(encoding)
|
|
if isinstance(b, bytes):
|
|
b = b.decode(encoding)
|
|
return a == b
|
|
|
|
|
|
def get_default_test_flags():
|
|
return [
|
|
'caffe2',
|
|
'--caffe2_log_level=0',
|
|
'--caffe2_cpu_allocator_do_zero_fill=0',
|
|
'--caffe2_cpu_allocator_do_junk_fill=1',
|
|
]
|
|
|
|
|
|
def caffe2_flaky(test_method):
|
|
# This decorator is used to mark a test method as flaky.
|
|
# This is used in conjunction with the environment variable
|
|
# CAFFE2_RUN_FLAKY_TESTS that specifies "flaky tests" mode
|
|
# If flaky tests mode are on, only flaky tests are run
|
|
# If flaky tests mode are off, only non-flaky tests are run
|
|
# NOTE: the decorator should be applied as the top-level decorator
|
|
# in a test method.
|
|
test_method.__caffe2_flaky__ = True
|
|
return test_method
|
|
|
|
|
|
def is_flaky_test_mode():
|
|
return os.getenv('CAFFE2_RUN_FLAKY_TESTS', '0') == '1'
|
|
|
|
|
|
class TestCase(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
workspace.GlobalInit(get_default_test_flags())
|
|
# clear the default engines settings to separate out its
|
|
# affect from the ops tests
|
|
core.SetEnginePref({}, {})
|
|
|
|
def setUp(self):
|
|
# Skip tests based on whether we're in flaky test mode and
|
|
# the test is decorated as a flaky test.
|
|
test_method = getattr(self, self._testMethodName)
|
|
is_flaky_test = getattr(test_method, "__caffe2_flaky__", False)
|
|
if (is_flaky_test_mode() and not is_flaky_test):
|
|
raise unittest.SkipTest("Non-flaky tests are skipped in flaky test mode")
|
|
elif (not is_flaky_test_mode() and is_flaky_test):
|
|
raise unittest.SkipTest("Flaky tests are skipped in regular test mode")
|
|
|
|
self.ws = workspace.C.Workspace()
|
|
workspace.ResetWorkspace()
|
|
|
|
def tearDown(self):
|
|
workspace.ResetWorkspace()
|
|
|
|
def make_tempdir(self) -> pathlib.Path:
|
|
tmp_folder = pathlib.Path(tempfile.mkdtemp(prefix="caffe2_test."))
|
|
self.addCleanup(self._remove_tempdir, tmp_folder)
|
|
return tmp_folder
|
|
|
|
def _remove_tempdir(self, path: pathlib.Path) -> None:
|
|
def _onerror(
|
|
fn: Callable[..., Any],
|
|
path: str,
|
|
exc_info: Tuple[Type[BaseException], BaseException, TracebackType],
|
|
) -> None:
|
|
# Ignore FileNotFoundError, but re-raise anything else
|
|
if not isinstance(exc_info[1], FileNotFoundError):
|
|
raise exc_info[1].with_traceback(exc_info[2])
|
|
|
|
shutil.rmtree(str(path), onerror=_onerror)
|