mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Pyt 2.0 rocm staging (#94660)
Add triton support for ROCm builds of PyTorch. * Enables inductor and dynamo when rocm is detected * Adds support for pytorch-triton-mlir backend * Adds check_rocm support for verify_dynamo.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/94660 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
71ec2617d2
commit
77d1135566
1
.github/requirements/triton-requirements-rocm.txt
vendored
Normal file
1
.github/requirements/triton-requirements-rocm.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
pytorch-triton-rocm>=2.0.0.dev
|
7
setup.py
7
setup.py
@ -939,6 +939,13 @@ def configure_extension_build():
|
||||
|
||||
# These extensions are built by cmake and copied manually in build_extensions()
|
||||
# inside the build_ext implementation
|
||||
if cmake_cache_vars['USE_ROCM']:
|
||||
triton_req_file = os.path.join(cwd, ".github", "requirements", "triton-requirements-rocm.txt")
|
||||
if os.path.exists(triton_req_file):
|
||||
with open(triton_req_file) as f:
|
||||
triton_req = f.read().strip()
|
||||
extra_install_requires.append(triton_req)
|
||||
|
||||
if cmake_cache_vars['BUILD_CAFFE2']:
|
||||
extensions.append(
|
||||
Extension(
|
||||
|
@ -8,6 +8,7 @@ import warnings
|
||||
from pkg_resources import packaging
|
||||
|
||||
MIN_CUDA_VERSION = packaging.version.parse("11.6")
|
||||
MIN_ROCM_VERSION = packaging.version.parse("5.4")
|
||||
MIN_PYTHON_VERSION = (3, 8)
|
||||
|
||||
|
||||
@ -52,6 +53,31 @@ def get_cuda_version():
|
||||
return packaging.version.parse(cuda_str_version)
|
||||
|
||||
|
||||
def get_rocm_version():
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
ROCM_HOME = cpp_extension._find_rocm_home()
|
||||
if not ROCM_HOME:
|
||||
raise VerifyDynamoError(
|
||||
"ROCM was not found on the system, please set ROCM_HOME environment variable"
|
||||
)
|
||||
|
||||
hipcc = os.path.join(ROCM_HOME, "bin", "hipcc")
|
||||
hip_version_str = (
|
||||
subprocess.check_output([hipcc, "--version"])
|
||||
.strip()
|
||||
.decode(*cpp_extension.SUBPROCESS_DECODE_ARGS)
|
||||
)
|
||||
hip_version = re.search(r"HIP version: (\d+[.]\d+)", hip_version_str)
|
||||
|
||||
if hip_version is None:
|
||||
raise VerifyDynamoError("HIP version not found in `hipcc --version` output")
|
||||
|
||||
hip_str_version = hip_version.group(1)
|
||||
|
||||
return packaging.version.parse(hip_str_version)
|
||||
|
||||
|
||||
def check_cuda():
|
||||
import torch
|
||||
|
||||
@ -81,7 +107,38 @@ def check_cuda():
|
||||
f"- minimum requirement: {MIN_CUDA_VERSION}"
|
||||
)
|
||||
|
||||
return cuda_ver
|
||||
return cuda_ver if torch.version.hip is None else "None"
|
||||
|
||||
|
||||
def check_rocm():
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available() or torch.version.hip is None:
|
||||
return None
|
||||
|
||||
# Extracts main ROCm version from full string
|
||||
torch_rocm_ver = packaging.version.parse(
|
||||
".".join(list(torch.version.hip.split(".")[0:2]))
|
||||
)
|
||||
|
||||
# check if torch rocm version matches system rocm version
|
||||
rocm_ver = get_rocm_version()
|
||||
if rocm_ver != torch_rocm_ver:
|
||||
warnings.warn(
|
||||
f"ROCm version mismatch, `torch` version: {torch_rocm_ver}, env version: {rocm_ver}"
|
||||
)
|
||||
if torch_rocm_ver < MIN_ROCM_VERSION:
|
||||
warnings.warn(
|
||||
f"(`torch`) ROCm version not supported: {torch_rocm_ver} "
|
||||
f"- minimum requirement: {MIN_ROCM_VERSION}"
|
||||
)
|
||||
if rocm_ver < MIN_ROCM_VERSION:
|
||||
warnings.warn(
|
||||
f"(env) ROCm version not supported: {rocm_ver} "
|
||||
f"- minimum requirement: {MIN_ROCM_VERSION}"
|
||||
)
|
||||
|
||||
return rocm_ver if torch.version.hip else "None"
|
||||
|
||||
|
||||
def check_dynamo(backend, device, err_msg):
|
||||
@ -150,10 +207,12 @@ def main():
|
||||
python_ver = check_python()
|
||||
torch_ver = check_torch()
|
||||
cuda_ver = check_cuda()
|
||||
rocm_ver = check_rocm()
|
||||
print(
|
||||
f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n"
|
||||
f"`torch` version: {torch_ver}\n"
|
||||
f"CUDA version: {cuda_ver}\n"
|
||||
f"ROCM version: {rocm_ver}\n"
|
||||
)
|
||||
for args in _SANITY_CHECK_ARGS:
|
||||
check_dynamo(*args)
|
||||
|
Reference in New Issue
Block a user