ONNXIFI transform (#9569)

Summary:
Cut-off runnable subgraph and off-load to ONNXIFI backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9569

Reviewed By: Maratyszcza

Differential Revision: D8930408

Pulled By: yinghai

fbshipit-source-id: 2b494f7f8dc10c00e58cf0fed5c4a9434be6155b
This commit is contained in:
Yinghai Lu
2018-07-20 15:08:02 -07:00
committed by Facebook Github Bot
parent 01581037dc
commit 45e5c17ecf
8 changed files with 689 additions and 10 deletions

View File

@ -53,23 +53,26 @@ def convert_onnx_model_to_trt_op(onnx_model,
op.ParseFromString(trt_str)
return op
def _infer_shapes(init_net, pred_net, inputs):
ws, outputs = c2_native_run_net(init_net, pred_net, inputs)
# Assume the workspace is already filled with init weights
def _infer_shapes(pred_net, inputs):
workspace.RunNetOnce(pred_net)
hints = {}
for op in pred_net.op:
for o in op.output:
if o not in hints:
blob = ws.FetchBlob(o)
blob = workspace.FetchBlob(o)
if hasattr(blob, 'shape'):
hints[o] = blob.shape
for i in op.input:
if i not in hints:
blob = ws.FetchBlob(i)
blob = workspace.FetchBlob(i)
if hasattr(blob, 'shape'):
hints[i] = blob.shape
return hints
def transform_caffe2_net(
pred_net,
input_shapes,
@ -91,7 +94,7 @@ def transform_caffe2_net(
input_data = {}
for k,v in input_shapes.items():
input_data[k] = np.random.randn(*v).astype(np.float32)
shape_hints = _infer_shapes(init_net, pred_net, input_data)
shape_hints = _infer_shapes(pred_net, input_data)
for k,v in input_shapes.items():
shape_hints[k] = v