mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] [CK] Composable Kernel integration for inductor backend (#158747)
This is a part of our effort for integrating Composable Kernel library for Inductor backend. Currently we have a submodule, but would prefer to have commit pin control over the library as with Triton. We intentionally avoid putting all installation logic in CI scripts to allow locally built versions to have this functionality. The idea is to have CK as a pytorch dependency in pytorch 2.9 release to allow people to use it with inductor and AOT inductor and then gradually step away from submodule usage. Right now CK usage in SDPA/Gemm is tied to submodule files. This PR is a remake of due to branch error: https://github.com/pytorch/pytorch/pull/156192 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158747 Approved by: https://github.com/jeffdaily Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com> Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
81aeefa657
commit
019fed39aa
1
.ci/docker/ci_commit_pins/rocm-composable-kernel.txt
Normal file
1
.ci/docker/ci_commit_pins/rocm-composable-kernel.txt
Normal file
@ -0,0 +1 @@
|
||||
7fe50dc3da2069d6645d9deb8c017a876472a977
|
8
setup.py
8
setup.py
@ -324,6 +324,7 @@ from tools.setup_helpers.env import (
|
||||
IS_WINDOWS,
|
||||
)
|
||||
from tools.setup_helpers.generate_linker_script import gen_linker_script
|
||||
from tools.setup_helpers.rocm_env import get_ck_dependency_string, IS_ROCM
|
||||
|
||||
|
||||
def str2bool(value: str | None) -> bool:
|
||||
@ -506,7 +507,6 @@ else:
|
||||
sysconfig.get_config_var("LIBDIR")
|
||||
) / sysconfig.get_config_var("INSTSONAME")
|
||||
|
||||
|
||||
################################################################################
|
||||
# Version, create_version_file, and package_name
|
||||
################################################################################
|
||||
@ -1494,6 +1494,12 @@ def configure_extension_build() -> tuple[
|
||||
map(str.strip, pytorch_extra_install_requires.split("|"))
|
||||
)
|
||||
|
||||
# Adding extra requirements for ROCm builds
|
||||
if IS_ROCM and platform.system() == "Linux":
|
||||
extra_install_requires.append(
|
||||
f"rocm-composable-kernel {get_ck_dependency_string()}"
|
||||
)
|
||||
|
||||
# Cross-compile for M1
|
||||
if IS_DARWIN:
|
||||
macos_target_arch = os.getenv("CMAKE_OSX_ARCHITECTURES", "")
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
@ -13,6 +12,7 @@ except ImportError:
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import try_import_ck_lib
|
||||
from torch.testing._internal.common_cuda import tf32_off
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -32,20 +32,8 @@ if HAS_CUDA_AND_TRITON:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_path_without_sccache() -> str:
|
||||
"""
|
||||
Get the PATH environment variable without sccache.
|
||||
"""
|
||||
path_envs = os.environ.get("PATH", "").split(":")
|
||||
path_envs = [env for env in path_envs if "/opt/cache/bin" not in env]
|
||||
return ":".join(path_envs)
|
||||
|
||||
|
||||
_test_env = {
|
||||
"PATH": _get_path_without_sccache(),
|
||||
"DISABLE_SCCACHE": "1",
|
||||
}
|
||||
# patch env for tests if needed
|
||||
_test_env = {}
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@ -61,13 +49,10 @@ class TestCKBackend(TestCase):
|
||||
)
|
||||
|
||||
torch.random.manual_seed(1234)
|
||||
try:
|
||||
import ck4inductor # @manual
|
||||
|
||||
self.ck_dir = os.path.dirname(ck4inductor.__file__)
|
||||
os.environ["TORCHINDUCTOR_CK_DIR"] = self.ck_dir
|
||||
except ImportError as e:
|
||||
raise unittest.SkipTest("Composable Kernel library not installed") from e
|
||||
self.ck_dir, _, _, _ = try_import_ck_lib()
|
||||
if not self.ck_dir:
|
||||
raise unittest.SkipTest("Composable Kernel library is not installed")
|
||||
|
||||
try:
|
||||
os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = "1"
|
||||
@ -288,6 +273,9 @@ class TestCKBackend(TestCase):
|
||||
|
||||
torch.testing.assert_close(Y_compiled, Y_eager)
|
||||
|
||||
@unittest.skip(
|
||||
"FIXME(tenpercent): kernel compilation errors on gfx942 as of 09/01/25"
|
||||
)
|
||||
@unittest.skipIf(not torch.version.hip, "ROCM only")
|
||||
@unittest.mock.patch.dict(os.environ, _test_env)
|
||||
@parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
|
||||
|
61
tools/setup_helpers/rocm_env.py
Normal file
61
tools/setup_helpers/rocm_env.py
Normal file
@ -0,0 +1,61 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_if_rocm() -> bool:
|
||||
# If user defines USE_ROCM during PyTorch build, respect their intention
|
||||
use_rocm_env = os.environ.get("USE_ROCM")
|
||||
if use_rocm_env:
|
||||
return bool(use_rocm_env)
|
||||
# otherwise infer existence of ROCm installation as indication of ROCm build
|
||||
rocm_path_env = os.environ.get("ROCM_PATH", "/opt/rocm")
|
||||
if rocm_path_env and os.path.exists(rocm_path_env):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
IS_ROCM = check_if_rocm()
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
REPO_DIR = SCRIPT_DIR.parent.parent
|
||||
|
||||
|
||||
# CK pin is read in a similar way that triton commit is
|
||||
def read_ck_pin() -> str:
|
||||
"""
|
||||
Reads the CK (Composable Kernel) commit hash.
|
||||
The hash is pinned to a known stable version of CK.
|
||||
|
||||
Returns:
|
||||
str: The commit hash read from 'rocm-composable-kernel.txt'.
|
||||
"""
|
||||
ck_file = "rocm-composable-kernel.txt"
|
||||
with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / ck_file) as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
# Prepares a dependency string for install_requires in setuptools
|
||||
# in specific PEP 508 URL format
|
||||
def get_ck_dependency_string() -> str:
|
||||
"""
|
||||
Generates a PEP 508-compliant dependency string for the ROCm Composable Kernel
|
||||
to be used in setuptools' install_requires.
|
||||
|
||||
The returned string is EITHER in the format:
|
||||
" @ git+<repo_url>@<commit_hash>#egg=rocm-composable-kernel"
|
||||
where:
|
||||
- <repo_url> is the URL for ROCm Composable Kernel
|
||||
- <commit_hash> is read from the commit pin file
|
||||
- "#egg=rocm-composable-kernel" specifies the package name for setuptools
|
||||
OR an empty string, making use of the existing rocm-composable-kernel installation.
|
||||
|
||||
Returns:
|
||||
str: The formatted dependency string for use in install_requires.
|
||||
"""
|
||||
egg_name = "#egg=rocm-composable-kernel"
|
||||
commit_pin = f"@{read_ck_pin()}"
|
||||
if os.getenv("TORCHINDUCTOR_CK_DIR"):
|
||||
# we take non-empty env as an indicator that the package has already been installed and doesn't need to be re-installed
|
||||
# this comes with a caveat that the pinned version is known to work while the preinstalled version might not
|
||||
return ""
|
||||
return f"@ git+https://github.com/ROCm/composable_kernel.git{commit_pin}{egg_name}"
|
@ -4,7 +4,7 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import is_linux
|
||||
from torch._inductor.utils import is_linux, try_import_ck_lib
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -18,18 +18,23 @@ def _rocm_include_paths(dst_file_ext: str) -> list[str]:
|
||||
if config.rocm.rocm_home
|
||||
else cpp_extension._join_rocm_home("include")
|
||||
)
|
||||
if not config.rocm.ck_dir:
|
||||
log.warning("Unspecified Composable Kernel include dir")
|
||||
|
||||
if config.is_fbcode():
|
||||
from libfb.py import parutil
|
||||
|
||||
ck_path = parutil.get_dir_path("composable-kernel-headers")
|
||||
else:
|
||||
if not config.rocm.ck_dir:
|
||||
ck_dir, _, _, _ = try_import_ck_lib()
|
||||
if not ck_dir:
|
||||
log.warning("Unspecified Composable Kernel directory")
|
||||
config.rocm.ck_dir = ck_dir
|
||||
ck_path = config.rocm.ck_dir or cpp_extension._join_rocm_home(
|
||||
"composable_kernel"
|
||||
)
|
||||
|
||||
log.debug("Using ck path %s", ck_path)
|
||||
|
||||
ck_include = os.path.join(ck_path, "include")
|
||||
ck_library_include = os.path.join(ck_path, "library", "include")
|
||||
|
||||
|
@ -1961,16 +1961,7 @@ def use_ck_template(layout: Layout) -> bool:
|
||||
log.warning("Please pip install Composable Kernel package")
|
||||
return False
|
||||
|
||||
if config.is_fbcode():
|
||||
config.rocm.ck_dir = ck_package_dirname
|
||||
|
||||
if not config.rocm.ck_dir:
|
||||
log.warning("Please set TORCHINDUCTOR_CK_DIR env variable")
|
||||
return False
|
||||
|
||||
if ck_package_dirname != config.rocm.ck_dir:
|
||||
log.warning("Invalid path to CK library")
|
||||
return False
|
||||
config.rocm.ck_dir = ck_package_dirname
|
||||
|
||||
return True
|
||||
|
||||
|
Reference in New Issue
Block a user