[inductor] Fix ignored options for torch.compile (#145131)

#139833 broke `torch.compile(options=...)` so that many (all?) options passed in get completely ignored.  @alexreinking pointed this out when `options={"cpu_backend":"halide"}` did nothing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145131
Approved by: https://github.com/exclamaforte
This commit is contained in:
Jason Ansel
2025-01-17 16:36:35 -08:00
committed by PyTorch MergeBot
parent 668fb7dfba
commit 4eea2f7496
2 changed files with 28 additions and 2 deletions

View File

@ -12,7 +12,7 @@ from torch._inductor import config
from torch._inductor.codecache import HalideCodeCache
from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import parallel_num_threads
from torch._inductor.utils import parallel_num_threads, run_and_get_code
from torch.testing._internal.common_utils import IS_CI, IS_MACOS, IS_WINDOWS
from torch.testing._internal.inductor_utils import HAS_CPU
from torch.utils._triton import has_triton
@ -232,6 +232,32 @@ class HalideTests(TestCase):
self.assertEqual(halide_output, triton_output)
def test_compile_options(self):
@torch.compile(
backend="inductor",
options={
"cuda_backend": "halide",
"cpu_backend": "halide",
"halide.scheduler_cuda": "Anderson2021",
"halide.scheduler_cpu": "Adams2019",
},
)
def halide(a, b):
return torch.softmax(a, -1) + torch.softmax(b, -1)
_, (code,) = run_and_get_code(
halide, torch.randn(1024, 1024), torch.randn(1024, 1024)
)
self.assertIn("@hl.generator", code)
if torch.cuda.is_available():
_, (code,) = run_and_get_code(
halide,
torch.randn(1024, 1024, device="cuda"),
torch.randn(1024, 1024, device="cuda"),
)
self.assertIn("@hl.generator", code)
if test_torchinductor.HAS_CPU and HAS_HALIDE:
SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest)