half-finished cnn wrapper, etc.

This commit is contained in:
Yangqing Jia
2015-09-09 20:33:34 -07:00
parent d4336af327
commit d07549bed2
5 changed files with 94 additions and 4 deletions

View File

@ -1,6 +1,7 @@
#ifndef CAFFE2_CORE_CONTEXT_H_
#define CAFFE2_CORE_CONTEXT_H_
#include <ctime>
#include <random>
#include "caffe2/proto/caffe2.pb.h"
@ -11,9 +12,10 @@ namespace caffe2 {
class CPUContext {
public:
CPUContext() : random_generator_(0) {}
explicit CPUContext(const DeviceOption& device_option)
: random_generator_(device_option.random_seed()) {
DCHECK_EQ(device_option.device_type(), CPU);
explicit CPUContext(const DeviceOption& option)
: random_generator_(
option.has_random_seed() ? option.random_seed() : time(NULL)) {
CHECK_EQ(option.device_type(), CPU);
}
virtual ~CPUContext() {}
inline void SwitchToDevice() {}

View File

@ -1,6 +1,8 @@
#ifndef CAFFE2_CORE_CONTEXT_GPU_H_
#define CAFFE2_CORE_CONTEXT_GPU_H_
#include <ctime>
#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context.h"
#include "caffe2/core/cuda_memorypool.h"
@ -26,7 +28,9 @@ class CUDAContext {
explicit CUDAContext(const DeviceOption& option)
: cuda_stream_(nullptr), cublas_handle_(nullptr),
random_seed_(option.random_seed()), curand_generator_(nullptr) {
random_seed_(
option.has_random_seed() ? option.random_seed() : time(NULL)),
curand_generator_(nullptr) {
DCHECK_EQ(option.device_type(), CUDA);
cuda_gpu_id_ = option.has_cuda_gpu_id() ?
option.cuda_gpu_id() : GetDefaultGPUID();

View File

@ -28,6 +28,7 @@ py_library(
srcs = [
"__init__.py",
"caffe_translator.py",
"cnn.py",
"core.py",
"core_gradients.py",
"device_checker.py",

81
pycaffe2/cnn.py Normal file
View File

@ -0,0 +1,81 @@
from pycaffe2 import core
class CNNModelHelper(object):
"""A helper model so we can write CNN models more easily, without having to
manually define parameter initializations and operators separately.
"""
def __init__(self, name, order):
self.net = core.Net(name)
self.param_init_net = core.Net(name + '_init')
self.params = []
self.order = order
if self.order != "NHWC" and self.order != "NCHW":
raise ValueError("Cannot understand the CNN storage order.")
def Conv(self, blob_in, blob_out, dim_in, dim_out, kernel,
weight_init, bias_init, **kwargs):
"""Convolution. We intentionally do not provide odd kernel/stride/pad
settings in order to discourage the use of odd cases.
"""
weight_shape = ([dim_out, dim_in, kernel, kernel] if self.order == "NCHW"
else [dim_out, kernel, kernel, dim_in])
weight = self.param_init_net.__getattr__(weight_init[0])(
[], blob_out + '_w', shape=weight_shape, **weight_init[1])
bias = self.param_init_net.__getattr__(bias_init[0])(
[], blob_out + '_b', shape=[dim_out,], **bias_init[1])
self.params.extend([weight, bias])
return self.net.Conv([blob_in, weight, bias], blob_out, kernel=kernel,
order=self.order, **kwargs)
def GroupConv(self, blob_in, blob_out, dim_in, dim_out, kernel,
weight_init, bias_init, group=1, **kwargs):
"""Convolution. We intentionally do not provide odd kernel/stride/pad
settings in order to discourage the use of odd cases.
"""
if dim_in % group:
raise ValueError("dim_in should be divisible by group.")
splitted_blobs = self.net.DepthSplit(
blob_in,
['_' + blob_out + '_gconv_split_' + str(i) for i in range(group)],
dimensions=[dim_in / group for i in range(group)],
order=self.order)
weight_shape = ([dim_out / group, dim_in / group, kernel, kernel]
if self.order == "NCHW"
else [dim_out / group, kernel, kernel, dim_in / group])
conv_blobs = []
for i in range(group):
weight = self.param_init_net.__getattr__(weight_init[0])(
[], blob_out + '_gconv_%d_w' % i, shape=weight_shape,
**weight_init[1])
bias = self.param_init_net.__getattr__(bias_init[0])(
[], blob_out + '_gconv_%d_b' % i, shape=[dim_out / group],
**bias_init[1])
self.params.extend([weight, bias])
conv_blobs.append(
splitted_blobs[i].Conv([weight, bias], blob_out + '_gconv_%d' % i,
kernel=kernel, order=self.order, **kwargs))
concat = self.net.DepthConcat(conv_blobs, blob_out, order=self.order)
return concat
def FC(self, blob_in, blob_out, dim_in, dim_out, weight_init, bias_init,
**kwargs):
"""FC"""
weight = self.param_init_net.__getattr__(weight_init[0])(
[], blob_out + '_w', shape=[dim_out, dim_in], **weight_init[1])
bias = self.param_init_net.__getattr__(bias_init[0])(
[], blob_out + '_b', shape=[dim_out,], **bias_init[1])
self.params.extend([weight, bias])
return self.net.FC([blob_in, weight, bias], blob_out, **kwargs)
def LRN(self, blob_in, blob_out, **kwargs):
"""LRN"""
return self.net.LRN(blob_in, [blob_out, "_" + blob_out + "_scale"],
order=self.order, **kwargs)[0]
def MaxPool(self, blob_in, blob_out, **kwargs):
"""Max pooling"""
return self.net.MaxPool(blob_in, blob_out, order=self.order, **kwargs)
def __getattr__(self, operator_type):
"""Catch-all for all other operators, mostly those without params."""
return self.net.__getattr__(operator_type)

View File

@ -16,6 +16,8 @@ class BlobReference(object):
def __init__(self, name, net):
self._name = name
self._from_net = net
# meta allows helper functions to put whatever metainformation needed there.
self.meta = {}
def __str__(self):
return self._name