[ONNX] Opt into ruff fmt (#134120)

Add ONNX directory to use ruff format.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134120
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007
This commit is contained in:
Justin Chu
2024-08-22 22:44:03 +00:00
committed by PyTorch MergeBot
parent 25499de814
commit b319fa3fd9
60 changed files with 313 additions and 276 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: onnx"]
"""Unit tests for the internal registration wrapper module."""
from __future__ import annotations
import operator

View File

@ -22,8 +22,7 @@ if typing.TYPE_CHECKING:
class _SarifLogBuilder(Protocol):
def sarif_log(self) -> sarif.SarifLog:
...
def sarif_log(self) -> sarif.SarifLog: ...
def _assert_has_diagnostics(
@ -344,9 +343,7 @@ class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
self.assertIn("test_diagnostics.py", frame.location.uri)
def test_diagnostics_records_cpp_call_stack(self):
diagnostic = (
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
)
diagnostic = self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
stack = diagnostic.cpp_call_stack
assert stack is not None # for mypy
self.assertGreater(len(stack.frames), 0)
@ -368,9 +365,9 @@ class TestDiagnosticsInfra(common_utils.TestCase):
def setUp(self):
self.rules = _RuleCollectionForTest()
with contextlib.ExitStack() as stack:
self.context: infra.DiagnosticContext[
infra.Diagnostic
] = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
self.context: infra.DiagnosticContext[infra.Diagnostic] = (
stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
)
self.addCleanup(stack.pop_all().close)
return super().setUp()
@ -400,12 +397,14 @@ class TestDiagnosticsInfra(common_utils.TestCase):
},
):
diagnostic1 = infra.Diagnostic(
custom_rules.custom_rule, infra.Level.WARNING # type: ignore[attr-defined]
custom_rules.custom_rule, # type: ignore[attr-defined]
infra.Level.WARNING,
)
self.context.log(diagnostic1)
diagnostic2 = infra.Diagnostic(
custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined]
custom_rules.custom_rule_2, # type: ignore[attr-defined]
infra.Level.ERROR,
)
self.context.log(diagnostic2)

View File

@ -49,9 +49,9 @@ class TestGlobalHelpers(common_utils.TestCase):
class TestOverrideDict(common_utils.TestCase):
def setUp(self):
self.override_dict: registration.OverrideDict[
str, int
] = registration.OverrideDict()
self.override_dict: registration.OverrideDict[str, int] = (
registration.OverrideDict()
)
def test_get_item_returns_base_value_when_no_override(self):
self.override_dict.set_base("a", 42)

View File

@ -44,7 +44,7 @@ class _netG(nn.Module):
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
nn.Tanh(),
# state size. (nc) x 64 x 64
)

View File

