Add Opinfo entries for HOP testing (#122265)

In this PR, we add a systematic way to test all HOPs to be exportable as export team has been running into various bugs related to newly added HOPs due to lack of tests. We do this by creating:
- hop_db -> a list of HOP OpInfo tests which then used inside various flows including export functionalities: [aot-export, pre-dispatch export, retrace, and ser/der

For now, we also create an allowlist so that people can bypass the failures for now. But we should discourage ppl to do that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122265
Approved by: https://github.com/ydwu4, https://github.com/zou3519
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2024-03-28 10:33:51 -07:00
committed by PyTorch MergeBot
parent 0bfa9f4758
commit d9a08de9a4
8 changed files with 343 additions and 88 deletions

View File

@ -5,8 +5,8 @@ import torch
from torch.testing._internal.common_utils import TestGradients, run_tests, TestCase
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.hop_db import hop_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
@ -18,7 +18,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
@unMarkDynamoStrictTest
class TestBwdGradients(TestGradients):
# Tests that gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
@_gradcheck_ops(op_db + hop_db + custom_op_db)
def test_fn_grad(self, device, dtype, op):
# This is verified by test_dtypes in test_ops.py
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
@ -52,7 +52,7 @@ class TestBwdGradients(TestGradients):
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
@_gradcheck_ops(op_db + hop_db + custom_op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.supports_gradgrad: