mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fix setUpClass() / tearDownClass() for device-specific tests (#151129)"
This reverts commit bd4cf30e31a2a0b0a57f54c7eedd3a39d5778cbe. Reverted https://github.com/pytorch/pytorch/pull/151129 on behalf of https://github.com/jbschlosser due to flex attention tests failing ([comment](https://github.com/pytorch/pytorch/pull/151129#issuecomment-2807632119))
This commit is contained in:
@ -125,12 +125,12 @@ class TestLinalg(TestCase):
|
||||
del os.environ["HIPBLASLT_ALLOW_TF32"]
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
super(self.__class__, self).setUp()
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
def tearDown(self):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
super().tearDown()
|
||||
super(self.__class__, self).tearDown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _tunableop_ctx(self):
|
||||
|
@ -64,12 +64,12 @@ assert torch.get_default_dtype() is torch.float32
|
||||
@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm")
|
||||
class TestMatmulCuda(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
super(self.__class__, self).setUp()
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
def tearDown(self):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
super().tearDown()
|
||||
super(self.__class__, self).tearDown()
|
||||
|
||||
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False):
|
||||
#
|
||||
|
@ -859,7 +859,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
||||
super().tearDown()
|
||||
super(self.__class__, self).tearDown()
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
|
||||
@inference_dtypes
|
||||
|
@ -440,40 +440,6 @@ if __name__ == '__main__':
|
||||
op.supported_dtypes(torch.device("cuda", index=1)),
|
||||
)
|
||||
|
||||
def test_setup_and_teardown_run_for_device_specific_tests(self, device):
|
||||
# TODO: Move this (and other similar text blocks) to some fixtures/ subdir
|
||||
stderr = TestCase.runWithPytorchAPIUsageStderr(f"""\
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
class TestFoo(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# store something on the test class to query during teardown
|
||||
cls.stored_thing = "called with " + cls.__name__
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# throw here so we know teardown was run
|
||||
raise RuntimeError(cls.stored_thing)
|
||||
|
||||
def test_bar(self, device):
|
||||
# make sure the test can access the stored thing
|
||||
print(self.stored_thing)
|
||||
|
||||
instantiate_device_type_tests(TestFoo, globals(), only_for='{self.device_type}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
""")
|
||||
expected_device_class_name = f"TestFoo{self.device_type.upper()}"
|
||||
expected_error_text = f"RuntimeError: called with {expected_device_class_name}"
|
||||
self.assertIn(expected_error_text, stderr)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestTesting, globals())
|
||||
|
||||
|
||||
|
@ -281,6 +281,35 @@ except ModuleNotFoundError:
|
||||
# they are run. This makes it useful for initializing devices and dependencies.
|
||||
|
||||
|
||||
# Note [Overriding methods in generic tests]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# Device generic tests look a lot like normal test classes, but they differ
|
||||
# from ordinary classes in some important ways. In particular, overriding
|
||||
# methods in generic tests doesn't work quite the way you expect.
|
||||
#
|
||||
# class TestFooDeviceType(TestCase):
|
||||
# # Intention is to override
|
||||
# def assertEqual(self, x, y):
|
||||
# # This DOESN'T WORK!
|
||||
# super().assertEqual(x, y)
|
||||
#
|
||||
# If you try to run this code, you'll get an error saying that TestFooDeviceType
|
||||
# is not in scope. This is because after instantiating our classes, we delete
|
||||
# it from the parent scope. Instead, you need to hardcode a direct invocation
|
||||
# of the desired subclass call, e.g.,
|
||||
#
|
||||
# class TestFooDeviceType(TestCase):
|
||||
# # Intention is to override
|
||||
# def assertEqual(self, x, y):
|
||||
# TestCase.assertEqual(x, y)
|
||||
#
|
||||
# However, a less error-prone way of customizing the behavior of TestCase
|
||||
# is to either (1) add your functionality to TestCase and make it toggled
|
||||
# by a class attribute, or (2) create your own subclass of TestCase, and
|
||||
# then inherit from it for your generic test.
|
||||
|
||||
|
||||
def _dtype_test_suffix(dtypes):
|
||||
"""Returns the test suffix for a dtype, sequence of dtypes, or None."""
|
||||
if isinstance(dtypes, (list, tuple)):
|
||||
@ -864,7 +893,20 @@ def instantiate_device_type_tests(
|
||||
# are not discoverable.
|
||||
del scope[generic_test_class.__name__]
|
||||
|
||||
generic_members = set(generic_test_class.__dict__.keys())
|
||||
# Creates an 'empty' version of the generic_test_class
|
||||
# Note: we don't inherit from the generic_test_class directly because
|
||||
# that would add its tests to our test classes and they would be
|
||||
# discovered (despite not being runnable). Inherited methods also
|
||||
# can't be removed later, and we can't rely on load_tests because
|
||||
# pytest doesn't support it (as of this writing).
|
||||
empty_name = generic_test_class.__name__ + "_base"
|
||||
empty_class = type(empty_name, generic_test_class.__bases__, {})
|
||||
|
||||
# Acquires members names
|
||||
# See Note [Overriding methods in generic tests]
|
||||
generic_members = set(generic_test_class.__dict__.keys()) - set(
|
||||
empty_class.__dict__.keys()
|
||||
)
|
||||
generic_tests = [x for x in generic_members if x.startswith("test")]
|
||||
|
||||
# Creates device-specific test cases
|
||||
@ -875,30 +917,7 @@ def instantiate_device_type_tests(
|
||||
|
||||
# type set to Any and suppressed due to unsupport runtime class:
|
||||
# https://github.com/python/mypy/wiki/Unsupported-Python-Features
|
||||
device_type_test_class: Any = type(class_name, (base, generic_test_class), {})
|
||||
|
||||
# Arrange for setUpClass and tearDownClass methods defined both in the test template
|
||||
# class and in the generic base to be called. This allows device-parameterized test
|
||||
# classes to support setup and teardown.
|
||||
# NB: This should be done before instantiate_test() is called as that invokes setup.
|
||||
@classmethod
|
||||
def _setUpClass(cls):
|
||||
# This should always be called, whether or not the test class invokes
|
||||
# super().setUpClass(), to set the primary device.
|
||||
base.setUpClass()
|
||||
# We want to call the @classmethod defined in the generic base, but pass
|
||||
# it the device-specific class object (cls), hence the __func__ call.
|
||||
generic_test_class.setUpClass.__func__(cls)
|
||||
|
||||
@classmethod
|
||||
def _tearDownClass(cls):
|
||||
# We want to call the @classmethod defined in the generic base, but pass
|
||||
# it the device-specific class object (cls), hence the __func__ call.
|
||||
generic_test_class.tearDownClass.__func__(cls)
|
||||
base.tearDownClass()
|
||||
|
||||
device_type_test_class.setUpClass = _setUpClass
|
||||
device_type_test_class.tearDownClass = _tearDownClass
|
||||
device_type_test_class: Any = type(class_name, (base, empty_class), {})
|
||||
|
||||
for name in generic_members:
|
||||
if name in generic_tests: # Instantiates test member
|
||||
@ -912,11 +931,30 @@ def instantiate_device_type_tests(
|
||||
)
|
||||
else:
|
||||
device_type_test_class.instantiate_test(name, copy.deepcopy(test))
|
||||
# Ports non-test member. Setup / teardown have already been handled above
|
||||
elif name not in device_type_test_class.__dict__:
|
||||
else: # Ports non-test member
|
||||
assert (
|
||||
name not in device_type_test_class.__dict__
|
||||
), f"Redefinition of directly defined member {name}"
|
||||
nontest = getattr(generic_test_class, name)
|
||||
setattr(device_type_test_class, name, nontest)
|
||||
|
||||
# The dynamically-created test class derives from the test template class
|
||||
# and the empty class. Arrange for both setUpClass and tearDownClass methods
|
||||
# to be called. This allows the parameterized test classes to support setup
|
||||
# and teardown.
|
||||
@classmethod
|
||||
def _setUpClass(cls):
|
||||
base.setUpClass()
|
||||
empty_class.setUpClass()
|
||||
|
||||
@classmethod
|
||||
def _tearDownClass(cls):
|
||||
empty_class.tearDownClass()
|
||||
base.tearDownClass()
|
||||
|
||||
device_type_test_class.setUpClass = _setUpClass
|
||||
device_type_test_class.tearDownClass = _tearDownClass
|
||||
|
||||
# Mimics defining the instantiated class in the caller's file
|
||||
# by setting its module to the given class's and adding
|
||||
# the module to the given scope.
|
||||
@ -924,13 +962,6 @@ def instantiate_device_type_tests(
|
||||
device_type_test_class.__module__ = generic_test_class.__module__
|
||||
scope[class_name] = device_type_test_class
|
||||
|
||||
# Delete the generic form of the test functions (e.g. TestFoo.test_bar()) so they're
|
||||
# not discoverable. This mutates the original class (TestFoo), which was removed from
|
||||
# scope above. At this point, device-specific tests (e.g. TestFooCUDA.test_bar_cuda)
|
||||
# have already been created and the generic forms are no longer needed.
|
||||
for name in generic_tests:
|
||||
delattr(generic_test_class, name)
|
||||
|
||||
|
||||
# Category of dtypes to run an OpInfo-based test for
|
||||
# Example use: @ops(dtype=OpDTypes.supported)
|
||||
|
Reference in New Issue
Block a user