@ -294,8 +294,8 @@ def xfail(error_message: str, reason: Optional[str] = None):
except Exception as e:
if isinstance(e, torch.onnx.OnnxExporterError):
# diagnostic message is in the cause of the exception
assert error_message in str(
e.__cause__
assert (
error_message in str(e.__cause__)
), f"Expected error message: {error_message} NOT in {str(e.__cause__)}"
else:
assert error_message in str(

View File

@ -175,9 +175,7 @@ def _init_test_roi_heads_faster_rcnn():
resolution = box_roi_pool.output_size[0]
representation_size = 1024
box_head = faster_rcnn.TwoMLPHead(
out_channels * resolution**2, representation_size
)
box_head = faster_rcnn.TwoMLPHead(out_channels * resolution**2, representation_size)
representation_size = 1024
box_predictor = faster_rcnn.FastRCNNPredictor(representation_size, num_classes)

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: onnx"]
"""Test the support on onnxscript in PyTorch-ONNX converter."""
import io
from typing import List

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: onnx"]
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
from typing import List
import onnx_test_common

View File

@ -6,6 +6,7 @@ Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
--produce-onnx-test-data: generate onnx test data
--accept: accept onnx updates and overwrite models
"""
import glob
import inspect
import io
@ -879,7 +880,8 @@ class TestOperators(common_utils.TestCase):
def forward(self, x_in):
x_out = {}
x_out["test_key_out"] = torch.add(
x_in[list(x_in.keys())[0]], list(x_in.keys())[0] # noqa: RUF015
x_in[list(x_in.keys())[0]], # noqa: RUF015
list(x_in.keys())[0], # noqa: RUF015
)
return x_out

View File

@ -483,7 +483,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
def forward(self, x_in):
x_out = {}
x_out["test_key_out"] = torch.add(
x_in[list(x_in.keys())[0]], list(x_in.keys())[0] # noqa: RUF015
x_in[list(x_in.keys())[0]], # noqa: RUF015
list(x_in.keys())[0], # noqa: RUF015
)
return x_out

View File

@ -174,7 +174,9 @@ class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1,
)
],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
class_name_func=lambda cls,
num,
params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
)
class TestUtilityFuns(_BaseTestCase):
opset_version = None

View File

@ -185,7 +185,9 @@ class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
# {"onnx_backend": verification.OnnxBackend.ONNX},
{"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
],
class_name_func=lambda cls, idx, input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
class_name_func=lambda cls,
idx,
input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
)
class TestFindMismatch(pytorch_test_common.ExportTestCase):
onnx_backend: verification.OnnxBackend

View File

@ -79,7 +79,9 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
with tempfile.NamedTemporaryFile(suffix=".pte") as f:
torch.export.save(exported_program, f.name)
del exported_program # Delete the exported program to ensure that we are loading from file
del (
exported_program
) # Delete the exported program to ensure that we are loading from file
loaded_exported_program = torch.export.load(f.name)
self._compare_onnx_and_torch_exported_program(

View File

@ -47,8 +47,12 @@ USE_BLACK_FILELIST = re.compile(
"test/[a-h]*/**",
# test/[i-j]*/**
"test/[i-j]*/**",
# test/[k-z]*/**
"test/[k-z]*/**",
# test/[k-n]*/**
"test/[k-n]*/**",
# test/optim/**
"test/optim/**",
# "test/[p-z]*/**",
"test/[p-z]*/**",
# torch/**
# torch/_[a-h]*/**
"torch/_[a-h]*/**",
@ -62,8 +66,10 @@ USE_BLACK_FILELIST = re.compile(
"torch/d*/**",
# torch/[e-n]*/**
"torch/[e-n]*/**",
# torch/[o-z]*/**
"torch/[o-z]*/**",
# torch/optim/**
"torch/optim/**",
# torch/[p-z]*/**
"torch/[p-z]*/**",
],
),
)

View File

@ -84,7 +84,6 @@ from .utils import (
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
_deprecation,
errors,
symbolic_caffe2,
symbolic_helper,
@ -215,12 +214,13 @@ def export(
def forward(self, x):
return torch.sum(x, dim=1)
torch.onnx.export(
SumModule(),
(torch.ones(2, 2),),
"onnx.pb",
input_names=["x"],
output_names=["sum"]
output_names=["sum"],
)
Produces::
@ -256,7 +256,7 @@ def export(
"x": {0: "my_custom_axis_name"},
# list value: automatic names
"sum": [0],
}
},
)
Produces::

View File

@ -6,6 +6,7 @@ Do not use this module outside of `torch.onnx` and its tests.
Be very judicious when adding any new global variables. Do not create new global
variables unless they are absolutely necessary.
"""
import torch._C._onnx as _C_onnx
# This module should only depend on _constants and nothing else in torch.onnx to keep

View File

@ -107,9 +107,9 @@ class OnnxRegistry:
# NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important
# not to directly modify this variable. Instead, access to it should be done through
# the public methods: register_custom_op, get_ops, and is_registered_op.
self._registry: dict[
registration.OpName, list[registration.ONNXFunction]
] = defaultdict(list)
self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = (
defaultdict(list)
)
# FIXME: Avoid importing onnxscript into torch
from onnxscript.function_libs.torch_lib import ( # type: ignore[import] # noqa: F401
registration,
@ -392,8 +392,10 @@ class ResolvedExportOptions(ExportOptions):
)
self.onnx_registry = resolve(options.onnx_registry, OnnxRegistry())
self.decomposition_table = decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment]
self.onnx_registry
self.decomposition_table = (
decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment]
self.onnx_registry
)
)
from torch.onnx._internal.fx import onnxfunction_dispatcher
@ -766,6 +768,7 @@ class ONNXProgram:
... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False)
... self.fc1 = torch.nn.Linear(9216, 128, bias=False)
... self.fc2 = torch.nn.Linear(128, 10, bias=False)
...
... def forward(self, x, b):
... tensor_x = self.conv1(x)
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
@ -778,11 +781,13 @@ class ONNXProgram:
... tensor_x = self.fc2(tensor_x)
... output = torch.nn.functional.log_softmax(tensor_x, dim=1)
... (
... self.my_buffer2.add_(1.0) + self.my_buffer1
... self.my_buffer2.add_(1.0) + self.my_buffer1
... ) # Mutate buffer through in-place addition
... return output
>>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3))
>>> exported_program = torch.export.export(CustomModule(), args=inputs).run_decompositions({})
>>> exported_program = torch.export.export(
... CustomModule(), args=inputs
... ).run_decompositions({})
>>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs)
>>> pprint.pprint(onnx_program.model_signature)
ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,
@ -1194,9 +1199,7 @@ class Exporter:
with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips(
self.options
), torch._dynamo.config.patch(
dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
):
), torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
graph_module = self.options.fx_tracer.generate_fx(
self.options, self.model, self.model_args, self.model_kwargs
)
@ -1401,17 +1404,19 @@ def dynamo_export(
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x, bias=None):
out = self.linear(x)
out = out + bias
return out
model = MyModel()
kwargs = {"bias": 3.}
kwargs = {"bias": 3.0}
args = (torch.randn(2, 2, 2),)
onnx_program = torch.onnx.dynamo_export(
model,
*args,
**kwargs).save("my_simple_model.onnx")
onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
"my_simple_model.onnx"
)
**Example 2 - Exporting with dynamic shapes**
::
@ -1419,10 +1424,8 @@ def dynamo_export(
# The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export(
model,
*args,
**kwargs,
export_options=export_options)
model, *args, **kwargs, export_options=export_options
)
onnx_program.save("my_dynamic_model.onnx")

