mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix test_halide.py report invocation to re-run failed tests (#147640)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147640 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
acca9b9cb0
commit
4ec6c1d1ec
@ -8,6 +8,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||||
|
from torch._dynamo.testing import make_test_cls_with_patches
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.codecache import HalideCodeCache
|
from torch._inductor.codecache import HalideCodeCache
|
||||||
from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta
|
from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta
|
||||||
@ -40,13 +41,29 @@ except ImportError:
|
|||||||
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
||||||
|
|
||||||
|
|
||||||
make_halide = config.patch(
|
test_classes = {}
|
||||||
{
|
|
||||||
"halide.scan_kernels": True,
|
|
||||||
"cpu_backend": "halide",
|
def make_halide(cls):
|
||||||
"cuda_backend": "halide",
|
suffix = "_halide"
|
||||||
}
|
|
||||||
)
|
cls_prefix = "Halide"
|
||||||
|
|
||||||
|
test_class = make_test_cls_with_patches(
|
||||||
|
cls,
|
||||||
|
cls_prefix,
|
||||||
|
suffix,
|
||||||
|
(config, "halide.scan_kernels", True),
|
||||||
|
(config, "cpu_backend", "halide"),
|
||||||
|
(config, "cuda_backend", "halide"),
|
||||||
|
xfail_prop="_expected_failure_halide",
|
||||||
|
)
|
||||||
|
|
||||||
|
test_classes[test_class.__name__] = test_class
|
||||||
|
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||||
|
globals()[test_class.__name__] = test_class
|
||||||
|
test_class.__module__ = __name__
|
||||||
|
return test_class
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(HAS_HALIDE, "requires halide")
|
@unittest.skipUnless(HAS_HALIDE, "requires halide")
|
||||||
@ -260,16 +277,16 @@ class HalideTests(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if test_torchinductor.HAS_CPU and HAS_HALIDE:
|
if test_torchinductor.HAS_CPU and HAS_HALIDE:
|
||||||
SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest)
|
make_halide(test_torchinductor.SweepInputsCpuTest)
|
||||||
CpuHalideTests = make_halide(test_torchinductor.CpuTests)
|
make_halide(test_torchinductor.CpuTests)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
test_torchinductor.HAS_GPU
|
test_torchinductor.HAS_GPU
|
||||||
and HAS_HALIDE
|
and HAS_HALIDE
|
||||||
and os.environ.get("TEST_HALIDE_GPU") == "1"
|
and os.environ.get("TEST_HALIDE_GPU") == "1"
|
||||||
):
|
):
|
||||||
SweepInputsGPUHalideTest = make_halide(test_torchinductor.SweepInputsGPUTest)
|
make_halide(test_torchinductor.SweepInputsGPUTest)
|
||||||
GPUHalideTests = make_halide(test_torchinductor.GPUTests)
|
make_halide(test_torchinductor.GPUTests)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if HAS_CPU and not IS_MACOS and HAS_HALIDE:
|
if HAS_CPU and not IS_MACOS and HAS_HALIDE:
|
||||||
|
Reference in New Issue
Block a user