(1) blob debugstring

(2) cnn bugfix and enhancement
(3) device checker fix
(4) suppress prints from caffe2_python
(5) moar tests!
This commit is contained in:
Yangqing Jia
2015-09-12 10:32:49 -07:00
parent 8490282cd1
commit 0b66c01462
8 changed files with 391 additions and 29 deletions

View File

@ -14,6 +14,9 @@ import traceback
from build_env import Env
# global variables
CAFFE2_RUN_TEST = False
class Colors(object):
HEADER = '\033[95m'
OKBLUE = '\033[94m'
@ -344,11 +347,13 @@ class BuildTarget(object):
Brewery.Register(self.name, self)
def GetSignature(self):
"""Generate the signature of the build object."""
"""Generate the signature of the build object, and see if we need to
rebuild it."""
src_digest = ''.join([hashlib.sha256(open(f, 'rb').read()).hexdigest()
for f in self.files])
dep_digest = ''.join([Brewery.Signature(d) for d in self.deps])
return hashlib.sha256(src_digest + dep_digest).hexdigest()
command_digest = str(self.command_groups)
return hashlib.sha256(src_digest + dep_digest + command_digest).hexdigest()
def SetUpAndBuild(self, built_signature):
# Add successful optional dependencies into deps.
@ -482,7 +487,7 @@ class cc_target(BuildTarget):
self.command_groups = [cpp_commands, link_commands, link_shared_commands]
else:
self.command_groups = [cpp_commands, link_commands]
if self.is_test:
if self.is_test and CAFFE2_RUN_TEST:
# Add test command
self.command_groups.append([
' '.join([self.OutputName(), '--caffe_test_root',
@ -516,11 +521,12 @@ class mpi_test(cc_target):
def SetUp(self):
cc_target.SetUp(self)
self.command_groups.append([
' '.join(['mpirun --allow-run-as-root -n',
str(self.mpi_size), self.OutputName(),
'--caffe_test_root', os.path.abspath(Env.GENDIR),
'--gtest_filter=-*.LARGE_*'])])
if CAFFE2_RUN_TEST:
self.command_groups.append([
' '.join(['mpirun --allow-run-as-root -n',
str(self.mpi_size), self.OutputName(),
'--caffe_test_root', os.path.abspath(Env.GENDIR),
'--gtest_filter=-*.LARGE_*'])])
class cuda_library(BuildTarget):
@ -585,9 +591,10 @@ class py_test(BuildTarget):
CopyToGenDir(self.srcs)
if len(self.srcs) > 1:
raise RuntimeError('py_test should only take one python source file.')
# Add test command
self.command_groups = [
['python %s' % GenFilename(self.srcs[0])]]
if CAFFE2_RUN_TEST:
# Add test command
self.command_groups = [
['python %s' % GenFilename(self.srcs[0])]]
class cc_thirdparty_target(BuildTarget):
@ -685,6 +692,11 @@ def main(argv):
# Build all targets.
targets = sys.argv[2:]
Brewery.Build(targets)
elif sys.argv[1] == 'test':
global CAFFE2_RUN_TEST
CAFFE2_RUN_TEST = True
targets = sys.argv[2:]
Brewery.Build(targets)
elif sys.argv[1] == 'draw':
# Draws the dependency graph.
Brewery.Draw()

View File

@ -2,6 +2,8 @@
#define CAFFE2_CORE_BLOB_H_
#include <cstddef>
#include <sstream>
#include <typeinfo>
#include <vector>
#include "caffe2/core/common.h"
@ -143,6 +145,20 @@ class Tensor {
virtual ~Tensor() {}
// A utility function to print the debug string for the tensor. Note that this
// is very slow since it involves quite some string operations, so do not use
// it in your performance-critical code.
string DebugString() const {
std::stringstream ss;
ss << "A Tensor of data type " << typeid(dtype).name()
<< " and dimension (";
for (int d : dims_) {
ss << d << ",";
}
ss << ").";
return ss.str();
}
void Reshape(const vector<int>& dims) {
dims_ = dims;
ndim_ = dims_.size();

View File

@ -91,3 +91,13 @@ py_test(
":pycaffe2",
],
)
py_test(
name = "gradient_check_test",
srcs = [
"gradient_check_test.py",
],
deps = [
":pycaffe2",
],
)

View File

@ -190,7 +190,7 @@ PyObject* ResetWorkspace(PyObject* self, PyObject* args) {
"specifying the root folder of the workspace.");
return NULL;
}
LOG(INFO) << "Resetting workspace.";
VLOG(1) << "Resetting workspace.";
if (root_folder == nullptr) {
gWorkspaces[gCurrentWorkspaceName].reset(
new Workspace());

View File

@ -54,7 +54,9 @@ class CNNModelHelper(object):
conv_blobs.append(
splitted_blobs[i].Conv([weight, bias], blob_out + '_gconv_%d' % i,
kernel=kernel, order=self.order, **kwargs))
concat = self.net.DepthConcat(conv_blobs, blob_out, order=self.order)
concat, concat_dims = self.net.DepthConcat(
conv_blobs, [blob_out, "_" + blob_out + "_concat_dims"],
order=self.order)
return concat
def FC(self, blob_in, blob_out, dim_in, dim_out, weight_init, bias_init,
@ -72,9 +74,18 @@ class CNNModelHelper(object):
return self.net.LRN(blob_in, [blob_out, "_" + blob_out + "_scale"],
order=self.order, **kwargs)[0]
def Dropout(self, blob_in, blob_out, **kwargs):
"""Dropout"""
return self.net.Dropout(blob_in, [blob_out, "_" + blob_out + "_mask"],
**kwargs)[0]
def MaxPool(self, blob_in, blob_out, **kwargs):
"""Max pooling"""
return self.net.MaxPool(blob_in, blob_out, order=self.order, **kwargs)
return self.net.MaxPool(blob_in, [blob_out, "_" + blob_out + "_idx"],
order=self.order, **kwargs)[0]
def AddGradientOperators(self):
return self.net.AddGradientOperators()
def __getattr__(self, operator_type):
"""Catch-all for all other operators, mostly those without params."""

View File

@ -2,7 +2,7 @@ import numpy as np
from pycaffe2 import core, workspace
class DeviceChecker(object):
"""A gradient checker in Python.
"""A device checker in Python to check consistency across multiple devices.
This is not the most efficient way to check gradients, as the Python interface
will involve a lot of copy back and forth operations. Use at your own risk.
@ -12,19 +12,12 @@ class DeviceChecker(object):
self._device_options = device_options
def CheckSimple(self, op, inputs, outputs_to_check):
"""Checks the operator in a very simple fashion by stacking a sum of squares
on the top.
"""Checks the operator with different device implementations.
Inputs:
op: the operator to be checked.
inputs: the input data in numpy arrays.
input_to_check: an index specifying which input blob we should
check.
outputs_with_grads: indices specifying which output blobs will we
need to check gradients with. For these outputs, we will collect a
squared sum and also feed in their gradients.
grad_operator: the gradient operator. If not given, we will get the
gradient operator from the gradient registry.
outputs_to_check: the outputs to check between devices.
Outputs:
boolean: True if it passes, False if it does not pass.
"""
@ -53,23 +46,27 @@ class DeviceChecker(object):
print x.flatten()
print y.flatten()
success = False
continue
#else:
# print ('Passed device pair (0, %d), %s %s' %
# (i, outputs_to_check[j], y.shape))
workspace.SwitchWorkspace(old_ws_name)
return success
def CheckNet(self, net, inputs={}, ignore=set()):
def CheckNet(self, net, inputs={}, blobs_to_check=None, ignore=set()):
"""Checks a network by inspecting all of its intermediate results, and see
if things match.
"""
old_ws_name = workspace.CurrentWorkspace()
results = []
blobs_to_check = sum([list(op.output) for op in net.operators], [])
if blobs_to_check is None:
blobs_to_check = sum([list(op.output) for op in net.op], [])
blobs_to_check = [b for b in blobs_to_check if b not in ignore]
workspace.SwitchWorkspace("_device_check_", True)
for i, device_option in enumerate(self._device_options):
for name, arr in inputs.iteritems():
print 'feeding', name
workspace.FeedBlob(name, arr, device_option)
for op in net.operators:
for op in net.op:
op.device_option.CopyFrom(device_option)
workspace.RunNetOnce(net)
results.append(
@ -86,6 +83,8 @@ class DeviceChecker(object):
print x.flatten()
print y.flatten()
success = False
continue
#else:
# print ('Passed device pair (%d, %d), %s %s: %s' %
# (i, j, blobs_to_check[j], y.shape, str(y.flatten())))
workspace.SwitchWorkspace(old_ws_name)
return success

View File

@ -0,0 +1,217 @@
import numpy as np
from pycaffe2 import core, device_checker, gradient_checker, workspace
from caffe2.proto import caffe2_pb2, caffe2_legacy_pb2
import sys
import unittest
if workspace.has_gpu_support and workspace.NumberOfGPUs() > 0:
gpu_device_option = caffe2_pb2.DeviceOption()
gpu_device_option.device_type = caffe2_pb2.CUDA
cpu_device_option = caffe2_pb2.DeviceOption()
device_checker = device_checker.DeviceChecker(
0.01, [gpu_device_option, cpu_device_option])
gradient_checkers = [
gradient_checker.GradientChecker(
0.005, 0.05, gpu_device_option, "gpu_checker_ws"),
gradient_checker.GradientChecker(
0.01, 0.05, cpu_device_option, "cpu_checker_ws"),
]
else:
cpu_device_option = caffe2_pb2.DeviceOption()
device_checker = device_checker.DeviceChecker(
0.01, [cpu_device_option])
gradient_checkers = [
gradient_checker.GradientChecker(
0.01, 0.05, cpu_device_option, "cpu_checker_ws")
]
class TestConvLegacyPooling(unittest.TestCase):
def setUp(self):
self.test_configs = [
# stride, kernel, legacy_pad, size, order
(1, 1, 1, 7, "NHWC"),
(1, 1, 2, 7, "NHWC"),
(1, 3, 1, 7, "NHWC"),
(1, 3, 2, 7, "NHWC"),
(1, 5, 1, 14, "NHWC"),
(1, 5, 2, 14, "NHWC"),
(2, 7, 1, 24, "NHWC"),
(2, 7, 2, 24, "NHWC"),
(1, 1, 1, 7, "NCHW"),
(1, 1, 2, 7, "NCHW"),
(1, 3, 1, 7, "NCHW"),
(1, 3, 2, 7, "NCHW"),
(1, 5, 1, 14, "NCHW"),
(1, 5, 2, 14, "NCHW"),
(2, 7, 1, 24, "NCHW"),
(2, 7, 2, 24, "NCHW"),
]
def testConvolutionLegacyPadding(self):
for stride, kernel, legacy_pad, size, order in self.test_configs:
print 'conv', stride, kernel, legacy_pad, size, order
op = core.CreateOperator("Conv")(
["X", "w", "b"], ["Y"], stride=stride, kernel=kernel,
legacy_pad=legacy_pad, order=order)
if order == "NHWC":
X = np.random.rand(2, size, size, 3).astype(np.float32) - 0.5
w = np.random.rand(4, kernel, kernel, 3).astype(np.float32) - 0.5
else:
X = np.random.rand(2, 3, size, size).astype(np.float32) - 0.5
w = np.random.rand(4, 3, kernel, kernel).astype(np.float32) - 0.5
b = np.random.rand(4).astype(np.float32) - 0.5
res = device_checker.CheckSimple(op, [X, w, b], [0])
self.assertTrue(res)
for checker in gradient_checkers:
for i in range(3):
res, grad, grad_estimated = checker.CheckSimple(
op, [X, w, b], i, [0])
self.assertTrue(res)
class TestMaxPoolingLegacyPadding(unittest.TestCase):
def setUp(self):
self.test_configs = [
(2, 3, 2, 12, "NHWC"),
(2, 3, 2, 16, "NHWC"),
(1, 3, 2, 8, "NHWC"),
(1, 3, 2, 14, "NHWC"),
(2, 3, 2, 14, "NHWC"),
(1, 3, 2, 7, "NHWC"),
(2, 3, 2, 12, "NCHW"),
(2, 3, 2, 16, "NCHW"),
(1, 3, 2, 8, "NCHW"),
(1, 3, 2, 14, "NCHW"),
(2, 3, 2, 14, "NCHW"),
(1, 3, 2, 7, "NCHW"),
]
def testMaxPoolingLegacyPadding(self):
for stride, kernel, legacy_pad, size, order in self.test_configs:
print 'MaxPool', stride, kernel, legacy_pad, size, order
op = core.CreateOperator("MaxPool")(
["X"], ["Y", "Y_maxid"], stride=stride, kernel=kernel,
legacy_pad=legacy_pad, order=order)
# In order to avoid the problem of race conditions, we will do a randperm
# so that the values will be apart at least 0.01
if order == "NHWC":
X = np.random.permutation(1 * size * size * 3).reshape(1, size, size, 3).astype(np.float32) * 0.01
else:
X = np.random.permutation(1 * size * size * 3).reshape(1, 3, size, size).astype(np.float32) * 0.01
res = device_checker.CheckSimple(op, [X], [0])
self.assertTrue(res)
for checker in gradient_checkers:
res, grad, grad_estimated = checker.CheckSimple(op, [X], 0, [0])
self.assertTrue(res)
class TestAveragePoolingLegacyPadding(unittest.TestCase):
def setUp(self):
self.test_configs = [
(1, 7, 1, 7, "NHWC"),
(1, 7, 2, 7, "NHWC"),
(1, 7, 1, 7, "NCHW"),
(1, 7, 2, 7, "NCHW"),
]
def testAveragePoolingLegacyPadding(self):
for stride, kernel, legacy_pad, size, order in self.test_configs:
print 'AveragePool', stride, kernel, legacy_pad, size, order
op = core.CreateOperator("AveragePool")(
["X"], ["Y"], stride=stride, kernel=kernel,
legacy_pad=legacy_pad, order=order)
if order == "NHWC":
X = np.random.rand(2, size, size, 3).astype(np.float32)
else:
X = np.random.rand(2, 3, size, size).astype(np.float32)
res = device_checker.CheckSimple(op, [X], [0])
self.assertTrue(res)
for checker in gradient_checkers:
res, grad, grad_estimated = checker.CheckSimple(op, [X], 0, [0])
self.assertTrue(res)
class TestLRN(unittest.TestCase):
def setUp(self):
self.test_configs = [
(6, 10),
(3, 13),
]
def testLRN(self):
for input_size, depth in self.test_configs:
op = core.CreateOperator("LRN")(
["X"], ["Y", "Y_scale"], size=11, alpha=0.001, beta=0.5, bias=2.0, order="NHWC")
X = np.random.rand(2, input_size, input_size, depth).astype(np.float32)
res = device_checker.CheckSimple(op, [X], [0])
self.assertTrue(res)
for checker in gradient_checkers:
res, grad, grad_estimated = checker.CheckSimple(op, [X], 0, [0])
self.assertTrue(res)
class TestDepthConcat(unittest.TestCase):
def setUp(self):
self.test_configs = [
# input_size, depth1, depth2, depth3, depth4
(3, 2, 3, 4, 5),
(4, 5, 4, 3, 2),
]
def testDepthConcatNHWC(self):
for input_size, d1, d2, d3, d4 in self.test_configs:
op = core.CreateOperator("DepthConcat")(
["X1", "X2", "X3", "X4"], ["Y", "Y_dims"], order="NHWC")
Xs = [np.random.rand(2, input_size, input_size, d1).astype(np.float32),
np.random.rand(2, input_size, input_size, d2).astype(np.float32),
np.random.rand(2, input_size, input_size, d3).astype(np.float32),
np.random.rand(2, input_size, input_size, d4).astype(np.float32)]
for i in range(4):
res = device_checker.CheckSimple(op, Xs, [0])
self.assertTrue(res)
for checker in gradient_checkers:
res, grad, grad_estimated = checker.CheckSimple(op, Xs, i, [0])
self.assertTrue(res)
def testDepthConcatNCHW(self):
for input_size, d1, d2, d3, d4 in self.test_configs:
op = core.CreateOperator("DepthConcat")(
["X1", "X2", "X3", "X4"], ["Y", "Y_dims"], order="NCHW")
Xs = [np.random.rand(2, d1, input_size, input_size).astype(np.float32),
np.random.rand(2, d2, input_size, input_size).astype(np.float32),
np.random.rand(2, d3, input_size, input_size).astype(np.float32),
np.random.rand(2, d4, input_size, input_size).astype(np.float32)]
for i in range(4):
res = device_checker.CheckSimple(op, Xs, [0])
self.assertTrue(res)
for checker in gradient_checkers:
res, grad, grad_estimated = checker.CheckSimple(op, Xs, i, [0])
self.assertTrue(res)
class TestRelu(unittest.TestCase):
def setUp(self):
self.test_configs = [
# input size
(1, 1),
(2, 1),
(1, 3, 3, 1),
(2, 3, 3, 1),
(1, 5, 5, 3),
(2, 5, 5, 3),
]
def testRelu(self):
for input_size in self.test_configs:
op = core.CreateOperator("Relu")(["X"], ["Y"])
X = np.random.rand(*input_size).astype(np.float32)
# go away from the origin point to avoid kink problems
X += 0.01 * np.sign(X)
X[X==0] = 0.01
res = device_checker.CheckSimple(op, [X], [0])
self.assertTrue(res)
for checker in gradient_checkers:
res, grad, grad_estimated = checker.CheckSimple(op, [X], 0, [0])
self.assertTrue(res)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,97 @@
import numpy as np
import unittest
from caffe2.proto import caffe2_pb2, caffe2_legacy_pb2
from pycaffe2 import core, cnn, workspace, device_checker
class TestModelDevice(unittest.TestCase):
def setUp(self):
pass
def _MiniAlexNetNoDropout(self, order):
# First, AlexNet using the cnn wrapper.
model = cnn.CNNModelHelper("alexnet", order=order)
conv1 = model.Conv("data", "conv1", 3, 16, 11,
("XavierFill", {}),
("ConstantFill", {}), stride=4, pad=0)
relu1 = model.Relu(conv1, "relu1")
norm1 = model.LRN(relu1, "norm1", size=5, alpha=0.0001, beta=0.75)
pool1 = model.MaxPool(norm1, "pool1", kernel=3, stride=2)
conv2 = model.GroupConv(pool1, "conv2", 16, 32, 5,
("XavierFill", {}),
("ConstantFill", {"value": 0.1}),
group=2, stride=1, pad=2)
relu2 = model.Relu(conv2, "relu2")
norm2 = model.LRN(relu2, "norm2", size=5, alpha=0.0001, beta=0.75)
pool2 = model.MaxPool(norm2, "pool2", kernel=3, stride=2)
conv3 = model.Conv(pool2, "conv3", 32, 64, 3,
("XavierFill", {'std': 0.01}),
("ConstantFill", {}), pad=1)
relu3 = model.Relu(conv3, "relu3")
conv4 = model.GroupConv(relu3, "conv4", 64, 64, 3,
("XavierFill", {}),
("ConstantFill", {"value": 0.1}),
group=2, pad=1)
relu4 = model.Relu(conv4, "relu4")
conv5 = model.GroupConv(relu4, "conv5", 64, 32, 3,
("XavierFill", {}),
("ConstantFill", {"value": 0.1}),
group=2, pad=1)
relu5 = model.Relu(conv5, "relu5")
pool5 = model.MaxPool(relu5, "pool5", kernel=3, stride=2)
fc6 = model.FC(pool5, "fc6", 1152, 1024,
("XavierFill", {}),
("ConstantFill", {"value": 0.1}))
relu6 = model.Relu(fc6, "relu6")
fc7 = model.FC(relu6, "fc7", 1024, 1024,
("XavierFill", {}),
("ConstantFill", {"value": 0.1}))
relu7 = model.Relu(fc7, "relu7")
fc8 = model.FC(relu7, "fc8", 1024, 5,
("XavierFill", {}),
("ConstantFill", {"value": 0.0}))
pred = model.Softmax(fc8, "pred")
xent = model.LabelCrossEntropy([pred, "label"], "xent")
loss, xent_grad = model.AveragedLoss([xent], ["loss", xent.Grad()])
model.AddGradientOperators()
return model
def _testMiniAlexNet(self, order):
# First, we get all the random initialization of parameters.
model = self._MiniAlexNetNoDropout(order);
workspace.ResetWorkspace()
workspace.RunNetOnce(model.param_init_net)
inputs = dict(
[(str(name), workspace.FetchBlob(str(name))) for name in model.params])
if order == "NCHW":
inputs["data"] = np.random.rand(4, 3, 227, 227).astype(np.float32)
else:
inputs["data"] = np.random.rand(4, 227, 227, 3).astype(np.float32)
inputs["label"] = np.array([1, 2, 3, 4]).astype(np.int32)
cpu_device = caffe2_pb2.DeviceOption()
cpu_device.device_type = caffe2_pb2.CPU
gpu_device = caffe2_pb2.DeviceOption()
gpu_device.device_type = caffe2_pb2.CUDA
checker = device_checker.DeviceChecker(
1e-5, [cpu_device, gpu_device])
ret = checker.CheckNet(
model.net.Proto(), inputs,
# The indices sometimes may be sensitive to small numerical differences
# in the input, so we ignore checking them.
ignore=['_pool1_idx', '_pool2_idx', '_pool5_idx'])
self.assertEqual(ret, True)
def testMiniAlexNet(self):
self._testMiniAlexNet(order="NCHW")
self._testMiniAlexNet(order="NHWC")
if __name__ == '__main__':
if not workspace.has_gpu_support:
print 'No GPU support. Skipping gpu test.'
elif workspace.NumberOfGPUs() == 0:
print 'No GPU device. Skipping gpu test.'
else:
unittest.main()