View File

@ -1,4 +1,5 @@
"""Utility to lazily import modules."""
# mypy: allow-untyped-defs
from __future__ import annotations

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
from __future__ import annotations
import contextlib

View File

@ -22,9 +22,9 @@ class ArtifactContent(object):
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
rendered: Optional[
_multiformat_message_string.MultiformatMessageString
] = dataclasses.field(default=None, metadata={"schema_property_name": "rendered"})
rendered: Optional[_multiformat_message_string.MultiformatMessageString] = (
dataclasses.field(default=None, metadata={"schema_property_name": "rendered"})
)
text: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "text"}
)

View File

@ -19,10 +19,10 @@ class Conversion(object):
"""Describes how a converter transformed the output of a static analysis tool from the analysis tool's native output format into the SARIF format."""
tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"})
analysis_tool_log_files: Optional[
List[_artifact_location.ArtifactLocation]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "analysisToolLogFiles"}
analysis_tool_log_files: Optional[List[_artifact_location.ArtifactLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "analysisToolLogFiles"}
)
)
invocation: Optional[_invocation.Invocation] = dataclasses.field(
default=None, metadata={"schema_property_name": "invocation"}

View File

@ -53,10 +53,10 @@ class ExternalProperties(object):
invocations: Optional[List[_invocation.Invocation]] = dataclasses.field(
default=None, metadata={"schema_property_name": "invocations"}
)
logical_locations: Optional[
List[_logical_location.LogicalLocation]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "logicalLocations"}
logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "logicalLocations"}
)
)
policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "policies"}
@ -76,10 +76,10 @@ class ExternalProperties(object):
taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "taxonomies"}
)
thread_flow_locations: Optional[
List[_thread_flow_location.ThreadFlowLocation]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "threadFlowLocations"}
thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "threadFlowLocations"}
)
)
translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "translations"}

View File

@ -36,10 +36,10 @@ class Invocation(object):
environment_variables: Any = dataclasses.field(
default=None, metadata={"schema_property_name": "environmentVariables"}
)
executable_location: Optional[
_artifact_location.ArtifactLocation
] = dataclasses.field(
default=None, metadata={"schema_property_name": "executableLocation"}
executable_location: Optional[_artifact_location.ArtifactLocation] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "executableLocation"}
)
)
exit_code: Optional[int] = dataclasses.field(
default=None, metadata={"schema_property_name": "exitCode"}
@ -71,10 +71,10 @@ class Invocation(object):
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
response_files: Optional[
List[_artifact_location.ArtifactLocation]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "responseFiles"}
response_files: Optional[List[_artifact_location.ArtifactLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "responseFiles"}
)
)
rule_configuration_overrides: Optional[
List[_configuration_override.ConfigurationOverride]
@ -96,21 +96,22 @@ class Invocation(object):
stdout_stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
default=None, metadata={"schema_property_name": "stdoutStderr"}
)
tool_configuration_notifications: Optional[
List[_notification.Notification]
] = dataclasses.field(
default=None,
metadata={"schema_property_name": "toolConfigurationNotifications"},
tool_configuration_notifications: Optional[List[_notification.Notification]] = (
dataclasses.field(
default=None,
metadata={"schema_property_name": "toolConfigurationNotifications"},
)
)
tool_execution_notifications: Optional[
List[_notification.Notification]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "toolExecutionNotifications"}
tool_execution_notifications: Optional[List[_notification.Notification]] = (
dataclasses.field(
default=None,
metadata={"schema_property_name": "toolExecutionNotifications"},
)
)
working_directory: Optional[
_artifact_location.ArtifactLocation
] = dataclasses.field(
default=None, metadata={"schema_property_name": "workingDirectory"}
working_directory: Optional[_artifact_location.ArtifactLocation] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "workingDirectory"}
)
)

View File

@ -24,26 +24,26 @@ class Location(object):
default=None, metadata={"schema_property_name": "annotations"}
)
id: int = dataclasses.field(default=-1, metadata={"schema_property_name": "id"})
logical_locations: Optional[
List[_logical_location.LogicalLocation]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "logicalLocations"}
logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "logicalLocations"}
)
)
message: Optional[_message.Message] = dataclasses.field(
default=None, metadata={"schema_property_name": "message"}
)
physical_location: Optional[
_physical_location.PhysicalLocation
] = dataclasses.field(
default=None, metadata={"schema_property_name": "physicalLocation"}
physical_location: Optional[_physical_location.PhysicalLocation] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "physicalLocation"}
)
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
relationships: Optional[
List[_location_relationship.LocationRelationship]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "relationships"}
relationships: Optional[List[_location_relationship.LocationRelationship]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "relationships"}
)
)

View File

