[ONNX] Implements converter for higher order ops scan (#154513)

Fixes #151327

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154513
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
xadupre
2025-06-17 00:54:07 +00:00
committed by PyTorch MergeBot
parent b618817479
commit e6252f62ef
2 changed files with 169 additions and 10 deletions

View File

@ -15,8 +15,7 @@ from torch.testing._internal import common_utils
from torch.utils import _pytree as torch_pytree
@common_utils.instantiate_parametrized_tests
class DynamoExporterTest(common_utils.TestCase):
class _WithExport:
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
onnx_program = torch.onnx.export(
model,
@ -30,6 +29,9 @@ class DynamoExporterTest(common_utils.TestCase):
assert onnx_program is not None
return onnx_program
@common_utils.instantiate_parametrized_tests
class DynamoExporterTest(common_utils.TestCase, _WithExport):
def test_insert_contiguous_between_transpose_and_view(self):
class Model(torch.nn.Module):
def forward(self, query, key, value):
@ -307,7 +309,7 @@ class DynamoExporterTest(common_utils.TestCase):
return x + y
dim0_x = torch.export.Dim("dim0_x", min=6)
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
dynamic_shapes = {"x": {0: dim0_x}, "y": torch.export.Dim.STATIC}
# specialized input y to 5 during tracing
onnx_program = self.export(
Model(),
@ -548,11 +550,11 @@ class DynamoExporterTest(common_utils.TestCase):
# all of these should be fine
dynamic_shapes = (
{0: dx, 1: torch.export.Dim.AUTO},
{0: dy, 1: None},
{0: dy, 1: torch.export.Dim.STATIC},
{0: dz, 1: 3},
)
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
onnx_testing.assert_onnx_program(onnx_program)
# make sre the naming is working
self.assertEqual(onnx_program.model.graph.inputs[0].shape[0], "dx")
@ -564,7 +566,7 @@ class DynamoExporterTest(common_utils.TestCase):
inputs = (torch.zeros((2, 3)),)
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
onnx_testing.assert_onnx_program(onnx_program)
self.assertIn(
"Max",
[node.op_type for node in onnx_program.model.graph],
@ -578,7 +580,7 @@ class DynamoExporterTest(common_utils.TestCase):
inputs = (torch.zeros((2, 3)),)
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
onnx_testing.assert_onnx_program(onnx_program)
self.assertIn(
"Min",
[node.op_type for node in onnx_program.model.graph],
@ -593,7 +595,7 @@ class DynamoExporterTest(common_utils.TestCase):
inputs = (torch.zeros((2, 2)),)
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
onnx_program = self.export(SymNotModel(), inputs, dynamic_shapes=dynamic_shapes)
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
onnx_testing.assert_onnx_program(onnx_program)
self.assertIn(
"Not",
[node.op_type for node in onnx_program.model.graph],
@ -610,12 +612,106 @@ class DynamoExporterTest(common_utils.TestCase):
onnx_program = self.export(
SymFloatModel(), inputs, dynamic_shapes=dynamic_shapes
)
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
onnx_testing.assert_onnx_program(onnx_program)
self.assertIn(
"Cast",
[node.op_type for node in onnx_program.model.graph],
)
def test_scan_cdist_add(self):
def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor):
sub = samex - x.reshape((1, -1))
sq = sub * sub
rd = torch.sqrt(sq.sum(axis=1))
return [unused.clone(), rd]
class ScanModel(torch.nn.Module):
def forward(self, x):
z = torch.tensor([0], dtype=torch.float32)
y = x.clone()
out = torch.ops.higher_order.scan(dist, [z], [x], additional_inputs=[y])
return out[1]
inputs = (
torch.tensor(
[[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32
),
)
onnx_program = self.export(ScanModel(), inputs)
onnx_testing.assert_onnx_program(onnx_program)
def test_scan_cdist_dynamic_shapes(self):
def dist(y: torch.Tensor, scanned_x: torch.Tensor):
sub = y - scanned_x.reshape((1, -1))
sq = sub * sub
rd = torch.sqrt(sq.sum(axis=1))
return [y.clone(), rd]
class ScanModel(torch.nn.Module):
def forward(self, x, y):
carry, out = torch.ops.higher_order.scan(
dist, [y], [x], additional_inputs=[]
)
return out
x_rows = torch.export.Dim("x_rows")
y_rows = torch.export.Dim("y_rows")
dim = torch.export.Dim("dim")
inputs = (torch.randn(3, 4), torch.randn(5, 4))
onnx_program = self.export(
ScanModel(),
inputs,
dynamic_shapes=({0: x_rows, 1: dim}, {0: y_rows, 1: dim}),
)
onnx_testing.assert_onnx_program(onnx_program)
@pytest.mark.xfail(reason="Data dependent error.")
def test_scan_loop_inplace(self):
def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
copy = torch.zeros(padded.shape)
for i in range(pos.shape[0]):
p = pos[i]
copy[i, :p] = padded[i, :p]
return copy
def dummy_loop_with_scan(padded: torch.Tensor, pos: torch.Tensor):
def pad_row(padded, p):
row = torch.zeros((padded.shape[0],))
torch._check(p.item() > 0)
torch._check(p.item() < padded.shape[0])
# this check is not always true, we add it anyway to make this dimension >= 2
# and avoid raising an exception about dynamic dimension in {0, 1}
if torch.compiler.is_exporting():
torch._check(p.item() > 1)
row[: p.item()] = padded[: p.item()]
return (row,)
return torch.ops.higher_order.scan(pad_row, [], [padded, pos], [])
def select_when_exporting(f, f_scan):
return f_scan if torch.compiler.is_exporting() else f
class ScanModel(torch.nn.Module):
def forward(self, images, position):
return select_when_exporting(dummy_loop, dummy_loop_with_scan)(
images, position
)
DYN = torch.export.Dim.DYNAMIC
x = torch.randn((5, 6))
y = torch.arange(5, dtype=torch.int64) + 1
ep = torch.export.export(
ScanModel(),
(x, y),
dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}},
strict=False,
)
onnx_program = self.export(ep)
onnx_testing.assert_onnx_program(onnx_program)
@common_utils.instantiate_parametrized_tests
class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
def test_group_norm_opset_21(self):
class Model(torch.nn.Module):
def forward(self, x):

View File

@ -19,7 +19,7 @@ def call_op(
*args: ir.Value,
_num_outputs: int = 1,
_domain: str = "",
**kwargs: int | float | str | bool | ir.Graph | ir.TensorProtocol,
**kwargs: int | float | str | bool | ir.Graph | ir.TensorProtocol | Sequence[int],
) -> Sequence[ir.Value]:
"""Call an operator with the given arguments and keyword arguments.
@ -92,3 +92,66 @@ def higher_order_cond(
(), else_node.outputs, nodes=[else_node], name=false_func.name
),
)
@onnx_impl(torch.ops.higher_order.scan, no_compile=True)
def higher_order_scan(
body_func: ir.Function,
scan_inits: Sequence[ir.Value],
scan_inputs: Sequence[ir.Value],
additional_inputs: Sequence[ir.Value] | None,
reverse: bool = False,
) -> Sequence[ir.Value]:
"""https://github.com/pytorch/pytorch/blob/66ac724b56e6c37a534f3e066423ef2f41d7477f/torch/_higher_order_ops/scan.py#L109"""
subgraph_inputs = [
*[
ir.Value(
name=f"{inp.name}_{body_func.name}__subgraph_in",
shape=inp.shape,
type=ir.TensorType(inp.dtype), # type: ignore[arg-type]
)
for inp in scan_inits
],
*[
ir.Value(
name=f"{inp.name}_{body_func.name}__subgraph_in",
# The iterated element passed to the body subgraph does not have a sequence axis.
# It will have a rank one less than the rank of the corresponding scan_input.
shape=ir.Shape(inp.shape[1:]), # type: ignore[index]
type=ir.TensorType(inp.dtype), # type: ignore[arg-type]
)
for inp in scan_inputs
],
]
# The one and only node in the Scan subgraph that calls the body_func
body_node = ir.Node(
body_func.domain,
body_func.name,
[
*subgraph_inputs,
*(additional_inputs or []),
],
num_outputs=len(body_func.outputs),
)
# ONNX Runtime complains about duplicate output names if we don't rename them.
# But the doesn't seem to be an actual violation of SSA form without renaming.
for func_out, out in zip(body_func.outputs, body_node.outputs):
out.name = f"{func_out.name}_{body_func.name}"
n_outputs = len(body_func.outputs) - len(scan_inits)
return call_op(
"Scan",
*scan_inits,
*scan_inputs,
_num_outputs=len(body_func.outputs),
body=ir.Graph(
subgraph_inputs,
body_node.outputs,
nodes=[body_node],
name=body_func.name,
),
num_scan_inputs=len(scan_inputs),
scan_input_directions=[(1 if reverse else 0) for _ in scan_inputs],
scan_output_directions=[(1 if reverse else 0) for _ in range(n_outputs)],
)