mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155342 Approved by: https://github.com/Skylion007, https://github.com/bdhirsh ghstack dependencies: #154768
35 lines
889 B
Python
35 lines
889 B
Python
# Owner(s): ["module: inductor"]
|
|
from torch._inductor import config
|
|
from torch._inductor.test_case import run_tests
|
|
from torch.testing._internal.inductor_utils import HAS_CPU, TRITON_HAS_CPU
|
|
|
|
|
|
try:
|
|
from . import test_torchinductor
|
|
except ImportError:
|
|
import test_torchinductor
|
|
|
|
|
|
if HAS_CPU and TRITON_HAS_CPU:
|
|
|
|
@config.patch(cpu_backend="triton")
|
|
class SweepInputsCpuTritonTest(test_torchinductor.SweepInputsCpuTest):
|
|
pass
|
|
|
|
@config.patch(cpu_backend="triton")
|
|
class CpuTritonTests(test_torchinductor.TestCase):
|
|
common = test_torchinductor.check_model
|
|
device = "cpu"
|
|
|
|
test_torchinductor.copy_tests(
|
|
test_torchinductor.CommonTemplate,
|
|
CpuTritonTests,
|
|
"cpu",
|
|
xfail_prop="_expected_failure_triton_cpu",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if HAS_CPU and TRITON_HAS_CPU:
|
|
run_tests(needs="filelock")
|