@ -21,10 +21,10 @@ class PhysicalLocation(object):
address: Optional[_address.Address] = dataclasses.field(
default=None, metadata={"schema_property_name": "address"}
)
artifact_location: Optional[
_artifact_location.ArtifactLocation
] = dataclasses.field(
default=None, metadata={"schema_property_name": "artifactLocation"}
artifact_location: Optional[_artifact_location.ArtifactLocation] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "artifactLocation"}
)
)
context_region: Optional[_region.Region] = dataclasses.field(
default=None, metadata={"schema_property_name": "contextRegion"}

View File

@ -19,10 +19,10 @@ class ReportingDescriptor(object):
"""Metadata that describes a specific report produced by the tool, as part of the analysis it provides or its runtime reporting."""
id: str = dataclasses.field(metadata={"schema_property_name": "id"})
default_configuration: Optional[
_reporting_configuration.ReportingConfiguration
] = dataclasses.field(
default=None, metadata={"schema_property_name": "defaultConfiguration"}
default_configuration: Optional[_reporting_configuration.ReportingConfiguration] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "defaultConfiguration"}
)
)
deprecated_guids: Optional[List[str]] = dataclasses.field(
default=None, metadata={"schema_property_name": "deprecatedGuids"}
@ -33,17 +33,17 @@ class ReportingDescriptor(object):
deprecated_names: Optional[List[str]] = dataclasses.field(
default=None, metadata={"schema_property_name": "deprecatedNames"}
)
full_description: Optional[
_multiformat_message_string.MultiformatMessageString
] = dataclasses.field(
default=None, metadata={"schema_property_name": "fullDescription"}
full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "fullDescription"}
)
)
guid: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "guid"}
)
help: Optional[
_multiformat_message_string.MultiformatMessageString
] = dataclasses.field(default=None, metadata={"schema_property_name": "help"})
help: Optional[_multiformat_message_string.MultiformatMessageString] = (
dataclasses.field(default=None, metadata={"schema_property_name": "help"})
)
help_uri: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "helpUri"}
)

View File

@ -28,10 +28,10 @@ class ReportingDescriptorReference(object):
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
tool_component: Optional[
_tool_component_reference.ToolComponentReference
] = dataclasses.field(
default=None, metadata={"schema_property_name": "toolComponent"}
tool_component: Optional[_tool_component_reference.ToolComponentReference] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "toolComponent"}
)
)

View File

@ -38,10 +38,10 @@ class Result(object):
attachments: Optional[List[_attachment.Attachment]] = dataclasses.field(
default=None, metadata={"schema_property_name": "attachments"}
)
baseline_state: Optional[
Literal["new", "unchanged", "updated", "absent"]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "baselineState"}
baseline_state: Optional[Literal["new", "unchanged", "updated", "absent"]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "baselineState"}
)
)
code_flows: Optional[List[_code_flow.CodeFlow]] = dataclasses.field(
default=None, metadata={"schema_property_name": "codeFlows"}
@ -55,10 +55,10 @@ class Result(object):
fixes: Optional[List[_fix.Fix]] = dataclasses.field(
default=None, metadata={"schema_property_name": "fixes"}
)
graph_traversals: Optional[
List[_graph_traversal.GraphTraversal]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "graphTraversals"}
graph_traversals: Optional[List[_graph_traversal.GraphTraversal]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "graphTraversals"}
)
)
graphs: Optional[List[_graph.Graph]] = dataclasses.field(
default=None, metadata={"schema_property_name": "graphs"}
@ -96,9 +96,9 @@ class Result(object):
related_locations: Optional[List[_location.Location]] = dataclasses.field(
default=None, metadata={"schema_property_name": "relatedLocations"}
)
rule: Optional[
_reporting_descriptor_reference.ReportingDescriptorReference
] = dataclasses.field(default=None, metadata={"schema_property_name": "rule"})
rule: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = (
dataclasses.field(default=None, metadata={"schema_property_name": "rule"})
)
rule_id: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "ruleId"}
)

View File

@ -16,10 +16,10 @@ from torch.onnx._internal.diagnostics.infra.sarif import (
class ResultProvenance(object):
"""Contains information about how and when a result was detected."""
conversion_sources: Optional[
List[_physical_location.PhysicalLocation]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "conversionSources"}
conversion_sources: Optional[List[_physical_location.PhysicalLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "conversionSources"}
)
)
first_detection_run_guid: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "firstDetectionRunGuid"}

View File

