mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[optests] Add dontGenerateOpCheckTests and is_inside_opcheck_mode (#110951)
This PR adds the following helper functions for generated opcheck tests: - dontGenerateOpCheckTests is a decorator that skips generation of the opcheck tests for the generated function - is_inside_opcheck_mode lets us query if we are in a generated test. Useful for fast debugging out-of-tree without needing to update PyTorch. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/110951 Approved by: https://github.com/williamwen42
This commit is contained in:
@ -1717,6 +1717,12 @@ class MiniOpTest(CustomOpTestCaseBase):
|
||||
lib.impl(name, lambda x: x.clone(), "CPU")
|
||||
return self.get_op(qualname)
|
||||
|
||||
@optests.dontGenerateOpCheckTests("Testing this API")
|
||||
def test_dont_generate(self):
|
||||
op = op_with_incorrect_schema(self, "incorrect_schema")
|
||||
x = torch.randn(3)
|
||||
op(x)
|
||||
|
||||
def test_mm(self):
|
||||
x = torch.randn(2, 3, requires_grad=True)
|
||||
y = torch.randn(3, 5)
|
||||
@ -1940,6 +1946,10 @@ opcheck(op, args, kwargs, test_utils="test_schema")
|
||||
FailuresDict("", failures), mini_op_test_checks, MiniOpTest
|
||||
)
|
||||
|
||||
def test_dont_generate_decorator(self):
|
||||
self.assertTrue(hasattr(MiniOpTest, "test_dont_generate"))
|
||||
self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate"))
|
||||
|
||||
def test_opcheck(self):
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
with self.assertRaisesRegex(ValueError, "OpOverload"):
|
||||
@ -1982,6 +1992,13 @@ opcheck(op, args, kwargs, test_utils="test_schema")
|
||||
},
|
||||
)
|
||||
|
||||
def test_is_inside_opcheck_mode(self):
|
||||
self.assertFalse(optests.is_inside_opcheck_mode())
|
||||
with optests.generate_tests.OpCheckMode(
|
||||
["foo"], "bar", lambda x: x, None, "baz", "brr"
|
||||
):
|
||||
self.assertTrue(optests.is_inside_opcheck_mode())
|
||||
|
||||
def test_opcheck_bad_op(self):
|
||||
op = op_with_incorrect_schema(self, "foo")
|
||||
x = torch.randn(3)
|
||||
|
@ -2,4 +2,4 @@ from .make_fx import make_fx_check
|
||||
from .aot_autograd import aot_autograd_check, _test_aot_autograd_forwards_backwards_helper
|
||||
from .fake_tensor import fake_check
|
||||
from .autograd_registration import autograd_registration_check
|
||||
from .generate_tests import generate_opcheck_tests, opcheck, OpCheckError
|
||||
from .generate_tests import generate_opcheck_tests, opcheck, OpCheckError, dontGenerateOpCheckTests, is_inside_opcheck_mode
|
||||
|
@ -4,6 +4,7 @@ import functools
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -21,6 +22,14 @@ from torch.testing._internal.optests import (
|
||||
)
|
||||
|
||||
|
||||
def dontGenerateOpCheckTests(reason: str):
|
||||
def inner(fun):
|
||||
fun._torch_dont_generate_opcheck_tests = True
|
||||
return fun
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def is_abstract(tensor: torch.Tensor) -> bool:
|
||||
if tensor.is_meta:
|
||||
return True
|
||||
@ -161,6 +170,8 @@ def generate_opcheck_tests(
|
||||
|
||||
def construct_method(attr, prefix, tester):
|
||||
method = getattr(testcase, attr)
|
||||
if getattr(method, "_torch_dont_generate_opcheck_tests", False):
|
||||
return
|
||||
new_method_name = prefix + "__" + attr
|
||||
|
||||
@functools.wraps(method)
|
||||
@ -319,6 +330,14 @@ def should_update_failures_dict() -> bool:
|
||||
return key in os.environ and os.environ[key] == "1"
|
||||
|
||||
|
||||
_is_inside_opcheck_mode = threading.local()
|
||||
_is_inside_opcheck_mode.value = False
|
||||
|
||||
|
||||
def is_inside_opcheck_mode():
|
||||
return _is_inside_opcheck_mode.value
|
||||
|
||||
|
||||
class OpCheckMode(TorchFunctionMode):
|
||||
"""
|
||||
For a given test, OpCheckMode intercepts calls to operators and runs
|
||||
@ -408,7 +427,13 @@ class OpCheckMode(TorchFunctionMode):
|
||||
f"To reproduce this problem locally, try to run the following:\n{repro_command}"
|
||||
) from ex
|
||||
|
||||
def __enter__(self, *args, **kwargs):
|
||||
self.prev_is_opcheck_mode = _is_inside_opcheck_mode.value
|
||||
_is_inside_opcheck_mode.value = True
|
||||
return super().__enter__(*args, **kwargs)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
_is_inside_opcheck_mode.value = self.prev_is_opcheck_mode
|
||||
try:
|
||||
self.maybe_raise_errors_on_exit()
|
||||
if should_update_failures_dict():
|
||||
|
Reference in New Issue
Block a user