mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[ONNX] Support symbolic arguments in onnx exporter (#157734)
Previous to this PR, torch.onnx.export(..., dynamo=True, veriy=True, report=True) does not support symbolic arguments. Such examples are like follwing:
```python
class M(torch.nn.Module):
def forward(self, a, x):
return a + torch.tensor(1) + x
op = torch.onnx.export(M(), (1, torch.ones(2)),
dynamic_shapes=(torch.export.Dim.DYNAMIC, {0: torch.export.Dim.DYNAMIC}),
dynamo=True, report=True)
```
symbolic arguments are like constant arguments that they don't have tensor_meta wither. Besides, torch.export.export supports model inputs having constants, which is different from the legacy issue: https://github.com/pytorch/pytorch/issues/99534 where we tried to get the FX directly from dynamo export. Thus, `_remove_non_tensor` is deleted from args processing.
NOTE: If the ConstantArugment shows up in exported_program, it was kept to align the length of inputs to nn.Module, but it's irrelevant to the model graph, hwich is why in ONNX model the input is omitted.
The test `test_constant_argument_user_input_is_omitted_in_onnx_graph` needs #157719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157734
Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
163f0d8f2a
commit
08e9dd280f
@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
@ -73,6 +74,68 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
||||
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
|
||||
)
|
||||
|
||||
def test_symbolic_argument_user_input_is_supported_by_report_and_call(self):
|
||||
class constant_plus_tensor_inputs(torch.nn.Module):
|
||||
def forward(self, a, x):
|
||||
return a + torch.tensor(1) + x
|
||||
|
||||
# Capture log output
|
||||
log_capture = io.StringIO()
|
||||
log_handler = logging.StreamHandler(log_capture)
|
||||
log_handler.setLevel(logging.ERROR)
|
||||
# Get the logger used in _core.py
|
||||
logger = logging.getLogger("torch.onnx._internal.exporter._core")
|
||||
original_level = logger.level
|
||||
logger.addHandler(log_handler)
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
try:
|
||||
with common_utils.TemporaryDirectoryName() as temp_dir:
|
||||
self.assert_export(
|
||||
constant_plus_tensor_inputs(),
|
||||
(
|
||||
1,
|
||||
torch.ones(2),
|
||||
),
|
||||
dynamic_shapes=(
|
||||
torch.export.Dim.DYNAMIC,
|
||||
{0: torch.export.Dim.DYNAMIC},
|
||||
),
|
||||
report=True,
|
||||
artifacts_dir=temp_dir,
|
||||
)
|
||||
# Check if the expected error was logged
|
||||
log_output = log_capture.getvalue()
|
||||
self.assertNotIn("Failed to save report due to an error", log_output)
|
||||
self.assertNotIn("KeyError: 'tensor_meta'", log_output)
|
||||
# Note: We don't call assert_onnx_program here because it will fail
|
||||
# due to the input name mismatch issue mentioned in your error
|
||||
|
||||
finally:
|
||||
# Clean up logging
|
||||
logger.removeHandler(log_handler)
|
||||
logger.setLevel(original_level)
|
||||
|
||||
def test_constant_argument_user_input_is_omitted_in_onnx_graph(self):
|
||||
class constant_plus_tensor_inputs(torch.nn.Module):
|
||||
def forward(self, a, x):
|
||||
return a + torch.tensor(1) + x
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
constant_plus_tensor_inputs(),
|
||||
(
|
||||
1,
|
||||
torch.ones(2),
|
||||
),
|
||||
dynamic_shapes=(
|
||||
None,
|
||||
{0: torch.export.Dim.DYNAMIC},
|
||||
),
|
||||
dynamo=True,
|
||||
)
|
||||
|
||||
self.assertEqual(len(onnx_program.model.graph.inputs), 1)
|
||||
|
||||
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
|
||||
self.assert_export(
|
||||
SampleModelForDynamicShapes(),
|
||||
|
||||
Reference in New Issue
Block a user