Add Opinfo entries for HOP testing (#122265)

In this PR, we add a systematic way to test all HOPs to be exportable as export team has been running into various bugs related to newly added HOPs due to lack of tests. We do this by creating:
- hop_db -> a list of HOP OpInfo tests which then used inside various flows including export functionalities: [aot-export, pre-dispatch export, retrace, and ser/der

For now, we also create an allowlist so that people can bypass the failures for now. But we should discourage ppl to do that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122265
Approved by: https://github.com/ydwu4, https://github.com/zou3519
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2024-03-28 10:33:51 -07:00
committed by PyTorch MergeBot
parent 0bfa9f4758
commit d9a08de9a4
8 changed files with 343 additions and 88 deletions

View File

@ -2371,7 +2371,7 @@ exclude_patterns = [
'torch/testing/_internal/common_subclass.py', 'torch/testing/_internal/common_subclass.py',
'torch/testing/_internal/common_utils.py', 'torch/testing/_internal/common_utils.py',
'torch/testing/_internal/composite_compliance.py', 'torch/testing/_internal/composite_compliance.py',
'torch/testing/_internal/control_flow_opinfo_db.py', 'torch/testing/_internal/hop_db.py',
'torch/testing/_internal/custom_op_db.py', 'torch/testing/_internal/custom_op_db.py',
'torch/testing/_internal/data/__init__.py', 'torch/testing/_internal/data/__init__.py',
'torch/testing/_internal/data/network1.py', 'torch/testing/_internal/data/network1.py',

View File

@ -67,6 +67,7 @@ nn/qat/ @jerryzh168
/test/run_test.py @pytorch/pytorch-dev-infra /test/run_test.py @pytorch/pytorch-dev-infra
/torch/testing/_internal/common_device_type.py @mruberry /torch/testing/_internal/common_device_type.py @mruberry
/torch/testing/_internal/common_utils.py @pytorch/pytorch-dev-infra /torch/testing/_internal/common_utils.py @pytorch/pytorch-dev-infra
/torch/testing/_internal/hop_db.py @tugsbayasgalan @zou3519 @ydwu4
# Parametrizations # Parametrizations
/torch/nn/utils/parametriz*.py @lezcano /torch/nn/utils/parametriz*.py @lezcano

143
test/export/test_hop.py Normal file
View File

@ -0,0 +1,143 @@
# Owner(s): ["oncall: export"]
# flake8: noqa
import copy
import io
import unittest
import torch
import torch._dynamo as torchdynamo
import torch.utils._pytree as pytree
from torch._dynamo.test_case import TestCase
from torch.export import export, load, save
from torch.export._trace import _export
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
)
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
TestCase as TorchTestCase,
)
from torch.testing._internal.hop_db import (
hop_db,
hop_that_doesnt_have_opinfo_test_allowlist,
)
hop_tests = []
for op_info in hop_db:
op_info_hop_name = op_info.name
if op_info_hop_name in hop_that_doesnt_have_opinfo_test_allowlist:
continue
hop_tests.append(op_info)
class TestHOPGeneric(TestCase):
def test_all_hops_have_op_info(self):
from torch._ops import _higher_order_ops
hops_that_have_op_info = set([k.name for k in hop_db])
all_hops = _higher_order_ops.keys()
missing_ops = []
for op in all_hops:
if (
op not in hops_that_have_op_info
and op not in hop_that_doesnt_have_opinfo_test_allowlist
):
missing_ops.append(op)
self.assertTrue(len(missing_ops) == 0, f"Missing op info for {missing_ops}")
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestHOP(TestCase):
def _compare(self, eager_model, export, args, kwargs):
eager_args = copy.deepcopy(args)
eager_kwargs = copy.deepcopy(kwargs)
export_args = copy.deepcopy(args)
export_kwargs = copy.deepcopy(kwargs)
flat_orig_outputs = pytree.tree_leaves(eager_model(*eager_args, **eager_kwargs))
flat_loaded_outputs = pytree.tree_leaves(
export.module()(*export_args, **export_kwargs)
)
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
self.assertEqual(type(orig), type(loaded))
self.assertEqual(orig, loaded)
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
def test_aot_export(self, device, dtype, op):
class Foo(torch.nn.Module):
def forward(self, *args):
return op.op(*args)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
for inp in sample_inputs_itr:
model = Foo()
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
args = (*input, *inp.args)
kwargs = inp.kwargs
ep = export(model, args, kwargs)
self._compare(model, ep, args, kwargs)
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
def test_pre_dispatch_export(self, device, dtype, op):
class Foo(torch.nn.Module):
def forward(self, *args):
return op.op(*args)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
for inp in sample_inputs_itr:
model = Foo()
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
args = (*input, *inp.args)
kwargs = inp.kwargs
ep = _export(model, args, kwargs, pre_dispatch=True)
self._compare(model, ep, args, kwargs)
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
def test_retrace_export(self, device, dtype, op):
class Foo(torch.nn.Module):
def forward(self, *args):
return op.op(*args)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
for inp in sample_inputs_itr:
model = Foo()
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
args = (*input, *inp.args)
kwargs = inp.kwargs
ep = _export(model, args, kwargs, pre_dispatch=True)
ep = ep.run_decompositions()
self._compare(model, ep, args, kwargs)
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
def test_serialize_export(self, device, dtype, op):
class Foo(torch.nn.Module):
def forward(self, *args):
return op.op(*args)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
for inp in sample_inputs_itr:
model = Foo()
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
args = (*input, *inp.args)
kwargs = inp.kwargs
ep = _export(model, args, kwargs, pre_dispatch=True)
ep = ep.run_decompositions()
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
ep = load(buffer)
self._compare(model, ep, args, kwargs)
instantiate_device_type_tests(TestHOP, globals())
if __name__ == "__main__":
run_tests()

View File

@ -35,8 +35,8 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import parametrize, instantiate_parametrized_tests from torch.testing._internal.common_utils import parametrize, instantiate_parametrized_tests
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check
from torch.testing._internal.hop_db import hop_db
from torch._higher_order_ops.out_dtype import out_dtype from torch._higher_order_ops.out_dtype import out_dtype
from functorch import ( from functorch import (
grad, vjp, vmap, jacrev, grad, vjp, vmap, jacrev,
@ -4606,12 +4606,12 @@ def _test_aot_autograd_module_helper(self, device, dtype, training, module_info,
class TestEagerFusionOpInfo(AOTTestCase): class TestEagerFusionOpInfo(AOTTestCase):
@ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,)) @ops(op_db + hop_db, allowed_dtypes=(torch.float,))
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', aot_autograd_failures) @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', aot_autograd_failures)
def test_aot_autograd_exhaustive(self, device, dtype, op): def test_aot_autograd_exhaustive(self, device, dtype, op):
_test_aot_autograd_helper(self, device, dtype, op) _test_aot_autograd_helper(self, device, dtype, op)
@ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,)) @ops(op_db + hop_db, allowed_dtypes=(torch.float,))
@patch("functorch.compile.config.debug_assert", True) @patch("functorch.compile.config.debug_assert", True)
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive', @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive',
aot_autograd_failures | symbolic_aot_autograd_failures) aot_autograd_failures | symbolic_aot_autograd_failures)

