mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Implements converter for higher order ops scan (#154513)
Fixes #151327 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154513 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
b618817479
commit
e6252f62ef
@ -15,8 +15,7 @@ from torch.testing._internal import common_utils
|
||||
from torch.utils import _pytree as torch_pytree
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class DynamoExporterTest(common_utils.TestCase):
|
||||
class _WithExport:
|
||||
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
|
||||
onnx_program = torch.onnx.export(
|
||||
model,
|
||||
@ -30,6 +29,9 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
assert onnx_program is not None
|
||||
return onnx_program
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class DynamoExporterTest(common_utils.TestCase, _WithExport):
|
||||
def test_insert_contiguous_between_transpose_and_view(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, query, key, value):
|
||||
@ -307,7 +309,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
return x + y
|
||||
|
||||
dim0_x = torch.export.Dim("dim0_x", min=6)
|
||||
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
|
||||
dynamic_shapes = {"x": {0: dim0_x}, "y": torch.export.Dim.STATIC}
|
||||
# specialized input y to 5 during tracing
|
||||
onnx_program = self.export(
|
||||
Model(),
|
||||
@ -548,11 +550,11 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
# all of these should be fine
|
||||
dynamic_shapes = (
|
||||
{0: dx, 1: torch.export.Dim.AUTO},
|
||||
{0: dy, 1: None},
|
||||
{0: dy, 1: torch.export.Dim.STATIC},
|
||||
{0: dz, 1: 3},
|
||||
)
|
||||
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
# make sre the naming is working
|
||||
self.assertEqual(onnx_program.model.graph.inputs[0].shape[0], "dx")
|
||||
|
||||
@ -564,7 +566,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
inputs = (torch.zeros((2, 3)),)
|
||||
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
|
||||
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
self.assertIn(
|
||||
"Max",
|
||||
[node.op_type for node in onnx_program.model.graph],
|
||||
@ -578,7 +580,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
inputs = (torch.zeros((2, 3)),)
|
||||
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
|
||||
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
self.assertIn(
|
||||
"Min",
|
||||
[node.op_type for node in onnx_program.model.graph],
|
||||
@ -593,7 +595,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
inputs = (torch.zeros((2, 2)),)
|
||||
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
|
||||
onnx_program = self.export(SymNotModel(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
self.assertIn(
|
||||
"Not",
|
||||
[node.op_type for node in onnx_program.model.graph],
|
||||
@ -610,12 +612,106 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
onnx_program = self.export(
|
||||
SymFloatModel(), inputs, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
self.assertIn(
|
||||
"Cast",
|
||||
[node.op_type for node in onnx_program.model.graph],
|
||||
)
|
||||
|
||||
def test_scan_cdist_add(self):
|
||||
def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor):
|
||||
sub = samex - x.reshape((1, -1))
|
||||
sq = sub * sub
|
||||
rd = torch.sqrt(sq.sum(axis=1))
|
||||
return [unused.clone(), rd]
|
||||
|
||||
class ScanModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
z = torch.tensor([0], dtype=torch.float32)
|
||||
y = x.clone()
|
||||
out = torch.ops.higher_order.scan(dist, [z], [x], additional_inputs=[y])
|
||||
return out[1]
|
||||
|
||||
inputs = (
|
||||
torch.tensor(
|
||||
[[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32
|
||||
),
|
||||
)
|
||||
onnx_program = self.export(ScanModel(), inputs)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
def test_scan_cdist_dynamic_shapes(self):
|
||||
def dist(y: torch.Tensor, scanned_x: torch.Tensor):
|
||||
sub = y - scanned_x.reshape((1, -1))
|
||||
sq = sub * sub
|
||||
rd = torch.sqrt(sq.sum(axis=1))
|
||||
return [y.clone(), rd]
|
||||
|
||||
class ScanModel(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
carry, out = torch.ops.higher_order.scan(
|
||||
dist, [y], [x], additional_inputs=[]
|
||||
)
|
||||
return out
|
||||
|
||||
x_rows = torch.export.Dim("x_rows")
|
||||
y_rows = torch.export.Dim("y_rows")
|
||||
dim = torch.export.Dim("dim")
|
||||
inputs = (torch.randn(3, 4), torch.randn(5, 4))
|
||||
onnx_program = self.export(
|
||||
ScanModel(),
|
||||
inputs,
|
||||
dynamic_shapes=({0: x_rows, 1: dim}, {0: y_rows, 1: dim}),
|
||||
)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
@pytest.mark.xfail(reason="Data dependent error.")
|
||||
def test_scan_loop_inplace(self):
|
||||
def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
|
||||
copy = torch.zeros(padded.shape)
|
||||
for i in range(pos.shape[0]):
|
||||
p = pos[i]
|
||||
copy[i, :p] = padded[i, :p]
|
||||
return copy
|
||||
|
||||
def dummy_loop_with_scan(padded: torch.Tensor, pos: torch.Tensor):
|
||||
def pad_row(padded, p):
|
||||
row = torch.zeros((padded.shape[0],))
|
||||
torch._check(p.item() > 0)
|
||||
torch._check(p.item() < padded.shape[0])
|
||||
# this check is not always true, we add it anyway to make this dimension >= 2
|
||||
# and avoid raising an exception about dynamic dimension in {0, 1}
|
||||
if torch.compiler.is_exporting():
|
||||
torch._check(p.item() > 1)
|
||||
row[: p.item()] = padded[: p.item()]
|
||||
return (row,)
|
||||
|
||||
return torch.ops.higher_order.scan(pad_row, [], [padded, pos], [])
|
||||
|
||||
def select_when_exporting(f, f_scan):
|
||||
return f_scan if torch.compiler.is_exporting() else f
|
||||
|
||||
class ScanModel(torch.nn.Module):
|
||||
def forward(self, images, position):
|
||||
return select_when_exporting(dummy_loop, dummy_loop_with_scan)(
|
||||
images, position
|
||||
)
|
||||
|
||||
DYN = torch.export.Dim.DYNAMIC
|
||||
x = torch.randn((5, 6))
|
||||
y = torch.arange(5, dtype=torch.int64) + 1
|
||||
ep = torch.export.export(
|
||||
ScanModel(),
|
||||
(x, y),
|
||||
dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}},
|
||||
strict=False,
|
||||
)
|
||||
onnx_program = self.export(ep)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
|
||||
def test_group_norm_opset_21(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -19,7 +19,7 @@ def call_op(
|
||||
*args: ir.Value,
|
||||
_num_outputs: int = 1,
|
||||
_domain: str = "",
|
||||
**kwargs: int | float | str | bool | ir.Graph | ir.TensorProtocol,
|
||||
**kwargs: int | float | str | bool | ir.Graph | ir.TensorProtocol | Sequence[int],
|
||||
) -> Sequence[ir.Value]:
|
||||
"""Call an operator with the given arguments and keyword arguments.
|
||||
|
||||
@ -92,3 +92,66 @@ def higher_order_cond(
|
||||
(), else_node.outputs, nodes=[else_node], name=false_func.name
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@onnx_impl(torch.ops.higher_order.scan, no_compile=True)
|
||||
def higher_order_scan(
|
||||
body_func: ir.Function,
|
||||
scan_inits: Sequence[ir.Value],
|
||||
scan_inputs: Sequence[ir.Value],
|
||||
additional_inputs: Sequence[ir.Value] | None,
|
||||
reverse: bool = False,
|
||||
) -> Sequence[ir.Value]:
|
||||
"""https://github.com/pytorch/pytorch/blob/66ac724b56e6c37a534f3e066423ef2f41d7477f/torch/_higher_order_ops/scan.py#L109"""
|
||||
subgraph_inputs = [
|
||||
*[
|
||||
ir.Value(
|
||||
name=f"{inp.name}_{body_func.name}__subgraph_in",
|
||||
shape=inp.shape,
|
||||
type=ir.TensorType(inp.dtype), # type: ignore[arg-type]
|
||||
)
|
||||
for inp in scan_inits
|
||||
],
|
||||
*[
|
||||
ir.Value(
|
||||
name=f"{inp.name}_{body_func.name}__subgraph_in",
|
||||
# The iterated element passed to the body subgraph does not have a sequence axis.
|
||||
# It will have a rank one less than the rank of the corresponding scan_input.
|
||||
shape=ir.Shape(inp.shape[1:]), # type: ignore[index]
|
||||
type=ir.TensorType(inp.dtype), # type: ignore[arg-type]
|
||||
)
|
||||
for inp in scan_inputs
|
||||
],
|
||||
]
|
||||
# The one and only node in the Scan subgraph that calls the body_func
|
||||
body_node = ir.Node(
|
||||
body_func.domain,
|
||||
body_func.name,
|
||||
[
|
||||
*subgraph_inputs,
|
||||
*(additional_inputs or []),
|
||||
],
|
||||
num_outputs=len(body_func.outputs),
|
||||
)
|
||||
|
||||
# ONNX Runtime complains about duplicate output names if we don't rename them.
|
||||
# But the doesn't seem to be an actual violation of SSA form without renaming.
|
||||
for func_out, out in zip(body_func.outputs, body_node.outputs):
|
||||
out.name = f"{func_out.name}_{body_func.name}"
|
||||
|
||||
n_outputs = len(body_func.outputs) - len(scan_inits)
|
||||
return call_op(
|
||||
"Scan",
|
||||
*scan_inits,
|
||||
*scan_inputs,
|
||||
_num_outputs=len(body_func.outputs),
|
||||
body=ir.Graph(
|
||||
subgraph_inputs,
|
||||
body_node.outputs,
|
||||
nodes=[body_node],
|
||||
name=body_func.name,
|
||||
),
|
||||
num_scan_inputs=len(scan_inputs),
|
||||
scan_input_directions=[(1 if reverse else 0) for _ in scan_inputs],
|
||||
scan_output_directions=[(1 if reverse else 0) for _ in range(n_outputs)],
|
||||
)
|
||||
|
Reference in New Issue
Block a user