Better hop_db comment; move test to a non-export test file (#145938)

Goal is for people to better test their HOPs.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145938
Approved by: https://github.com/ydwu4
This commit is contained in:
rzou
2025-01-29 08:40:20 -08:00
committed by PyTorch MergeBot
parent e02c038a23
commit 2141c1aebe
4 changed files with 52 additions and 25 deletions

View File

@ -73,7 +73,6 @@ nn/qat/ @jerryzh168
/test/run_test.py @pytorch/pytorch-dev-infra
/torch/testing/_internal/common_device_type.py @mruberry
/torch/testing/_internal/common_utils.py @pytorch/pytorch-dev-infra
/torch/testing/_internal/hop_db.py @tugsbayasgalan @zou3519 @ydwu4
# Parametrizations
/torch/nn/utils/parametriz*.py @lezcano
@ -106,6 +105,8 @@ test/functorch/test_vmap.py @zou3519 @chillee @kshitij12345
# HOPs
torch/_higher_order_ops/*.py @zou3519
torch/_dynamo/variables/higher_order_ops.py @zou3519
test/test_hop_infra.py @zou3519
torch/testing/_internal/hop_db.py @tugsbayasgalan @zou3519 @ydwu4
# AOTAutograd
torch/_functorch/_aot_autograd/*.py @bdhirsh

View File

@ -20,8 +20,8 @@ from torch.testing._internal.common_utils import (
TestCase as TorchTestCase,
)
from torch.testing._internal.hop_db import (
FIXME_hop_that_doesnt_have_opinfo_test_allowlist,
hop_db,
hop_that_doesnt_have_opinfo_test_allowlist,
)
@ -29,30 +29,11 @@ hop_tests = []
for op_info in hop_db:
op_info_hop_name = op_info.name
if op_info_hop_name in hop_that_doesnt_have_opinfo_test_allowlist:
if op_info_hop_name in FIXME_hop_that_doesnt_have_opinfo_test_allowlist:
continue
hop_tests.append(op_info)
class TestHOPGeneric(TestCase):
def test_all_hops_have_op_info(self):
from torch._ops import _higher_order_ops
hops_that_have_op_info = set([k.name for k in hop_db])
all_hops = _higher_order_ops.keys()
missing_ops = []
for op in all_hops:
if (
op not in hops_that_have_op_info
and op not in hop_that_doesnt_have_opinfo_test_allowlist
):
missing_ops.append(op)
self.assertTrue(len(missing_ops) == 0, f"Missing op info for {missing_ops}")
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestHOP(TestCase):

33
test/test_hop_infra.py Normal file
View File

@ -0,0 +1,33 @@
# Owner(s): ["module: higher order operators"]
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.hop_db import (
FIXME_hop_that_doesnt_have_opinfo_test_allowlist,
hop_db,
)
@skipIfTorchDynamo("not applicable")
class TestHOPInfra(TestCase):
def test_all_hops_have_opinfo(self):
from torch._ops import _higher_order_ops
hops_that_have_op_info = {k.name for k in hop_db}
all_hops = _higher_order_ops.keys()
missing_ops = set()
for op in all_hops:
if (
op not in hops_that_have_op_info
and op not in FIXME_hop_that_doesnt_have_opinfo_test_allowlist
):
missing_ops.add(op)
self.assertTrue(
len(missing_ops) == 0,
f"Missing hop_db OpInfo entries for {missing_ops}, please add them to torch/testing/_internal/hop_db.py",
)
if __name__ == "__main__":
run_tests()

View File

@ -56,9 +56,21 @@ def triple_nested_map(xs, y0, y1):
return map(f0, xs, y0, y1)
# Please consult with torch.export team before
# adding new entry to this list.
hop_that_doesnt_have_opinfo_test_allowlist = [
# PLEASE DON'T ADD ANYTHING NEW TO THIS LIST,
# and do add an OpInfo for your HOP.
# The OpInfo lets us do automated testing for the HOP to check that
# your HOP will work correctly with PyTorch!
#
# Your new HOP may fail some automated testing. That's OK. If you don't
# care about certain features (like torch.export), it's fine to xfail those
# failing tests. It is less fine to xfail a more critical check (like checking
# if torch.compile works with your HOP, or if your HOP has a docstring).
# If you don't know if a test is fine to xfail, please ask.
#
# There are legitimate reasons why something cannot be added to this list
# (e.g. it uses executorch which is not in PyTorch). If that's the case then
# please leave a comment.
FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
"custom_function_call",
"autograd_function_apply",
"run_and_save_rng_state",