mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
xpu: get xpu arch flags at runtime in cpp_extensions (#152192)
This commit moves query for xpu arch flags to runtime when building SYCL extensions which allows to adjust `TORCH_XPU_ARCH_LIST` at python script level. That's handy for example in ci test which gives a try few variants of the list. CC: @malfet, @jingxu10, @EikanWang, @guangyey Pull Request resolved: https://github.com/pytorch/pytorch/pull/152192 Approved by: https://github.com/guangyey, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
9fa07340fd
commit
aca2c99a65
@ -3,8 +3,10 @@
|
||||
import glob
|
||||
import locale
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@ -116,11 +118,13 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||
def test_jit_xpu_extension(self):
|
||||
# NOTE: The name of the extension must equal the name of the module.
|
||||
def _test_jit_xpu_extension(self):
|
||||
name = "torch_test_xpu_extension_"
|
||||
# randomizing name for the case when we test building few extensions
|
||||
# in a row using this function
|
||||
name += "".join(random.sample(string.ascii_letters, 5))
|
||||
module = torch.utils.cpp_extension.load(
|
||||
name="torch_test_xpu_extension",
|
||||
name=name,
|
||||
sources=[
|
||||
"cpp_extensions/xpu_extension.sycl",
|
||||
],
|
||||
@ -136,6 +140,31 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||
def test_jit_xpu_extension(self):
|
||||
# NOTE: this test can be affected by setting TORCH_XPU_ARCH_LIST
|
||||
self._test_jit_xpu_extension()
|
||||
|
||||
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||
def test_jit_xpu_archlists(self):
|
||||
# NOTE: in this test we explicitly test few different options
|
||||
# for TORCH_XPU_ARCH_LIST. Setting TORCH_XPU_ARCH_LIST in the
|
||||
# environment before the test won't affect it.
|
||||
archlists = [
|
||||
"", # expecting JIT compilation
|
||||
",".join(torch.xpu.get_arch_list()),
|
||||
]
|
||||
old_envvar = os.environ.get("TORCH_XPU_ARCH_LIST", None)
|
||||
try:
|
||||
for al in archlists:
|
||||
os.environ["TORCH_XPU_ARCH_LIST"] = al
|
||||
self._test_jit_xpu_extension()
|
||||
finally:
|
||||
if old_envvar is None:
|
||||
os.environ.pop("TORCH_XPU_ARCH_LIST")
|
||||
else:
|
||||
os.environ["TORCH_XPU_ARCH_LIST"] = old_envvar
|
||||
|
||||
@unittest.skipIf(not TEST_MPS, "MPS not found")
|
||||
def test_mps_extension(self):
|
||||
module = torch.utils.cpp_extension.load(
|
||||
|
Reference in New Issue
Block a user