View File

@ -5,8 +5,8 @@ import torch
from torch.testing._internal.common_utils import TestGradients, run_tests, TestCase from torch.testing._internal.common_utils import TestGradients, run_tests, TestCase
from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.hop_db import hop_db
from torch.testing._internal.common_device_type import \ from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes) (instantiate_device_type_tests, ops, OpDTypes)
from torch.testing._internal.common_utils import unMarkDynamoStrictTest from torch.testing._internal.common_utils import unMarkDynamoStrictTest
@ -18,7 +18,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
@unMarkDynamoStrictTest @unMarkDynamoStrictTest
class TestBwdGradients(TestGradients): class TestBwdGradients(TestGradients):
# Tests that gradients are computed correctly # Tests that gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db) @_gradcheck_ops(op_db + hop_db + custom_op_db)
def test_fn_grad(self, device, dtype, op): def test_fn_grad(self, device, dtype, op):
# This is verified by test_dtypes in test_ops.py # This is verified by test_dtypes in test_ops.py
if dtype not in op.supported_backward_dtypes(torch.device(device).type): if dtype not in op.supported_backward_dtypes(torch.device(device).type):
@ -52,7 +52,7 @@ class TestBwdGradients(TestGradients):
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace())) self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly # Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db) @_gradcheck_ops(op_db + hop_db + custom_op_db)
def test_fn_gradgrad(self, device, dtype, op): def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype) self._skip_helper(op, device, dtype)
if not op.supports_gradgrad: if not op.supports_gradgrad:

View File

@ -2,6 +2,7 @@
from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_utils import TestCase, run_tests
import torch import torch
import torch._dynamo
import unittest import unittest
import warnings import warnings
import operator import operator
@ -17,7 +18,7 @@ from torch.fx.experimental.symbolic_shapes import (
guard_int, GuardOnDataDependentSymNode guard_int, GuardOnDataDependentSymNode
) )
from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db from torch.testing._internal.hop_db import hop_db
from torch.testing._internal.common_device_type import ops from torch.testing._internal.common_device_type import ops
import torch.testing._internal.optests as optests import torch.testing._internal.optests as optests
from torch._C import _disabled_torch_function_impl from torch._C import _disabled_torch_function_impl
@ -2012,18 +2013,34 @@ def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, o
self.skipTest("Dynamic output shape operation in trace") self.skipTest("Dynamic output shape operation in trace")
def skipIfNameMatches(pattern):
"""
Decorator to skip a test if its name matches the given pattern.
"""
def decorator(test_func):
def wrapper(*args, **kwargs):
if re.match(pattern, test_func.__name__):
raise unittest.SkipTest(f"Test '{test_func.__name__}' skipped because its name matches the pattern '{pattern}'")
return test_func(*args, **kwargs)
return wrapper
return decorator
# Auto functionalize shouldn't work with make_fx directly
filtered_hop_db = [op for op in hop_db if op.name != "auto_functionalize"]
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond requires dynamo")
class TestProxyTensorOpInfo(TestCase): class TestProxyTensorOpInfo(TestCase):
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,)) @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
def test_make_fx_exhaustive(self, device, dtype, op): def test_make_fx_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "real") _test_make_fx_helper(self, device, dtype, op, "real")
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,)) @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
def test_make_fx_fake_exhaustive(self, device, dtype, op): def test_make_fx_fake_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "fake") _test_make_fx_helper(self, device, dtype, op, "fake")
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,)) @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
def test_make_fx_symbolic_exhaustive(self, device, dtype, op): def test_make_fx_symbolic_exhaustive(self, device, dtype, op):

