mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 10:01:39 +08:00
Extend Net.RunAllOnGPU() to support RecurrentNetwork op (#15713)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15713 [caffe2] Extend Net.RunAllOnGPU() to support RecurrentNetwork op Reviewed By: dzhulgakov Differential Revision: D13576507 fbshipit-source-id: f517127492c9d516ece663d42fef84338c70344e
This commit is contained in:
committed by
Facebook Github Bot
parent
48fe839d56
commit
0799a81cb7
@ -2118,15 +2118,28 @@ class Net(object):
|
||||
raise ValueError('{} is not supported'.format(aggregator))
|
||||
return GradientSlice(indices=unique, values=new_g)
|
||||
|
||||
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False):
|
||||
"""A convenient function to run everything on the GPU."""
|
||||
@staticmethod
|
||||
def _RunAllOnGPU(net, gpu_id=0, use_cudnn=False):
|
||||
device_option = caffe2_pb2.DeviceOption()
|
||||
device_option.device_type = workspace.GpuDeviceType
|
||||
device_option.device_id = gpu_id
|
||||
self._net.device_option.CopyFrom(device_option)
|
||||
net.device_option.CopyFrom(device_option)
|
||||
if use_cudnn:
|
||||
for op in self._net.op:
|
||||
for op in net.op:
|
||||
op.engine = "CUDNN"
|
||||
# Move RecurrentNetwork operators on GPU as well
|
||||
for op in net.op:
|
||||
if op.type != "RecurrentNetwork":
|
||||
continue
|
||||
for arg in op.arg:
|
||||
if arg.name == "step_net":
|
||||
Net._RunAllOnGPU(arg.n, gpu_id, use_cudnn)
|
||||
|
||||
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False):
|
||||
"""A convenient function to run everything on the GPU."""
|
||||
self._RunAllOnGPU(self._net, gpu_id, use_cudnn)
|
||||
|
||||
|
||||
|
||||
def RunAllOnMKL(self):
|
||||
"""A convenient function to run everything using MKLDNN."""
|
||||
|
||||
Reference in New Issue
Block a user