mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Updated PyTorch ONNX exporter (markdown)
@ -17,6 +17,7 @@ Documentation for developing the PyTorch-ONNX exporter (`torch.onnx`).
|
||||
* [Relevant parts of PyTorch repo](#relevant-parts-of-pytorch-repo)
|
||||
* [Features](#features)
|
||||
* [Quantized model export](#quantized-model-export)
|
||||
* [Updating default opset_version](#updating-default-opset_version)
|
||||
|
||||
# Development process
|
||||
|
||||
@ -201,4 +202,60 @@ An example of adding unit tests for a new symbolic function: [Add binary_cross_e
|
||||
## Quantized model export
|
||||
|
||||
To support quantized model export, we need to unpack the quantized tensor inputs and the PackedParam weights (https://github.com/pytorch/pytorch/pull/69232). We construct through `TupleConstruct` to have a 1-to-1 input mapping,
|
||||
so that we can use `replaceAllUsesWith` API for its successors. In addition, we support quantized namespace export, and the developers can add more symbolics for quantized operators conveniently in the current framework.
|
||||
so that we can use `replaceAllUsesWith` API for its successors. In addition, we support quantized namespace export, and the developers can add more symbolics for quantized operators conveniently in the current framework.
|
||||
|
||||
# Updating default opset_version
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Outputs what the default value of opset_version should be.
|
||||
|
||||
The current policy is that the default should be set to the
|
||||
latest released version as of 18 months ago.
|
||||
|
||||
Usage:
|
||||
Only argument should be path to onnx git repo.
|
||||
If run from the root of the PyTorch repo, run:
|
||||
$ <this script> third_party/onnx
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
sys.exit("need exactly 1 argument, the path to onnx git repo")
|
||||
|
||||
onnx_dir = sys.argv[1]
|
||||
os.chdir(onnx_dir)
|
||||
|
||||
date = datetime.datetime.now() - datetime.timedelta(days=18*30)
|
||||
onnx_commit = subprocess.check_output(("git", "log", f"--until={date}", "--max-count=1", "--format=%H"), encoding="utf-8").strip()
|
||||
onnx_tags = subprocess.check_output(("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8")
|
||||
tag_tups = []
|
||||
semver_pat = re.compile(r"v(\d+)\.(\d+)\.(\d+)")
|
||||
for tag in onnx_tags.splitlines():
|
||||
match = semver_pat.match(tag)
|
||||
if match:
|
||||
tag_tups.append(tuple(int(x) for x in match.groups()))
|
||||
|
||||
min_tup = sorted(tag_tups)[0]
|
||||
version_str = "{}.{}.{}".format(*min_tup)
|
||||
|
||||
head_commit = subprocess.check_output(("git", "log", "--max-count=1", "--format=%H", "HEAD"), encoding="utf-8").strip()
|
||||
|
||||
subprocess.check_call(("git", "checkout", f"v{version_str}"), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
try:
|
||||
from onnx import helper
|
||||
for version in helper.VERSION_TABLE:
|
||||
if version[0] == version_str:
|
||||
print(version[2])
|
||||
sys.exit() # success
|
||||
sys.exit(f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}")
|
||||
finally:
|
||||
subprocess.check_call(("git", "checkout", head_commit), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
|
||||
```
|
Reference in New Issue
Block a user