Extend Net.RunAllOnGPU() to support RecurrentNetwork op (#15713)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15713

[caffe2] Extend Net.RunAllOnGPU() to support RecurrentNetwork op

Reviewed By: dzhulgakov

Differential Revision: D13576507

fbshipit-source-id: f517127492c9d516ece663d42fef84338c70344e
This commit is contained in:
Nikita Shulga
2019-02-08 14:20:31 -08:00
committed by Facebook Github Bot
parent 48fe839d56
commit 0799a81cb7
2 changed files with 40 additions and 4 deletions

View File

@ -2118,15 +2118,28 @@ class Net(object):
raise ValueError('{} is not supported'.format(aggregator))
return GradientSlice(indices=unique, values=new_g)
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False):
"""A convenient function to run everything on the GPU."""
@staticmethod
def _RunAllOnGPU(net, gpu_id=0, use_cudnn=False):
device_option = caffe2_pb2.DeviceOption()
device_option.device_type = workspace.GpuDeviceType
device_option.device_id = gpu_id
self._net.device_option.CopyFrom(device_option)
net.device_option.CopyFrom(device_option)
if use_cudnn:
for op in self._net.op:
for op in net.op:
op.engine = "CUDNN"
# Move RecurrentNetwork operators on GPU as well
for op in net.op:
if op.type != "RecurrentNetwork":
continue
for arg in op.arg:
if arg.name == "step_net":
Net._RunAllOnGPU(arg.n, gpu_id, use_cudnn)
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False):
"""A convenient function to run everything on the GPU."""
self._RunAllOnGPU(self._net, gpu_id, use_cudnn)
def RunAllOnMKL(self):
"""A convenient function to run everything using MKLDNN."""