mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51754 This API allows you to manage multiple python interpreters in a single process to deploy PyTorch models packaged with torch.package. torch/csrc/deploy/deploy.h contains the API definition torch/csrc/deploy/test_deploy.cpp has some examples. Notes: * mutex is added to PyTorchStreamReader to make it safe to use from multiple threads at once. * USE_DEPLOY is only true for the special libtorch_deployinterpreter.so library, when enabled we use a hash table to maintain PyObject <> at::Tensor mappping rather than the internal pointer in Tensor since >1 interpreter may have a reference to the tensor. * serialization.py has some additional functions for creating pickle objects but keeping storages in memory for use transfering tensors between interpreters Test Plan: Imported from OSS Reviewed By: wconstab Differential Revision: D26329468 Pulled By: zdevito fbshipit-source-id: d75f4ebb9a27f1d911179d9996041bcb3ca04a07
67 lines
2.0 KiB
Python
67 lines
2.0 KiB
Python
# used by the benchmarking program to wrap cpu models for GPU use
|
|
import torch
|
|
from copy import deepcopy
|
|
|
|
def to_device(i, d):
|
|
if isinstance(i, torch.Tensor):
|
|
return i.to(device=d)
|
|
elif isinstance(i, (tuple, list)):
|
|
return tuple(to_device(e, d) for e in i)
|
|
else:
|
|
raise RuntimeError('inputs are weird')
|
|
|
|
class GPUWrapper(torch.nn.Module):
|
|
def __init__(self, root):
|
|
super().__init__()
|
|
self.models = []
|
|
self.streams = {}
|
|
for i in range(torch.cuda.device_count()):
|
|
m = deepcopy(root) if i != 0 else root
|
|
d = f'cuda:{i}'
|
|
m.to(device=d)
|
|
self.models.append((m, d))
|
|
|
|
def __getstate__(self):
|
|
return self.models
|
|
|
|
def __setstate__(self, models):
|
|
super().__init__()
|
|
self.models = models
|
|
self.streams = {}
|
|
for m, d in models:
|
|
torch.cuda.synchronize(d)
|
|
|
|
# roi_align, 2210 count, ROIAlign_cuda.cu: add threadsync: problem goes away, return rand problem goes away,
|
|
# use different streams here, problem goes away.
|
|
def forward(self, tid, *args):
|
|
m, d = self.models[tid % len(self.models)]
|
|
if tid not in self.streams:
|
|
self.streams[tid] = torch.cuda.Stream(d)
|
|
s = self.streams[tid]
|
|
with torch.cuda.stream(s):
|
|
iput = to_device(args, d)
|
|
r = to_device(m(*iput), 'cpu')
|
|
return r
|
|
|
|
|
|
if __name__ == '__main__':
|
|
def check_close(a, b):
|
|
if isinstance(a, (list, tuple)):
|
|
for ae, be in zip(a, b):
|
|
check_close(ae, be)
|
|
else:
|
|
print(torch.max(torch.abs(a - b)))
|
|
assert torch.allclose(a, b)
|
|
|
|
import sys
|
|
from torch.package import PackageImporter
|
|
i = PackageImporter(sys.argv[1])
|
|
torch.version.interp = 0
|
|
model = i.load_pickle('model', 'model.pkl')
|
|
eg = i.load_pickle('model', 'example.pkl')
|
|
r = model(*eg)
|
|
|
|
gpu_model = GPUWrapper(model)
|
|
r2 = gpu_model(*eg)
|
|
check_close(r, r2)
|