@ -38,17 +38,17 @@ class Run(object):
artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field(
default=None, metadata={"schema_property_name": "artifacts"}
)
automation_details: Optional[
_run_automation_details.RunAutomationDetails
] = dataclasses.field(
default=None, metadata={"schema_property_name": "automationDetails"}
automation_details: Optional[_run_automation_details.RunAutomationDetails] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "automationDetails"}
)
)
baseline_guid: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "baselineGuid"}
)
column_kind: Optional[
Literal["utf16CodeUnits", "unicodeCodePoints"]
] = dataclasses.field(default=None, metadata={"schema_property_name": "columnKind"})
column_kind: Optional[Literal["utf16CodeUnits", "unicodeCodePoints"]] = (
dataclasses.field(default=None, metadata={"schema_property_name": "columnKind"})
)
conversion: Optional[_conversion.Conversion] = dataclasses.field(
default=None, metadata={"schema_property_name": "conversion"}
)
@ -73,10 +73,10 @@ class Run(object):
language: str = dataclasses.field(
default="en-US", metadata={"schema_property_name": "language"}
)
logical_locations: Optional[
List[_logical_location.LogicalLocation]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "logicalLocations"}
logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "logicalLocations"}
)
)
newline_sequences: List[str] = dataclasses.field(
default_factory=lambda: ["\r\n", "\n"],
@ -97,23 +97,23 @@ class Run(object):
results: Optional[List[_result.Result]] = dataclasses.field(
default=None, metadata={"schema_property_name": "results"}
)
run_aggregates: Optional[
List[_run_automation_details.RunAutomationDetails]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "runAggregates"}
run_aggregates: Optional[List[_run_automation_details.RunAutomationDetails]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "runAggregates"}
)
)
special_locations: Optional[
_special_locations.SpecialLocations
] = dataclasses.field(
default=None, metadata={"schema_property_name": "specialLocations"}
special_locations: Optional[_special_locations.SpecialLocations] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "specialLocations"}
)
)
taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "taxonomies"}
)
thread_flow_locations: Optional[
List[_thread_flow_location.ThreadFlowLocation]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "threadFlowLocations"}
thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "threadFlowLocations"}
)
)
translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "translations"}

View File

@ -21,10 +21,10 @@ class ToolComponent(object):
"""A component, such as a plug-in or the driver, of the analysis tool that was run."""
name: str = dataclasses.field(metadata={"schema_property_name": "name"})
associated_component: Optional[
_tool_component_reference.ToolComponentReference
] = dataclasses.field(
default=None, metadata={"schema_property_name": "associatedComponent"}
associated_component: Optional[_tool_component_reference.ToolComponentReference] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "associatedComponent"}
)
)
contents: List[Literal["localizedData", "nonLocalizedData"]] = dataclasses.field(
default_factory=lambda: ["localizedData", "nonLocalizedData"],
@ -36,10 +36,10 @@ class ToolComponent(object):
download_uri: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "downloadUri"}
)
full_description: Optional[
_multiformat_message_string.MultiformatMessageString
] = dataclasses.field(
default=None, metadata={"schema_property_name": "fullDescription"}
full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "fullDescription"}
)
)
full_name: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "fullName"}
@ -71,10 +71,10 @@ class ToolComponent(object):
"schema_property_name": "minimumRequiredLocalizedDataSemanticVersion"
},
)
notifications: Optional[
List[_reporting_descriptor.ReportingDescriptor]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "notifications"}
notifications: Optional[List[_reporting_descriptor.ReportingDescriptor]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "notifications"}
)
)
organization: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "organization"}
@ -91,9 +91,9 @@ class ToolComponent(object):
release_date_utc: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "releaseDateUtc"}
)
rules: Optional[
List[_reporting_descriptor.ReportingDescriptor]
] = dataclasses.field(default=None, metadata={"schema_property_name": "rules"})
rules: Optional[List[_reporting_descriptor.ReportingDescriptor]] = (
dataclasses.field(default=None, metadata={"schema_property_name": "rules"})
)
semantic_version: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "semanticVersion"}
)
@ -110,10 +110,10 @@ class ToolComponent(object):
taxa: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field(
default=None, metadata={"schema_property_name": "taxa"}
)
translation_metadata: Optional[
_translation_metadata.TranslationMetadata
] = dataclasses.field(
default=None, metadata={"schema_property_name": "translationMetadata"}
translation_metadata: Optional[_translation_metadata.TranslationMetadata] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "translationMetadata"}
)
)
version: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "version"}

View File

@ -20,10 +20,10 @@ class TranslationMetadata(object):
download_uri: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "downloadUri"}
)
full_description: Optional[
_multiformat_message_string.MultiformatMessageString
] = dataclasses.field(
default=None, metadata={"schema_property_name": "fullDescription"}
full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "fullDescription"}
)
)
full_name: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "fullName"}

View File

