mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Opinfo entries for HOP testing (#122265)
In this PR, we add a systematic way to test all HOPs to be exportable as export team has been running into various bugs related to newly added HOPs due to lack of tests. We do this by creating: - hop_db -> a list of HOP OpInfo tests which then used inside various flows including export functionalities: [aot-export, pre-dispatch export, retrace, and ser/der For now, we also create an allowlist so that people can bypass the failures for now. But we should discourage ppl to do that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122265 Approved by: https://github.com/ydwu4, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
0bfa9f4758
commit
d9a08de9a4
@ -2371,7 +2371,7 @@ exclude_patterns = [
|
|||||||
'torch/testing/_internal/common_subclass.py',
|
'torch/testing/_internal/common_subclass.py',
|
||||||
'torch/testing/_internal/common_utils.py',
|
'torch/testing/_internal/common_utils.py',
|
||||||
'torch/testing/_internal/composite_compliance.py',
|
'torch/testing/_internal/composite_compliance.py',
|
||||||
'torch/testing/_internal/control_flow_opinfo_db.py',
|
'torch/testing/_internal/hop_db.py',
|
||||||
'torch/testing/_internal/custom_op_db.py',
|
'torch/testing/_internal/custom_op_db.py',
|
||||||
'torch/testing/_internal/data/__init__.py',
|
'torch/testing/_internal/data/__init__.py',
|
||||||
'torch/testing/_internal/data/network1.py',
|
'torch/testing/_internal/data/network1.py',
|
||||||
|
@ -67,6 +67,7 @@ nn/qat/ @jerryzh168
|
|||||||
/test/run_test.py @pytorch/pytorch-dev-infra
|
/test/run_test.py @pytorch/pytorch-dev-infra
|
||||||
/torch/testing/_internal/common_device_type.py @mruberry
|
/torch/testing/_internal/common_device_type.py @mruberry
|
||||||
/torch/testing/_internal/common_utils.py @pytorch/pytorch-dev-infra
|
/torch/testing/_internal/common_utils.py @pytorch/pytorch-dev-infra
|
||||||
|
/torch/testing/_internal/hop_db.py @tugsbayasgalan @zou3519 @ydwu4
|
||||||
|
|
||||||
# Parametrizations
|
# Parametrizations
|
||||||
/torch/nn/utils/parametriz*.py @lezcano
|
/torch/nn/utils/parametriz*.py @lezcano
|
||||||
|
143
test/export/test_hop.py
Normal file
143
test/export/test_hop.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
# Owner(s): ["oncall: export"]
|
||||||
|
# flake8: noqa
|
||||||
|
import copy
|
||||||
|
import io
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._dynamo as torchdynamo
|
||||||
|
import torch.utils._pytree as pytree
|
||||||
|
from torch._dynamo.test_case import TestCase
|
||||||
|
from torch.export import export, load, save
|
||||||
|
from torch.export._trace import _export
|
||||||
|
from torch.testing._internal.common_device_type import (
|
||||||
|
instantiate_device_type_tests,
|
||||||
|
ops,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
IS_WINDOWS,
|
||||||
|
run_tests,
|
||||||
|
TestCase as TorchTestCase,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.hop_db import (
|
||||||
|
hop_db,
|
||||||
|
hop_that_doesnt_have_opinfo_test_allowlist,
|
||||||
|
)
|
||||||
|
|
||||||
|
hop_tests = []
|
||||||
|
|
||||||
|
for op_info in hop_db:
|
||||||
|
op_info_hop_name = op_info.name
|
||||||
|
if op_info_hop_name in hop_that_doesnt_have_opinfo_test_allowlist:
|
||||||
|
continue
|
||||||
|
hop_tests.append(op_info)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHOPGeneric(TestCase):
|
||||||
|
def test_all_hops_have_op_info(self):
|
||||||
|
from torch._ops import _higher_order_ops
|
||||||
|
|
||||||
|
hops_that_have_op_info = set([k.name for k in hop_db])
|
||||||
|
all_hops = _higher_order_ops.keys()
|
||||||
|
|
||||||
|
missing_ops = []
|
||||||
|
|
||||||
|
for op in all_hops:
|
||||||
|
if (
|
||||||
|
op not in hops_that_have_op_info
|
||||||
|
and op not in hop_that_doesnt_have_opinfo_test_allowlist
|
||||||
|
):
|
||||||
|
missing_ops.append(op)
|
||||||
|
|
||||||
|
self.assertTrue(len(missing_ops) == 0, f"Missing op info for {missing_ops}")
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
|
||||||
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||||
|
class TestHOP(TestCase):
|
||||||
|
def _compare(self, eager_model, export, args, kwargs):
|
||||||
|
eager_args = copy.deepcopy(args)
|
||||||
|
eager_kwargs = copy.deepcopy(kwargs)
|
||||||
|
export_args = copy.deepcopy(args)
|
||||||
|
export_kwargs = copy.deepcopy(kwargs)
|
||||||
|
|
||||||
|
flat_orig_outputs = pytree.tree_leaves(eager_model(*eager_args, **eager_kwargs))
|
||||||
|
flat_loaded_outputs = pytree.tree_leaves(
|
||||||
|
export.module()(*export_args, **export_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
|
||||||
|
self.assertEqual(type(orig), type(loaded))
|
||||||
|
self.assertEqual(orig, loaded)
|
||||||
|
|
||||||
|
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
|
||||||
|
def test_aot_export(self, device, dtype, op):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, *args):
|
||||||
|
return op.op(*args)
|
||||||
|
|
||||||
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
|
||||||
|
for inp in sample_inputs_itr:
|
||||||
|
model = Foo()
|
||||||
|
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
||||||
|
args = (*input, *inp.args)
|
||||||
|
kwargs = inp.kwargs
|
||||||
|
ep = export(model, args, kwargs)
|
||||||
|
self._compare(model, ep, args, kwargs)
|
||||||
|
|
||||||
|
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
|
||||||
|
def test_pre_dispatch_export(self, device, dtype, op):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, *args):
|
||||||
|
return op.op(*args)
|
||||||
|
|
||||||
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
|
||||||
|
for inp in sample_inputs_itr:
|
||||||
|
model = Foo()
|
||||||
|
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
||||||
|
args = (*input, *inp.args)
|
||||||
|
kwargs = inp.kwargs
|
||||||
|
ep = _export(model, args, kwargs, pre_dispatch=True)
|
||||||
|
self._compare(model, ep, args, kwargs)
|
||||||
|
|
||||||
|
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
|
||||||
|
def test_retrace_export(self, device, dtype, op):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, *args):
|
||||||
|
return op.op(*args)
|
||||||
|
|
||||||
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
|
||||||
|
for inp in sample_inputs_itr:
|
||||||
|
model = Foo()
|
||||||
|
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
||||||
|
args = (*input, *inp.args)
|
||||||
|
kwargs = inp.kwargs
|
||||||
|
ep = _export(model, args, kwargs, pre_dispatch=True)
|
||||||
|
ep = ep.run_decompositions()
|
||||||
|
self._compare(model, ep, args, kwargs)
|
||||||
|
|
||||||
|
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
|
||||||
|
def test_serialize_export(self, device, dtype, op):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, *args):
|
||||||
|
return op.op(*args)
|
||||||
|
|
||||||
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
|
||||||
|
for inp in sample_inputs_itr:
|
||||||
|
model = Foo()
|
||||||
|
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
||||||
|
args = (*input, *inp.args)
|
||||||
|
kwargs = inp.kwargs
|
||||||
|
ep = _export(model, args, kwargs, pre_dispatch=True)
|
||||||
|
ep = ep.run_decompositions()
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
save(ep, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
ep = load(buffer)
|
||||||
|
self._compare(model, ep, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_device_type_tests(TestHOP, globals())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
@ -35,8 +35,8 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
|
|||||||
from torch.testing._internal.common_methods_invocations import op_db
|
from torch.testing._internal.common_methods_invocations import op_db
|
||||||
from torch.testing._internal.common_modules import module_db, modules
|
from torch.testing._internal.common_modules import module_db, modules
|
||||||
from torch.testing._internal.common_utils import parametrize, instantiate_parametrized_tests
|
from torch.testing._internal.common_utils import parametrize, instantiate_parametrized_tests
|
||||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
|
||||||
from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check
|
from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check
|
||||||
|
from torch.testing._internal.hop_db import hop_db
|
||||||
from torch._higher_order_ops.out_dtype import out_dtype
|
from torch._higher_order_ops.out_dtype import out_dtype
|
||||||
from functorch import (
|
from functorch import (
|
||||||
grad, vjp, vmap, jacrev,
|
grad, vjp, vmap, jacrev,
|
||||||
@ -4606,12 +4606,12 @@ def _test_aot_autograd_module_helper(self, device, dtype, training, module_info,
|
|||||||
|
|
||||||
|
|
||||||
class TestEagerFusionOpInfo(AOTTestCase):
|
class TestEagerFusionOpInfo(AOTTestCase):
|
||||||
@ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
|
||||||
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', aot_autograd_failures)
|
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', aot_autograd_failures)
|
||||||
def test_aot_autograd_exhaustive(self, device, dtype, op):
|
def test_aot_autograd_exhaustive(self, device, dtype, op):
|
||||||
_test_aot_autograd_helper(self, device, dtype, op)
|
_test_aot_autograd_helper(self, device, dtype, op)
|
||||||
|
|
||||||
@ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
|
||||||
@patch("functorch.compile.config.debug_assert", True)
|
@patch("functorch.compile.config.debug_assert", True)
|
||||||
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive',
|
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive',
|
||||||
aot_autograd_failures | symbolic_aot_autograd_failures)
|
aot_autograd_failures | symbolic_aot_autograd_failures)
|
||||||
|
@ -5,8 +5,8 @@ import torch
|
|||||||
|
|
||||||
from torch.testing._internal.common_utils import TestGradients, run_tests, TestCase
|
from torch.testing._internal.common_utils import TestGradients, run_tests, TestCase
|
||||||
from torch.testing._internal.common_methods_invocations import op_db
|
from torch.testing._internal.common_methods_invocations import op_db
|
||||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
|
||||||
from torch.testing._internal.custom_op_db import custom_op_db
|
from torch.testing._internal.custom_op_db import custom_op_db
|
||||||
|
from torch.testing._internal.hop_db import hop_db
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(instantiate_device_type_tests, ops, OpDTypes)
|
(instantiate_device_type_tests, ops, OpDTypes)
|
||||||
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
|
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
|
||||||
@ -18,7 +18,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
|||||||
@unMarkDynamoStrictTest
|
@unMarkDynamoStrictTest
|
||||||
class TestBwdGradients(TestGradients):
|
class TestBwdGradients(TestGradients):
|
||||||
# Tests that gradients are computed correctly
|
# Tests that gradients are computed correctly
|
||||||
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
|
@_gradcheck_ops(op_db + hop_db + custom_op_db)
|
||||||
def test_fn_grad(self, device, dtype, op):
|
def test_fn_grad(self, device, dtype, op):
|
||||||
# This is verified by test_dtypes in test_ops.py
|
# This is verified by test_dtypes in test_ops.py
|
||||||
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
||||||
@ -52,7 +52,7 @@ class TestBwdGradients(TestGradients):
|
|||||||
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
||||||
|
|
||||||
# Test that gradients of gradients are computed correctly
|
# Test that gradients of gradients are computed correctly
|
||||||
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
|
@_gradcheck_ops(op_db + hop_db + custom_op_db)
|
||||||
def test_fn_gradgrad(self, device, dtype, op):
|
def test_fn_gradgrad(self, device, dtype, op):
|
||||||
self._skip_helper(op, device, dtype)
|
self._skip_helper(op, device, dtype)
|
||||||
if not op.supports_gradgrad:
|
if not op.supports_gradgrad:
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
import torch
|
import torch
|
||||||
|
import torch._dynamo
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
import operator
|
import operator
|
||||||
@ -17,7 +18,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||||||
guard_int, GuardOnDataDependentSymNode
|
guard_int, GuardOnDataDependentSymNode
|
||||||
)
|
)
|
||||||
from torch.testing._internal.custom_op_db import custom_op_db
|
from torch.testing._internal.custom_op_db import custom_op_db
|
||||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
from torch.testing._internal.hop_db import hop_db
|
||||||
from torch.testing._internal.common_device_type import ops
|
from torch.testing._internal.common_device_type import ops
|
||||||
import torch.testing._internal.optests as optests
|
import torch.testing._internal.optests as optests
|
||||||
from torch._C import _disabled_torch_function_impl
|
from torch._C import _disabled_torch_function_impl
|
||||||
@ -2012,18 +2013,34 @@ def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, o
|
|||||||
self.skipTest("Dynamic output shape operation in trace")
|
self.skipTest("Dynamic output shape operation in trace")
|
||||||
|
|
||||||
|
|
||||||
|
def skipIfNameMatches(pattern):
|
||||||
|
"""
|
||||||
|
Decorator to skip a test if its name matches the given pattern.
|
||||||
|
"""
|
||||||
|
def decorator(test_func):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if re.match(pattern, test_func.__name__):
|
||||||
|
raise unittest.SkipTest(f"Test '{test_func.__name__}' skipped because its name matches the pattern '{pattern}'")
|
||||||
|
return test_func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
# Auto functionalize shouldn't work with make_fx directly
|
||||||
|
filtered_hop_db = [op for op in hop_db if op.name != "auto_functionalize"]
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond requires dynamo")
|
||||||
class TestProxyTensorOpInfo(TestCase):
|
class TestProxyTensorOpInfo(TestCase):
|
||||||
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
@ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
|
||||||
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
|
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
|
||||||
def test_make_fx_exhaustive(self, device, dtype, op):
|
def test_make_fx_exhaustive(self, device, dtype, op):
|
||||||
_test_make_fx_helper(self, device, dtype, op, "real")
|
_test_make_fx_helper(self, device, dtype, op, "real")
|
||||||
|
|
||||||
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
@ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
|
||||||
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
|
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
|
||||||
def test_make_fx_fake_exhaustive(self, device, dtype, op):
|
def test_make_fx_fake_exhaustive(self, device, dtype, op):
|
||||||
_test_make_fx_helper(self, device, dtype, op, "fake")
|
_test_make_fx_helper(self, device, dtype, op, "fake")
|
||||||
|
|
||||||
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
|
@ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
|
||||||
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
|
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
|
||||||
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
|
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
|
||||||
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
|
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
|
||||||
|
@ -1,77 +0,0 @@
|
|||||||
# mypy: ignore-errors
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import functools
|
|
||||||
from torch.testing import make_tensor
|
|
||||||
from functorch.experimental.control_flow import map
|
|
||||||
from torch.testing._internal.opinfo.core import (
|
|
||||||
OpInfo,
|
|
||||||
SampleInput,
|
|
||||||
)
|
|
||||||
from torch.testing._internal.common_dtype import all_types_and
|
|
||||||
|
|
||||||
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
|
|
||||||
make_arg = functools.partial(
|
|
||||||
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
|
||||||
yield SampleInput([make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
|
|
||||||
args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)))
|
|
||||||
|
|
||||||
def inner_f(x, y0, y1):
|
|
||||||
return [x[0].cos().add_(1.) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
|
|
||||||
|
|
||||||
def simple_map(xs, y0, y1):
|
|
||||||
def f(x, y0, y1):
|
|
||||||
return inner_f(x, y0, y1)
|
|
||||||
return map(f, xs, y0, y1)
|
|
||||||
|
|
||||||
def nested_map(xs, y0, y1):
|
|
||||||
def f1(xx, y0, y1):
|
|
||||||
def f2(x, y0, y1):
|
|
||||||
return inner_f(x, y0, y1)
|
|
||||||
return map(f2, xx, y0, y1)
|
|
||||||
return map(f1, xs, y0, y1)
|
|
||||||
|
|
||||||
def triple_nested_map(xs, y0, y1):
|
|
||||||
def f0(xs, y0, y1):
|
|
||||||
def f1(xx, y0, y1):
|
|
||||||
def f2(x, y0, y1):
|
|
||||||
return inner_f(x, y0, y1)
|
|
||||||
return map(f2, xx, y0, y1)
|
|
||||||
return map(f1, xs, y0, y1)
|
|
||||||
return map(f0, xs, y0, y1)
|
|
||||||
|
|
||||||
control_flow_opinfo_db = [
|
|
||||||
OpInfo(
|
|
||||||
"MapControlflowOp",
|
|
||||||
op=simple_map,
|
|
||||||
sample_inputs_func=sample_inputs_map,
|
|
||||||
dtypes=all_types_and(torch.bool, torch.half),
|
|
||||||
supports_out=False,
|
|
||||||
check_batched_grad=False,
|
|
||||||
check_batched_gradgrad=False,
|
|
||||||
check_batched_forward_grad=False,
|
|
||||||
check_inplace_batched_forward_grad=False,
|
|
||||||
),
|
|
||||||
OpInfo(
|
|
||||||
"NestedMapControlflowOp",
|
|
||||||
op=nested_map,
|
|
||||||
sample_inputs_func=sample_inputs_map,
|
|
||||||
dtypes=all_types_and(torch.bool, torch.half),
|
|
||||||
supports_out=False,
|
|
||||||
check_batched_grad=False,
|
|
||||||
check_batched_gradgrad=False,
|
|
||||||
check_batched_forward_grad=False,
|
|
||||||
check_inplace_batched_forward_grad=False,
|
|
||||||
),
|
|
||||||
OpInfo(
|
|
||||||
"TripleNestedMapControlflowOp",
|
|
||||||
op=triple_nested_map,
|
|
||||||
sample_inputs_func=sample_inputs_map,
|
|
||||||
dtypes=all_types_and(torch.bool, torch.half),
|
|
||||||
supports_out=False,
|
|
||||||
check_batched_grad=False,
|
|
||||||
check_batched_gradgrad=False,
|
|
||||||
check_batched_forward_grad=False,
|
|
||||||
check_inplace_batched_forward_grad=False,
|
|
||||||
)
|
|
||||||
]
|
|
171
torch/testing/_internal/hop_db.py
Normal file
171
torch/testing/_internal/hop_db.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
# mypy: ignore-errors
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import functools
|
||||||
|
from torch.testing import make_tensor
|
||||||
|
from functorch.experimental.control_flow import map
|
||||||
|
from torch.testing._internal.opinfo.core import (
|
||||||
|
OpInfo,
|
||||||
|
SampleInput,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_dtype import all_types_and
|
||||||
|
|
||||||
|
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
|
||||||
|
make_arg = functools.partial(
|
||||||
|
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
|
yield SampleInput([make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
|
||||||
|
args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)))
|
||||||
|
|
||||||
|
def inner_f(x, y0, y1):
|
||||||
|
return [x[0].cos().add_(1.) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
|
||||||
|
|
||||||
|
def simple_map(xs, y0, y1):
|
||||||
|
def f(x, y0, y1):
|
||||||
|
return inner_f(x, y0, y1)
|
||||||
|
return map(f, xs, y0, y1)
|
||||||
|
|
||||||
|
def nested_map(xs, y0, y1):
|
||||||
|
def f1(xx, y0, y1):
|
||||||
|
def f2(x, y0, y1):
|
||||||
|
return inner_f(x, y0, y1)
|
||||||
|
return map(f2, xx, y0, y1)
|
||||||
|
return map(f1, xs, y0, y1)
|
||||||
|
|
||||||
|
def triple_nested_map(xs, y0, y1):
|
||||||
|
def f0(xs, y0, y1):
|
||||||
|
def f1(xx, y0, y1):
|
||||||
|
def f2(x, y0, y1):
|
||||||
|
return inner_f(x, y0, y1)
|
||||||
|
return map(f2, xx, y0, y1)
|
||||||
|
return map(f1, xs, y0, y1)
|
||||||
|
return map(f0, xs, y0, y1)
|
||||||
|
|
||||||
|
|
||||||
|
# Please consult with torch.export team before
|
||||||
|
# adding new entry to this list.
|
||||||
|
hop_that_doesnt_have_opinfo_test_allowlist = [
|
||||||
|
"custom_function_call",
|
||||||
|
"autograd_function_apply",
|
||||||
|
"run_and_save_rng_state",
|
||||||
|
"run_with_rng_state",
|
||||||
|
"out_dtype",
|
||||||
|
"trace_wrapped",
|
||||||
|
"map",
|
||||||
|
"map_impl",
|
||||||
|
"with_effects",
|
||||||
|
"strict_mode",
|
||||||
|
"_export_tracepoint",
|
||||||
|
"while_loop",
|
||||||
|
]
|
||||||
|
|
||||||
|
torch.library.define(
|
||||||
|
"testlib::mutating_custom_op",
|
||||||
|
"(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
|
||||||
|
tags=torch.Tag.pt2_compliant_tag,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.impl("testlib::mutating_custom_op", "cpu")
|
||||||
|
def foo_impl_cpu(x, z):
|
||||||
|
x.add_(5)
|
||||||
|
z.add_(5)
|
||||||
|
return x, z, x + z
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.impl("testlib::mutating_custom_op", "cuda")
|
||||||
|
def foo_impl_cuda(x, z):
|
||||||
|
x.add_(5)
|
||||||
|
z.add_(5)
|
||||||
|
return x, z, x + z
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.impl_abstract("testlib::mutating_custom_op")
|
||||||
|
def foo_impl_abstract(x, z):
|
||||||
|
return x, z, x + z
|
||||||
|
|
||||||
|
|
||||||
|
def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
|
||||||
|
make_arg = functools.partial(
|
||||||
|
make_tensor, device=device, dtype=dtype, requires_grad=False
|
||||||
|
)
|
||||||
|
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
|
||||||
|
|
||||||
|
|
||||||
|
def simple_cond(x):
|
||||||
|
return torch.cond(x.shape[0] > 2, lambda x: x.cos(), lambda x: x.sin(), [x])
|
||||||
|
|
||||||
|
|
||||||
|
def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
|
||||||
|
make_arg = functools.partial(
|
||||||
|
make_tensor, device=device, dtype=dtype, requires_grad=False
|
||||||
|
)
|
||||||
|
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2))
|
||||||
|
|
||||||
|
|
||||||
|
def simple_auto_functionalize(x, z):
|
||||||
|
return torch.ops.testlib.mutating_custom_op(x, z)
|
||||||
|
|
||||||
|
hop_db = [
|
||||||
|
OpInfo(
|
||||||
|
name="map",
|
||||||
|
variant_test_name="simple",
|
||||||
|
op=simple_map,
|
||||||
|
sample_inputs_func=sample_inputs_map,
|
||||||
|
dtypes=all_types_and(torch.bool, torch.half),
|
||||||
|
supports_out=False,
|
||||||
|
check_batched_grad=False,
|
||||||
|
check_batched_gradgrad=False,
|
||||||
|
check_batched_forward_grad=False,
|
||||||
|
check_inplace_batched_forward_grad=False,
|
||||||
|
),
|
||||||
|
OpInfo(
|
||||||
|
name="map",
|
||||||
|
variant_test_name="nested",
|
||||||
|
op=nested_map,
|
||||||
|
sample_inputs_func=sample_inputs_map,
|
||||||
|
dtypes=all_types_and(torch.bool, torch.half),
|
||||||
|
supports_out=False,
|
||||||
|
check_batched_grad=False,
|
||||||
|
check_batched_gradgrad=False,
|
||||||
|
check_batched_forward_grad=False,
|
||||||
|
check_inplace_batched_forward_grad=False,
|
||||||
|
),
|
||||||
|
OpInfo(
|
||||||
|
name="map",
|
||||||
|
variant_test_name="triple_nested",
|
||||||
|
op=triple_nested_map,
|
||||||
|
sample_inputs_func=sample_inputs_map,
|
||||||
|
dtypes=all_types_and(torch.bool, torch.half),
|
||||||
|
supports_out=False,
|
||||||
|
check_batched_grad=False,
|
||||||
|
check_batched_gradgrad=False,
|
||||||
|
check_batched_forward_grad=False,
|
||||||
|
check_inplace_batched_forward_grad=False,
|
||||||
|
),
|
||||||
|
OpInfo(
|
||||||
|
name="cond",
|
||||||
|
variant_test_name="simple",
|
||||||
|
op=simple_cond,
|
||||||
|
sample_inputs_func=sample_inputs_cond,
|
||||||
|
dtypes=all_types_and(torch.bool, torch.half),
|
||||||
|
supports_out=False,
|
||||||
|
check_batched_grad=False,
|
||||||
|
check_batched_gradgrad=False,
|
||||||
|
check_batched_forward_grad=False,
|
||||||
|
check_inplace_batched_forward_grad=False,
|
||||||
|
supports_autograd=False,
|
||||||
|
),
|
||||||
|
OpInfo(
|
||||||
|
name="auto_functionalize",
|
||||||
|
variant_test_name="simple",
|
||||||
|
op=simple_auto_functionalize,
|
||||||
|
sample_inputs_func=sample_inputs_auto_functionalize,
|
||||||
|
dtypes=all_types_and(torch.bool, torch.half),
|
||||||
|
supports_out=False,
|
||||||
|
check_batched_grad=False,
|
||||||
|
check_batched_gradgrad=False,
|
||||||
|
check_batched_forward_grad=False,
|
||||||
|
check_inplace_batched_forward_grad=False,
|
||||||
|
supports_autograd=False,
|
||||||
|
)
|
||||||
|
]
|
Reference in New Issue
Block a user