xpu: get xpu arch flags at runtime in cpp_extensions (#152192)

This commit moves query for xpu arch flags to runtime when building SYCL extensions which allows to adjust `TORCH_XPU_ARCH_LIST` at python script level. That's handy for example in ci test which gives a try few variants of the list.

CC: @malfet, @jingxu10, @EikanWang, @guangyey

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152192
Approved by: https://github.com/guangyey, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
Dmitry Rogozhkin
2025-05-09 05:43:47 +00:00
committed by PyTorch MergeBot
parent 9fa07340fd
commit aca2c99a65
2 changed files with 48 additions and 6 deletions

View File

@ -3,8 +3,10 @@
import glob
import locale
import os
import random
import re
import shutil
import string
import subprocess
import sys
import tempfile
@ -116,11 +118,13 @@ class TestCppExtensionJIT(common.TestCase):
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_extension(self):
# NOTE: The name of the extension must equal the name of the module.
def _test_jit_xpu_extension(self):
name = "torch_test_xpu_extension_"
# randomizing name for the case when we test building few extensions
# in a row using this function
name += "".join(random.sample(string.ascii_letters, 5))
module = torch.utils.cpp_extension.load(
name="torch_test_xpu_extension",
name=name,
sources=[
"cpp_extensions/xpu_extension.sycl",
],
@ -136,6 +140,31 @@ class TestCppExtensionJIT(common.TestCase):
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_extension(self):
# NOTE: this test can be affected by setting TORCH_XPU_ARCH_LIST
self._test_jit_xpu_extension()
@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_archlists(self):
# NOTE: in this test we explicitly test few different options
# for TORCH_XPU_ARCH_LIST. Setting TORCH_XPU_ARCH_LIST in the
# environment before the test won't affect it.
archlists = [
"", # expecting JIT compilation
",".join(torch.xpu.get_arch_list()),
]
old_envvar = os.environ.get("TORCH_XPU_ARCH_LIST", None)
try:
for al in archlists:
os.environ["TORCH_XPU_ARCH_LIST"] = al
self._test_jit_xpu_extension()
finally:
if old_envvar is None:
os.environ.pop("TORCH_XPU_ARCH_LIST")
else:
os.environ["TORCH_XPU_ARCH_LIST"] = old_envvar
@unittest.skipIf(not TEST_MPS, "MPS not found")
def test_mps_extension(self):
module = torch.utils.cpp_extension.load(