[ROCm] [CK] Composable Kernel integration for inductor backend (#158747)

This is a part of our effort for integrating Composable Kernel library for Inductor backend. Currently we have a submodule, but would prefer to have commit pin control over the library as with Triton. We intentionally avoid putting all installation logic in CI scripts to allow locally built versions to have this functionality.

The idea is to have CK as a pytorch dependency in pytorch 2.9 release to allow people to use it with inductor and AOT inductor and then gradually step away from submodule usage. Right now CK usage in SDPA/Gemm is tied to submodule files.

This PR is a remake of due to branch error: https://github.com/pytorch/pytorch/pull/156192

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158747
Approved by: https://github.com/jeffdaily

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
iupaikov-amd
2025-09-04 16:51:06 +00:00
committed by PyTorch MergeBot
parent 81aeefa657
commit 019fed39aa
6 changed files with 87 additions and 35 deletions

View File

@ -0,0 +1 @@
7fe50dc3da2069d6645d9deb8c017a876472a977

View File

@ -324,6 +324,7 @@ from tools.setup_helpers.env import (
IS_WINDOWS,
)
from tools.setup_helpers.generate_linker_script import gen_linker_script
from tools.setup_helpers.rocm_env import get_ck_dependency_string, IS_ROCM
def str2bool(value: str | None) -> bool:
@ -506,7 +507,6 @@ else:
sysconfig.get_config_var("LIBDIR")
) / sysconfig.get_config_var("INSTSONAME")
################################################################################
# Version, create_version_file, and package_name
################################################################################
@ -1494,6 +1494,12 @@ def configure_extension_build() -> tuple[
map(str.strip, pytorch_extra_install_requires.split("|"))
)
# Adding extra requirements for ROCm builds
if IS_ROCM and platform.system() == "Linux":
extra_install_requires.append(
f"rocm-composable-kernel {get_ck_dependency_string()}"
)
# Cross-compile for M1
if IS_DARWIN:
macos_target_arch = os.getenv("CMAKE_OSX_ARCHITECTURES", "")

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: inductor"]
import functools
import logging
import os
import unittest
@ -13,6 +12,7 @@ except ImportError:
import torch
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import try_import_ck_lib
from torch.testing._internal.common_cuda import tf32_off
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -32,20 +32,8 @@ if HAS_CUDA_AND_TRITON:
log = logging.getLogger(__name__)
@functools.lru_cache(None)
def _get_path_without_sccache() -> str:
"""
Get the PATH environment variable without sccache.
"""
path_envs = os.environ.get("PATH", "").split(":")
path_envs = [env for env in path_envs if "/opt/cache/bin" not in env]
return ":".join(path_envs)
_test_env = {
"PATH": _get_path_without_sccache(),
"DISABLE_SCCACHE": "1",
}
# patch env for tests if needed
_test_env = {}
@instantiate_parametrized_tests
@ -61,13 +49,10 @@ class TestCKBackend(TestCase):
)
torch.random.manual_seed(1234)
try:
import ck4inductor # @manual
self.ck_dir = os.path.dirname(ck4inductor.__file__)
os.environ["TORCHINDUCTOR_CK_DIR"] = self.ck_dir
except ImportError as e:
raise unittest.SkipTest("Composable Kernel library not installed") from e
self.ck_dir, _, _, _ = try_import_ck_lib()
if not self.ck_dir:
raise unittest.SkipTest("Composable Kernel library is not installed")
try:
os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = "1"
@ -288,6 +273,9 @@ class TestCKBackend(TestCase):
torch.testing.assert_close(Y_compiled, Y_eager)
@unittest.skip(
"FIXME(tenpercent): kernel compilation errors on gfx942 as of 09/01/25"
)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@unittest.mock.patch.dict(os.environ, _test_env)
@parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))

View File

@ -0,0 +1,61 @@
import os
from pathlib import Path
def check_if_rocm() -> bool:
# If user defines USE_ROCM during PyTorch build, respect their intention
use_rocm_env = os.environ.get("USE_ROCM")
if use_rocm_env:
return bool(use_rocm_env)
# otherwise infer existence of ROCm installation as indication of ROCm build
rocm_path_env = os.environ.get("ROCM_PATH", "/opt/rocm")
if rocm_path_env and os.path.exists(rocm_path_env):
return True
return False
IS_ROCM = check_if_rocm()
SCRIPT_DIR = Path(__file__).parent
REPO_DIR = SCRIPT_DIR.parent.parent
# CK pin is read in a similar way that triton commit is
def read_ck_pin() -> str:
"""
Reads the CK (Composable Kernel) commit hash.
The hash is pinned to a known stable version of CK.
Returns:
str: The commit hash read from 'rocm-composable-kernel.txt'.
"""
ck_file = "rocm-composable-kernel.txt"
with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / ck_file) as f:
return f.read().strip()
# Prepares a dependency string for install_requires in setuptools
# in specific PEP 508 URL format
def get_ck_dependency_string() -> str:
"""
Generates a PEP 508-compliant dependency string for the ROCm Composable Kernel
to be used in setuptools' install_requires.
The returned string is EITHER in the format:
" @ git+<repo_url>@<commit_hash>#egg=rocm-composable-kernel"
where:
- <repo_url> is the URL for ROCm Composable Kernel
- <commit_hash> is read from the commit pin file
- "#egg=rocm-composable-kernel" specifies the package name for setuptools
OR an empty string, making use of the existing rocm-composable-kernel installation.
Returns:
str: The formatted dependency string for use in install_requires.
"""
egg_name = "#egg=rocm-composable-kernel"
commit_pin = f"@{read_ck_pin()}"
if os.getenv("TORCHINDUCTOR_CK_DIR"):
# we take non-empty env as an indicator that the package has already been installed and doesn't need to be re-installed
# this comes with a caveat that the pinned version is known to work while the preinstalled version might not
return ""
return f"@ git+https://github.com/ROCm/composable_kernel.git{commit_pin}{egg_name}"

View File

@ -4,7 +4,7 @@ import os
from typing import Optional
from torch._inductor import config
from torch._inductor.utils import is_linux
from torch._inductor.utils import is_linux, try_import_ck_lib
log = logging.getLogger(__name__)
@ -18,18 +18,23 @@ def _rocm_include_paths(dst_file_ext: str) -> list[str]:
if config.rocm.rocm_home
else cpp_extension._join_rocm_home("include")
)
if not config.rocm.ck_dir:
log.warning("Unspecified Composable Kernel include dir")
if config.is_fbcode():
from libfb.py import parutil
ck_path = parutil.get_dir_path("composable-kernel-headers")
else:
if not config.rocm.ck_dir:
ck_dir, _, _, _ = try_import_ck_lib()
if not ck_dir:
log.warning("Unspecified Composable Kernel directory")
config.rocm.ck_dir = ck_dir
ck_path = config.rocm.ck_dir or cpp_extension._join_rocm_home(
"composable_kernel"
)
log.debug("Using ck path %s", ck_path)
ck_include = os.path.join(ck_path, "include")
ck_library_include = os.path.join(ck_path, "library", "include")

View File

@ -1961,16 +1961,7 @@ def use_ck_template(layout: Layout) -> bool:
log.warning("Please pip install Composable Kernel package")
return False
if config.is_fbcode():
config.rocm.ck_dir = ck_package_dirname
if not config.rocm.ck_dir:
log.warning("Please set TORCHINDUCTOR_CK_DIR env variable")
return False
if ck_package_dirname != config.rocm.ck_dir:
log.warning("Invalid path to CK library")
return False
config.rocm.ck_dir = ck_package_dirname
return True