mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Pyt 2.0 rocm staging (#94660)
Add triton support for ROCm builds of PyTorch. * Enables inductor and dynamo when rocm is detected * Adds support for pytorch-triton-mlir backend * Adds check_rocm support for verify_dynamo.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/94660 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
71ec2617d2
commit
77d1135566
1
.github/requirements/triton-requirements-rocm.txt
vendored
Normal file
1
.github/requirements/triton-requirements-rocm.txt
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
pytorch-triton-rocm>=2.0.0.dev
|
7
setup.py
7
setup.py
@ -939,6 +939,13 @@ def configure_extension_build():
|
|||||||
|
|
||||||
# These extensions are built by cmake and copied manually in build_extensions()
|
# These extensions are built by cmake and copied manually in build_extensions()
|
||||||
# inside the build_ext implementation
|
# inside the build_ext implementation
|
||||||
|
if cmake_cache_vars['USE_ROCM']:
|
||||||
|
triton_req_file = os.path.join(cwd, ".github", "requirements", "triton-requirements-rocm.txt")
|
||||||
|
if os.path.exists(triton_req_file):
|
||||||
|
with open(triton_req_file) as f:
|
||||||
|
triton_req = f.read().strip()
|
||||||
|
extra_install_requires.append(triton_req)
|
||||||
|
|
||||||
if cmake_cache_vars['BUILD_CAFFE2']:
|
if cmake_cache_vars['BUILD_CAFFE2']:
|
||||||
extensions.append(
|
extensions.append(
|
||||||
Extension(
|
Extension(
|
||||||
|
@ -8,6 +8,7 @@ import warnings
|
|||||||
from pkg_resources import packaging
|
from pkg_resources import packaging
|
||||||
|
|
||||||
MIN_CUDA_VERSION = packaging.version.parse("11.6")
|
MIN_CUDA_VERSION = packaging.version.parse("11.6")
|
||||||
|
MIN_ROCM_VERSION = packaging.version.parse("5.4")
|
||||||
MIN_PYTHON_VERSION = (3, 8)
|
MIN_PYTHON_VERSION = (3, 8)
|
||||||
|
|
||||||
|
|
||||||
@ -52,6 +53,31 @@ def get_cuda_version():
|
|||||||
return packaging.version.parse(cuda_str_version)
|
return packaging.version.parse(cuda_str_version)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rocm_version():
|
||||||
|
from torch.utils import cpp_extension
|
||||||
|
|
||||||
|
ROCM_HOME = cpp_extension._find_rocm_home()
|
||||||
|
if not ROCM_HOME:
|
||||||
|
raise VerifyDynamoError(
|
||||||
|
"ROCM was not found on the system, please set ROCM_HOME environment variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
hipcc = os.path.join(ROCM_HOME, "bin", "hipcc")
|
||||||
|
hip_version_str = (
|
||||||
|
subprocess.check_output([hipcc, "--version"])
|
||||||
|
.strip()
|
||||||
|
.decode(*cpp_extension.SUBPROCESS_DECODE_ARGS)
|
||||||
|
)
|
||||||
|
hip_version = re.search(r"HIP version: (\d+[.]\d+)", hip_version_str)
|
||||||
|
|
||||||
|
if hip_version is None:
|
||||||
|
raise VerifyDynamoError("HIP version not found in `hipcc --version` output")
|
||||||
|
|
||||||
|
hip_str_version = hip_version.group(1)
|
||||||
|
|
||||||
|
return packaging.version.parse(hip_str_version)
|
||||||
|
|
||||||
|
|
||||||
def check_cuda():
|
def check_cuda():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -81,7 +107,38 @@ def check_cuda():
|
|||||||
f"- minimum requirement: {MIN_CUDA_VERSION}"
|
f"- minimum requirement: {MIN_CUDA_VERSION}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return cuda_ver
|
return cuda_ver if torch.version.hip is None else "None"
|
||||||
|
|
||||||
|
|
||||||
|
def check_rocm():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available() or torch.version.hip is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extracts main ROCm version from full string
|
||||||
|
torch_rocm_ver = packaging.version.parse(
|
||||||
|
".".join(list(torch.version.hip.split(".")[0:2]))
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if torch rocm version matches system rocm version
|
||||||
|
rocm_ver = get_rocm_version()
|
||||||
|
if rocm_ver != torch_rocm_ver:
|
||||||
|
warnings.warn(
|
||||||
|
f"ROCm version mismatch, `torch` version: {torch_rocm_ver}, env version: {rocm_ver}"
|
||||||
|
)
|
||||||
|
if torch_rocm_ver < MIN_ROCM_VERSION:
|
||||||
|
warnings.warn(
|
||||||
|
f"(`torch`) ROCm version not supported: {torch_rocm_ver} "
|
||||||
|
f"- minimum requirement: {MIN_ROCM_VERSION}"
|
||||||
|
)
|
||||||
|
if rocm_ver < MIN_ROCM_VERSION:
|
||||||
|
warnings.warn(
|
||||||
|
f"(env) ROCm version not supported: {rocm_ver} "
|
||||||
|
f"- minimum requirement: {MIN_ROCM_VERSION}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return rocm_ver if torch.version.hip else "None"
|
||||||
|
|
||||||
|
|
||||||
def check_dynamo(backend, device, err_msg):
|
def check_dynamo(backend, device, err_msg):
|
||||||
@ -150,10 +207,12 @@ def main():
|
|||||||
python_ver = check_python()
|
python_ver = check_python()
|
||||||
torch_ver = check_torch()
|
torch_ver = check_torch()
|
||||||
cuda_ver = check_cuda()
|
cuda_ver = check_cuda()
|
||||||
|
rocm_ver = check_rocm()
|
||||||
print(
|
print(
|
||||||
f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n"
|
f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n"
|
||||||
f"`torch` version: {torch_ver}\n"
|
f"`torch` version: {torch_ver}\n"
|
||||||
f"CUDA version: {cuda_ver}\n"
|
f"CUDA version: {cuda_ver}\n"
|
||||||
|
f"ROCM version: {rocm_ver}\n"
|
||||||
)
|
)
|
||||||
for args in _SANITY_CHECK_ARGS:
|
for args in _SANITY_CHECK_ARGS:
|
||||||
check_dynamo(*args)
|
check_dynamo(*args)
|
||||||
|
Reference in New Issue
Block a user