mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Fix ignored options for torch.compile (#145131)
#139833 broke `torch.compile(options=...)` so that many (all?) options passed in get completely ignored. @alexreinking pointed this out when `options={"cpu_backend":"halide"}` did nothing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145131 Approved by: https://github.com/exclamaforte
This commit is contained in:
committed by
PyTorch MergeBot
parent
668fb7dfba
commit
4eea2f7496
@ -12,7 +12,7 @@ from torch._inductor import config
|
||||
from torch._inductor.codecache import HalideCodeCache
|
||||
from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import parallel_num_threads
|
||||
from torch._inductor.utils import parallel_num_threads, run_and_get_code
|
||||
from torch.testing._internal.common_utils import IS_CI, IS_MACOS, IS_WINDOWS
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
from torch.utils._triton import has_triton
|
||||
@ -232,6 +232,32 @@ class HalideTests(TestCase):
|
||||
|
||||
self.assertEqual(halide_output, triton_output)
|
||||
|
||||
def test_compile_options(self):
|
||||
@torch.compile(
|
||||
backend="inductor",
|
||||
options={
|
||||
"cuda_backend": "halide",
|
||||
"cpu_backend": "halide",
|
||||
"halide.scheduler_cuda": "Anderson2021",
|
||||
"halide.scheduler_cpu": "Adams2019",
|
||||
},
|
||||
)
|
||||
def halide(a, b):
|
||||
return torch.softmax(a, -1) + torch.softmax(b, -1)
|
||||
|
||||
_, (code,) = run_and_get_code(
|
||||
halide, torch.randn(1024, 1024), torch.randn(1024, 1024)
|
||||
)
|
||||
self.assertIn("@hl.generator", code)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
_, (code,) = run_and_get_code(
|
||||
halide,
|
||||
torch.randn(1024, 1024, device="cuda"),
|
||||
torch.randn(1024, 1024, device="cuda"),
|
||||
)
|
||||
self.assertIn("@hl.generator", code)
|
||||
|
||||
|
||||
if test_torchinductor.HAS_CPU and HAS_HALIDE:
|
||||
SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest)
|
||||
|
Reference in New Issue
Block a user