mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Reviewed By: avikchaudhuri Differential Revision: D67530154 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143688 Approved by: https://github.com/tugsbayasgalan
153 lines
5.3 KiB
Python
153 lines
5.3 KiB
Python
# Owner(s): ["oncall: export"]
|
|
# flake8: noqa
|
|
import copy
|
|
import io
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.test_case import TestCase
|
|
from torch.export import export, load, save
|
|
from torch.export._trace import _export
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
ops,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
IS_WINDOWS,
|
|
run_tests,
|
|
TestCase as TorchTestCase,
|
|
)
|
|
from torch.testing._internal.hop_db import (
|
|
hop_db,
|
|
hop_that_doesnt_have_opinfo_test_allowlist,
|
|
)
|
|
|
|
|
|
hop_tests = []
|
|
|
|
for op_info in hop_db:
|
|
op_info_hop_name = op_info.name
|
|
if op_info_hop_name in hop_that_doesnt_have_opinfo_test_allowlist:
|
|
continue
|
|
hop_tests.append(op_info)
|
|
|
|
|
|
class TestHOPGeneric(TestCase):
|
|
def test_all_hops_have_op_info(self):
|
|
from torch._ops import _higher_order_ops
|
|
|
|
hops_that_have_op_info = set([k.name for k in hop_db])
|
|
all_hops = _higher_order_ops.keys()
|
|
|
|
missing_ops = []
|
|
|
|
for op in all_hops:
|
|
if (
|
|
op not in hops_that_have_op_info
|
|
and op not in hop_that_doesnt_have_opinfo_test_allowlist
|
|
):
|
|
missing_ops.append(op)
|
|
|
|
self.assertTrue(len(missing_ops) == 0, f"Missing op info for {missing_ops}")
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
|
class TestHOP(TestCase):
|
|
def _compare(self, eager_model, export, args, kwargs):
|
|
eager_args = copy.deepcopy(args)
|
|
eager_kwargs = copy.deepcopy(kwargs)
|
|
export_args = copy.deepcopy(args)
|
|
export_kwargs = copy.deepcopy(kwargs)
|
|
|
|
flat_orig_outputs = pytree.tree_leaves(eager_model(*eager_args, **eager_kwargs))
|
|
flat_loaded_outputs = pytree.tree_leaves(
|
|
export.module()(*export_args, **export_kwargs)
|
|
)
|
|
|
|
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
|
|
self.assertEqual(type(orig), type(loaded))
|
|
self.assertEqual(orig, loaded)
|
|
|
|
@ops(hop_tests, allowed_dtypes=(torch.float,))
|
|
def test_aot_export(self, device, dtype, op):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, *args):
|
|
return op.op(*args)
|
|
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
|
|
for inp in sample_inputs_itr:
|
|
model = Foo()
|
|
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
|
args = (*input, *inp.args)
|
|
kwargs = inp.kwargs
|
|
ep = export(model, args, kwargs, strict=True)
|
|
self._compare(model, ep, args, kwargs)
|
|
|
|
@ops(hop_tests, allowed_dtypes=(torch.float,))
|
|
def test_pre_dispatch_export(self, device, dtype, op):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, *args):
|
|
return op.op(*args)
|
|
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
|
|
for inp in sample_inputs_itr:
|
|
model = Foo()
|
|
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
|
args = (*input, *inp.args)
|
|
kwargs = inp.kwargs
|
|
ep = _export(model, args, kwargs, pre_dispatch=True)
|
|
self._compare(model, ep, args, kwargs)
|
|
|
|
@ops(hop_tests, allowed_dtypes=(torch.float,))
|
|
def test_retrace_export(self, device, dtype, op):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, *args):
|
|
return op.op(*args)
|
|
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
|
|
for inp in sample_inputs_itr:
|
|
model = Foo()
|
|
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
|
args = (*input, *inp.args)
|
|
kwargs = inp.kwargs
|
|
ep = _export(model, args, kwargs, pre_dispatch=True)
|
|
ep = ep.run_decompositions()
|
|
self._compare(model, ep, args, kwargs)
|
|
|
|
@ops(hop_tests, allowed_dtypes=(torch.float,))
|
|
def test_serialize_export(self, device, dtype, op):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, *args):
|
|
return op.op(*args)
|
|
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
|
|
for inp in sample_inputs_itr:
|
|
model = Foo()
|
|
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
|
args = (*input, *inp.args)
|
|
kwargs = inp.kwargs
|
|
ep = _export(model, args, kwargs, pre_dispatch=True)
|
|
ep = ep.run_decompositions()
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
ep = load(buffer)
|
|
if "while_loop" in str(op):
|
|
# while_loop's arguments are cast into list after deserailize
|
|
# but while_loop expects it to still be tuple
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "carried_inputs must be a tuple"
|
|
):
|
|
self._compare(model, ep, args, kwargs)
|
|
else:
|
|
self._compare(model, ep, args, kwargs)
|
|
|
|
|
|
instantiate_device_type_tests(TestHOP, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|