Compare commits

...

10 Commits

Author SHA1 Message Date
061c97d7b5 fix codelint 2025-11-11 21:15:07 -08:00
32112902c9 Merge branch 'pytorch:main' into xpu-online 2025-11-11 11:01:40 +08:00
5006092362 rebase torch/_inductor/config.py 2025-11-10 19:00:59 -08:00
6a1235cbc3 rebase torch/_inductor/utils.py 2025-11-10 18:54:32 -08:00
c9faa2fa67 Merge branch 'main' into xpu-online 2025-11-10 18:53:23 -08:00
f1633576b8 fix codelint 2025-09-24 01:43:43 -07:00
2681fba705 refine typo 2025-09-19 00:02:30 -07:00
7fe3475389 add note for ut 2025-09-19 00:01:36 -07:00
260d8e823f add ut and refine code lint 2025-09-18 23:12:45 -07:00
7319eb1498 enable online softmax for xpu devices with triton 2025-09-18 00:16:54 -07:00
2 changed files with 8 additions and 7 deletions

View File

@ -14,7 +14,7 @@ from torch.testing._internal.common_utils import (
IS_LINUX,
parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, HAS_TRITON
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
@ -137,7 +137,9 @@ class TestOnlineSoftmax(TestCase):
ref = _prepare_softmax(x, dim)
self.assertTrue(same(ref, act, tol=1e-2))
if nrow == 2048 and dim == 0:
if nrow == 2048 and dim == 0 and GPU_TYPE != "xpu":
# Note: split reduction is not triggered for this shape on xpu devices.
# check "num_splits" for more details
# split reduction is triggered. We have multiple kernels
self.assertTrue(code.count("def triton") >= 2)
else:
@ -310,5 +312,5 @@ class TestOnlineSoftmax(TestCase):
instantiate_parametrized_tests(TestOnlineSoftmax)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA_AND_TRITON:
if IS_LINUX and HAS_GPU and HAS_TRITON:
run_tests()

View File

@ -351,10 +351,9 @@ def prepare_softmax_extra_check(match):
"""
We only have triton online softmax kernels currently.
"""
return (
config.online_softmax
and match.kwargs["x"].meta["val"].device.type == "cuda"
and config.cuda_backend == "triton"
return config.online_softmax and (
match.kwargs["x"].meta["val"].device.type in ["cuda", "xpu"]
and "triton" in [config.cuda_backend, config.xpu_backend]
)