@ -808,9 +808,9 @@ def _exported_program_to_onnx_program(
value, Sequence
), f"Input '{value_name}' should not be a sequence. This is unexpected."
value.metadata_props[
"pkg.torch.export.graph_signature.InputSpec.kind"
] = input_kind.name
value.metadata_props["pkg.torch.export.graph_signature.InputSpec.kind"] = (
input_kind.name
)
value.metadata_props[
"pkg.torch.export.graph_signature.InputSpec.persistent"
] = str(persistent)
@ -859,9 +859,9 @@ def _exported_program_to_onnx_program(
)
for value in _values:
value.metadata_props[
"pkg.torch.export.graph_signature.OutputSpec.kind"
] = output_kind.name
value.metadata_props["pkg.torch.export.graph_signature.OutputSpec.kind"] = (
output_kind.name
)
if output_kind == graph_signature.OutputKind.USER_OUTPUT:
model.graph.outputs.append(value)
@ -1218,7 +1218,9 @@ def export(
if byte_size < 2 * 1024 * 1024 * 1024:
# The checker may segfault so we need to run it in a separate process
_isolated.safe_call(
onnx.checker.check_model, onnx_program.model_proto, full_check=True # type: ignore[attr-defined]
onnx.checker.check_model,
onnx_program.model_proto,
full_check=True, # type: ignore[attr-defined]
)
export_status.onnx_checker = True
verbose_print("Run `onnx.checker` on the ONNX model... ✅")
@ -1312,9 +1314,7 @@ def export(
_format_exceptions_for_all_strategies(failed_results)
)
if onnx_runtime_error_message:
traceback_lines.append(
"# ⚠️ ONNX Runtime error -----------------------"
)
traceback_lines.append("# ⚠️ ONNX Runtime error -----------------------")
traceback_lines.append(onnx_runtime_error_message)
if not traceback_lines:
traceback_lines.append("No errors")

View File

@ -304,8 +304,8 @@ def _get_allowed_types_from_type_annotation(
allowed_types = set()
subtypes = typing.get_args(type_)
for subtype in subtypes:
assert subtype is not type(
None
assert (
subtype is not type(None)
), "Union should not contain None type because it is handled by _is_optional."
allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
return allowed_types

View File

@ -235,8 +235,7 @@ class Transform(abc.ABC):
)
@abc.abstractmethod
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
...
def _run(self, *args, **kwargs) -> torch.fx.GraphModule: ...
@diagnostics.diagnose_call(
diagnostics.rules.fx_pass,
@ -321,5 +320,4 @@ class Analysis(abc.ABC):
self.onnxfunction_dispatcher = onnxfunction_dispatcher
@abc.abstractmethod
def analyze(self, diagnostic_level: diagnostics.infra.Level) -> AnalysisResult:
...
def analyze(self, diagnostic_level: diagnostics.infra.Level) -> AnalysisResult: ...

View File

@ -11,6 +11,7 @@ https://github.com/pytorch/pytorch/issues/115883
This solution will no longer be required once the issue is resolved.
"""
from __future__ import annotations
import abc

View File

@ -94,7 +94,9 @@ class _PyTreeExtensionContext:
for _, class_type in named_model_output_classes:
self.register_pytree_node(
class_type, model_output_flatten, model_output_unflatten # type: ignore[arg-type ]
class_type,
model_output_flatten,
model_output_unflatten, # type: ignore[arg-type ]
)

View File

@ -626,7 +626,8 @@ class FxOnnxInterpreter:
):
# aten ops and other stateless functions.
if node.target == operator.getitem and isinstance(
fx_name_to_onnxscript_value[node.args[0].name], tuple # type: ignore[union-attr,index]
fx_name_to_onnxscript_value[node.args[0].name], # type: ignore[union-attr,index]
tuple,
):
onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index]
index = node.args[1]
@ -660,9 +661,10 @@ class FxOnnxInterpreter:
diagnostic_context=self.diagnostic_context,
)
with onnxscript.evaluator.default_as(onnxscript_tracer):
output: onnxscript_graph_building.TorchScriptTensor | tuple[
onnxscript_graph_building.TorchScriptTensor, ...
] = symbolic_fn(*onnx_args, **onnx_kwargs)
output: (
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...]
) = symbolic_fn(*onnx_args, **onnx_kwargs)
assert (
output is not None
), f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}"
@ -779,9 +781,10 @@ class FxOnnxInterpreter:
# be considered.
unique_module_name = f"{sub_module._get_name()}_{node.target}"
outputs: onnxscript_graph_building.TorchScriptTensor | tuple[
onnxscript_graph_building.TorchScriptTensor, ...
] = parent_onnxscript_graph.add_module_call( # type: ignore[assignment]
outputs: (
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...]
) = parent_onnxscript_graph.add_module_call( # type: ignore[assignment]
unique_module_name, sub_onnxscript_graph, onnx_args
)

View File

@ -147,8 +147,10 @@ class FXSymbolicTracer(_exporter_legacy.FXGraphExtractor):
for v in x.values():
out += v
return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
assert f({"a": 1, "b": 2, "c": 4}) == 7
"""
def __init__(self, concrete_args: dict[str, Any] | None = None):

View File

@ -415,10 +415,9 @@ class _OnnxSchemaChecker:
inputs = (Tensor[2, 3], Tensor[2, 3])
attributes = {"alpha": 1.0}
@torch_op("aten::op")
def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal:
...
@torch_op("aten::op")
def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ...
```
Result: Perfect match.

View File

@ -295,7 +295,7 @@ def _convert_torch_args_to_onnxfunction_args(
args: list[fx_type_utils.Argument],
kwargs: dict[str, fx_type_utils.Argument],
allow_extra_kwargs: bool = False,
) -> tuple[list[Any], dict[str, Any],]:
) -> tuple[list[Any], dict[str, Any]]:
"""Convert Python args and kwargs to OnnxFunction acceptable with matching ONNX ParamSchema.
NOTE: This is different from the param_schema separating in dispatcher, since at this point

View File

@ -3,6 +3,7 @@
These functions should NOT be directly invoked outside of `passes` package.
"""
from __future__ import annotations
import collections

View File

@ -66,9 +66,7 @@ class Decompose(_pass.Transform):
# Apply decomposition table to the input graph.
assert fake_mode is not None # for mypy
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), (
fake_mode
):
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), fake_mode:
decomposed_module = proxy_tensor.make_fx(
module,
decomposition_table=self.decomposition_table,

View File

@ -814,7 +814,9 @@ class Modularize(_pass.Transform):
>>> out = self.linear(out)
>>> return out
>>>
>>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(torch.tensor([0, 1, 2]))
>>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(
... torch.tensor([0, 1, 2])
... )
>>> gm.print_readable()
>>> gm = passes.Modularize(infra.DiagnosticContext("test_context", "1.0"), gm).run()

View File

@ -76,16 +76,13 @@ class TypePromotionRule(abc.ABC):
# A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None.
# Ref: https://docs.python.org/3/reference/datamodel.html#object.__hash__
@abc.abstractmethod
def __hash__(self) -> int:
...
def __hash__(self) -> int: ...
@abc.abstractmethod
def __repr__(self):
...
def __repr__(self): ...
@abc.abstractmethod
def __eq__(self, other: object) -> bool:
...
def __eq__(self, other: object) -> bool: ...
def is_valid(self) -> bool:
"""Check if the rule is valid."""

View File

@ -95,7 +95,9 @@ class TorchExport(_exporter_legacy.FXGraphExtractor):
model = model.run_decompositions(options.decomposition_table)
# Export FX graph to ONNX ModelProto.
return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value]
return self.pre_export_passes( # type: ignore[return-value]
options, model, model.graph_module, updated_model_args
)
def pre_export_passes(
self,

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""Utilities for converting and operating on ONNX, JIT and torch types."""
from __future__ import annotations
from typing import (
@ -31,8 +32,7 @@ if TYPE_CHECKING:
@runtime_checkable
class TensorLike(Protocol):
@property
def dtype(self) -> torch.dtype | None:
...
def dtype(self) -> torch.dtype | None: ...
def is_torch_complex_dtype(tensor_dtype: torch.dtype) -> bool:

View File

@ -40,8 +40,7 @@ class InputAdaptStep(Protocol):
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]:
...
) -> tuple[Sequence[Any], Mapping[str, Any]]: ...
class InputAdapter:
@ -98,8 +97,7 @@ class OutputAdaptStep(Protocol):
self,
model_outputs: Any,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> Any:
...
) -> Any: ...
class OutputAdapter:
@ -573,7 +571,8 @@ class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
A tuple of the model args and kwargs.
"""
ordered_params = tuple(
model.state_dict[name] for name in model.graph_signature.parameters # type: ignore[union-attr,index]
model.state_dict[name] # type: ignore[union-attr,index]
for name in model.graph_signature.parameters # type: ignore[union-attr]
)
non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[union-attr]
ordered_buffers = []
@ -583,7 +582,8 @@ class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
else:
ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index]
ordered_constant_tensors = tuple(
model.constants[fqn] for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr,index]
model.constants[fqn] # type: ignore[union-attr,index]
for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr]
)
# NOTE: calling convention is first params, then buffers, then args as user supplied them.

View File

@ -304,7 +304,7 @@ def _get_onnx_devices(
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
],
...,
]
],
) -> Tuple["ORTC.OrtDevice", ...]:
def _device_id_or_zero(device_id: int) -> int:
return device_id or 0
@ -403,7 +403,12 @@ def _adjust_scalar_from_onnx_to_fx(
torch.SymBool,
bool,
],
) -> Union[torch.Tensor, int, float, bool,]:
) -> Union[
torch.Tensor,
int,
float,
bool,
]:
"""Helper function to wrap ORT-produced torch.Tensor as PyTorch variables"""
assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor."
if isinstance(
@ -561,9 +566,9 @@ class OrtExecutionInfoPerSession:
self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices
# This is the outputs of executing the original torch.fx.GraphModule with example inputs
# (i.e., args passed into OrtBackend._ort_acclerated_call).
self.example_outputs: Union[
Tuple[torch.Tensor, ...], torch.Tensor
] = example_outputs
self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = (
example_outputs
)
def is_supported(self, *args):
# Compare the args and the input schema in ONNX model and

View File

@ -276,10 +276,13 @@ def onnx_symbolic(
Usage::
```
@onnx_symbolic("aten::symbolic_b", opset=10, decorate=[quantized_aten_handler(scale=1/128, zero_point=0)])
@onnx_symbolic(
"aten::symbolic_b",
opset=10,
decorate=[quantized_aten_handler(scale=1 / 128, zero_point=0)],
)
@symbolic_helper.parse_args("v", "v", "b")
def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value:
...
def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: ...
```
Args:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""Utilities for converting and operating on ONNX, JIT and torch types."""
from __future__ import annotations
import enum

View File

@ -1,4 +1,5 @@
"""ONNX exporter exceptions."""
from __future__ import annotations
import textwrap

View File

@ -234,7 +234,7 @@ def parse_args(
"""
def decorator(
fn: Callable[_Concatenate[_U, _P], _T]
fn: Callable[_Concatenate[_U, _P], _T],
) -> Callable[_Concatenate[_U, _P], _T]:
fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined]

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 11."""
from __future__ import annotations
import functools

View File

@ -148,9 +148,7 @@ def scaled_dot_product_attention(
assert (not is_causal) or (
is_causal and symbolic_helper._is_none(attn_mask)
), "is_causal and attn_mask cannot be set at the same time"
assert (
not enable_gqa
), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
scale = symbolic_helper._maybe_get_const(scale, "f")
if symbolic_helper._is_none(scale):
@ -254,7 +252,7 @@ def _causal_attention_mask(
Equivalent to::
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_mask = torch.zeros(L, S, dtype=torch.float)
attn_mask = attn_mask.masked_fill(not mask, -float('inf'))
attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
Args:
query: Tensor of shape [..., L, E]

View File

@ -56,7 +56,9 @@ def grid_sampler(
if symbolic_helper._get_tensor_rank(input) == 5:
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg]
padding_mode_enum
]
return g.op(
"GridSample",
input,

View File

@ -57,7 +57,9 @@ def _grid_sampler(
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
mode_s = convert_grid_sample_mode(mode_s)
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg, index]
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
padding_mode_enum # type: ignore[index]
]
return g.op(
"GridSample",
input,

View File

@ -2746,7 +2746,9 @@ def native_layer_norm(
# mean and normalized, so we need to Cast it back
if is_type_half:
denominator = g.op(
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined]
"Cast",
denominator,
to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), # type: ignore[possibly-undefined]
)
rdenominator = g.op("Reciprocal", denominator)
else:
@ -4368,7 +4370,8 @@ def _generic_rnn(
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
)
return tuple(
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined]
symbolic_helper._unsqueeze_helper(g, x, [0])
for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined]
)
def transform_weights(layer_index):
@ -4498,9 +4501,10 @@ def _lstm_full(
bidirectional,
batch_first,
):
hidden, weight = symbolic_helper._unpack_list(
hidden_v
), symbolic_helper._unpack_list(weight_v)
hidden, weight = (
symbolic_helper._unpack_list(hidden_v),
symbolic_helper._unpack_list(weight_v),
)
return _generic_rnn(
g,
"LSTM",
@ -4529,9 +4533,10 @@ def _lstm_packed(
train,
bidirectional,
):
hidden, weight = symbolic_helper._unpack_list(
hidden_v
), symbolic_helper._unpack_list(weight_v)
hidden, weight = (
symbolic_helper._unpack_list(hidden_v),
symbolic_helper._unpack_list(weight_v),
)
return _generic_rnn(
g,
"LSTM",

View File

@ -4,6 +4,7 @@
These models can be loaded with the ONNX library and then
converted to models which run on other deep learning frameworks.
"""
from __future__ import annotations
import contextlib
@ -224,13 +225,7 @@ def export(
3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS::
args = (
x,
{
"y": input_y,
"z": input_z
}
)
args = (x, {"y": input_y, "z": input_z})
All but the last element of the tuple will be passed as non-keyword arguments,
and named arguments will be set from the last element. If a named argument is
@ -252,22 +247,14 @@ def export(
(
x,
# WRONG: will be interpreted as named arguments
{y: z}
{y: z},
),
"test.onnx.pb"
"test.onnx.pb",
)
Write::
torch.onnx.export(
model,
(
x,
{y: z},
{}
),
"test.onnx.pb"
)
torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb")
f: Path to the output ONNX model file. E.g. "model.onnx".
kwargs: Named arguments to the model.
@ -369,12 +356,13 @@ def export(
def forward(self, x):
return torch.sum(x, dim=1)
torch.onnx.export(
SumModule(),
(torch.ones(2, 2),),
"onnx.pb",
input_names=["x"],
output_names=["sum"]
output_names=["sum"],
)
Produces::
@ -410,7 +398,7 @@ def export(
"x": {0: "my_custom_axis_name"},
# list value: automatic names
"sum": [0],
}
},
)
Produces::
@ -1398,9 +1386,9 @@ def _setup_trace_module_map(
and start from the first non-numeric atom.
Example:
>>> _unqualified_variable_name('__main__.Foo.bar')
>>> _unqualified_variable_name("__main__.Foo.bar")
'bar'
>>> _unqualified_variable_name('__main__.Foo.bar.0')
>>> _unqualified_variable_name("__main__.Foo.bar.0")
'bar.0'
"""
name_atoms = qualified_name.split(".")
@ -1605,7 +1593,9 @@ def _export(
if keep_initializers_as_inputs is not True:
params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
graph,
params_dict, # type: ignore[arg-type]
getattr(model, "training", False), # type: ignore[arg-type]
)
_C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
if export_params:
@ -1863,7 +1853,9 @@ def _run_symbolic_function(
}
if namespace == "onnx":
# Clone node to trigger ONNX shape inference
return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined]
return graph_context.op(
op_name, *inputs, **attrs, outputs=node.outputsSize()
) # type: ignore[attr-defined]
raise errors.UnsupportedOperatorError(
symbolic_function_name,

View File

@ -217,8 +217,8 @@ def _compare_onnx_pytorch_outputs_in_np(
pt_outs: _OutputsType,
options: VerificationOptions,
):
assert len(onnx_outs) == len(
pt_outs
assert (
len(onnx_outs) == len(pt_outs)
), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
acceptable_error_percentage = options.acceptable_error_percentage
if acceptable_error_percentage and (