[optests] Add dontGenerateOpCheckTests and is_inside_opcheck_mode (#110951)

This PR adds the following helper functions for generated opcheck tests:
- dontGenerateOpCheckTests is a decorator that skips generation of the
  opcheck tests for the generated function
- is_inside_opcheck_mode lets us query if we are in a generated test.
  Useful for fast debugging out-of-tree without needing to update
  PyTorch.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110951
Approved by: https://github.com/williamwen42
This commit is contained in:
rzou
2023-10-10 08:14:27 -07:00
committed by PyTorch MergeBot
parent d9eb5a57aa
commit 3a29cdc5e6
3 changed files with 43 additions and 1 deletions

View File

@ -1717,6 +1717,12 @@ class MiniOpTest(CustomOpTestCaseBase):
lib.impl(name, lambda x: x.clone(), "CPU")
return self.get_op(qualname)
@optests.dontGenerateOpCheckTests("Testing this API")
def test_dont_generate(self):
op = op_with_incorrect_schema(self, "incorrect_schema")
x = torch.randn(3)
op(x)
def test_mm(self):
x = torch.randn(2, 3, requires_grad=True)
y = torch.randn(3, 5)
@ -1940,6 +1946,10 @@ opcheck(op, args, kwargs, test_utils="test_schema")
FailuresDict("", failures), mini_op_test_checks, MiniOpTest
)
def test_dont_generate_decorator(self):
self.assertTrue(hasattr(MiniOpTest, "test_dont_generate"))
self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate"))
def test_opcheck(self):
x = torch.randn(3, requires_grad=True)
with self.assertRaisesRegex(ValueError, "OpOverload"):
@ -1982,6 +1992,13 @@ opcheck(op, args, kwargs, test_utils="test_schema")
},
)
def test_is_inside_opcheck_mode(self):
self.assertFalse(optests.is_inside_opcheck_mode())
with optests.generate_tests.OpCheckMode(
["foo"], "bar", lambda x: x, None, "baz", "brr"
):
self.assertTrue(optests.is_inside_opcheck_mode())
def test_opcheck_bad_op(self):
op = op_with_incorrect_schema(self, "foo")
x = torch.randn(3)

View File

@ -2,4 +2,4 @@ from .make_fx import make_fx_check
from .aot_autograd import aot_autograd_check, _test_aot_autograd_forwards_backwards_helper
from .fake_tensor import fake_check
from .autograd_registration import autograd_registration_check
from .generate_tests import generate_opcheck_tests, opcheck, OpCheckError
from .generate_tests import generate_opcheck_tests, opcheck, OpCheckError, dontGenerateOpCheckTests, is_inside_opcheck_mode

View File

@ -4,6 +4,7 @@ import functools
import json
import os
import tempfile
import threading
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
@ -21,6 +22,14 @@ from torch.testing._internal.optests import (
)
def dontGenerateOpCheckTests(reason: str):
def inner(fun):
fun._torch_dont_generate_opcheck_tests = True
return fun
return inner
def is_abstract(tensor: torch.Tensor) -> bool:
if tensor.is_meta:
return True
@ -161,6 +170,8 @@ def generate_opcheck_tests(
def construct_method(attr, prefix, tester):
method = getattr(testcase, attr)
if getattr(method, "_torch_dont_generate_opcheck_tests", False):
return
new_method_name = prefix + "__" + attr
@functools.wraps(method)
@ -319,6 +330,14 @@ def should_update_failures_dict() -> bool:
return key in os.environ and os.environ[key] == "1"
_is_inside_opcheck_mode = threading.local()
_is_inside_opcheck_mode.value = False
def is_inside_opcheck_mode():
return _is_inside_opcheck_mode.value
class OpCheckMode(TorchFunctionMode):
"""
For a given test, OpCheckMode intercepts calls to operators and runs
@ -408,7 +427,13 @@ class OpCheckMode(TorchFunctionMode):
f"To reproduce this problem locally, try to run the following:\n{repro_command}"
) from ex
def __enter__(self, *args, **kwargs):
self.prev_is_opcheck_mode = _is_inside_opcheck_mode.value
_is_inside_opcheck_mode.value = True
return super().__enter__(*args, **kwargs)
def __exit__(self, *args, **kwargs):
_is_inside_opcheck_mode.value = self.prev_is_opcheck_mode
try:
self.maybe_raise_errors_on_exit()
if should_update_failures_dict():