mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
(1) blob debugstring
(2) cnn bugfix and enhancement (3) device checker fix (4) suppress prints from caffe2_python (5) moar tests!
This commit is contained in:
34
brewery.py
34
brewery.py
@ -14,6 +14,9 @@ import traceback
|
||||
|
||||
from build_env import Env
|
||||
|
||||
# global variables
|
||||
CAFFE2_RUN_TEST = False
|
||||
|
||||
class Colors(object):
|
||||
HEADER = '\033[95m'
|
||||
OKBLUE = '\033[94m'
|
||||
@ -344,11 +347,13 @@ class BuildTarget(object):
|
||||
Brewery.Register(self.name, self)
|
||||
|
||||
def GetSignature(self):
|
||||
"""Generate the signature of the build object."""
|
||||
"""Generate the signature of the build object, and see if we need to
|
||||
rebuild it."""
|
||||
src_digest = ''.join([hashlib.sha256(open(f, 'rb').read()).hexdigest()
|
||||
for f in self.files])
|
||||
dep_digest = ''.join([Brewery.Signature(d) for d in self.deps])
|
||||
return hashlib.sha256(src_digest + dep_digest).hexdigest()
|
||||
command_digest = str(self.command_groups)
|
||||
return hashlib.sha256(src_digest + dep_digest + command_digest).hexdigest()
|
||||
|
||||
def SetUpAndBuild(self, built_signature):
|
||||
# Add successful optional dependencies into deps.
|
||||
@ -482,7 +487,7 @@ class cc_target(BuildTarget):
|
||||
self.command_groups = [cpp_commands, link_commands, link_shared_commands]
|
||||
else:
|
||||
self.command_groups = [cpp_commands, link_commands]
|
||||
if self.is_test:
|
||||
if self.is_test and CAFFE2_RUN_TEST:
|
||||
# Add test command
|
||||
self.command_groups.append([
|
||||
' '.join([self.OutputName(), '--caffe_test_root',
|
||||
@ -516,11 +521,12 @@ class mpi_test(cc_target):
|
||||
|
||||
def SetUp(self):
|
||||
cc_target.SetUp(self)
|
||||
self.command_groups.append([
|
||||
' '.join(['mpirun --allow-run-as-root -n',
|
||||
str(self.mpi_size), self.OutputName(),
|
||||
'--caffe_test_root', os.path.abspath(Env.GENDIR),
|
||||
'--gtest_filter=-*.LARGE_*'])])
|
||||
if CAFFE2_RUN_TEST:
|
||||
self.command_groups.append([
|
||||
' '.join(['mpirun --allow-run-as-root -n',
|
||||
str(self.mpi_size), self.OutputName(),
|
||||
'--caffe_test_root', os.path.abspath(Env.GENDIR),
|
||||
'--gtest_filter=-*.LARGE_*'])])
|
||||
|
||||
|
||||
class cuda_library(BuildTarget):
|
||||
@ -585,9 +591,10 @@ class py_test(BuildTarget):
|
||||
CopyToGenDir(self.srcs)
|
||||
if len(self.srcs) > 1:
|
||||
raise RuntimeError('py_test should only take one python source file.')
|
||||
# Add test command
|
||||
self.command_groups = [
|
||||
['python %s' % GenFilename(self.srcs[0])]]
|
||||
if CAFFE2_RUN_TEST:
|
||||
# Add test command
|
||||
self.command_groups = [
|
||||
['python %s' % GenFilename(self.srcs[0])]]
|
||||
|
||||
|
||||
class cc_thirdparty_target(BuildTarget):
|
||||
@ -685,6 +692,11 @@ def main(argv):
|
||||
# Build all targets.
|
||||
targets = sys.argv[2:]
|
||||
Brewery.Build(targets)
|
||||
elif sys.argv[1] == 'test':
|
||||
global CAFFE2_RUN_TEST
|
||||
CAFFE2_RUN_TEST = True
|
||||
targets = sys.argv[2:]
|
||||
Brewery.Build(targets)
|
||||
elif sys.argv[1] == 'draw':
|
||||
# Draws the dependency graph.
|
||||
Brewery.Draw()
|
||||
|
@ -2,6 +2,8 @@
|
||||
#define CAFFE2_CORE_BLOB_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <sstream>
|
||||
#include <typeinfo>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
@ -143,6 +145,20 @@ class Tensor {
|
||||
|
||||
virtual ~Tensor() {}
|
||||
|
||||
// A utility function to print the debug string for the tensor. Note that this
|
||||
// is very slow since it involves quite some string operations, so do not use
|
||||
// it in your performance-critical code.
|
||||
string DebugString() const {
|
||||
std::stringstream ss;
|
||||
ss << "A Tensor of data type " << typeid(dtype).name()
|
||||
<< " and dimension (";
|
||||
for (int d : dims_) {
|
||||
ss << d << ",";
|
||||
}
|
||||
ss << ").";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void Reshape(const vector<int>& dims) {
|
||||
dims_ = dims;
|
||||
ndim_ = dims_.size();
|
||||
|
@ -91,3 +91,13 @@ py_test(
|
||||
":pycaffe2",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "gradient_check_test",
|
||||
srcs = [
|
||||
"gradient_check_test.py",
|
||||
],
|
||||
deps = [
|
||||
":pycaffe2",
|
||||
],
|
||||
)
|
||||
|
@ -190,7 +190,7 @@ PyObject* ResetWorkspace(PyObject* self, PyObject* args) {
|
||||
"specifying the root folder of the workspace.");
|
||||
return NULL;
|
||||
}
|
||||
LOG(INFO) << "Resetting workspace.";
|
||||
VLOG(1) << "Resetting workspace.";
|
||||
if (root_folder == nullptr) {
|
||||
gWorkspaces[gCurrentWorkspaceName].reset(
|
||||
new Workspace());
|
||||
|
@ -54,7 +54,9 @@ class CNNModelHelper(object):
|
||||
conv_blobs.append(
|
||||
splitted_blobs[i].Conv([weight, bias], blob_out + '_gconv_%d' % i,
|
||||
kernel=kernel, order=self.order, **kwargs))
|
||||
concat = self.net.DepthConcat(conv_blobs, blob_out, order=self.order)
|
||||
concat, concat_dims = self.net.DepthConcat(
|
||||
conv_blobs, [blob_out, "_" + blob_out + "_concat_dims"],
|
||||
order=self.order)
|
||||
return concat
|
||||
|
||||
def FC(self, blob_in, blob_out, dim_in, dim_out, weight_init, bias_init,
|
||||
@ -72,9 +74,18 @@ class CNNModelHelper(object):
|
||||
return self.net.LRN(blob_in, [blob_out, "_" + blob_out + "_scale"],
|
||||
order=self.order, **kwargs)[0]
|
||||
|
||||
def Dropout(self, blob_in, blob_out, **kwargs):
|
||||
"""Dropout"""
|
||||
return self.net.Dropout(blob_in, [blob_out, "_" + blob_out + "_mask"],
|
||||
**kwargs)[0]
|
||||
|
||||
def MaxPool(self, blob_in, blob_out, **kwargs):
|
||||
"""Max pooling"""
|
||||
return self.net.MaxPool(blob_in, blob_out, order=self.order, **kwargs)
|
||||
return self.net.MaxPool(blob_in, [blob_out, "_" + blob_out + "_idx"],
|
||||
order=self.order, **kwargs)[0]
|
||||
|
||||
def AddGradientOperators(self):
|
||||
return self.net.AddGradientOperators()
|
||||
|
||||
def __getattr__(self, operator_type):
|
||||
"""Catch-all for all other operators, mostly those without params."""
|
||||
|
@ -2,7 +2,7 @@ import numpy as np
|
||||
from pycaffe2 import core, workspace
|
||||
|
||||
class DeviceChecker(object):
|
||||
"""A gradient checker in Python.
|
||||
"""A device checker in Python to check consistency across multiple devices.
|
||||
|
||||
This is not the most efficient way to check gradients, as the Python interface
|
||||
will involve a lot of copy back and forth operations. Use at your own risk.
|
||||
@ -12,19 +12,12 @@ class DeviceChecker(object):
|
||||
self._device_options = device_options
|
||||
|
||||
def CheckSimple(self, op, inputs, outputs_to_check):
|
||||
"""Checks the operator in a very simple fashion by stacking a sum of squares
|
||||
on the top.
|
||||
"""Checks the operator with different device implementations.
|
||||
|
||||
Inputs:
|
||||
op: the operator to be checked.
|
||||
inputs: the input data in numpy arrays.
|
||||
input_to_check: an index specifying which input blob we should
|
||||
check.
|
||||
outputs_with_grads: indices specifying which output blobs will we
|
||||
need to check gradients with. For these outputs, we will collect a
|
||||
squared sum and also feed in their gradients.
|
||||
grad_operator: the gradient operator. If not given, we will get the
|
||||
gradient operator from the gradient registry.
|
||||
outputs_to_check: the outputs to check between devices.
|
||||
Outputs:
|
||||
boolean: True if it passes, False if it does not pass.
|
||||
"""
|
||||
@ -53,23 +46,27 @@ class DeviceChecker(object):
|
||||
print x.flatten()
|
||||
print y.flatten()
|
||||
success = False
|
||||
continue
|
||||
#else:
|
||||
# print ('Passed device pair (0, %d), %s %s' %
|
||||
# (i, outputs_to_check[j], y.shape))
|
||||
workspace.SwitchWorkspace(old_ws_name)
|
||||
return success
|
||||
|
||||
def CheckNet(self, net, inputs={}, ignore=set()):
|
||||
def CheckNet(self, net, inputs={}, blobs_to_check=None, ignore=set()):
|
||||
"""Checks a network by inspecting all of its intermediate results, and see
|
||||
if things match.
|
||||
"""
|
||||
old_ws_name = workspace.CurrentWorkspace()
|
||||
results = []
|
||||
blobs_to_check = sum([list(op.output) for op in net.operators], [])
|
||||
if blobs_to_check is None:
|
||||
blobs_to_check = sum([list(op.output) for op in net.op], [])
|
||||
blobs_to_check = [b for b in blobs_to_check if b not in ignore]
|
||||
workspace.SwitchWorkspace("_device_check_", True)
|
||||
for i, device_option in enumerate(self._device_options):
|
||||
for name, arr in inputs.iteritems():
|
||||
print 'feeding', name
|
||||
workspace.FeedBlob(name, arr, device_option)
|
||||
for op in net.operators:
|
||||
for op in net.op:
|
||||
op.device_option.CopyFrom(device_option)
|
||||
workspace.RunNetOnce(net)
|
||||
results.append(
|
||||
@ -86,6 +83,8 @@ class DeviceChecker(object):
|
||||
print x.flatten()
|
||||
print y.flatten()
|
||||
success = False
|
||||
continue
|
||||
#else:
|
||||
# print ('Passed device pair (%d, %d), %s %s: %s' %
|
||||
# (i, j, blobs_to_check[j], y.shape, str(y.flatten())))
|
||||
workspace.SwitchWorkspace(old_ws_name)
|
||||
return success
|
||||
|
217
pycaffe2/gradient_check_test.py
Normal file
217
pycaffe2/gradient_check_test.py
Normal file
@ -0,0 +1,217 @@
|
||||
import numpy as np
|
||||
from pycaffe2 import core, device_checker, gradient_checker, workspace
|
||||
from caffe2.proto import caffe2_pb2, caffe2_legacy_pb2
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
if workspace.has_gpu_support and workspace.NumberOfGPUs() > 0:
|
||||
gpu_device_option = caffe2_pb2.DeviceOption()
|
||||
gpu_device_option.device_type = caffe2_pb2.CUDA
|
||||
cpu_device_option = caffe2_pb2.DeviceOption()
|
||||
device_checker = device_checker.DeviceChecker(
|
||||
0.01, [gpu_device_option, cpu_device_option])
|
||||
gradient_checkers = [
|
||||
gradient_checker.GradientChecker(
|
||||
0.005, 0.05, gpu_device_option, "gpu_checker_ws"),
|
||||
gradient_checker.GradientChecker(
|
||||
0.01, 0.05, cpu_device_option, "cpu_checker_ws"),
|
||||
]
|
||||
else:
|
||||
cpu_device_option = caffe2_pb2.DeviceOption()
|
||||
device_checker = device_checker.DeviceChecker(
|
||||
0.01, [cpu_device_option])
|
||||
gradient_checkers = [
|
||||
gradient_checker.GradientChecker(
|
||||
0.01, 0.05, cpu_device_option, "cpu_checker_ws")
|
||||
]
|
||||
|
||||
|
||||
|
||||
class TestConvLegacyPooling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_configs = [
|
||||
# stride, kernel, legacy_pad, size, order
|
||||
(1, 1, 1, 7, "NHWC"),
|
||||
(1, 1, 2, 7, "NHWC"),
|
||||
(1, 3, 1, 7, "NHWC"),
|
||||
(1, 3, 2, 7, "NHWC"),
|
||||
(1, 5, 1, 14, "NHWC"),
|
||||
(1, 5, 2, 14, "NHWC"),
|
||||
(2, 7, 1, 24, "NHWC"),
|
||||
(2, 7, 2, 24, "NHWC"),
|
||||
(1, 1, 1, 7, "NCHW"),
|
||||
(1, 1, 2, 7, "NCHW"),
|
||||
(1, 3, 1, 7, "NCHW"),
|
||||
(1, 3, 2, 7, "NCHW"),
|
||||
(1, 5, 1, 14, "NCHW"),
|
||||
(1, 5, 2, 14, "NCHW"),
|
||||
(2, 7, 1, 24, "NCHW"),
|
||||
(2, 7, 2, 24, "NCHW"),
|
||||
]
|
||||
|
||||
def testConvolutionLegacyPadding(self):
|
||||
for stride, kernel, legacy_pad, size, order in self.test_configs:
|
||||
print 'conv', stride, kernel, legacy_pad, size, order
|
||||
op = core.CreateOperator("Conv")(
|
||||
["X", "w", "b"], ["Y"], stride=stride, kernel=kernel,
|
||||
legacy_pad=legacy_pad, order=order)
|
||||
if order == "NHWC":
|
||||
X = np.random.rand(2, size, size, 3).astype(np.float32) - 0.5
|
||||
w = np.random.rand(4, kernel, kernel, 3).astype(np.float32) - 0.5
|
||||
else:
|
||||
X = np.random.rand(2, 3, size, size).astype(np.float32) - 0.5
|
||||
w = np.random.rand(4, 3, kernel, kernel).astype(np.float32) - 0.5
|
||||
b = np.random.rand(4).astype(np.float32) - 0.5
|
||||
res = device_checker.CheckSimple(op, [X, w, b], [0])
|
||||
self.assertTrue(res)
|
||||
for checker in gradient_checkers:
|
||||
for i in range(3):
|
||||
res, grad, grad_estimated = checker.CheckSimple(
|
||||
op, [X, w, b], i, [0])
|
||||
self.assertTrue(res)
|
||||
|
||||
class TestMaxPoolingLegacyPadding(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_configs = [
|
||||
(2, 3, 2, 12, "NHWC"),
|
||||
(2, 3, 2, 16, "NHWC"),
|
||||
(1, 3, 2, 8, "NHWC"),
|
||||
(1, 3, 2, 14, "NHWC"),
|
||||
(2, 3, 2, 14, "NHWC"),
|
||||
(1, 3, 2, 7, "NHWC"),
|
||||
(2, 3, 2, 12, "NCHW"),
|
||||
(2, 3, 2, 16, "NCHW"),
|
||||
(1, 3, 2, 8, "NCHW"),
|
||||
(1, 3, 2, 14, "NCHW"),
|
||||
(2, 3, 2, 14, "NCHW"),
|
||||
(1, 3, 2, 7, "NCHW"),
|
||||
]
|
||||
|
||||
def testMaxPoolingLegacyPadding(self):
|
||||
for stride, kernel, legacy_pad, size, order in self.test_configs:
|
||||
print 'MaxPool', stride, kernel, legacy_pad, size, order
|
||||
op = core.CreateOperator("MaxPool")(
|
||||
["X"], ["Y", "Y_maxid"], stride=stride, kernel=kernel,
|
||||
legacy_pad=legacy_pad, order=order)
|
||||
# In order to avoid the problem of race conditions, we will do a randperm
|
||||
# so that the values will be apart at least 0.01
|
||||
if order == "NHWC":
|
||||
X = np.random.permutation(1 * size * size * 3).reshape(1, size, size, 3).astype(np.float32) * 0.01
|
||||
else:
|
||||
X = np.random.permutation(1 * size * size * 3).reshape(1, 3, size, size).astype(np.float32) * 0.01
|
||||
res = device_checker.CheckSimple(op, [X], [0])
|
||||
self.assertTrue(res)
|
||||
for checker in gradient_checkers:
|
||||
res, grad, grad_estimated = checker.CheckSimple(op, [X], 0, [0])
|
||||
self.assertTrue(res)
|
||||
|
||||
class TestAveragePoolingLegacyPadding(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_configs = [
|
||||
(1, 7, 1, 7, "NHWC"),
|
||||
(1, 7, 2, 7, "NHWC"),
|
||||
(1, 7, 1, 7, "NCHW"),
|
||||
(1, 7, 2, 7, "NCHW"),
|
||||
]
|
||||
|
||||
def testAveragePoolingLegacyPadding(self):
|
||||
for stride, kernel, legacy_pad, size, order in self.test_configs:
|
||||
print 'AveragePool', stride, kernel, legacy_pad, size, order
|
||||
op = core.CreateOperator("AveragePool")(
|
||||
["X"], ["Y"], stride=stride, kernel=kernel,
|
||||
legacy_pad=legacy_pad, order=order)
|
||||
if order == "NHWC":
|
||||
X = np.random.rand(2, size, size, 3).astype(np.float32)
|
||||
else:
|
||||
X = np.random.rand(2, 3, size, size).astype(np.float32)
|
||||
res = device_checker.CheckSimple(op, [X], [0])
|
||||
self.assertTrue(res)
|
||||
for checker in gradient_checkers:
|
||||
res, grad, grad_estimated = checker.CheckSimple(op, [X], 0, [0])
|
||||
self.assertTrue(res)
|
||||
|
||||
class TestLRN(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_configs = [
|
||||
(6, 10),
|
||||
(3, 13),
|
||||
]
|
||||
|
||||
def testLRN(self):
|
||||
for input_size, depth in self.test_configs:
|
||||
op = core.CreateOperator("LRN")(
|
||||
["X"], ["Y", "Y_scale"], size=11, alpha=0.001, beta=0.5, bias=2.0, order="NHWC")
|
||||
X = np.random.rand(2, input_size, input_size, depth).astype(np.float32)
|
||||
res = device_checker.CheckSimple(op, [X], [0])
|
||||
self.assertTrue(res)
|
||||
for checker in gradient_checkers:
|
||||
res, grad, grad_estimated = checker.CheckSimple(op, [X], 0, [0])
|
||||
self.assertTrue(res)
|
||||
|
||||
class TestDepthConcat(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_configs = [
|
||||
# input_size, depth1, depth2, depth3, depth4
|
||||
(3, 2, 3, 4, 5),
|
||||
(4, 5, 4, 3, 2),
|
||||
]
|
||||
|
||||
def testDepthConcatNHWC(self):
|
||||
for input_size, d1, d2, d3, d4 in self.test_configs:
|
||||
op = core.CreateOperator("DepthConcat")(
|
||||
["X1", "X2", "X3", "X4"], ["Y", "Y_dims"], order="NHWC")
|
||||
Xs = [np.random.rand(2, input_size, input_size, d1).astype(np.float32),
|
||||
np.random.rand(2, input_size, input_size, d2).astype(np.float32),
|
||||
np.random.rand(2, input_size, input_size, d3).astype(np.float32),
|
||||
np.random.rand(2, input_size, input_size, d4).astype(np.float32)]
|
||||
for i in range(4):
|
||||
res = device_checker.CheckSimple(op, Xs, [0])
|
||||
self.assertTrue(res)
|
||||
for checker in gradient_checkers:
|
||||
res, grad, grad_estimated = checker.CheckSimple(op, Xs, i, [0])
|
||||
self.assertTrue(res)
|
||||
|
||||
def testDepthConcatNCHW(self):
|
||||
for input_size, d1, d2, d3, d4 in self.test_configs:
|
||||
op = core.CreateOperator("DepthConcat")(
|
||||
["X1", "X2", "X3", "X4"], ["Y", "Y_dims"], order="NCHW")
|
||||
Xs = [np.random.rand(2, d1, input_size, input_size).astype(np.float32),
|
||||
np.random.rand(2, d2, input_size, input_size).astype(np.float32),
|
||||
np.random.rand(2, d3, input_size, input_size).astype(np.float32),
|
||||
np.random.rand(2, d4, input_size, input_size).astype(np.float32)]
|
||||
for i in range(4):
|
||||
res = device_checker.CheckSimple(op, Xs, [0])
|
||||
self.assertTrue(res)
|
||||
for checker in gradient_checkers:
|
||||
res, grad, grad_estimated = checker.CheckSimple(op, Xs, i, [0])
|
||||
self.assertTrue(res)
|
||||
|
||||
class TestRelu(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_configs = [
|
||||
# input size
|
||||
(1, 1),
|
||||
(2, 1),
|
||||
(1, 3, 3, 1),
|
||||
(2, 3, 3, 1),
|
||||
(1, 5, 5, 3),
|
||||
(2, 5, 5, 3),
|
||||
]
|
||||
|
||||
def testRelu(self):
|
||||
for input_size in self.test_configs:
|
||||
op = core.CreateOperator("Relu")(["X"], ["Y"])
|
||||
X = np.random.rand(*input_size).astype(np.float32)
|
||||
# go away from the origin point to avoid kink problems
|
||||
X += 0.01 * np.sign(X)
|
||||
X[X==0] = 0.01
|
||||
res = device_checker.CheckSimple(op, [X], [0])
|
||||
self.assertTrue(res)
|
||||
for checker in gradient_checkers:
|
||||
res, grad, grad_estimated = checker.CheckSimple(op, [X], 0, [0])
|
||||
self.assertTrue(res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
97
pycaffe2/model_device_test.py
Normal file
97
pycaffe2/model_device_test.py
Normal file
@ -0,0 +1,97 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from caffe2.proto import caffe2_pb2, caffe2_legacy_pb2
|
||||
from pycaffe2 import core, cnn, workspace, device_checker
|
||||
|
||||
class TestModelDevice(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def _MiniAlexNetNoDropout(self, order):
|
||||
# First, AlexNet using the cnn wrapper.
|
||||
model = cnn.CNNModelHelper("alexnet", order=order)
|
||||
conv1 = model.Conv("data", "conv1", 3, 16, 11,
|
||||
("XavierFill", {}),
|
||||
("ConstantFill", {}), stride=4, pad=0)
|
||||
relu1 = model.Relu(conv1, "relu1")
|
||||
norm1 = model.LRN(relu1, "norm1", size=5, alpha=0.0001, beta=0.75)
|
||||
pool1 = model.MaxPool(norm1, "pool1", kernel=3, stride=2)
|
||||
conv2 = model.GroupConv(pool1, "conv2", 16, 32, 5,
|
||||
("XavierFill", {}),
|
||||
("ConstantFill", {"value": 0.1}),
|
||||
group=2, stride=1, pad=2)
|
||||
relu2 = model.Relu(conv2, "relu2")
|
||||
norm2 = model.LRN(relu2, "norm2", size=5, alpha=0.0001, beta=0.75)
|
||||
pool2 = model.MaxPool(norm2, "pool2", kernel=3, stride=2)
|
||||
conv3 = model.Conv(pool2, "conv3", 32, 64, 3,
|
||||
("XavierFill", {'std': 0.01}),
|
||||
("ConstantFill", {}), pad=1)
|
||||
relu3 = model.Relu(conv3, "relu3")
|
||||
conv4 = model.GroupConv(relu3, "conv4", 64, 64, 3,
|
||||
("XavierFill", {}),
|
||||
("ConstantFill", {"value": 0.1}),
|
||||
group=2, pad=1)
|
||||
relu4 = model.Relu(conv4, "relu4")
|
||||
conv5 = model.GroupConv(relu4, "conv5", 64, 32, 3,
|
||||
("XavierFill", {}),
|
||||
("ConstantFill", {"value": 0.1}),
|
||||
group=2, pad=1)
|
||||
relu5 = model.Relu(conv5, "relu5")
|
||||
pool5 = model.MaxPool(relu5, "pool5", kernel=3, stride=2)
|
||||
fc6 = model.FC(pool5, "fc6", 1152, 1024,
|
||||
("XavierFill", {}),
|
||||
("ConstantFill", {"value": 0.1}))
|
||||
relu6 = model.Relu(fc6, "relu6")
|
||||
fc7 = model.FC(relu6, "fc7", 1024, 1024,
|
||||
("XavierFill", {}),
|
||||
("ConstantFill", {"value": 0.1}))
|
||||
relu7 = model.Relu(fc7, "relu7")
|
||||
fc8 = model.FC(relu7, "fc8", 1024, 5,
|
||||
("XavierFill", {}),
|
||||
("ConstantFill", {"value": 0.0}))
|
||||
pred = model.Softmax(fc8, "pred")
|
||||
xent = model.LabelCrossEntropy([pred, "label"], "xent")
|
||||
loss, xent_grad = model.AveragedLoss([xent], ["loss", xent.Grad()])
|
||||
model.AddGradientOperators()
|
||||
return model
|
||||
|
||||
def _testMiniAlexNet(self, order):
|
||||
# First, we get all the random initialization of parameters.
|
||||
model = self._MiniAlexNetNoDropout(order);
|
||||
workspace.ResetWorkspace()
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
inputs = dict(
|
||||
[(str(name), workspace.FetchBlob(str(name))) for name in model.params])
|
||||
if order == "NCHW":
|
||||
inputs["data"] = np.random.rand(4, 3, 227, 227).astype(np.float32)
|
||||
else:
|
||||
inputs["data"] = np.random.rand(4, 227, 227, 3).astype(np.float32)
|
||||
inputs["label"] = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||
|
||||
cpu_device = caffe2_pb2.DeviceOption()
|
||||
cpu_device.device_type = caffe2_pb2.CPU
|
||||
gpu_device = caffe2_pb2.DeviceOption()
|
||||
gpu_device.device_type = caffe2_pb2.CUDA
|
||||
|
||||
checker = device_checker.DeviceChecker(
|
||||
1e-5, [cpu_device, gpu_device])
|
||||
ret = checker.CheckNet(
|
||||
model.net.Proto(), inputs,
|
||||
# The indices sometimes may be sensitive to small numerical differences
|
||||
# in the input, so we ignore checking them.
|
||||
ignore=['_pool1_idx', '_pool2_idx', '_pool5_idx'])
|
||||
self.assertEqual(ret, True)
|
||||
|
||||
def testMiniAlexNet(self):
|
||||
self._testMiniAlexNet(order="NCHW")
|
||||
self._testMiniAlexNet(order="NHWC")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not workspace.has_gpu_support:
|
||||
print 'No GPU support. Skipping gpu test.'
|
||||
elif workspace.NumberOfGPUs() == 0:
|
||||
print 'No GPU device. Skipping gpu test.'
|
||||
else:
|
||||
unittest.main()
|
Reference in New Issue
Block a user