mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Opt into ruff fmt (#134120)
Add ONNX directory to use ruff format. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134120 Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
25499de814
commit
b319fa3fd9
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
"""Unit tests for the internal registration wrapper module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
|
@ -22,8 +22,7 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
|
||||
class _SarifLogBuilder(Protocol):
|
||||
def sarif_log(self) -> sarif.SarifLog:
|
||||
...
|
||||
def sarif_log(self) -> sarif.SarifLog: ...
|
||||
|
||||
|
||||
def _assert_has_diagnostics(
|
||||
@ -344,9 +343,7 @@ class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
|
||||
self.assertIn("test_diagnostics.py", frame.location.uri)
|
||||
|
||||
def test_diagnostics_records_cpp_call_stack(self):
|
||||
diagnostic = (
|
||||
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
|
||||
)
|
||||
diagnostic = self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
|
||||
stack = diagnostic.cpp_call_stack
|
||||
assert stack is not None # for mypy
|
||||
self.assertGreater(len(stack.frames), 0)
|
||||
@ -368,9 +365,9 @@ class TestDiagnosticsInfra(common_utils.TestCase):
|
||||
def setUp(self):
|
||||
self.rules = _RuleCollectionForTest()
|
||||
with contextlib.ExitStack() as stack:
|
||||
self.context: infra.DiagnosticContext[
|
||||
infra.Diagnostic
|
||||
] = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
|
||||
self.context: infra.DiagnosticContext[infra.Diagnostic] = (
|
||||
stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
|
||||
)
|
||||
self.addCleanup(stack.pop_all().close)
|
||||
return super().setUp()
|
||||
|
||||
@ -400,12 +397,14 @@ class TestDiagnosticsInfra(common_utils.TestCase):
|
||||
},
|
||||
):
|
||||
diagnostic1 = infra.Diagnostic(
|
||||
custom_rules.custom_rule, infra.Level.WARNING # type: ignore[attr-defined]
|
||||
custom_rules.custom_rule, # type: ignore[attr-defined]
|
||||
infra.Level.WARNING,
|
||||
)
|
||||
self.context.log(diagnostic1)
|
||||
|
||||
diagnostic2 = infra.Diagnostic(
|
||||
custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined]
|
||||
custom_rules.custom_rule_2, # type: ignore[attr-defined]
|
||||
infra.Level.ERROR,
|
||||
)
|
||||
self.context.log(diagnostic2)
|
||||
|
||||
|
@ -49,9 +49,9 @@ class TestGlobalHelpers(common_utils.TestCase):
|
||||
|
||||
class TestOverrideDict(common_utils.TestCase):
|
||||
def setUp(self):
|
||||
self.override_dict: registration.OverrideDict[
|
||||
str, int
|
||||
] = registration.OverrideDict()
|
||||
self.override_dict: registration.OverrideDict[str, int] = (
|
||||
registration.OverrideDict()
|
||||
)
|
||||
|
||||
def test_get_item_returns_base_value_when_no_override(self):
|
||||
self.override_dict.set_base("a", 42)
|
||||
|
@ -44,7 +44,7 @@ class _netG(nn.Module):
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf) x 32 x 32
|
||||
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
||||
nn.Tanh()
|
||||
nn.Tanh(),
|
||||
# state size. (nc) x 64 x 64
|
||||
)
|
||||
|
||||
|
@ -294,8 +294,8 @@ def xfail(error_message: str, reason: Optional[str] = None):
|
||||
except Exception as e:
|
||||
if isinstance(e, torch.onnx.OnnxExporterError):
|
||||
# diagnostic message is in the cause of the exception
|
||||
assert error_message in str(
|
||||
e.__cause__
|
||||
assert (
|
||||
error_message in str(e.__cause__)
|
||||
), f"Expected error message: {error_message} NOT in {str(e.__cause__)}"
|
||||
else:
|
||||
assert error_message in str(
|
||||
|
@ -175,9 +175,7 @@ def _init_test_roi_heads_faster_rcnn():
|
||||
|
||||
resolution = box_roi_pool.output_size[0]
|
||||
representation_size = 1024
|
||||
box_head = faster_rcnn.TwoMLPHead(
|
||||
out_channels * resolution**2, representation_size
|
||||
)
|
||||
box_head = faster_rcnn.TwoMLPHead(out_channels * resolution**2, representation_size)
|
||||
|
||||
representation_size = 1024
|
||||
box_predictor = faster_rcnn.FastRCNNPredictor(representation_size, num_classes)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
|
||||
"""Test the support on onnxscript in PyTorch-ONNX converter."""
|
||||
|
||||
import io
|
||||
from typing import List
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
|
||||
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
|
||||
|
||||
from typing import List
|
||||
|
||||
import onnx_test_common
|
||||
|
@ -6,6 +6,7 @@ Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
|
||||
--produce-onnx-test-data: generate onnx test data
|
||||
--accept: accept onnx updates and overwrite models
|
||||
"""
|
||||
|
||||
import glob
|
||||
import inspect
|
||||
import io
|
||||
@ -879,7 +880,8 @@ class TestOperators(common_utils.TestCase):
|
||||
def forward(self, x_in):
|
||||
x_out = {}
|
||||
x_out["test_key_out"] = torch.add(
|
||||
x_in[list(x_in.keys())[0]], list(x_in.keys())[0] # noqa: RUF015
|
||||
x_in[list(x_in.keys())[0]], # noqa: RUF015
|
||||
list(x_in.keys())[0], # noqa: RUF015
|
||||
)
|
||||
return x_out
|
||||
|
||||
|
@ -483,7 +483,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
def forward(self, x_in):
|
||||
x_out = {}
|
||||
x_out["test_key_out"] = torch.add(
|
||||
x_in[list(x_in.keys())[0]], list(x_in.keys())[0] # noqa: RUF015
|
||||
x_in[list(x_in.keys())[0]], # noqa: RUF015
|
||||
list(x_in.keys())[0], # noqa: RUF015
|
||||
)
|
||||
return x_out
|
||||
|
||||
|
@ -174,7 +174,9 @@ class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
|
||||
_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1,
|
||||
)
|
||||
],
|
||||
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
|
||||
class_name_func=lambda cls,
|
||||
num,
|
||||
params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
|
||||
)
|
||||
class TestUtilityFuns(_BaseTestCase):
|
||||
opset_version = None
|
||||
|
@ -185,7 +185,9 @@ class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
|
||||
# {"onnx_backend": verification.OnnxBackend.ONNX},
|
||||
{"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
|
||||
],
|
||||
class_name_func=lambda cls, idx, input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
|
||||
class_name_func=lambda cls,
|
||||
idx,
|
||||
input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
|
||||
)
|
||||
class TestFindMismatch(pytorch_test_common.ExportTestCase):
|
||||
onnx_backend: verification.OnnxBackend
|
||||
|
@ -79,7 +79,9 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pte") as f:
|
||||
torch.export.save(exported_program, f.name)
|
||||
del exported_program # Delete the exported program to ensure that we are loading from file
|
||||
del (
|
||||
exported_program
|
||||
) # Delete the exported program to ensure that we are loading from file
|
||||
loaded_exported_program = torch.export.load(f.name)
|
||||
|
||||
self._compare_onnx_and_torch_exported_program(
|
||||
|
@ -47,8 +47,12 @@ USE_BLACK_FILELIST = re.compile(
|
||||
"test/[a-h]*/**",
|
||||
# test/[i-j]*/**
|
||||
"test/[i-j]*/**",
|
||||
# test/[k-z]*/**
|
||||
"test/[k-z]*/**",
|
||||
# test/[k-n]*/**
|
||||
"test/[k-n]*/**",
|
||||
# test/optim/**
|
||||
"test/optim/**",
|
||||
# "test/[p-z]*/**",
|
||||
"test/[p-z]*/**",
|
||||
# torch/**
|
||||
# torch/_[a-h]*/**
|
||||
"torch/_[a-h]*/**",
|
||||
@ -62,8 +66,10 @@ USE_BLACK_FILELIST = re.compile(
|
||||
"torch/d*/**",
|
||||
# torch/[e-n]*/**
|
||||
"torch/[e-n]*/**",
|
||||
# torch/[o-z]*/**
|
||||
"torch/[o-z]*/**",
|
||||
# torch/optim/**
|
||||
"torch/optim/**",
|
||||
# torch/[p-z]*/**
|
||||
"torch/[p-z]*/**",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
@ -84,7 +84,6 @@ from .utils import (
|
||||
|
||||
|
||||
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
|
||||
_deprecation,
|
||||
errors,
|
||||
symbolic_caffe2,
|
||||
symbolic_helper,
|
||||
@ -215,12 +214,13 @@ def export(
|
||||
def forward(self, x):
|
||||
return torch.sum(x, dim=1)
|
||||
|
||||
|
||||
torch.onnx.export(
|
||||
SumModule(),
|
||||
(torch.ones(2, 2),),
|
||||
"onnx.pb",
|
||||
input_names=["x"],
|
||||
output_names=["sum"]
|
||||
output_names=["sum"],
|
||||
)
|
||||
|
||||
Produces::
|
||||
@ -256,7 +256,7 @@ def export(
|
||||
"x": {0: "my_custom_axis_name"},
|
||||
# list value: automatic names
|
||||
"sum": [0],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
Produces::
|
||||
|
@ -6,6 +6,7 @@ Do not use this module outside of `torch.onnx` and its tests.
|
||||
Be very judicious when adding any new global variables. Do not create new global
|
||||
variables unless they are absolutely necessary.
|
||||
"""
|
||||
|
||||
import torch._C._onnx as _C_onnx
|
||||
|
||||
# This module should only depend on _constants and nothing else in torch.onnx to keep
|
||||
|
@ -107,9 +107,9 @@ class OnnxRegistry:
|
||||
# NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important
|
||||
# not to directly modify this variable. Instead, access to it should be done through
|
||||
# the public methods: register_custom_op, get_ops, and is_registered_op.
|
||||
self._registry: dict[
|
||||
registration.OpName, list[registration.ONNXFunction]
|
||||
] = defaultdict(list)
|
||||
self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
# FIXME: Avoid importing onnxscript into torch
|
||||
from onnxscript.function_libs.torch_lib import ( # type: ignore[import] # noqa: F401
|
||||
registration,
|
||||
@ -392,8 +392,10 @@ class ResolvedExportOptions(ExportOptions):
|
||||
)
|
||||
|
||||
self.onnx_registry = resolve(options.onnx_registry, OnnxRegistry())
|
||||
self.decomposition_table = decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment]
|
||||
self.onnx_registry
|
||||
self.decomposition_table = (
|
||||
decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment]
|
||||
self.onnx_registry
|
||||
)
|
||||
)
|
||||
|
||||
from torch.onnx._internal.fx import onnxfunction_dispatcher
|
||||
@ -766,6 +768,7 @@ class ONNXProgram:
|
||||
... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False)
|
||||
... self.fc1 = torch.nn.Linear(9216, 128, bias=False)
|
||||
... self.fc2 = torch.nn.Linear(128, 10, bias=False)
|
||||
...
|
||||
... def forward(self, x, b):
|
||||
... tensor_x = self.conv1(x)
|
||||
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
|
||||
@ -778,11 +781,13 @@ class ONNXProgram:
|
||||
... tensor_x = self.fc2(tensor_x)
|
||||
... output = torch.nn.functional.log_softmax(tensor_x, dim=1)
|
||||
... (
|
||||
... self.my_buffer2.add_(1.0) + self.my_buffer1
|
||||
... self.my_buffer2.add_(1.0) + self.my_buffer1
|
||||
... ) # Mutate buffer through in-place addition
|
||||
... return output
|
||||
>>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3))
|
||||
>>> exported_program = torch.export.export(CustomModule(), args=inputs).run_decompositions({})
|
||||
>>> exported_program = torch.export.export(
|
||||
... CustomModule(), args=inputs
|
||||
... ).run_decompositions({})
|
||||
>>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs)
|
||||
>>> pprint.pprint(onnx_program.model_signature)
|
||||
ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,
|
||||
@ -1194,9 +1199,7 @@ class Exporter:
|
||||
|
||||
with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips(
|
||||
self.options
|
||||
), torch._dynamo.config.patch(
|
||||
dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
|
||||
):
|
||||
), torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
|
||||
graph_module = self.options.fx_tracer.generate_fx(
|
||||
self.options, self.model, self.model_args, self.model_kwargs
|
||||
)
|
||||
@ -1401,17 +1404,19 @@ def dynamo_export(
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x, bias=None):
|
||||
out = self.linear(x)
|
||||
out = out + bias
|
||||
return out
|
||||
|
||||
|
||||
model = MyModel()
|
||||
kwargs = {"bias": 3.}
|
||||
kwargs = {"bias": 3.0}
|
||||
args = (torch.randn(2, 2, 2),)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
model,
|
||||
*args,
|
||||
**kwargs).save("my_simple_model.onnx")
|
||||
onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
|
||||
"my_simple_model.onnx"
|
||||
)
|
||||
|
||||
**Example 2 - Exporting with dynamic shapes**
|
||||
::
|
||||
@ -1419,10 +1424,8 @@ def dynamo_export(
|
||||
# The previous model can be exported with dynamic shapes
|
||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
model,
|
||||
*args,
|
||||
**kwargs,
|
||||
export_options=export_options)
|
||||
model, *args, **kwargs, export_options=export_options
|
||||
)
|
||||
onnx_program.save("my_dynamic_model.onnx")
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Utility to lazily import modules."""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
@ -22,9 +22,9 @@ class ArtifactContent(object):
|
||||
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "properties"}
|
||||
)
|
||||
rendered: Optional[
|
||||
_multiformat_message_string.MultiformatMessageString
|
||||
] = dataclasses.field(default=None, metadata={"schema_property_name": "rendered"})
|
||||
rendered: Optional[_multiformat_message_string.MultiformatMessageString] = (
|
||||
dataclasses.field(default=None, metadata={"schema_property_name": "rendered"})
|
||||
)
|
||||
text: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "text"}
|
||||
)
|
||||
|
@ -19,10 +19,10 @@ class Conversion(object):
|
||||
"""Describes how a converter transformed the output of a static analysis tool from the analysis tool's native output format into the SARIF format."""
|
||||
|
||||
tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"})
|
||||
analysis_tool_log_files: Optional[
|
||||
List[_artifact_location.ArtifactLocation]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "analysisToolLogFiles"}
|
||||
analysis_tool_log_files: Optional[List[_artifact_location.ArtifactLocation]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "analysisToolLogFiles"}
|
||||
)
|
||||
)
|
||||
invocation: Optional[_invocation.Invocation] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "invocation"}
|
||||
|
@ -53,10 +53,10 @@ class ExternalProperties(object):
|
||||
invocations: Optional[List[_invocation.Invocation]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "invocations"}
|
||||
)
|
||||
logical_locations: Optional[
|
||||
List[_logical_location.LogicalLocation]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "logicalLocations"}
|
||||
logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "logicalLocations"}
|
||||
)
|
||||
)
|
||||
policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "policies"}
|
||||
@ -76,10 +76,10 @@ class ExternalProperties(object):
|
||||
taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "taxonomies"}
|
||||
)
|
||||
thread_flow_locations: Optional[
|
||||
List[_thread_flow_location.ThreadFlowLocation]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "threadFlowLocations"}
|
||||
thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "threadFlowLocations"}
|
||||
)
|
||||
)
|
||||
translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "translations"}
|
||||
|
@ -36,10 +36,10 @@ class Invocation(object):
|
||||
environment_variables: Any = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "environmentVariables"}
|
||||
)
|
||||
executable_location: Optional[
|
||||
_artifact_location.ArtifactLocation
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "executableLocation"}
|
||||
executable_location: Optional[_artifact_location.ArtifactLocation] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "executableLocation"}
|
||||
)
|
||||
)
|
||||
exit_code: Optional[int] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "exitCode"}
|
||||
@ -71,10 +71,10 @@ class Invocation(object):
|
||||
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "properties"}
|
||||
)
|
||||
response_files: Optional[
|
||||
List[_artifact_location.ArtifactLocation]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "responseFiles"}
|
||||
response_files: Optional[List[_artifact_location.ArtifactLocation]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "responseFiles"}
|
||||
)
|
||||
)
|
||||
rule_configuration_overrides: Optional[
|
||||
List[_configuration_override.ConfigurationOverride]
|
||||
@ -96,21 +96,22 @@ class Invocation(object):
|
||||
stdout_stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "stdoutStderr"}
|
||||
)
|
||||
tool_configuration_notifications: Optional[
|
||||
List[_notification.Notification]
|
||||
] = dataclasses.field(
|
||||
default=None,
|
||||
metadata={"schema_property_name": "toolConfigurationNotifications"},
|
||||
tool_configuration_notifications: Optional[List[_notification.Notification]] = (
|
||||
dataclasses.field(
|
||||
default=None,
|
||||
metadata={"schema_property_name": "toolConfigurationNotifications"},
|
||||
)
|
||||
)
|
||||
tool_execution_notifications: Optional[
|
||||
List[_notification.Notification]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "toolExecutionNotifications"}
|
||||
tool_execution_notifications: Optional[List[_notification.Notification]] = (
|
||||
dataclasses.field(
|
||||
default=None,
|
||||
metadata={"schema_property_name": "toolExecutionNotifications"},
|
||||
)
|
||||
)
|
||||
working_directory: Optional[
|
||||
_artifact_location.ArtifactLocation
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "workingDirectory"}
|
||||
working_directory: Optional[_artifact_location.ArtifactLocation] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "workingDirectory"}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -24,26 +24,26 @@ class Location(object):
|
||||
default=None, metadata={"schema_property_name": "annotations"}
|
||||
)
|
||||
id: int = dataclasses.field(default=-1, metadata={"schema_property_name": "id"})
|
||||
logical_locations: Optional[
|
||||
List[_logical_location.LogicalLocation]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "logicalLocations"}
|
||||
logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "logicalLocations"}
|
||||
)
|
||||
)
|
||||
message: Optional[_message.Message] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "message"}
|
||||
)
|
||||
physical_location: Optional[
|
||||
_physical_location.PhysicalLocation
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "physicalLocation"}
|
||||
physical_location: Optional[_physical_location.PhysicalLocation] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "physicalLocation"}
|
||||
)
|
||||
)
|
||||
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "properties"}
|
||||
)
|
||||
relationships: Optional[
|
||||
List[_location_relationship.LocationRelationship]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "relationships"}
|
||||
relationships: Optional[List[_location_relationship.LocationRelationship]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "relationships"}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -21,10 +21,10 @@ class PhysicalLocation(object):
|
||||
address: Optional[_address.Address] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "address"}
|
||||
)
|
||||
artifact_location: Optional[
|
||||
_artifact_location.ArtifactLocation
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "artifactLocation"}
|
||||
artifact_location: Optional[_artifact_location.ArtifactLocation] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "artifactLocation"}
|
||||
)
|
||||
)
|
||||
context_region: Optional[_region.Region] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "contextRegion"}
|
||||
|
@ -19,10 +19,10 @@ class ReportingDescriptor(object):
|
||||
"""Metadata that describes a specific report produced by the tool, as part of the analysis it provides or its runtime reporting."""
|
||||
|
||||
id: str = dataclasses.field(metadata={"schema_property_name": "id"})
|
||||
default_configuration: Optional[
|
||||
_reporting_configuration.ReportingConfiguration
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "defaultConfiguration"}
|
||||
default_configuration: Optional[_reporting_configuration.ReportingConfiguration] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "defaultConfiguration"}
|
||||
)
|
||||
)
|
||||
deprecated_guids: Optional[List[str]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "deprecatedGuids"}
|
||||
@ -33,17 +33,17 @@ class ReportingDescriptor(object):
|
||||
deprecated_names: Optional[List[str]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "deprecatedNames"}
|
||||
)
|
||||
full_description: Optional[
|
||||
_multiformat_message_string.MultiformatMessageString
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "fullDescription"}
|
||||
full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "fullDescription"}
|
||||
)
|
||||
)
|
||||
guid: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "guid"}
|
||||
)
|
||||
help: Optional[
|
||||
_multiformat_message_string.MultiformatMessageString
|
||||
] = dataclasses.field(default=None, metadata={"schema_property_name": "help"})
|
||||
help: Optional[_multiformat_message_string.MultiformatMessageString] = (
|
||||
dataclasses.field(default=None, metadata={"schema_property_name": "help"})
|
||||
)
|
||||
help_uri: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "helpUri"}
|
||||
)
|
||||
|
@ -28,10 +28,10 @@ class ReportingDescriptorReference(object):
|
||||
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "properties"}
|
||||
)
|
||||
tool_component: Optional[
|
||||
_tool_component_reference.ToolComponentReference
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "toolComponent"}
|
||||
tool_component: Optional[_tool_component_reference.ToolComponentReference] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "toolComponent"}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -38,10 +38,10 @@ class Result(object):
|
||||
attachments: Optional[List[_attachment.Attachment]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "attachments"}
|
||||
)
|
||||
baseline_state: Optional[
|
||||
Literal["new", "unchanged", "updated", "absent"]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "baselineState"}
|
||||
baseline_state: Optional[Literal["new", "unchanged", "updated", "absent"]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "baselineState"}
|
||||
)
|
||||
)
|
||||
code_flows: Optional[List[_code_flow.CodeFlow]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "codeFlows"}
|
||||
@ -55,10 +55,10 @@ class Result(object):
|
||||
fixes: Optional[List[_fix.Fix]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "fixes"}
|
||||
)
|
||||
graph_traversals: Optional[
|
||||
List[_graph_traversal.GraphTraversal]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "graphTraversals"}
|
||||
graph_traversals: Optional[List[_graph_traversal.GraphTraversal]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "graphTraversals"}
|
||||
)
|
||||
)
|
||||
graphs: Optional[List[_graph.Graph]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "graphs"}
|
||||
@ -96,9 +96,9 @@ class Result(object):
|
||||
related_locations: Optional[List[_location.Location]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "relatedLocations"}
|
||||
)
|
||||
rule: Optional[
|
||||
_reporting_descriptor_reference.ReportingDescriptorReference
|
||||
] = dataclasses.field(default=None, metadata={"schema_property_name": "rule"})
|
||||
rule: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = (
|
||||
dataclasses.field(default=None, metadata={"schema_property_name": "rule"})
|
||||
)
|
||||
rule_id: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "ruleId"}
|
||||
)
|
||||
|
@ -16,10 +16,10 @@ from torch.onnx._internal.diagnostics.infra.sarif import (
|
||||
class ResultProvenance(object):
|
||||
"""Contains information about how and when a result was detected."""
|
||||
|
||||
conversion_sources: Optional[
|
||||
List[_physical_location.PhysicalLocation]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "conversionSources"}
|
||||
conversion_sources: Optional[List[_physical_location.PhysicalLocation]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "conversionSources"}
|
||||
)
|
||||
)
|
||||
first_detection_run_guid: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "firstDetectionRunGuid"}
|
||||
|
@ -38,17 +38,17 @@ class Run(object):
|
||||
artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "artifacts"}
|
||||
)
|
||||
automation_details: Optional[
|
||||
_run_automation_details.RunAutomationDetails
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "automationDetails"}
|
||||
automation_details: Optional[_run_automation_details.RunAutomationDetails] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "automationDetails"}
|
||||
)
|
||||
)
|
||||
baseline_guid: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "baselineGuid"}
|
||||
)
|
||||
column_kind: Optional[
|
||||
Literal["utf16CodeUnits", "unicodeCodePoints"]
|
||||
] = dataclasses.field(default=None, metadata={"schema_property_name": "columnKind"})
|
||||
column_kind: Optional[Literal["utf16CodeUnits", "unicodeCodePoints"]] = (
|
||||
dataclasses.field(default=None, metadata={"schema_property_name": "columnKind"})
|
||||
)
|
||||
conversion: Optional[_conversion.Conversion] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "conversion"}
|
||||
)
|
||||
@ -73,10 +73,10 @@ class Run(object):
|
||||
language: str = dataclasses.field(
|
||||
default="en-US", metadata={"schema_property_name": "language"}
|
||||
)
|
||||
logical_locations: Optional[
|
||||
List[_logical_location.LogicalLocation]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "logicalLocations"}
|
||||
logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "logicalLocations"}
|
||||
)
|
||||
)
|
||||
newline_sequences: List[str] = dataclasses.field(
|
||||
default_factory=lambda: ["\r\n", "\n"],
|
||||
@ -97,23 +97,23 @@ class Run(object):
|
||||
results: Optional[List[_result.Result]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "results"}
|
||||
)
|
||||
run_aggregates: Optional[
|
||||
List[_run_automation_details.RunAutomationDetails]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "runAggregates"}
|
||||
run_aggregates: Optional[List[_run_automation_details.RunAutomationDetails]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "runAggregates"}
|
||||
)
|
||||
)
|
||||
special_locations: Optional[
|
||||
_special_locations.SpecialLocations
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "specialLocations"}
|
||||
special_locations: Optional[_special_locations.SpecialLocations] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "specialLocations"}
|
||||
)
|
||||
)
|
||||
taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "taxonomies"}
|
||||
)
|
||||
thread_flow_locations: Optional[
|
||||
List[_thread_flow_location.ThreadFlowLocation]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "threadFlowLocations"}
|
||||
thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "threadFlowLocations"}
|
||||
)
|
||||
)
|
||||
translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "translations"}
|
||||
|
@ -21,10 +21,10 @@ class ToolComponent(object):
|
||||
"""A component, such as a plug-in or the driver, of the analysis tool that was run."""
|
||||
|
||||
name: str = dataclasses.field(metadata={"schema_property_name": "name"})
|
||||
associated_component: Optional[
|
||||
_tool_component_reference.ToolComponentReference
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "associatedComponent"}
|
||||
associated_component: Optional[_tool_component_reference.ToolComponentReference] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "associatedComponent"}
|
||||
)
|
||||
)
|
||||
contents: List[Literal["localizedData", "nonLocalizedData"]] = dataclasses.field(
|
||||
default_factory=lambda: ["localizedData", "nonLocalizedData"],
|
||||
@ -36,10 +36,10 @@ class ToolComponent(object):
|
||||
download_uri: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "downloadUri"}
|
||||
)
|
||||
full_description: Optional[
|
||||
_multiformat_message_string.MultiformatMessageString
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "fullDescription"}
|
||||
full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "fullDescription"}
|
||||
)
|
||||
)
|
||||
full_name: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "fullName"}
|
||||
@ -71,10 +71,10 @@ class ToolComponent(object):
|
||||
"schema_property_name": "minimumRequiredLocalizedDataSemanticVersion"
|
||||
},
|
||||
)
|
||||
notifications: Optional[
|
||||
List[_reporting_descriptor.ReportingDescriptor]
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "notifications"}
|
||||
notifications: Optional[List[_reporting_descriptor.ReportingDescriptor]] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "notifications"}
|
||||
)
|
||||
)
|
||||
organization: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "organization"}
|
||||
@ -91,9 +91,9 @@ class ToolComponent(object):
|
||||
release_date_utc: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "releaseDateUtc"}
|
||||
)
|
||||
rules: Optional[
|
||||
List[_reporting_descriptor.ReportingDescriptor]
|
||||
] = dataclasses.field(default=None, metadata={"schema_property_name": "rules"})
|
||||
rules: Optional[List[_reporting_descriptor.ReportingDescriptor]] = (
|
||||
dataclasses.field(default=None, metadata={"schema_property_name": "rules"})
|
||||
)
|
||||
semantic_version: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "semanticVersion"}
|
||||
)
|
||||
@ -110,10 +110,10 @@ class ToolComponent(object):
|
||||
taxa: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "taxa"}
|
||||
)
|
||||
translation_metadata: Optional[
|
||||
_translation_metadata.TranslationMetadata
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "translationMetadata"}
|
||||
translation_metadata: Optional[_translation_metadata.TranslationMetadata] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "translationMetadata"}
|
||||
)
|
||||
)
|
||||
version: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "version"}
|
||||
|
@ -20,10 +20,10 @@ class TranslationMetadata(object):
|
||||
download_uri: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "downloadUri"}
|
||||
)
|
||||
full_description: Optional[
|
||||
_multiformat_message_string.MultiformatMessageString
|
||||
] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "fullDescription"}
|
||||
full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
|
||||
dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "fullDescription"}
|
||||
)
|
||||
)
|
||||
full_name: Optional[str] = dataclasses.field(
|
||||
default=None, metadata={"schema_property_name": "fullName"}
|
||||
|
@ -808,9 +808,9 @@ def _exported_program_to_onnx_program(
|
||||
value, Sequence
|
||||
), f"Input '{value_name}' should not be a sequence. This is unexpected."
|
||||
|
||||
value.metadata_props[
|
||||
"pkg.torch.export.graph_signature.InputSpec.kind"
|
||||
] = input_kind.name
|
||||
value.metadata_props["pkg.torch.export.graph_signature.InputSpec.kind"] = (
|
||||
input_kind.name
|
||||
)
|
||||
value.metadata_props[
|
||||
"pkg.torch.export.graph_signature.InputSpec.persistent"
|
||||
] = str(persistent)
|
||||
@ -859,9 +859,9 @@ def _exported_program_to_onnx_program(
|
||||
)
|
||||
|
||||
for value in _values:
|
||||
value.metadata_props[
|
||||
"pkg.torch.export.graph_signature.OutputSpec.kind"
|
||||
] = output_kind.name
|
||||
value.metadata_props["pkg.torch.export.graph_signature.OutputSpec.kind"] = (
|
||||
output_kind.name
|
||||
)
|
||||
if output_kind == graph_signature.OutputKind.USER_OUTPUT:
|
||||
model.graph.outputs.append(value)
|
||||
|
||||
@ -1218,7 +1218,9 @@ def export(
|
||||
if byte_size < 2 * 1024 * 1024 * 1024:
|
||||
# The checker may segfault so we need to run it in a separate process
|
||||
_isolated.safe_call(
|
||||
onnx.checker.check_model, onnx_program.model_proto, full_check=True # type: ignore[attr-defined]
|
||||
onnx.checker.check_model,
|
||||
onnx_program.model_proto,
|
||||
full_check=True, # type: ignore[attr-defined]
|
||||
)
|
||||
export_status.onnx_checker = True
|
||||
verbose_print("Run `onnx.checker` on the ONNX model... ✅")
|
||||
@ -1312,9 +1314,7 @@ def export(
|
||||
_format_exceptions_for_all_strategies(failed_results)
|
||||
)
|
||||
if onnx_runtime_error_message:
|
||||
traceback_lines.append(
|
||||
"# ⚠️ ONNX Runtime error -----------------------"
|
||||
)
|
||||
traceback_lines.append("# ⚠️ ONNX Runtime error -----------------------")
|
||||
traceback_lines.append(onnx_runtime_error_message)
|
||||
if not traceback_lines:
|
||||
traceback_lines.append("No errors")
|
||||
|
@ -304,8 +304,8 @@ def _get_allowed_types_from_type_annotation(
|
||||
allowed_types = set()
|
||||
subtypes = typing.get_args(type_)
|
||||
for subtype in subtypes:
|
||||
assert subtype is not type(
|
||||
None
|
||||
assert (
|
||||
subtype is not type(None)
|
||||
), "Union should not contain None type because it is handled by _is_optional."
|
||||
allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
|
||||
return allowed_types
|
||||
|
@ -235,8 +235,7 @@ class Transform(abc.ABC):
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
|
||||
...
|
||||
def _run(self, *args, **kwargs) -> torch.fx.GraphModule: ...
|
||||
|
||||
@diagnostics.diagnose_call(
|
||||
diagnostics.rules.fx_pass,
|
||||
@ -321,5 +320,4 @@ class Analysis(abc.ABC):
|
||||
self.onnxfunction_dispatcher = onnxfunction_dispatcher
|
||||
|
||||
@abc.abstractmethod
|
||||
def analyze(self, diagnostic_level: diagnostics.infra.Level) -> AnalysisResult:
|
||||
...
|
||||
def analyze(self, diagnostic_level: diagnostics.infra.Level) -> AnalysisResult: ...
|
||||
|
@ -11,6 +11,7 @@ https://github.com/pytorch/pytorch/issues/115883
|
||||
|
||||
This solution will no longer be required once the issue is resolved.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
@ -94,7 +94,9 @@ class _PyTreeExtensionContext:
|
||||
|
||||
for _, class_type in named_model_output_classes:
|
||||
self.register_pytree_node(
|
||||
class_type, model_output_flatten, model_output_unflatten # type: ignore[arg-type ]
|
||||
class_type,
|
||||
model_output_flatten,
|
||||
model_output_unflatten, # type: ignore[arg-type ]
|
||||
)
|
||||
|
||||
|
||||
|
@ -626,7 +626,8 @@ class FxOnnxInterpreter:
|
||||
):
|
||||
# aten ops and other stateless functions.
|
||||
if node.target == operator.getitem and isinstance(
|
||||
fx_name_to_onnxscript_value[node.args[0].name], tuple # type: ignore[union-attr,index]
|
||||
fx_name_to_onnxscript_value[node.args[0].name], # type: ignore[union-attr,index]
|
||||
tuple,
|
||||
):
|
||||
onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index]
|
||||
index = node.args[1]
|
||||
@ -660,9 +661,10 @@ class FxOnnxInterpreter:
|
||||
diagnostic_context=self.diagnostic_context,
|
||||
)
|
||||
with onnxscript.evaluator.default_as(onnxscript_tracer):
|
||||
output: onnxscript_graph_building.TorchScriptTensor | tuple[
|
||||
onnxscript_graph_building.TorchScriptTensor, ...
|
||||
] = symbolic_fn(*onnx_args, **onnx_kwargs)
|
||||
output: (
|
||||
onnxscript_graph_building.TorchScriptTensor
|
||||
| tuple[onnxscript_graph_building.TorchScriptTensor, ...]
|
||||
) = symbolic_fn(*onnx_args, **onnx_kwargs)
|
||||
assert (
|
||||
output is not None
|
||||
), f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}"
|
||||
@ -779,9 +781,10 @@ class FxOnnxInterpreter:
|
||||
# be considered.
|
||||
unique_module_name = f"{sub_module._get_name()}_{node.target}"
|
||||
|
||||
outputs: onnxscript_graph_building.TorchScriptTensor | tuple[
|
||||
onnxscript_graph_building.TorchScriptTensor, ...
|
||||
] = parent_onnxscript_graph.add_module_call( # type: ignore[assignment]
|
||||
outputs: (
|
||||
onnxscript_graph_building.TorchScriptTensor
|
||||
| tuple[onnxscript_graph_building.TorchScriptTensor, ...]
|
||||
) = parent_onnxscript_graph.add_module_call( # type: ignore[assignment]
|
||||
unique_module_name, sub_onnxscript_graph, onnx_args
|
||||
)
|
||||
|
||||
|
@ -147,8 +147,10 @@ class FXSymbolicTracer(_exporter_legacy.FXGraphExtractor):
|
||||
for v in x.values():
|
||||
out += v
|
||||
return out
|
||||
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
|
||||
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
|
||||
|
||||
|
||||
f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
|
||||
assert f({"a": 1, "b": 2, "c": 4}) == 7
|
||||
"""
|
||||
|
||||
def __init__(self, concrete_args: dict[str, Any] | None = None):
|
||||
|
@ -415,10 +415,9 @@ class _OnnxSchemaChecker:
|
||||
inputs = (Tensor[2, 3], Tensor[2, 3])
|
||||
attributes = {"alpha": 1.0}
|
||||
|
||||
@torch_op("aten::op")
|
||||
def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal:
|
||||
...
|
||||
|
||||
@torch_op("aten::op")
|
||||
def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ...
|
||||
```
|
||||
Result: Perfect match.
|
||||
|
||||
|
@ -295,7 +295,7 @@ def _convert_torch_args_to_onnxfunction_args(
|
||||
args: list[fx_type_utils.Argument],
|
||||
kwargs: dict[str, fx_type_utils.Argument],
|
||||
allow_extra_kwargs: bool = False,
|
||||
) -> tuple[list[Any], dict[str, Any],]:
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""Convert Python args and kwargs to OnnxFunction acceptable with matching ONNX ParamSchema.
|
||||
|
||||
NOTE: This is different from the param_schema separating in dispatcher, since at this point
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
These functions should NOT be directly invoked outside of `passes` package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
|
@ -66,9 +66,7 @@ class Decompose(_pass.Transform):
|
||||
|
||||
# Apply decomposition table to the input graph.
|
||||
assert fake_mode is not None # for mypy
|
||||
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), (
|
||||
fake_mode
|
||||
):
|
||||
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), fake_mode:
|
||||
decomposed_module = proxy_tensor.make_fx(
|
||||
module,
|
||||
decomposition_table=self.decomposition_table,
|
||||
|
@ -814,7 +814,9 @@ class Modularize(_pass.Transform):
|
||||
>>> out = self.linear(out)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(torch.tensor([0, 1, 2]))
|
||||
>>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(
|
||||
... torch.tensor([0, 1, 2])
|
||||
... )
|
||||
>>> gm.print_readable()
|
||||
|
||||
>>> gm = passes.Modularize(infra.DiagnosticContext("test_context", "1.0"), gm).run()
|
||||
|
@ -76,16 +76,13 @@ class TypePromotionRule(abc.ABC):
|
||||
# A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None.
|
||||
# Ref: https://docs.python.org/3/reference/datamodel.html#object.__hash__
|
||||
@abc.abstractmethod
|
||||
def __hash__(self) -> int:
|
||||
...
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __repr__(self):
|
||||
...
|
||||
def __repr__(self): ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other: object) -> bool:
|
||||
...
|
||||
def __eq__(self, other: object) -> bool: ...
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the rule is valid."""
|
||||
|
@ -95,7 +95,9 @@ class TorchExport(_exporter_legacy.FXGraphExtractor):
|
||||
model = model.run_decompositions(options.decomposition_table)
|
||||
|
||||
# Export FX graph to ONNX ModelProto.
|
||||
return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value]
|
||||
return self.pre_export_passes( # type: ignore[return-value]
|
||||
options, model, model.graph_module, updated_model_args
|
||||
)
|
||||
|
||||
def pre_export_passes(
|
||||
self,
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Utilities for converting and operating on ONNX, JIT and torch types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
@ -31,8 +32,7 @@ if TYPE_CHECKING:
|
||||
@runtime_checkable
|
||||
class TensorLike(Protocol):
|
||||
@property
|
||||
def dtype(self) -> torch.dtype | None:
|
||||
...
|
||||
def dtype(self) -> torch.dtype | None: ...
|
||||
|
||||
|
||||
def is_torch_complex_dtype(tensor_dtype: torch.dtype) -> bool:
|
||||
|
@ -40,8 +40,7 @@ class InputAdaptStep(Protocol):
|
||||
model_args: Sequence[Any],
|
||||
model_kwargs: Mapping[str, Any],
|
||||
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
||||
) -> tuple[Sequence[Any], Mapping[str, Any]]:
|
||||
...
|
||||
) -> tuple[Sequence[Any], Mapping[str, Any]]: ...
|
||||
|
||||
|
||||
class InputAdapter:
|
||||
@ -98,8 +97,7 @@ class OutputAdaptStep(Protocol):
|
||||
self,
|
||||
model_outputs: Any,
|
||||
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
|
||||
) -> Any:
|
||||
...
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class OutputAdapter:
|
||||
@ -573,7 +571,8 @@ class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
|
||||
A tuple of the model args and kwargs.
|
||||
"""
|
||||
ordered_params = tuple(
|
||||
model.state_dict[name] for name in model.graph_signature.parameters # type: ignore[union-attr,index]
|
||||
model.state_dict[name] # type: ignore[union-attr,index]
|
||||
for name in model.graph_signature.parameters # type: ignore[union-attr]
|
||||
)
|
||||
non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[union-attr]
|
||||
ordered_buffers = []
|
||||
@ -583,7 +582,8 @@ class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
|
||||
else:
|
||||
ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index]
|
||||
ordered_constant_tensors = tuple(
|
||||
model.constants[fqn] for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr,index]
|
||||
model.constants[fqn] # type: ignore[union-attr,index]
|
||||
for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
|
||||
|
@ -304,7 +304,7 @@ def _get_onnx_devices(
|
||||
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
|
||||
],
|
||||
...,
|
||||
]
|
||||
],
|
||||
) -> Tuple["ORTC.OrtDevice", ...]:
|
||||
def _device_id_or_zero(device_id: int) -> int:
|
||||
return device_id or 0
|
||||
@ -403,7 +403,12 @@ def _adjust_scalar_from_onnx_to_fx(
|
||||
torch.SymBool,
|
||||
bool,
|
||||
],
|
||||
) -> Union[torch.Tensor, int, float, bool,]:
|
||||
) -> Union[
|
||||
torch.Tensor,
|
||||
int,
|
||||
float,
|
||||
bool,
|
||||
]:
|
||||
"""Helper function to wrap ORT-produced torch.Tensor as PyTorch variables"""
|
||||
assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor."
|
||||
if isinstance(
|
||||
@ -561,9 +566,9 @@ class OrtExecutionInfoPerSession:
|
||||
self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices
|
||||
# This is the outputs of executing the original torch.fx.GraphModule with example inputs
|
||||
# (i.e., args passed into OrtBackend._ort_acclerated_call).
|
||||
self.example_outputs: Union[
|
||||
Tuple[torch.Tensor, ...], torch.Tensor
|
||||
] = example_outputs
|
||||
self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = (
|
||||
example_outputs
|
||||
)
|
||||
|
||||
def is_supported(self, *args):
|
||||
# Compare the args and the input schema in ONNX model and
|
||||
|
@ -276,10 +276,13 @@ def onnx_symbolic(
|
||||
Usage::
|
||||
|
||||
```
|
||||
@onnx_symbolic("aten::symbolic_b", opset=10, decorate=[quantized_aten_handler(scale=1/128, zero_point=0)])
|
||||
@onnx_symbolic(
|
||||
"aten::symbolic_b",
|
||||
opset=10,
|
||||
decorate=[quantized_aten_handler(scale=1 / 128, zero_point=0)],
|
||||
)
|
||||
@symbolic_helper.parse_args("v", "v", "b")
|
||||
def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value:
|
||||
...
|
||||
def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: ...
|
||||
```
|
||||
|
||||
Args:
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Utilities for converting and operating on ONNX, JIT and torch types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""ONNX exporter exceptions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
|
@ -234,7 +234,7 @@ def parse_args(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
fn: Callable[_Concatenate[_U, _P], _T]
|
||||
fn: Callable[_Concatenate[_U, _P], _T],
|
||||
) -> Callable[_Concatenate[_U, _P], _T]:
|
||||
fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined]
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 11."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
@ -148,9 +148,7 @@ def scaled_dot_product_attention(
|
||||
assert (not is_causal) or (
|
||||
is_causal and symbolic_helper._is_none(attn_mask)
|
||||
), "is_causal and attn_mask cannot be set at the same time"
|
||||
assert (
|
||||
not enable_gqa
|
||||
), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
|
||||
assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
|
||||
|
||||
scale = symbolic_helper._maybe_get_const(scale, "f")
|
||||
if symbolic_helper._is_none(scale):
|
||||
@ -254,7 +252,7 @@ def _causal_attention_mask(
|
||||
Equivalent to::
|
||||
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
||||
attn_mask = torch.zeros(L, S, dtype=torch.float)
|
||||
attn_mask = attn_mask.masked_fill(not mask, -float('inf'))
|
||||
attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
|
||||
|
||||
Args:
|
||||
query: Tensor of shape [..., L, E]
|
||||
|
@ -56,7 +56,9 @@ def grid_sampler(
|
||||
if symbolic_helper._get_tensor_rank(input) == 5:
|
||||
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
|
||||
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
|
||||
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
|
||||
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg]
|
||||
padding_mode_enum
|
||||
]
|
||||
return g.op(
|
||||
"GridSample",
|
||||
input,
|
||||
|
@ -57,7 +57,9 @@ def _grid_sampler(
|
||||
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
|
||||
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
|
||||
mode_s = convert_grid_sample_mode(mode_s)
|
||||
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg, index]
|
||||
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
|
||||
padding_mode_enum # type: ignore[index]
|
||||
]
|
||||
return g.op(
|
||||
"GridSample",
|
||||
input,
|
||||
|
@ -2746,7 +2746,9 @@ def native_layer_norm(
|
||||
# mean and normalized, so we need to Cast it back
|
||||
if is_type_half:
|
||||
denominator = g.op(
|
||||
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined]
|
||||
"Cast",
|
||||
denominator,
|
||||
to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), # type: ignore[possibly-undefined]
|
||||
)
|
||||
rdenominator = g.op("Reciprocal", denominator)
|
||||
else:
|
||||
@ -4368,7 +4370,8 @@ def _generic_rnn(
|
||||
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
|
||||
)
|
||||
return tuple(
|
||||
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined]
|
||||
symbolic_helper._unsqueeze_helper(g, x, [0])
|
||||
for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined]
|
||||
)
|
||||
|
||||
def transform_weights(layer_index):
|
||||
@ -4498,9 +4501,10 @@ def _lstm_full(
|
||||
bidirectional,
|
||||
batch_first,
|
||||
):
|
||||
hidden, weight = symbolic_helper._unpack_list(
|
||||
hidden_v
|
||||
), symbolic_helper._unpack_list(weight_v)
|
||||
hidden, weight = (
|
||||
symbolic_helper._unpack_list(hidden_v),
|
||||
symbolic_helper._unpack_list(weight_v),
|
||||
)
|
||||
return _generic_rnn(
|
||||
g,
|
||||
"LSTM",
|
||||
@ -4529,9 +4533,10 @@ def _lstm_packed(
|
||||
train,
|
||||
bidirectional,
|
||||
):
|
||||
hidden, weight = symbolic_helper._unpack_list(
|
||||
hidden_v
|
||||
), symbolic_helper._unpack_list(weight_v)
|
||||
hidden, weight = (
|
||||
symbolic_helper._unpack_list(hidden_v),
|
||||
symbolic_helper._unpack_list(weight_v),
|
||||
)
|
||||
return _generic_rnn(
|
||||
g,
|
||||
"LSTM",
|
||||
|
@ -4,6 +4,7 @@
|
||||
These models can be loaded with the ONNX library and then
|
||||
converted to models which run on other deep learning frameworks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
@ -224,13 +225,7 @@ def export(
|
||||
|
||||
3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS::
|
||||
|
||||
args = (
|
||||
x,
|
||||
{
|
||||
"y": input_y,
|
||||
"z": input_z
|
||||
}
|
||||
)
|
||||
args = (x, {"y": input_y, "z": input_z})
|
||||
|
||||
All but the last element of the tuple will be passed as non-keyword arguments,
|
||||
and named arguments will be set from the last element. If a named argument is
|
||||
@ -252,22 +247,14 @@ def export(
|
||||
(
|
||||
x,
|
||||
# WRONG: will be interpreted as named arguments
|
||||
{y: z}
|
||||
{y: z},
|
||||
),
|
||||
"test.onnx.pb"
|
||||
"test.onnx.pb",
|
||||
)
|
||||
|
||||
Write::
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(
|
||||
x,
|
||||
{y: z},
|
||||
{}
|
||||
),
|
||||
"test.onnx.pb"
|
||||
)
|
||||
torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb")
|
||||
|
||||
f: Path to the output ONNX model file. E.g. "model.onnx".
|
||||
kwargs: Named arguments to the model.
|
||||
@ -369,12 +356,13 @@ def export(
|
||||
def forward(self, x):
|
||||
return torch.sum(x, dim=1)
|
||||
|
||||
|
||||
torch.onnx.export(
|
||||
SumModule(),
|
||||
(torch.ones(2, 2),),
|
||||
"onnx.pb",
|
||||
input_names=["x"],
|
||||
output_names=["sum"]
|
||||
output_names=["sum"],
|
||||
)
|
||||
|
||||
Produces::
|
||||
@ -410,7 +398,7 @@ def export(
|
||||
"x": {0: "my_custom_axis_name"},
|
||||
# list value: automatic names
|
||||
"sum": [0],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
Produces::
|
||||
@ -1398,9 +1386,9 @@ def _setup_trace_module_map(
|
||||
and start from the first non-numeric atom.
|
||||
|
||||
Example:
|
||||
>>> _unqualified_variable_name('__main__.Foo.bar')
|
||||
>>> _unqualified_variable_name("__main__.Foo.bar")
|
||||
'bar'
|
||||
>>> _unqualified_variable_name('__main__.Foo.bar.0')
|
||||
>>> _unqualified_variable_name("__main__.Foo.bar.0")
|
||||
'bar.0'
|
||||
"""
|
||||
name_atoms = qualified_name.split(".")
|
||||
@ -1605,7 +1593,9 @@ def _export(
|
||||
|
||||
if keep_initializers_as_inputs is not True:
|
||||
params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
|
||||
graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
|
||||
graph,
|
||||
params_dict, # type: ignore[arg-type]
|
||||
getattr(model, "training", False), # type: ignore[arg-type]
|
||||
)
|
||||
_C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
|
||||
if export_params:
|
||||
@ -1863,7 +1853,9 @@ def _run_symbolic_function(
|
||||
}
|
||||
if namespace == "onnx":
|
||||
# Clone node to trigger ONNX shape inference
|
||||
return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined]
|
||||
return graph_context.op(
|
||||
op_name, *inputs, **attrs, outputs=node.outputsSize()
|
||||
) # type: ignore[attr-defined]
|
||||
|
||||
raise errors.UnsupportedOperatorError(
|
||||
symbolic_function_name,
|
||||
|
@ -217,8 +217,8 @@ def _compare_onnx_pytorch_outputs_in_np(
|
||||
pt_outs: _OutputsType,
|
||||
options: VerificationOptions,
|
||||
):
|
||||
assert len(onnx_outs) == len(
|
||||
pt_outs
|
||||
assert (
|
||||
len(onnx_outs) == len(pt_outs)
|
||||
), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
|
||||
acceptable_error_percentage = options.acceptable_error_percentage
|
||||
if acceptable_error_percentage and (
|
||||
|
Reference in New Issue
Block a user