Files
pytorch/caffe2/python/test_util.py
Will Feng 4c06f1f2bb CircleCI: enable all flaky tests (#13356)
Summary:
A few Caffe2 tests are currently disabled in `py2-gcc4.8-ubuntu14.04` test job because they are known to be flaky. https://github.com/pytorch/pytorch/pull/13055 likely had fixed the flakiness, and this PR tests it.

Fixes https://github.com/pytorch/pytorch/issues/12395.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13356

Differential Revision: D12858206

Pulled By: yf225

fbshipit-source-id: 491c9c4a5c48ac1b791fdc9d78acf66091e80457
2018-10-31 09:34:49 -07:00

62 lines
1.5 KiB
Python

## @package test_util
# Module caffe2.python.test_util
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
from caffe2.python import core, workspace
import unittest
import os
def rand_array(*dims):
# np.random.rand() returns float instead of 0-dim array, that's why need to
# do some tricks
return np.array(np.random.rand(*dims) - 0.5).astype(np.float32)
def randBlob(name, type, *dims, **kwargs):
offset = kwargs['offset'] if 'offset' in kwargs else 0.0
workspace.FeedBlob(name, np.random.rand(*dims).astype(type) + offset)
def randBlobFloat32(name, *dims, **kwargs):
randBlob(name, np.float32, *dims, **kwargs)
def randBlobsFloat32(names, *dims, **kwargs):
for name in names:
randBlobFloat32(name, *dims, **kwargs)
def numOps(net):
return len(net.Proto().op)
def str_compare(a, b, encoding="utf8"):
if isinstance(a, bytes):
a = a.decode(encoding)
if isinstance(b, bytes):
b = b.decode(encoding)
return a == b
class TestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
workspace.GlobalInit([
'caffe2',
'--caffe2_log_level=0',
])
# clear the default engines settings to separate out its
# affect from the ops tests
core.SetEnginePref({}, {})
def setUp(self):
self.ws = workspace.C.Workspace()
workspace.ResetWorkspace()
def tearDown(self):
workspace.ResetWorkspace()