Files
pytorch/test/inductor/test_halide.py

294 lines
9.5 KiB
Python

# Owner(s): ["oncall: pt2"]
import functools
import itertools
import os
import sys
import textwrap
import unittest
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._dynamo.testing import make_test_cls_with_patches
from torch._inductor import config
from torch._inductor.codecache import HalideCodeCache
from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import parallel_num_threads, run_and_get_code
from torch.testing._internal.common_utils import IS_CI, IS_MACOS, IS_WINDOWS
from torch.testing._internal.inductor_utils import HAS_CPU
from torch.utils._triton import has_triton
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor_dynamic_shapes yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
try:
import halide # @manual
HAS_HALIDE = halide is not None
except ImportError:
HAS_HALIDE = False
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
test_classes = {}
def make_halide(cls):
suffix = "_halide"
cls_prefix = "Halide"
test_class = make_test_cls_with_patches(
cls,
cls_prefix,
suffix,
(config, "halide.scan_kernels", True),
(config, "cpu_backend", "halide"),
(config, "cuda_backend", "halide"),
xfail_prop="_expected_failure_halide",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class
@unittest.skipUnless(HAS_HALIDE, "requires halide")
class HalideTests(TestCase):
def test_codecache(self):
fn = HalideCodeCache.generate_halide(
HalideMeta(
argtypes=[
HalideInputSpec(
ctype="float*",
name="in_ptr0",
shape=["1024L"],
stride=["1L"],
offset="0",
),
HalideInputSpec(
ctype="float*",
name="in_ptr1",
shape=["1024L"],
stride=["1L"],
offset="0",
),
HalideInputSpec(
ctype="float*",
name="out_ptr0",
shape=["1024L"],
stride=["1L"],
offset="0",
),
],
target="host-no_runtime",
scheduler="Mullapudi2016",
scheduler_flags={
"parallelism": parallel_num_threads(),
},
),
textwrap.dedent(
"""
import halide as hl
@hl.generator(name="kernel")
class Kernel:
in_ptr0 = hl.InputBuffer(hl.Float(32), 1)
in_ptr1 = hl.InputBuffer(hl.Float(32), 1)
out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)
def generate(g):
in_ptr0 = g.in_ptr0
in_ptr1 = g.in_ptr1
out_ptr0 = g.out_ptr0
xindex = hl.Var('xindex')
x0 = xindex
tmp0 = hl.Func()
tmp0[xindex] = in_ptr0[x0]
tmp1 = hl.Func()
tmp1[xindex] = in_ptr1[x0]
tmp2 = hl.Func()
tmp2[xindex] = tmp0[xindex] + tmp1[xindex]
out_ptr0[x0] = tmp2[xindex]
assert g.using_autoscheduler()
in_ptr0.set_estimates([hl.Range(1024, 1024)])
in_ptr1.set_estimates([hl.Range(1024, 1024)])
out_ptr0.set_estimates([hl.Range(1024, 1024)])
__name__ == '__main__' and hl.main()
"""
),
)
a = torch.randn(1024)
b = torch.randn(1024)
c = torch.randn(1024)
fn(a, b, c)
self.assertEqual(c, a + b)
def test_manual_schedule(self):
fn = HalideCodeCache.generate_halide(
HalideMeta(
argtypes=[
HalideInputSpec(
ctype="float*",
name="in_ptr0",
shape=["1024L"],
stride=["1L"],
offset="0",
),
HalideInputSpec(
ctype="float*",
name="in_ptr1",
shape=["1024L"],
stride=["1L"],
offset="0",
),
HalideInputSpec(
ctype="float*",
name="out_ptr0",
shape=["1024L"],
stride=["1L"],
offset="0",
),
],
target="host-no_runtime",
scheduler=None,
),
textwrap.dedent(
"""
import halide as hl
@hl.generator(name="kernel")
class Kernel:
in_ptr0 = hl.InputBuffer(hl.Float(32), 1)
in_ptr1 = hl.InputBuffer(hl.Float(32), 1)
out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)
def generate(g):
in_ptr0 = g.in_ptr0
in_ptr1 = g.in_ptr1
out_ptr0 = g.out_ptr0
xindex = hl.Var('xindex')
x0 = xindex
tmp0 = hl.Func()
tmp0[xindex] = in_ptr0[x0]
tmp1 = hl.Func()
tmp1[xindex] = in_ptr1[x0]
tmp2 = hl.Func()
tmp2[xindex] = tmp0[xindex] + tmp1[xindex]
out_ptr0[x0] = tmp2[xindex]
assert not g.using_autoscheduler()
i = hl.Var()
j = hl.Var()
out_ptr0.compute_root()
out_ptr0.split(xindex, i, j, 32)
out_ptr0.parallel(i)
out_ptr0.vectorize(j)
tmp2.compute_at(out_ptr0, i)
tmp2.store_at(out_ptr0, i)
tmp1.compute_inline()
__name__ == '__main__' and hl.main()
"""
),
)
a = torch.randn(1024)
b = torch.randn(1024)
c = torch.randn(1024)
fn(a, b, c)
self.assertEqual(c, a + b)
@unittest.skipUnless(has_triton(), "requires triton")
def test_random_consistency(self):
seed = 1234
shape = (3, 3)
dtype = torch.float32
for (rand_fn,) in itertools.product(
(
functools.partial(torch.rand, shape, dtype=dtype, device="cuda"),
functools.partial(torch.randn, shape, dtype=dtype, device="cuda"),
functools.partial(
torch.randint,
-1000,
1000,
size=shape,
dtype=torch.int64,
device="cuda",
),
)
):
@torch.compile(backend="inductor", options={"cuda_backend": "halide"})
def get_rand_halide():
return rand_fn()
@torch.compile(backend="inductor", options={"cuda_backend": "triton"})
def get_rand_triton():
return rand_fn()
torch.manual_seed(seed)
halide_output = get_rand_halide()
torch.manual_seed(seed)
triton_output = get_rand_triton()
self.assertEqual(halide_output, triton_output)
def test_compile_options(self):
@torch.compile(
backend="inductor",
options={
"cuda_backend": "halide",
"cpu_backend": "halide",
"halide.scheduler_cuda": "Anderson2021",
"halide.scheduler_cpu": "Adams2019",
},
)
def halide(a, b):
return torch.softmax(a, -1) + torch.softmax(b, -1)
_, (code,) = run_and_get_code(
halide, torch.randn(1024, 1024), torch.randn(1024, 1024)
)
self.assertIn("@hl.generator", code)
if torch.cuda.is_available():
_, (code,) = run_and_get_code(
halide,
torch.randn(1024, 1024, device="cuda"),
torch.randn(1024, 1024, device="cuda"),
)
self.assertIn("@hl.generator", code)
if test_torchinductor.HAS_CPU and HAS_HALIDE:
make_halide(test_torchinductor.SweepInputsCpuTest)
make_halide(test_torchinductor.CpuTests)
if (
test_torchinductor.HAS_GPU
and HAS_HALIDE
and os.environ.get("TEST_HALIDE_GPU") == "1"
):
make_halide(test_torchinductor.SweepInputsGPUTest)
make_halide(test_torchinductor.GPUTests)
if __name__ == "__main__":
if HAS_CPU and not IS_MACOS and HAS_HALIDE:
run_tests(needs="filelock")