mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add Opinfo entries for HOP testing (#122265)
In this PR, we add a systematic way to test all HOPs to be exportable as export team has been running into various bugs related to newly added HOPs due to lack of tests. We do this by creating: - hop_db -> a list of HOP OpInfo tests which then used inside various flows including export functionalities: [aot-export, pre-dispatch export, retrace, and ser/der For now, we also create an allowlist so that people can bypass the failures for now. But we should discourage ppl to do that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122265 Approved by: https://github.com/ydwu4, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
0bfa9f4758
commit
d9a08de9a4
@ -5,8 +5,8 @@ import torch
|
||||
|
||||
from torch.testing._internal.common_utils import TestGradients, run_tests, TestCase
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.testing._internal.hop_db import hop_db
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, OpDTypes)
|
||||
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
|
||||
@ -18,7 +18,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
@unMarkDynamoStrictTest
|
||||
class TestBwdGradients(TestGradients):
|
||||
# Tests that gradients are computed correctly
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
|
||||
@_gradcheck_ops(op_db + hop_db + custom_op_db)
|
||||
def test_fn_grad(self, device, dtype, op):
|
||||
# This is verified by test_dtypes in test_ops.py
|
||||
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
||||
@ -52,7 +52,7 @@ class TestBwdGradients(TestGradients):
|
||||
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
||||
|
||||
# Test that gradients of gradients are computed correctly
|
||||
@_gradcheck_ops(op_db + control_flow_opinfo_db + custom_op_db)
|
||||
@_gradcheck_ops(op_db + hop_db + custom_op_db)
|
||||
def test_fn_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.supports_gradgrad:
|
||||
|
Reference in New Issue
Block a user