Updated PyTorch ONNX exporter (markdown)

Gary Miguel
2022-03-07 14:00:19 -08:00
parent 0850c2073b
commit 5ffe24063d

@ -17,6 +17,7 @@ Documentation for developing the PyTorch-ONNX exporter (`torch.onnx`).
* [Relevant parts of PyTorch repo](#relevant-parts-of-pytorch-repo)
* [Features](#features)
* [Quantized model export](#quantized-model-export)
* [Updating default opset_version](#updating-default-opset_version)
# Development process
@ -201,4 +202,60 @@ An example of adding unit tests for a new symbolic function: [Add binary_cross_e
## Quantized model export
To support quantized model export, we need to unpack the quantized tensor inputs and the PackedParam weights (https://github.com/pytorch/pytorch/pull/69232). We construct through `TupleConstruct` to have a 1-to-1 input mapping,
so that we can use `replaceAllUsesWith` API for its successors. In addition, we support quantized namespace export, and the developers can add more symbolics for quantized operators conveniently in the current framework.
so that we can use `replaceAllUsesWith` API for its successors. In addition, we support quantized namespace export, and the developers can add more symbolics for quantized operators conveniently in the current framework.
# Updating default opset_version
```python
#!/usr/bin/env python
"""Outputs what the default value of opset_version should be.
The current policy is that the default should be set to the
latest released version as of 18 months ago.
Usage:
Only argument should be path to onnx git repo.
If run from the root of the PyTorch repo, run:
$ <this script> third_party/onnx
"""
import datetime
import os
import re
import sys
import subprocess
if len(sys.argv) != 2:
sys.exit("need exactly 1 argument, the path to onnx git repo")
onnx_dir = sys.argv[1]
os.chdir(onnx_dir)
date = datetime.datetime.now() - datetime.timedelta(days=18*30)
onnx_commit = subprocess.check_output(("git", "log", f"--until={date}", "--max-count=1", "--format=%H"), encoding="utf-8").strip()
onnx_tags = subprocess.check_output(("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8")
tag_tups = []
semver_pat = re.compile(r"v(\d+)\.(\d+)\.(\d+)")
for tag in onnx_tags.splitlines():
match = semver_pat.match(tag)
if match:
tag_tups.append(tuple(int(x) for x in match.groups()))
min_tup = sorted(tag_tups)[0]
version_str = "{}.{}.{}".format(*min_tup)
head_commit = subprocess.check_output(("git", "log", "--max-count=1", "--format=%H", "HEAD"), encoding="utf-8").strip()
subprocess.check_call(("git", "checkout", f"v{version_str}"), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
try:
from onnx import helper
for version in helper.VERSION_TABLE:
if version[0] == version_str:
print(version[2])
sys.exit() # success
sys.exit(f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}")
finally:
subprocess.check_call(("git", "checkout", head_commit), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
```