View File

@ -1,77 +0,0 @@
# mypy: ignore-errors
import torch
import functools
from torch.testing import make_tensor
from functorch.experimental.control_flow import map
from torch.testing._internal.opinfo.core import (
OpInfo,
SampleInput,
)
from torch.testing._internal.common_dtype import all_types_and
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput([make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)))
def inner_f(x, y0, y1):
return [x[0].cos().add_(1.) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
def simple_map(xs, y0, y1):
def f(x, y0, y1):
return inner_f(x, y0, y1)
return map(f, xs, y0, y1)
def nested_map(xs, y0, y1):
def f1(xx, y0, y1):
def f2(x, y0, y1):
return inner_f(x, y0, y1)
return map(f2, xx, y0, y1)
return map(f1, xs, y0, y1)
def triple_nested_map(xs, y0, y1):
def f0(xs, y0, y1):
def f1(xx, y0, y1):
def f2(x, y0, y1):
return inner_f(x, y0, y1)
return map(f2, xx, y0, y1)
return map(f1, xs, y0, y1)
return map(f0, xs, y0, y1)
control_flow_opinfo_db = [
OpInfo(
"MapControlflowOp",
op=simple_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
"NestedMapControlflowOp",
op=nested_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
"TripleNestedMapControlflowOp",
op=triple_nested_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
)
]

View File

@ -0,0 +1,171 @@
# mypy: ignore-errors
import torch
import functools
from torch.testing import make_tensor
from functorch.experimental.control_flow import map
from torch.testing._internal.opinfo.core import (
OpInfo,
SampleInput,
)
from torch.testing._internal.common_dtype import all_types_and
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput([make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)))
def inner_f(x, y0, y1):
return [x[0].cos().add_(1.) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
def simple_map(xs, y0, y1):
def f(x, y0, y1):
return inner_f(x, y0, y1)
return map(f, xs, y0, y1)
def nested_map(xs, y0, y1):
def f1(xx, y0, y1):
def f2(x, y0, y1):
return inner_f(x, y0, y1)
return map(f2, xx, y0, y1)
return map(f1, xs, y0, y1)
def triple_nested_map(xs, y0, y1):
def f0(xs, y0, y1):
def f1(xx, y0, y1):
def f2(x, y0, y1):
return inner_f(x, y0, y1)
return map(f2, xx, y0, y1)
return map(f1, xs, y0, y1)
return map(f0, xs, y0, y1)
# Please consult with torch.export team before
# adding new entry to this list.
hop_that_doesnt_have_opinfo_test_allowlist = [
"custom_function_call",
"autograd_function_apply",
"run_and_save_rng_state",
"run_with_rng_state",
"out_dtype",
"trace_wrapped",
"map",
"map_impl",
"with_effects",
"strict_mode",
"_export_tracepoint",
"while_loop",
]
torch.library.define(
"testlib::mutating_custom_op",
"(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)
@torch.library.impl("testlib::mutating_custom_op", "cpu")
def foo_impl_cpu(x, z):
x.add_(5)
z.add_(5)
return x, z, x + z
@torch.library.impl("testlib::mutating_custom_op", "cuda")
def foo_impl_cuda(x, z):
x.add_(5)
z.add_(5)
return x, z, x + z
@torch.library.impl_abstract("testlib::mutating_custom_op")
def foo_impl_abstract(x, z):
return x, z, x + z
def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=False
)
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
def simple_cond(x):
return torch.cond(x.shape[0] > 2, lambda x: x.cos(), lambda x: x.sin(), [x])
def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=False
)
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2))
def simple_auto_functionalize(x, z):
return torch.ops.testlib.mutating_custom_op(x, z)
hop_db = [
OpInfo(
name="map",
variant_test_name="simple",
op=simple_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
name="map",
variant_test_name="nested",
op=nested_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
name="map",
variant_test_name="triple_nested",
op=triple_nested_map,
sample_inputs_func=sample_inputs_map,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
),
OpInfo(
name="cond",
variant_test_name="simple",
op=simple_cond,
sample_inputs_func=sample_inputs_cond,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=False,
),
OpInfo(
name="auto_functionalize",
variant_test_name="simple",
op=simple_auto_functionalize,
sample_inputs_func=sample_inputs_auto_functionalize,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=False,
)
]