mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
ONNXIFI transform (#9569)
Summary: Cut-off runnable subgraph and off-load to ONNXIFI backend Pull Request resolved: https://github.com/pytorch/pytorch/pull/9569 Reviewed By: Maratyszcza Differential Revision: D8930408 Pulled By: yinghai fbshipit-source-id: 2b494f7f8dc10c00e58cf0fed5c4a9434be6155b
This commit is contained in:
committed by
Facebook Github Bot
parent
01581037dc
commit
45e5c17ecf
@ -53,23 +53,26 @@ def convert_onnx_model_to_trt_op(onnx_model,
|
||||
op.ParseFromString(trt_str)
|
||||
return op
|
||||
|
||||
def _infer_shapes(init_net, pred_net, inputs):
|
||||
ws, outputs = c2_native_run_net(init_net, pred_net, inputs)
|
||||
|
||||
# Assume the workspace is already filled with init weights
|
||||
def _infer_shapes(pred_net, inputs):
|
||||
workspace.RunNetOnce(pred_net)
|
||||
hints = {}
|
||||
for op in pred_net.op:
|
||||
for o in op.output:
|
||||
if o not in hints:
|
||||
blob = ws.FetchBlob(o)
|
||||
blob = workspace.FetchBlob(o)
|
||||
if hasattr(blob, 'shape'):
|
||||
hints[o] = blob.shape
|
||||
for i in op.input:
|
||||
if i not in hints:
|
||||
blob = ws.FetchBlob(i)
|
||||
blob = workspace.FetchBlob(i)
|
||||
if hasattr(blob, 'shape'):
|
||||
hints[i] = blob.shape
|
||||
|
||||
return hints
|
||||
|
||||
|
||||
def transform_caffe2_net(
|
||||
pred_net,
|
||||
input_shapes,
|
||||
@ -91,7 +94,7 @@ def transform_caffe2_net(
|
||||
input_data = {}
|
||||
for k,v in input_shapes.items():
|
||||
input_data[k] = np.random.randn(*v).astype(np.float32)
|
||||
shape_hints = _infer_shapes(init_net, pred_net, input_data)
|
||||
shape_hints = _infer_shapes(pred_net, input_data)
|
||||
|
||||
for k,v in input_shapes.items():
|
||||
shape_hints[k] = v
|
||||
|
Reference in New Issue
Block a user