Files
pytorch/test/onnx/export_onnx_tests_generator.py
Edward Yang 173f224570 Turn on F401: Unused import warning. (#18598)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598
ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18598 Turn on F401: Unused import warning.**

This was requested by someone at Facebook; this lint is turned
on for Facebook by default.  "Sure, why not."

I had to noqa a number of imports in __init__.  Hypothetically
we're supposed to use __all__ in this case, but I was too lazy
to fix it.  Left for future work.

Be careful!  flake8-2 and flake8-3 behave differently with
respect to import resolution for # type: comments.  flake8-3 will
report an import unused; flake8-2 will not.  For now, I just
noqa'd all these sites.

All the changes were done by hand.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D14687478

fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
2019-03-30 09:01:17 -07:00

146 lines
5.2 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from torch.autograd import Variable
from onnx import numpy_helper
import io
import onnx
import os
import shutil
import torch
import traceback
import test_onnx_common
from common_nn import module_tests
from test_nn import new_module_tests
# Take a test case (a dict) as input, return the test name.
def get_test_name(testcase):
if "fullname" in testcase:
return "test_" + testcase["fullname"]
test_name = "test_" + testcase["constructor"].__name__
if "desc" in testcase:
test_name += "_" + testcase["desc"]
return test_name
# Take a test case (a dict) as input, return the input for the module.
def gen_input(testcase):
if "input_size" in testcase:
if testcase["input_size"] == () and "desc" in testcase and testcase["desc"][-6:] == "scalar":
testcase["input_size"] = (1,)
return Variable(torch.randn(*testcase["input_size"]))
elif "input_fn" in testcase:
input = testcase["input_fn"]()
if isinstance(input, Variable):
return input
return Variable(testcase["input_fn"]())
def gen_module(testcase):
if "constructor_args" in testcase:
args = testcase["constructor_args"]
module = testcase["constructor"](*args)
module.train(False)
return module
module = testcase["constructor"]()
module.train(False)
return module
def print_stats(FunctionalModule_nums, nn_module):
print("{} functional modules detected.".format(FunctionalModule_nums))
supported = []
unsupported = []
not_fully_supported = []
for key, value in nn_module.items():
if (value == 1):
supported.append(key)
elif (value == 2):
unsupported.append(key)
elif (value == 3):
not_fully_supported.append(key)
def fun(info, l):
print(info)
for v in l:
print(v)
# Fully Supported Ops: All related test cases of these ops have been exported
# Semi-Supported Ops: Part of related test cases of these ops have been exported
# Unsupported Ops: None of related test cases of these ops have been exported
for info, l in [["{} Fully Supported Operators:".format(len(supported)),
supported],
["{} Semi-Supported Operators:".format(len(not_fully_supported)),
not_fully_supported],
["{} Unsupported Operators:".format(len(unsupported)),
unsupported]]:
fun(info, l)
def convert_tests(testcases, sets=1):
print("Collect {} test cases from PyTorch.".format(len(testcases)))
failed = 0
FunctionalModule_nums = 0
nn_module = {}
for t in testcases:
test_name = get_test_name(t)
module = gen_module(t)
module_name = str(module).split("(")[0]
if (module_name == "FunctionalModule"):
FunctionalModule_nums += 1
else:
if (module_name not in nn_module):
nn_module[module_name] = 0
try:
input = gen_input(t)
f = io.BytesIO()
torch.onnx._export(module, input, f,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
onnx_model = onnx.load_from_string(f.getvalue())
onnx.checker.check_model(onnx_model)
onnx.helper.strip_doc_string(onnx_model)
output_dir = os.path.join(test_onnx_common.pytorch_converted_dir, test_name)
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
with open(os.path.join(output_dir, "model.onnx"), "wb") as file:
file.write(onnx_model.SerializeToString())
for i in range(sets):
output = module(input)
data_dir = os.path.join(output_dir, "test_data_set_{}".format(i))
os.makedirs(data_dir)
for index, var in enumerate([input]):
tensor = numpy_helper.from_array(var.data.numpy())
with open(os.path.join(data_dir, "input_{}.pb".format(index)), "wb") as file:
file.write(tensor.SerializeToString())
for index, var in enumerate([output]):
tensor = numpy_helper.from_array(var.data.numpy())
with open(os.path.join(data_dir, "output_{}.pb".format(index)), "wb") as file:
file.write(tensor.SerializeToString())
input = gen_input(t)
if (module_name != "FunctionalModule"):
nn_module[module_name] |= 1
except: # noqa: E722
traceback.print_exc()
if (module_name != "FunctionalModule"):
nn_module[module_name] |= 2
failed += 1
print("Collect {} test cases from PyTorch repo, failed to export {} cases.".format(
len(testcases), failed))
print("PyTorch converted cases are stored in {}.".format(test_onnx_common.pytorch_converted_dir))
print_stats(FunctionalModule_nums, nn_module)
if __name__ == '__main__':
testcases = module_tests + new_module_tests
convert_tests(testcases)