mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
more test for tf benchmark purposes.
This commit is contained in:
@ -163,7 +163,7 @@ bool CudnnConvOp<T>::RunWithCudnnWorkspace(
|
||||
bool input_changed = (X.dims() != cudnn_input_dims_);
|
||||
bool filter_changed = (filter.dims() != cudnn_filter_dims_);
|
||||
if (input_changed || filter_changed) {
|
||||
CAFFE_LOG_INFO << "Changing the cudnn descriptor configurations.";
|
||||
CAFFE_VLOG(1) << "Changing the cudnn descriptor configurations.";
|
||||
if (input_changed) {
|
||||
cudnn_input_dims_ = X.dims();
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
@ -271,7 +271,7 @@ bool CudnnConvGradientOp<T>::RunWithCudnnWorkspace(
|
||||
bool input_changed = (X.dims() != cudnn_input_dims_);
|
||||
bool filter_changed = (filter.dims() != cudnn_filter_dims_);
|
||||
if (input_changed || filter_changed) {
|
||||
CAFFE_LOG_INFO << "Changing the cudnn descriptor configurations.";
|
||||
CAFFE_VLOG(1) << "Changing the cudnn descriptor configurations.";
|
||||
if (input_changed) {
|
||||
cudnn_input_dims_ = X.dims();
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
|
@ -96,6 +96,11 @@ class CNNModelHelper(object):
|
||||
return self.net.MaxPool(blob_in, [blob_out, "_" + blob_out + "_idx"],
|
||||
order=self.order, **kwargs)[0]
|
||||
|
||||
def DepthConcat(self, blobs_in, blob_out, **kwargs):
|
||||
"""Depth Concat."""
|
||||
return self.net.DepthConcat(blobs_in, [blob_out, "_" + blob_out + "_condat_dims"],
|
||||
order=self.order)[0]
|
||||
|
||||
def AddGradientOperators(self):
|
||||
self.net.AddGradientOperators()
|
||||
|
||||
|
@ -105,7 +105,7 @@ def VGGA(order):
|
||||
relu8 = model.Relu(conv8, "conv8")
|
||||
pool8 = model.MaxPool(relu8, "pool8", kernel=2, stride=2)
|
||||
|
||||
fcix = model.FC(pool8, "fcix", 512*7*7, 3072,
|
||||
fcix = model.FC(pool8, "fcix", 512*7*7, 4096,
|
||||
('XavierFill', {}), ('ConstantFill', {}))
|
||||
reluix = model.Relu(fcix, "fcix")
|
||||
fcx = model.FC(reluix, "fcx", 4096, 4096,
|
||||
@ -119,7 +119,84 @@ def VGGA(order):
|
||||
return model, 231
|
||||
|
||||
|
||||
def _InceptionModule(model, input_blob, input_depth, output_name,
|
||||
conv1_depth, conv3_depths, conv5_depths, pool_depth):
|
||||
# path 1: 1x1 conv
|
||||
conv1 = model.Conv(input_blob, output_name + ":conv1",
|
||||
input_depth, conv1_depth, 1,
|
||||
('XavierFill', {}), ('ConstantFill', {}))
|
||||
# path 2: 1x1 conv + 3x3 conv
|
||||
conv3_reduce = model.Conv(input_blob, output_name + ":conv3_reduce",
|
||||
input_depth, conv3_depths[0], 1,
|
||||
('XavierFill', {}), ('ConstantFill', {}))
|
||||
conv3_reduce = model.Relu(conv3_reduce, conv3_reduce)
|
||||
conv3 = model.Conv(conv3_reduce, output_name + ":conv3",
|
||||
conv3_depths[0], conv3_depths[1], 3,
|
||||
('XavierFill', {}), ('ConstantFill', {}), pad=1)
|
||||
# path 3: 1x1 conv + 5x5 conv
|
||||
conv5_reduce = model.Conv(input_blob, output_name + ":conv5_reduce",
|
||||
input_depth, conv5_depths[0], 1,
|
||||
('XavierFill', {}), ('ConstantFill', {}))
|
||||
conv5_reduce = model.Relu(conv5_reduce, conv5_reduce)
|
||||
conv5 = model.Conv(conv5_reduce, output_name + ":conv5",
|
||||
conv5_depths[0], conv5_depths[1], 5,
|
||||
('XavierFill', {}), ('ConstantFill', {}), pad=2)
|
||||
# path 4: pool + 1x1 conv
|
||||
pool = model.MaxPool(input_blob, output_name + ":pool",
|
||||
kernel=3, stride=1, pad=1)
|
||||
pool_proj = model.Conv(pool, output_name + ":pool_proj",
|
||||
input_depth, pool_depth, 1,
|
||||
('XavierFill', {}), ('ConstantFill', {}))
|
||||
output = model.DepthConcat([conv1, conv3, conv5, pool_proj], output_name)
|
||||
# We run Relu after depth concat, which would save a little bit of
|
||||
# kernel launch time (not crucial anyway)
|
||||
output = model.Relu(output, output)
|
||||
return output
|
||||
|
||||
def Inception(order):
|
||||
model = cnn.CNNModelHelper(order, name="inception")
|
||||
conv1 = model.Conv("data", "conv1", 3, 64, 7,
|
||||
('XavierFill', {}), ('ConstantFill', {}), stride=2, pad=3)
|
||||
relu1 = model.Relu(conv1, "conv1")
|
||||
pool1 = model.MaxPool(relu1, "pool1", kernel=3, stride=2, pad=1)
|
||||
conv2a = model.Conv(pool1, "conv2a", 64, 64, 1,
|
||||
('XavierFill', {}), ('ConstantFill', {}))
|
||||
conv2a = model.Relu(conv2a, conv2a)
|
||||
conv2 = model.Conv(conv2a, "conv2", 64, 192, 3,
|
||||
('XavierFill', {}), ('ConstantFill', {}), pad=1)
|
||||
relu2 = model.Relu(conv2, "conv2")
|
||||
pool2 = model.MaxPool(relu2, "pool2", kernel=3, stride=2, pad=1)
|
||||
# Inception modules
|
||||
inc3 = _InceptionModule(model, pool2, 192, "inc3",
|
||||
64, [96, 128], [16, 32], 32)
|
||||
inc4 = _InceptionModule(model, inc3, 256, "inc4",
|
||||
128, [128, 192], [32, 96], 64)
|
||||
pool5 = model.MaxPool(inc4, "pool5", kernel=3, stride=2, pad=1)
|
||||
inc5 = _InceptionModule(model, pool5, 480, "inc5",
|
||||
192, [96, 208], [16, 48], 64)
|
||||
inc6 = _InceptionModule(model, inc5, 512, "inc6",
|
||||
160, [112, 224], [24, 64], 64)
|
||||
inc7 = _InceptionModule(model, inc6, 512, "inc7",
|
||||
128, [128, 256], [24, 64], 64)
|
||||
inc8 = _InceptionModule(model, inc7, 512, "inc8",
|
||||
112, [144, 288], [32, 64], 64)
|
||||
inc9 = _InceptionModule(model, inc8, 528, "inc9",
|
||||
256, [160, 320], [32, 128], 128)
|
||||
pool9 = model.MaxPool(inc9, "pool9", kernel=3, stride=2, pad=1)
|
||||
inc10 = _InceptionModule(model, pool9, 832, "inc10",
|
||||
256, [160, 320], [32, 128], 128)
|
||||
inc11 = _InceptionModule(model, inc10, 832, "inc11",
|
||||
384, [192, 384], [48, 128], 128)
|
||||
pool11 = model.MaxPool(inc11, "pool11", kernel=7, stride=1)
|
||||
fc = model.FC(pool11, "fc", 1024, 1000,
|
||||
('XavierFill', {}), ('ConstantFill', {}))
|
||||
# It seems that Soumith's benchmark does not have softmax on top
|
||||
# for Inception. We will add it anyway so we can have a proper
|
||||
# backward pass.
|
||||
pred = model.Softmax(fc, "pred")
|
||||
xent = model.LabelCrossEntropy([pred, "label"], "xent")
|
||||
loss = model.AveragedLoss(xent, "loss")
|
||||
return model, 224
|
||||
|
||||
|
||||
def Benchmark(model_gen, order, batch_size, cudnn_limit, forward_only,
|
||||
@ -156,6 +233,12 @@ def Benchmark(model_gen, order, batch_size, cudnn_limit, forward_only,
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net)
|
||||
workspace.RunNet(model.net.Proto().name)
|
||||
|
||||
# Print out all the tensors.
|
||||
#for name in workspace.Blobs():
|
||||
# content = workspace.FetchBlob(name)
|
||||
# print name, content if type(content) is str else content.shape
|
||||
|
||||
start = time.time()
|
||||
for i in range(iterations):
|
||||
workspace.RunNet(model.net.Proto().name)
|
||||
@ -177,6 +260,6 @@ if __name__ == '__main__':
|
||||
parser.print_help()
|
||||
|
||||
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
|
||||
model_map = {'AlexNet': AlexNet, 'OverFeat': OverFeat, 'VGGA': VGGA}
|
||||
model_map = {'AlexNet': AlexNet, 'OverFeat': OverFeat, 'VGGA': VGGA, 'Inception': Inception}
|
||||
Benchmark(model_map[args.model], args.order, args.batch_size, args.cudnn_ws,
|
||||
args.forward_only, args.iterations)
|
||||
|
Reference in New Issue
Block a user