mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[generate_opcheck_tests] Enable using same failures_dict for multiple testclasses (#110164)
This PR allows us to use the same failures_dict for multiple test classes. This is helpful if you have a bunch of small TestCase(es) and to centralize all the failures dict into one big one. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/110164 Approved by: https://github.com/williamwen42
This commit is contained in:
@ -7,6 +7,10 @@
|
||||
"MiniOpTest.test_aot_dispatch_static__test_nonzero": {
|
||||
"comment": "",
|
||||
"status": "xfail"
|
||||
},
|
||||
"MiniOpTestOther.test_aot_dispatch_static__test_nonzero_again": {
|
||||
"comment": "",
|
||||
"status": "xfail"
|
||||
}
|
||||
},
|
||||
"aten::sin_": {},
|
||||
|
@ -1776,6 +1776,15 @@ class MiniOpTest(CustomOpTestCaseBase):
|
||||
y = op(x)
|
||||
|
||||
|
||||
class MiniOpTestOther(CustomOpTestCaseBase):
|
||||
test_ns = "mini_op_test"
|
||||
|
||||
def test_nonzero_again(self):
|
||||
x = torch.tensor([0, 1, 2, 0, 0])
|
||||
y = torch.ops.aten.nonzero.default(x)
|
||||
self.assertEqual(y, torch.tensor([[1], [2]]))
|
||||
|
||||
|
||||
mini_op_test_checks = [
|
||||
"test_schema",
|
||||
"test_autograd_registration",
|
||||
@ -1795,6 +1804,17 @@ optests.generate_opcheck_tests(
|
||||
mini_op_test_checks,
|
||||
)
|
||||
|
||||
optests.generate_opcheck_tests(
|
||||
MiniOpTestOther,
|
||||
["aten", "mini_op_test"],
|
||||
get_file_path_2(
|
||||
os.path.dirname(__file__),
|
||||
"minioptest_failures_dict.json",
|
||||
),
|
||||
[],
|
||||
mini_op_test_checks,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateOpcheckTests(CustomOpTestCaseBase):
|
||||
def test_MiniOpTest(self):
|
||||
|
@ -302,9 +302,9 @@ def validate_failures_dict_structure(
|
||||
if not actual_test_name.startswith(test):
|
||||
continue
|
||||
base_test_name = actual_test_name[len(test) + 2 :]
|
||||
if testcase.__name__ == test_class and hasattr(
|
||||
testcase, base_test_name
|
||||
):
|
||||
if testcase.__name__ != test_class:
|
||||
continue
|
||||
if hasattr(testcase, base_test_name):
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"In failures dict, got test name '{test_name}'. We parsed this as "
|
||||
|
Reference in New Issue
Block a user