mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
half-finished cnn wrapper, etc.
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
#ifndef CAFFE2_CORE_CONTEXT_H_
|
||||
#define CAFFE2_CORE_CONTEXT_H_
|
||||
|
||||
#include <ctime>
|
||||
#include <random>
|
||||
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
@ -11,9 +12,10 @@ namespace caffe2 {
|
||||
class CPUContext {
|
||||
public:
|
||||
CPUContext() : random_generator_(0) {}
|
||||
explicit CPUContext(const DeviceOption& device_option)
|
||||
: random_generator_(device_option.random_seed()) {
|
||||
DCHECK_EQ(device_option.device_type(), CPU);
|
||||
explicit CPUContext(const DeviceOption& option)
|
||||
: random_generator_(
|
||||
option.has_random_seed() ? option.random_seed() : time(NULL)) {
|
||||
CHECK_EQ(option.device_type(), CPU);
|
||||
}
|
||||
virtual ~CPUContext() {}
|
||||
inline void SwitchToDevice() {}
|
||||
|
@ -1,6 +1,8 @@
|
||||
#ifndef CAFFE2_CORE_CONTEXT_GPU_H_
|
||||
#define CAFFE2_CORE_CONTEXT_GPU_H_
|
||||
|
||||
#include <ctime>
|
||||
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/cuda_memorypool.h"
|
||||
@ -26,7 +28,9 @@ class CUDAContext {
|
||||
|
||||
explicit CUDAContext(const DeviceOption& option)
|
||||
: cuda_stream_(nullptr), cublas_handle_(nullptr),
|
||||
random_seed_(option.random_seed()), curand_generator_(nullptr) {
|
||||
random_seed_(
|
||||
option.has_random_seed() ? option.random_seed() : time(NULL)),
|
||||
curand_generator_(nullptr) {
|
||||
DCHECK_EQ(option.device_type(), CUDA);
|
||||
cuda_gpu_id_ = option.has_cuda_gpu_id() ?
|
||||
option.cuda_gpu_id() : GetDefaultGPUID();
|
||||
|
@ -28,6 +28,7 @@ py_library(
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"caffe_translator.py",
|
||||
"cnn.py",
|
||||
"core.py",
|
||||
"core_gradients.py",
|
||||
"device_checker.py",
|
||||
|
81
pycaffe2/cnn.py
Normal file
81
pycaffe2/cnn.py
Normal file
@ -0,0 +1,81 @@
|
||||
from pycaffe2 import core
|
||||
|
||||
class CNNModelHelper(object):
|
||||
"""A helper model so we can write CNN models more easily, without having to
|
||||
manually define parameter initializations and operators separately.
|
||||
"""
|
||||
def __init__(self, name, order):
|
||||
self.net = core.Net(name)
|
||||
self.param_init_net = core.Net(name + '_init')
|
||||
self.params = []
|
||||
self.order = order
|
||||
if self.order != "NHWC" and self.order != "NCHW":
|
||||
raise ValueError("Cannot understand the CNN storage order.")
|
||||
|
||||
def Conv(self, blob_in, blob_out, dim_in, dim_out, kernel,
|
||||
weight_init, bias_init, **kwargs):
|
||||
"""Convolution. We intentionally do not provide odd kernel/stride/pad
|
||||
settings in order to discourage the use of odd cases.
|
||||
"""
|
||||
weight_shape = ([dim_out, dim_in, kernel, kernel] if self.order == "NCHW"
|
||||
else [dim_out, kernel, kernel, dim_in])
|
||||
weight = self.param_init_net.__getattr__(weight_init[0])(
|
||||
[], blob_out + '_w', shape=weight_shape, **weight_init[1])
|
||||
bias = self.param_init_net.__getattr__(bias_init[0])(
|
||||
[], blob_out + '_b', shape=[dim_out,], **bias_init[1])
|
||||
self.params.extend([weight, bias])
|
||||
return self.net.Conv([blob_in, weight, bias], blob_out, kernel=kernel,
|
||||
order=self.order, **kwargs)
|
||||
|
||||
def GroupConv(self, blob_in, blob_out, dim_in, dim_out, kernel,
|
||||
weight_init, bias_init, group=1, **kwargs):
|
||||
"""Convolution. We intentionally do not provide odd kernel/stride/pad
|
||||
settings in order to discourage the use of odd cases.
|
||||
"""
|
||||
if dim_in % group:
|
||||
raise ValueError("dim_in should be divisible by group.")
|
||||
splitted_blobs = self.net.DepthSplit(
|
||||
blob_in,
|
||||
['_' + blob_out + '_gconv_split_' + str(i) for i in range(group)],
|
||||
dimensions=[dim_in / group for i in range(group)],
|
||||
order=self.order)
|
||||
weight_shape = ([dim_out / group, dim_in / group, kernel, kernel]
|
||||
if self.order == "NCHW"
|
||||
else [dim_out / group, kernel, kernel, dim_in / group])
|
||||
conv_blobs = []
|
||||
for i in range(group):
|
||||
weight = self.param_init_net.__getattr__(weight_init[0])(
|
||||
[], blob_out + '_gconv_%d_w' % i, shape=weight_shape,
|
||||
**weight_init[1])
|
||||
bias = self.param_init_net.__getattr__(bias_init[0])(
|
||||
[], blob_out + '_gconv_%d_b' % i, shape=[dim_out / group],
|
||||
**bias_init[1])
|
||||
self.params.extend([weight, bias])
|
||||
conv_blobs.append(
|
||||
splitted_blobs[i].Conv([weight, bias], blob_out + '_gconv_%d' % i,
|
||||
kernel=kernel, order=self.order, **kwargs))
|
||||
concat = self.net.DepthConcat(conv_blobs, blob_out, order=self.order)
|
||||
return concat
|
||||
|
||||
def FC(self, blob_in, blob_out, dim_in, dim_out, weight_init, bias_init,
|
||||
**kwargs):
|
||||
"""FC"""
|
||||
weight = self.param_init_net.__getattr__(weight_init[0])(
|
||||
[], blob_out + '_w', shape=[dim_out, dim_in], **weight_init[1])
|
||||
bias = self.param_init_net.__getattr__(bias_init[0])(
|
||||
[], blob_out + '_b', shape=[dim_out,], **bias_init[1])
|
||||
self.params.extend([weight, bias])
|
||||
return self.net.FC([blob_in, weight, bias], blob_out, **kwargs)
|
||||
|
||||
def LRN(self, blob_in, blob_out, **kwargs):
|
||||
"""LRN"""
|
||||
return self.net.LRN(blob_in, [blob_out, "_" + blob_out + "_scale"],
|
||||
order=self.order, **kwargs)[0]
|
||||
|
||||
def MaxPool(self, blob_in, blob_out, **kwargs):
|
||||
"""Max pooling"""
|
||||
return self.net.MaxPool(blob_in, blob_out, order=self.order, **kwargs)
|
||||
|
||||
def __getattr__(self, operator_type):
|
||||
"""Catch-all for all other operators, mostly those without params."""
|
||||
return self.net.__getattr__(operator_type)
|
@ -16,6 +16,8 @@ class BlobReference(object):
|
||||
def __init__(self, name, net):
|
||||
self._name = name
|
||||
self._from_net = net
|
||||
# meta allows helper functions to put whatever metainformation needed there.
|
||||
self.meta = {}
|
||||
|
||||
def __str__(self):
|
||||
return self._name
|
||||
|
Reference in New Issue
Block a user