Fix test_halide.py report invocation to re-run failed tests (#147640)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147640
Approved by: https://github.com/jansel
This commit is contained in:
Isuru Fernando
2025-02-21 20:23:42 +00:00
committed by PyTorch MergeBot
parent acca9b9cb0
commit 4ec6c1d1ec

View File

@ -8,6 +8,7 @@ import unittest
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._dynamo.testing import make_test_cls_with_patches
from torch._inductor import config
from torch._inductor.codecache import HalideCodeCache
from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta
@ -40,14 +41,30 @@ except ImportError:
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
make_halide = config.patch(
{
"halide.scan_kernels": True,
"cpu_backend": "halide",
"cuda_backend": "halide",
}
test_classes = {}
def make_halide(cls):
suffix = "_halide"
cls_prefix = "Halide"
test_class = make_test_cls_with_patches(
cls,
cls_prefix,
suffix,
(config, "halide.scan_kernels", True),
(config, "cpu_backend", "halide"),
(config, "cuda_backend", "halide"),
xfail_prop="_expected_failure_halide",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class
@unittest.skipUnless(HAS_HALIDE, "requires halide")
class HalideTests(TestCase):
@ -260,16 +277,16 @@ class HalideTests(TestCase):
if test_torchinductor.HAS_CPU and HAS_HALIDE:
SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest)
CpuHalideTests = make_halide(test_torchinductor.CpuTests)
make_halide(test_torchinductor.SweepInputsCpuTest)
make_halide(test_torchinductor.CpuTests)
if (
test_torchinductor.HAS_GPU
and HAS_HALIDE
and os.environ.get("TEST_HALIDE_GPU") == "1"
):
SweepInputsGPUHalideTest = make_halide(test_torchinductor.SweepInputsGPUTest)
GPUHalideTests = make_halide(test_torchinductor.GPUTests)
make_halide(test_torchinductor.SweepInputsGPUTest)
make_halide(test_torchinductor.GPUTests)
if __name__ == "__main__":
if HAS_CPU and not IS_MACOS and HAS_HALIDE: